diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000000..fef0dde9ba --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,47 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: bug +assignees: '' + +--- + +**Describe the bug** + +A clear and concise description of what the bug is. + +**Steps/Code to reproduce bug** + +Please list *minimal* steps or code snippet for us to be able to reproduce the bug. + +A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports. + + +**Expected behavior** + +A clear and concise description of what you expected to happen. + +**Environment overview (please complete the following information)** + + - Environment location: [Bare-metal, Docker, Cloud(specify cloud provider - AWS, Azure, GCP, Collab)] + - Method of Transformer Engine install: [pip install or from source]. Please specify exact commands you used to install. + - If method of install is [Docker], provide `docker pull` & `docker run` commands used + +**Environment details** + +If NVIDIA docker image is used you don't need to specify these. +Otherwise, please provide: +- OS version +- PyTorch version +- Python version +- Transformer Engine version +- CUDA version +- CUDNN version + +**Device details** +- GPU model + +**Additional context** + +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000000..355e553939 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,25 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: feature request +assignees: '' + +--- + +**Is your feature request related to a problem? Please describe.** + +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** + +A clear and concise description of what you want to happen. +Provide a code snippet on how new APIs/changes would be used by others. + +**Describe alternatives you've considered** + +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** + +Add any other context or screenshots about the feature request here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 5fee1b7191..abd8f33ccd 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -11,7 +11,7 @@ Fixes # (issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Infra/Build change -- [ ] Code refractor +- [ ] Code refactoring ## Changes diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml index 260adfc6d3..1402cc091a 100644 --- a/.github/workflows/blossom-ci.yml +++ b/.github/workflows/blossom-ci.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7039d38cf5..6653294c59 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,7 +12,7 @@ jobs: name: 'Core' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/cuda:12.0.0-devel-ubuntu22.04 + image: nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' @@ -28,6 +28,7 @@ jobs: run: pip install . -v env: NVTE_FRAMEWORK: none + MAX_JOBS: 1 - name: 'Sanity check' run: python3 -c "import transformer_engine" working-directory: / @@ -35,7 +36,7 @@ jobs: name: 'PyTorch' runs-on: ubuntu-latest container: - image: nvcr.io/nvidia/cuda:12.5.0-devel-ubuntu22.04 + image: nvcr.io/nvidia/cuda:12.8.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' @@ -70,25 +71,6 @@ jobs: run: pip install . -v env: NVTE_FRAMEWORK: jax + MAX_JOBS: 1 - name: 'Sanity check' run: python tests/jax/test_sanity_import.py - paddle: - name: 'PaddlePaddle' - runs-on: ubuntu-latest - container: - image: nvcr.io/nvidia/paddlepaddle:24.10-py3 - options: --user root - steps: - - name: 'Checkout' - uses: actions/checkout@v3 - with: - submodules: recursive - - name: 'Build' - run: | - apt-get update - apt-get install -y libgoogle-glog-dev - pip install . -v - env: - NVTE_FRAMEWORK: paddle - - name: 'Sanity check' - run: python tests/paddle/test_sanity_import.py diff --git a/.github/workflows/deploy_nightly_docs.yml b/.github/workflows/deploy_nightly_docs.yml index cd68019c8f..6470eee838 100644 --- a/.github/workflows/deploy_nightly_docs.yml +++ b/.github/workflows/deploy_nightly_docs.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -16,13 +16,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Download artifact - uses: actions/download-artifact@v4.1.7 + uses: actions/download-artifact@v4 with: name: "te_docs" path: "html" - name: Prepare for pages uses: actions/upload-pages-artifact@v1.0.7 with: + name: github-pages path: "html" deploy: needs: prepare diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 4762cccee6..3c4229a888 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -27,7 +27,7 @@ jobs: cd docs make html - name: 'Upload docs' - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: te_docs path: docs/_build/html diff --git a/.github/workflows/license.yml b/.github/workflows/license.yml index f789a83d1a..d70c7def61 100644 --- a/.github/workflows/license.yml +++ b/.github/workflows/license.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d2bd865a8f..ee6433d484 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -61,30 +61,3 @@ jobs: export PYTHON_ONLY=1 export TE_PATH=. bash ./qa/L0_jax_lint/test.sh - paddle_cpplint: - name: 'PaddlePaddle C++' - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v3 - - name: 'Lint' - run: | - sudo apt-get update - sudo apt-get install pip -y - export CPP_ONLY=1 - export TE_PATH=. - bash ./qa/L0_paddle_lint/test.sh - paddle_pylint: - name: 'PaddlePaddle Python' - runs-on: ubuntu-latest - steps: - - name: 'Checkout' - uses: actions/checkout@v3 - - name: 'Lint' - run: | - sudo apt-get update - sudo apt-get install pip -y - pip install paddlepaddle-gpu - export PYTHON_ONLY=1 - export TE_PATH=. - bash ./qa/L0_paddle_lint/test.sh diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index c2317c6509..681b662036 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -40,6 +40,11 @@ jobs: || github.actor == 'vasunvidia' || github.actor == 'erhoo82' || github.actor == 'kocchop' + || github.actor == 'youngeunkwon0405' + || github.actor == 'KshitijLakhani' + || github.actor == 'jberchtold-nvidia' + || github.actor == 'sanandaraj5597' + || github.actor == 'negvet' ) steps: - name: Check if comment is issued by authorized person diff --git a/.github/workflows/upload-ci-logs.yml b/.github/workflows/upload-ci-logs.yml index b3be2f5c89..c9c7e4ef4d 100644 --- a/.github/workflows/upload-ci-logs.yml +++ b/.github/workflows/upload-ci-logs.yml @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/.gitignore b/.gitignore index 9b61454e21..850b352d31 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ *.nsys-rep *.ncu-rep *.sqlite -*.onnx *.eggs build/ *.so @@ -39,3 +38,4 @@ downloads/ .pytest_cache/ compile_commands.json .nfs +tensor_dumps/ \ No newline at end of file diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 936021bfed..20c28ea798 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 +Subproject commit 20c28ea798fe99e31d7274e009ee2fbf0e88abfd diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 95767b742f..d92fd95675 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/CPPLINT.cfg b/CPPLINT.cfg index e42ec720b1..ecfbbf3d0b 100644 --- a/CPPLINT.cfg +++ b/CPPLINT.cfg @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/README.rst b/README.rst index 6cc7eeae8a..c4fde5bd11 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -33,11 +33,12 @@ What is Transformer Engine? .. overview-begin-marker-do-not-remove Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including -using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower -memory utilization in both training and inference. TE provides a collection of highly optimized -building blocks for popular Transformer architectures and an automatic mixed precision-like API that -can be used seamlessly with your framework-specific code. TE also includes a framework agnostic -C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers. +using 8-bit floating point (FP8) precision on Hopper, Ada, and Blackwell GPUs, to provide better +performance with lower memory utilization in both training and inference. TE provides a collection +of highly optimized building blocks for popular Transformer architectures and an automatic mixed +precision-like API that can be used seamlessly with your framework-specific code. TE also includes a +framework agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 +support for Transformers. As the number of parameters in Transformer models continues to grow, training and inference for architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning @@ -51,16 +52,16 @@ not available natively in frameworks today. TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer -layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support. -Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly -simplifying mixed precision training for users. +layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 +support. Modules provided by TE internally maintain scaling factors and other values needed for FP8 +training, greatly simplifying mixed precision training for users. Highlights ========== * Easy-to-use modules for building Transformer layers with FP8 support * Optimizations (e.g. fused kernels) for Transformer models -* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs +* Support for FP8 on NVIDIA Hopper, Ada, and Blackwell GPUs * Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later Examples @@ -149,22 +150,22 @@ Installation Pre-requisites ^^^^^^^^^^^^^^^^^^^^ * Linux x86_64 -* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada -* NVIDIA Driver supporting CUDA 12.0 or later -* cuDNN 8.1 or later -* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later. +* CUDA 12.1+ (CUDA 12.8+ for Blackwell) +* NVIDIA Driver supporting CUDA 12.1 or later +* cuDNN 9.3 or later Docker ^^^^^^^^^^^^^^^^^^^^ The quickest way to get started with Transformer Engine is by using Docker images on -`NVIDIA GPU Cloud (NGC) Catalog `_. For example to use the NGC PyTorch container interactively, +`NVIDIA GPU Cloud (NGC) Catalog `_. +For example to use the NGC PyTorch container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.01-py3 -Where 23.10 is the container version. For example, 23.10 for the October 2023 release. +Where 25.01 (corresponding to January 2025 release) is the container version. pip ^^^^^^^^^^^^^^^^^^^^ @@ -172,17 +173,23 @@ To install the latest stable version of Transformer Engine, .. code-block:: bash - pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable + pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable -This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle). +This will automatically detect if any supported deep learning frameworks are installed and build +Transformer Engine support for them. To explicitly specify frameworks, set the environment variable +NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch). -Alternatively, the package can be directly installed from `Transformer Engine's PyPI `_, e.g. +Alternatively, the package can be directly installed from +`Transformer Engine's PyPI `_, e.g. .. code-block:: bash - pip install transformer_engine[pytorch] + pip3 install transformer_engine[pytorch] -To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions. +To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be +explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). +Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX +and PyTorch extensions. From source ^^^^^^^^^^^ @@ -190,7 +197,7 @@ From source Compiling with FlashAttention-2 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. +Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance. It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug `_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. @@ -264,10 +271,10 @@ Transformer Engine has been integrated with popular LLM frameworks such as: * `NVIDIA NeMo Framework `_ * `Amazon SageMaker Model Parallel Library `_ * `Levanter `_ +* `GPT-NeoX `_ * `Hugging Face Nanotron `_ - Coming soon! * `Colossal-AI `_ - Coming soon! * `PeriFlow `_ - Coming soon! -* `GPT-NeoX `_ - Coming soon! Contributing diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index cff7c65fbc..dafafdff47 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 28444e84a9..6b959d99e8 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -1.13.0.dev0 +2.2.0.dev0 diff --git a/build_tools/__init__.py b/build_tools/__init__.py index 9bcbd954eb..7669e4cfa6 100644 --- a/build_tools/__init__.py +++ b/build_tools/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index af11ada34c..f0724f617e 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -94,7 +94,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: print(f"Time for build_ext: {total_time:.2f} seconds") -def get_build_ext(extension_cls: Type[setuptools.Extension]): +def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False): class _CMakeBuildExtension(extension_cls): """Setuptools command with support for CMake extension modules""" @@ -129,81 +129,23 @@ def run(self) -> None: super().run() self.extensions = all_extensions - paddle_ext = None - if "paddle" in get_frameworks(): - for ext in self.extensions: - if "paddle" in ext.name: - paddle_ext = ext - break - - # Manually write stub file for Paddle extension - if paddle_ext is not None: - # Load libtransformer_engine.so to avoid linker errors - if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): - # Source compilation from top-level (--editable) - search_paths = list(Path(__file__).resolve().parent.parent.iterdir()) - # Source compilation from top-level - search_paths.extend(list(Path(self.build_lib).iterdir())) - - # Dynamically load required_libs. - from transformer_engine.common import _load_cudnn, _load_nvrtc - - _load_cudnn() - _load_nvrtc() - else: - # Only during release bdist build for paddlepaddle. - import transformer_engine - - search_paths = list(Path(transformer_engine.__path__[0]).iterdir()) - del transformer_engine - - common_so_path = "" - for path in search_paths: - if path.name.startswith("libtransformer_engine."): - common_so_path = str(path) - assert common_so_path, "Could not find libtransformer_engine" - ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL) - - # Figure out stub file path - module_name = paddle_ext.name - assert module_name.endswith( - "_pd_" - ), "Expected Paddle extension module to end with '_pd_'" - stub_name = module_name[:-4] # remove '_pd_' - stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py") - Path(stub_path).parent.mkdir(exist_ok=True, parents=True) - - # Figure out library name - # Note: This library doesn't actually exist. Paddle - # internally reinserts the '_pd_' suffix. - so_path = self.get_ext_fullpath(module_name) - _, so_ext = os.path.splitext(so_path) - lib_name = stub_name + so_ext - - # Write stub file - print(f"Writing Paddle stub for {lib_name} into file {stub_path}") - from paddle.utils.cpp_extension.extension_utils import custom_write_stub - - custom_write_stub(lib_name, stub_path) - # Ensure that binaries are not in global package space. - target_dir = install_dir / "transformer_engine" + lib_dir = ( + "wheel_lib" + if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib + else "" + ) + target_dir = install_dir / "transformer_engine" / lib_dir target_dir.mkdir(exist_ok=True, parents=True) for ext in Path(self.build_lib).glob("*.so"): self.copy_file(ext, target_dir) os.remove(ext) - # For paddle, the stub file needs to be copied to the install location. - if paddle_ext is not None: - stub_path = Path(self.build_lib) / "transformer_engine" - for stub in stub_path.glob("transformer_engine_paddle.py"): - self.copy_file(stub, target_dir) - def build_extensions(self): - # BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly + # BuildExtensions from PyTorch already handle CUDA files correctly # so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed. - if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks(): + if "pytorch" not in get_frameworks(): # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # extra_compile_args is a dict. for ext in self.extensions: diff --git a/build_tools/jax.py b/build_tools/jax.py index f829230f50..7e0652c629 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/paddle.py b/build_tools/paddle.py deleted file mode 100644 index a68d73956e..0000000000 --- a/build_tools/paddle.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Paddle-paddle related extensions.""" -from pathlib import Path - -import setuptools -import os - -from .utils import cuda_version - -import paddle - -paddle_version = paddle.__version__.replace(".", "") - - -def setup_paddle_extension( - csrc_source_files, - csrc_header_files, - common_header_files, -) -> setuptools.Extension: - """Setup CUDA extension for Paddle support""" - - # Source files - csrc_source_files = Path(csrc_source_files) - sources = [ - csrc_source_files / "extensions.cpp", - csrc_source_files / "common.cpp", - csrc_source_files / "custom_ops.cu", - ] - - # Header files - include_dirs = [ - common_header_files, - common_header_files / "common", - common_header_files / "common" / "include", - csrc_header_files, - ] - - # Compiler flags - cxx_flags = ["-O3"] - nvcc_flags = [ - "-O3", - "-gencode", - "arch=compute_70,code=sm_70", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - f"-DPADDLE_VERSION={paddle_version}", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - ] - - # Version-dependent CUDA options - try: - version = cuda_version() - except FileNotFoundError: - print("Could not determine CUDA Toolkit version") - else: - if version < (12, 0): - raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") - nvcc_flags.extend( - ( - "--threads", - os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"), - "-gencode", - "arch=compute_80,code=sm_80", - "-gencode", - "arch=compute_90,code=sm_90", - ) - ) - - # Construct Paddle CUDA extension - sources = [str(path) for path in sources] - include_dirs = [str(path) for path in include_dirs] - from paddle.utils.cpp_extension import CUDAExtension - - ext = CUDAExtension( - sources=sources, - include_dirs=include_dirs, - extra_compile_args={ - "cxx": cxx_flags, - "nvcc": nvcc_flags, - }, - ) - ext.name = "transformer_engine_paddle_pd_" - return ext diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 575b7bee79..b8501e1008 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -27,7 +27,6 @@ def setup_pytorch_extension( extensions_dir = csrc_source_files / "extensions" sources = [ csrc_source_files / "common.cpp", - csrc_source_files / "ts_fp8_op.cpp", ] + all_files_in_dir(extensions_dir) # Header files diff --git a/build_tools/te_version.py b/build_tools/te_version.py index b40fb26014..0aee63f647 100644 --- a/build_tools/te_version.py +++ b/build_tools/te_version.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/utils.py b/build_tools/utils.py index d846b87f22..723f2f200c 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -190,7 +190,12 @@ def cuda_path() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90") + version = cuda_version() + if os.getenv("NVTE_CUDA_ARCHS") is None: + os.environ["NVTE_CUDA_ARCHS"] = ( + "70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90" + ) + return os.getenv("NVTE_CUDA_ARCHS") def cuda_version() -> Tuple[int, ...]: @@ -211,7 +216,7 @@ def cuda_version() -> Tuple[int, ...]: def get_frameworks() -> List[str]: """DL frameworks to build support for""" _frameworks: List[str] = [] - supported_frameworks = ["pytorch", "jax", "paddle"] + supported_frameworks = ["pytorch", "jax"] # Check environment variable if os.getenv("NVTE_FRAMEWORK"): @@ -237,12 +242,6 @@ def get_frameworks() -> List[str]: pass else: _frameworks.append("jax") - try: - import paddle - except ImportError: - pass - else: - _frameworks.append("paddle") # Special framework names if "all" in _frameworks: @@ -311,7 +310,6 @@ def uninstall_te_wheel_packages(): "-y", "transformer_engine_cu12", "transformer_engine_torch", - "transformer_engine_paddle", "transformer_engine_jax", ] ) diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index 7d839958cb..223c4a7f1c 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 7dedf2a761..26122eed9b 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 7682a2b6aa..9acb22aee6 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,7 +9,6 @@ BUILD_METAPACKAGE=${2:-true} BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} BUILD_JAX=${5:-true} -BUILD_PADDLE=${6:-true} export NVTE_RELEASE_BUILD=1 export TARGET_BRANCH=${TARGET_BRANCH:-} @@ -63,38 +62,3 @@ if $BUILD_JAX ; then /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi - -if $BUILD_PADDLE ; then - if [ "$PLATFORM" == "manylinux_2_28_x86_64" ] ; then - dnf -y remove --allowerasing cudnn9-cuda-12 - dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 - cd /TransformerEngine/transformer_engine/paddle - - /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt - /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt - /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt - /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt - /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps - /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 - /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt - /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu - - mv dist/* /wheelhouse/ - fi -fi diff --git a/build_tools/wheel_utils/launch_aarch.sh b/build_tools/wheel_utils/launch_aarch.sh index 9a8d796119..04e3cd6916 100644 --- a/build_tools/wheel_utils/launch_aarch.sh +++ b/build_tools/wheel_utils/launch_aarch.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/wheel_utils/launch_x86.sh b/build_tools/wheel_utils/launch_x86.sh index 7b5649a642..b0d20be3f4 100644 --- a/build_tools/wheel_utils/launch_x86.sh +++ b/build_tools/wheel_utils/launch_x86.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/api/c/activation.rst b/docs/api/c/activation.rst index 1790121236..5b50aa513d 100644 --- a/docs/api/c/activation.rst +++ b/docs/api/c/activation.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/cast.rst b/docs/api/c/cast.rst index ef98441812..2ae05a8456 100644 --- a/docs/api/c/cast.rst +++ b/docs/api/c/cast.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/fused_attn.rst b/docs/api/c/fused_attn.rst index a0b6255ebe..6db67f26fe 100644 --- a/docs/api/c/fused_attn.rst +++ b/docs/api/c/fused_attn.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/fused_rope.rst b/docs/api/c/fused_rope.rst new file mode 100644 index 0000000000..289bb53d9b --- /dev/null +++ b/docs/api/c/fused_rope.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +fused_rope.h +============ + +.. doxygenfile:: fused_rope.h + diff --git a/docs/api/c/gemm.rst b/docs/api/c/gemm.rst index e7a14cab97..711733fc4c 100644 --- a/docs/api/c/gemm.rst +++ b/docs/api/c/gemm.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/index.rst b/docs/api/c/index.rst index ae0b6ddfa1..7bc864dcc8 100644 --- a/docs/api/c/index.rst +++ b/docs/api/c/index.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -12,12 +12,16 @@ directly from C/C++, without Python. .. toctree:: :caption: Headers + transformer_engine.h activation.h cast.h - gemm.h fused_attn.h - layer_norm.h - rmsnorm.h + fused_rope.h + gemm.h + normalization.h + padding.h + permutation.h + recipe.h softmax.h - transformer_engine.h + swizzle.h transpose.h diff --git a/docs/api/c/layer_norm.rst b/docs/api/c/layer_norm.rst deleted file mode 100644 index 47c0585a42..0000000000 --- a/docs/api/c/layer_norm.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - - See LICENSE for license information. - -layer_norm.h -============ - -.. doxygenfile:: layer_norm.h diff --git a/docs/api/c/normalization.rst b/docs/api/c/normalization.rst new file mode 100644 index 0000000000..edbea00ac0 --- /dev/null +++ b/docs/api/c/normalization.rst @@ -0,0 +1,9 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +normalization.h +=============== + +.. doxygenfile:: normalization.h diff --git a/docs/api/c/padding.rst b/docs/api/c/padding.rst new file mode 100644 index 0000000000..2141b874d2 --- /dev/null +++ b/docs/api/c/padding.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +padding.h +========= + +.. doxygenfile:: padding.h + diff --git a/docs/api/c/permutation.rst b/docs/api/c/permutation.rst new file mode 100644 index 0000000000..bad6961621 --- /dev/null +++ b/docs/api/c/permutation.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +permutation.h +============= + +.. doxygenfile:: permutation.h + diff --git a/docs/api/c/recipe.rst b/docs/api/c/recipe.rst new file mode 100644 index 0000000000..7c368f69b6 --- /dev/null +++ b/docs/api/c/recipe.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +recipe.h +======== + +.. doxygenfile:: recipe.h + diff --git a/docs/api/c/rmsnorm.rst b/docs/api/c/rmsnorm.rst deleted file mode 100644 index fba3b97c57..0000000000 --- a/docs/api/c/rmsnorm.rst +++ /dev/null @@ -1,9 +0,0 @@ -.. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - - See LICENSE for license information. - -rmsnorm.h -============ - -.. doxygenfile:: rmsnorm.h diff --git a/docs/api/c/softmax.rst b/docs/api/c/softmax.rst index 69875d603c..55dc5d47de 100644 --- a/docs/api/c/softmax.rst +++ b/docs/api/c/softmax.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/swizzle.rst b/docs/api/c/swizzle.rst new file mode 100644 index 0000000000..b2dd8f5977 --- /dev/null +++ b/docs/api/c/swizzle.rst @@ -0,0 +1,10 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +swizzle.h +========= + +.. doxygenfile:: swizzle.h + diff --git a/docs/api/c/transformer_engine.rst b/docs/api/c/transformer_engine.rst index ec474592c3..b5fd95e005 100644 --- a/docs/api/c/transformer_engine.rst +++ b/docs/api/c/transformer_engine.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/c/transpose.rst b/docs/api/c/transpose.rst index d839f3d3b1..9a3ba9e48b 100644 --- a/docs/api/c/transpose.rst +++ b/docs/api/c/transpose.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/common.rst b/docs/api/common.rst index 40afd88ff3..95d4b50f30 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -8,4 +8,6 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.Format -.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None, override_linear_precision=(False, False, False)) +.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None) + +.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) diff --git a/docs/api/framework.rst b/docs/api/framework.rst index 88785f941a..0ac1a0e34e 100644 --- a/docs/api/framework.rst +++ b/docs/api/framework.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -10,4 +10,3 @@ Framework-specific API pytorch jax - paddle diff --git a/docs/api/jax.rst b/docs/api/jax.rst index c7701bd699..d72af37ec5 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/api/paddle.rst b/docs/api/paddle.rst deleted file mode 100644 index ad23031f58..0000000000 --- a/docs/api/paddle.rst +++ /dev/null @@ -1,34 +0,0 @@ -.. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - - See LICENSE for license information. - -paddle -====== - -.. autoapiclass:: transformer_engine.paddle.Linear(in_features, out_features, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.LayerNorm(hidden_size, eps=1e-5, **kwargs) - -.. autoapiclass:: transformer_engine.paddle.LayerNormLinear(in_features, out_features, eps=1e-5, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.LayerNormMLP(hidden_size, ffn_hidden_size, eps=1e-5, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.FusedScaleMaskSoftmax(attn_mask_type, mask_func, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.DotProductAttention(num_attention_heads, kv_channels, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.MultiHeadAttention(hidden_size, num_attention_heads, **kwargs) - :members: forward - -.. autoapiclass:: transformer_engine.paddle.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) - :members: forward - -.. autoapifunction:: transformer_engine.paddle.fp8_autocast - -.. autoapifunction:: transformer_engine.paddle.recompute diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index ba4e7db352..ca4bd91420 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -31,7 +31,7 @@ pyTorch .. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs) :members: forward, set_context_parallel_group, set_tensor_parallel_group -.. autoapiclass:: transformer_engine.pytorch.InferenceParams(max_batch_size, max_sequence_length) +.. autoapiclass:: transformer_engine.pytorch.dot_product_attention.inference.InferenceParams(max_batch_size, max_sequence_length) .. autoapiclass:: transformer_engine.pytorch.CudaRNGStatesTracker() :members: reset, get_states, set_states, add, fork @@ -42,16 +42,22 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.checkpoint -.. autoapifunction:: transformer_engine.pytorch.onnx_export - .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context .. autoapifunction:: transformer_engine.pytorch.moe_permute +.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs + .. autoapifunction:: transformer_engine.pytorch.moe_unpermute +.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index + +.. autoapifunction:: transformer_engine.pytorch.parallel_cross_entropy + +.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index_with_probs + .. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub diff --git a/docs/conf.py b/docs/conf.py index 7d2d4ea7b9..4083bfd242 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/E8M0.png b/docs/examples/E8M0.png new file mode 100644 index 0000000000..841df25e74 Binary files /dev/null and b/docs/examples/E8M0.png differ diff --git a/docs/examples/MXFP8_FP8_comparison_1.png b/docs/examples/MXFP8_FP8_comparison_1.png new file mode 100644 index 0000000000..a2f5a28de8 Binary files /dev/null and b/docs/examples/MXFP8_FP8_comparison_1.png differ diff --git a/docs/examples/MXFP8_FP8_comparison_2.png b/docs/examples/MXFP8_FP8_comparison_2.png new file mode 100644 index 0000000000..f5cbc81bfe Binary files /dev/null and b/docs/examples/MXFP8_FP8_comparison_2.png differ diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index 85ce01079c..e9eec14d99 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 27017b4773..d20cd5c74e 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -14,11 +14,10 @@ "
Figure 1: Dot product attention.
\n", "\n", "\n", - "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in three frameworks, [PyTorch](https://github.com/pytorch/pytorch), [JAX](https://github.com/google/jax) and [PaddlePaddle](https://github.com/PaddlePaddle/Paddle). The API for each framework is\n", + "[Transformer Engine](https://github.com/NVIDIA/TransformerEngine.git) supports the calculation of dot product attention in two frameworks, [PyTorch](https://github.com/pytorch/pytorch) and [JAX](https://github.com/google/jax). The API for each framework is\n", "\n", "- [transformer_engine.pytorch.DotProductAttention](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention)\n", - "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)\n", - "- [transformer_engine.paddle.DotProductAttention](../../api/paddle.rst#transformer_engine.paddle.DotProductAttention)" + "- [transformer_engine.jax.flax.DotProductAttention](../../api/jax.rst#transformer_engine.jax.flax.DotProductAttention)" ] }, { @@ -56,15 +55,6 @@ " \n", " JAX-native attention (`_UnfusedDotProductAttention`)\n", " \n", - " \n", - " PaddlePaddle\n", - " cuDNN attention (`_te_forward`) \n", - " [transformer_engine.paddle.layer.attention](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/paddle/layer/attention.py)\n", - " \n", - " \n", - " \n", - " PaddlePaddle-native attention (`_pd_forward`)\n", - " \n", " \n", "" ] @@ -87,7 +77,7 @@ "
\n", "Note: \n", " \n", - "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.\n", + "Transformer Engine's flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), available in PyTorch and JAX, are both based on the flash algorithm.\n", "
\n" ] }, @@ -102,13 +92,13 @@ "\n", "The flash-attention backend supports `flash-attn`'s features as well as a few extra functionalities to facilitate the use of `flash-attn`, such as converting the `attention_mask` to cumulative sequence lengths `cu_seqlens` for `padding` mask use cases. Please see `transformer_engine.pytorch.attention.FlashAttention` for details.\n", "\n", - "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v1.10, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", + "The `flash-attn` dependency is regularly updated in Transformer Engine. As of v2.0, Transformer Engine supports `flash-attn` 2.0.6+ (see [setup.py](https://github.com/NVIDIA/TransformerEngine/blob/main/setup.py)).\n", "\n", "To understand `flash-attn`'s performance, please refer to their benchmarks [here](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#performance).\n", "\n", "### 1.3 cuDNN Attention\n", "\n", - "The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", + "The cuDNN attention backend, available in PyTorch and JAX, offers another high-performance solution to the attention calculation. It requires [cuDNN](https://developer.nvidia.com/cudnn) to run, and has several sub-backends to support the different precisions and sequence lengths.\n", "\n", "\n", " \n", @@ -153,9 +143,9 @@ " \n", "
\n", "\n", - "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.10, cuDNN 9.3 and `flash-attn` 2.4.2,\n", + "The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 2.0, cuDNN 9.3 and `flash-attn` 2.4.2,\n", "\n", - "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.\n", + "- flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch and JAX.\n", "- flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).\n", "- flash-attention supports `bshd`, `thd` input formats, without any transposes, and `sbhd` format, with transposes, while cuDNN attention supports all three formats without transposes (see Section 3.1 for more details).\n", "- flash-attention does not support `post_scale_bias`, and cuDNN attention does.\n", @@ -244,10 +234,6 @@ " JAX\n", " cuDNN attention > JAX-native attention\n", " \n", - " \n", - " PaddlePaddle\n", - " cuDNN attention > PaddlePaddle-native attention \n", - " \n", "" ] }, @@ -266,7 +252,7 @@ "
\n", "Note:\n", " \n", - "These flags are supported in PyTorch only as of Transformer Engine 1.10. JAX and PaddlePaddle support is expected to be added in the future.\n", + "These flags are supported in PyTorch only as of Transformer Engine 2.0. JAX support is expected to be added in the future.\n", "
" ] }, @@ -382,7 +368,7 @@ "
\n", "Note\n", " \n", - "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.\n", + "Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX in the future.\n", "
\n", "\n", "### 2.3 Example Tests\n", @@ -399,7 +385,7 @@ "source": [ "## 3. Backend Support\n", "\n", - "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.10, Transformer Engine's attention backends have the following support matrix.\n", + "Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v2.0, Transformer Engine's attention backends have the following support matrix.\n", "\n", "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", @@ -442,7 +428,7 @@ "**qkv_layout=thd_thd_thd:**\n", "`q`, `k`, `v` have variable sequence lengths in a batch. They are all contiguous and have no interleaving.\n", "\n", - "As of v1.10, Transformer Engine has the following support matrix.\n", + "As of v2.0, Transformer Engine has the following support matrix.\n", "\n", "\n", " \n", @@ -462,17 +448,17 @@ " \n", " \n", " \n", " \n", " \n", " \n", " \n", - " \n", + " \n", " \n", "
\n", - " JAX, PaddlePaddle: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n", + " JAX: `bs3hd`, `bshd_bs2hd`, `bshd_bshd_bshd` layouts\n", "
Framework-native attention`bshd`, `sbhd`PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layoutsPyTorch, JAX: 2 formats, i.e. 10 layouts
\n", "\n", - "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", + "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "\n", "
\n", "Note\n", @@ -492,7 +478,7 @@ "\n", "- `no_mask`, `padding`, `causal`, `causal_bottom_right`, `padding_causal`, `padding_causal_bottom_right`, `arbitrary`\n", "\n", - "Different backends offer different support for attention mask. As of Transformer Engine 1.10,\n", + "Different backends offer different support for attention mask. As of Transformer Engine 2.0,\n", "\n", "\n", " \n", @@ -512,21 +498,21 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", " \n", "
Framework-native attention
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax, PaddlePaddle)
  • All (PyTorch)
  • `no_mask`, `causal`, `padding` (Jax)
  • \n", "\n", - "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.10, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.\n", + "**Padding masks:** For `padding`, `padding_causal`, `padding_causal_bottom_right` mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 2.0, there are two options to do so in PyTorch and one in JAX.\n", "\n", "* PyTorch: When both options are provided by the user, `cu_seqlens` is preferred as there is no extra conversion needed.\n", " - `cu_seqlens`: Users can provide cumulative sequence length tensors `cu_seqlens_q` and `cu_seqlens_kv` for `q` and `k`/`v` to the flash-attention or cuDNN attention backend. An example of `cu_seqlens` is `[0, 2, 6, 7]` for a batch of 3 `[aa000, bbbb0, c0000]`.\n", " - `attention_mask`: Users can also provide `attention_mask` as an alternative, which will then be converted to `cu_seqlens`. For self-attention, `attention_mask` should be one single tensor in shape `[batch_size, 1, 1, seqlen_q]`, and for cross-attention, `attention_mask` should be a list of two tensors in shapes `[batch_size, 1, 1, seqlen_q]` and `[batch_size, 1, 1, seqlen_kv]`, respectively.\n", "\n", "\n", - "* JAX and PaddlePaddle: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", + "* JAX: Users should provide the `attention_mask` tensor in shape `[batch_size, 1, seqlen_q, seqlen_kv]`.\n", "\n", "**qkv_format=thd:** Transformer Engine extracts the max sequence length information from `q`, `k`, `v` if `max_seqlen_q` and `max_seqlen_kv` are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set `max_seqlen_q` and `max_seqlen_kv` to their appropriate values for `thd` QKV format.\n", "\n", @@ -566,7 +552,7 @@ "\n", "### 3.3 Attention Bias\n", "\n", - "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 1.10, their support matrix is as follows.\n", + "Transformer Engine supports 4 attention bias types, `no_bias`, `pre_scale_bias`, `post_scale_bias`, and `ALiBi` (with/without custom slopes). As of Transformer Engine 2.0, their support matrix is as follows.\n", "\n", "\n", " \n", @@ -591,7 +577,7 @@ " \n", " \n", " \n", - " \n", + " \n", " \n", " \n", " \n", @@ -620,7 +606,7 @@ "\n", "A unique feature of Transformer Engine is its FP8 support, not only for the `Linear` layers but also for dot product attention. Transformer Engine's FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two `MatMul` operations are performed in FP8 for computational efficiency, and the `SoftMax` operation is performed in FP32 for numerical accuracy.\n", "\n", - "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v1.10. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", + "Transformer Engine supports FP8 attention through its [C APIs](../../api/c/fused_attn.rst), and [PyTorch API](../../api/pytorch.rst#transformer_engine.pytorch.DotProductAttention), as of v2.0. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, `transformer_engine.common.recipe.DelayedScaling`.\n", "\n", "- `DelayedScaling.fp8_dpa=True (default=False)`: This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The `FusedAttention` module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.\n", "\n", diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 15022005a1..2c32e8b5f7 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/examples/fp8_primer.ipynb b/docs/examples/fp8_primer.ipynb index b8a63dabff..788d6c37ae 100644 --- a/docs/examples/fp8_primer.ipynb +++ b/docs/examples/fp8_primer.ipynb @@ -18,7 +18,7 @@ "* E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and `nan`.\n", "* E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf` and `nan`. The tradeoff of the increased dynamic range is lower precision of the stored values.\n", "\n", - "
    \n", + "
    \n", "\n", "
    Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.
    \n", "
    \n", @@ -56,6 +56,50 @@ "As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration." ] }, + { + "cell_type": "markdown", + "id": "f03b58ed-71e8-422a-95be-35c1cc60c4e2", + "metadata": {}, + "source": [ + "## MXFP8 and block scaling\n", + "\n", + "NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: [MXFP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). \n", + "\n", + "### MXFP8 vs FP8\n", + "\n", + "The main difference between \"regular\" FP8 and MXFP8 lies in the granularity of the scaling. In FP8, each tensor has a single FP32 scaling factor, so all values in the tensor need to \"fit\" within the dynamic range of the FP8 datatype. This requires using the less precise E5M2 format to represent some tensors in the network (like gradients).\n", + "\n", + "MXFP8 addresses this by assigning a different scaling factor to each block of 32 [consecutive](#handling-transposes) values. This allows all values to be represented with the E4M3 datatype.\n", + "\n", + "
    \n", + "\n", + "
    Figure 4: MXFP8 uses multiple scaling factors for a single tensor. The picture shows only 4 values per block for simplicity, but real MXFP8 has 32 values per block.
    \n", + "
    \n", + "\n", + "
    \n", + "\n", + "
    Figure 5: Due to multiple scaling factors, tensor's dynamic range requirements are reduced and so E4M3 format can be used as far fewer elements get saturated to 0.
    \n", + "
    \n", + "\n", + "The second difference is the datatype used to store the scaling factors. FP8 uses FP32 (E8M23) while MXFP8 uses an 8-bit representation of a power of 2 (E8M0).\n", + "\n", + "
    \n", + "\n", + "
    Figure 6: Structure of the E8M0 datatype used for storing scaling factors in MXFP8.
    \n", + "
    \n", + "\n", + "### Handling transposes\n", + "\n", + "The forward and backward passes of linear layers involve multiple matrix multiplications with different reduction dimensions. Blackwell Tensor Cores require MXFP8 data to be \"consecutive\" over the reduction dimension, so MXFP8 training uses non-transposed and transposed MXFP8 tensors at different points. However, while transposing FP8 data is numerically trivial, transposing MXFP8 data requires requantization.\n", + "\n", + "To avoid loss of precision connected with this double quantization, Transformer Engine creates both regular and transposed copies of the tensor from the original high precision input.\n", + "\n", + "
    \n", + "\n", + "
    Figure 7: Linear layer in MXFP8. Calculating both forward and backward pass requires tensors quantized in both directions.
    \n", + "
    " + ] + }, { "cell_type": "markdown", "id": "cf5e0b0d", @@ -63,11 +107,12 @@ "source": [ "## Using FP8 with Transformer Engine\n", "\n", - "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using delayed scaling strategy.\n", + "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n", "\n", "### FP8 recipe\n", "\n", - "[DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from `transformer_engine.common.recipe` module stores all of the required options for FP8 training - length of the amax history to use for scaling factor computation, FP8 data format etc." + "The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", + "Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training." ] }, { @@ -77,10 +122,12 @@ "metadata": {}, "outputs": [], "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling\n", + "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n", "\n", "fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")" + "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", + "mxfp8_format = Format.E4M3 # E4M3 used everywhere\n", + "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)" ] }, { @@ -341,7 +388,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/docs/examples/linear_mxfp8.png b/docs/examples/linear_mxfp8.png new file mode 100644 index 0000000000..3434732835 Binary files /dev/null and b/docs/examples/linear_mxfp8.png differ diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index 0582efd52e..f7a81d4d82 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -1,12 +1,11 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import math -from typing import Callable, Optional +from typing import Optional import torch import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type def speedometer( @@ -204,16 +203,13 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model): def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"): - import transformer_engine.pytorch.cpp_extensions as texcpp + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine_torch as tex - from transformer_engine.pytorch.constants import TE_DType fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2 - input_type = TE_DType[inp.dtype] - meta = tex.FP8TensorMeta() - meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale - meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale - meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type) - ret = texcpp.cast_from_fp8(ret, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type, input_type) + scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale + amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda") + quantizer = Float8Quantizer(scale=scale, amax=amax_history, fp8_dtype=fp8_type) + ret = quantizer(inp) + ret = ret.dequantize() return ret diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index 4413bdfd00..3ddf7f411a 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,7 +11,7 @@ from torch import nn import transformer_engine as te -from transformer_engine.pytorch.attention import RotaryPositionEmbedding +from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.fp8 import fp8_model_init import transformers diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 1aebe13afb..66f05701f5 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/docs/faq.rst b/docs/faq.rst index 50b3a7481e..2f9cbd2720 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/index.rst b/docs/index.rst index 38e095c239..cd9ce41cf5 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. diff --git a/docs/installation.rst b/docs/installation.rst index 9ac0ddf841..10046d6306 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,5 +1,5 @@ .. - Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. See LICENSE for license information. @@ -12,10 +12,9 @@ Prerequisites .. _driver link: https://www.nvidia.com/drivers 1. Linux x86_64 -2. `CUDA 12.0 `__ -3. |driver link|_ supporting CUDA 12.0 or later. -4. `cuDNN 8.1 `__ or later. -5. For FP8/FP16/BF16 fused attention, `CUDA 12.1 `__ or later, |driver link|_ supporting CUDA 12.1 or later, and `cuDNN 8.9.1 `__ or later. +2. `CUDA 12.1+ (12.8+ for Blackwell support) `__ +3. |driver link|_ supporting CUDA 12.1 or later. +4. `cuDNN 9.3 `__ or later. If the CUDA Toolkit headers are not available at runtime in a standard installation path, e.g. within `CUDA_HOME`, set @@ -35,9 +34,9 @@ Transformer Engine can be directly installed from `our PyPI = 80 + + +@lru_cache +def is_fp8_supported(): + """Return if FP8 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 90 diff --git a/examples/jax/encoder/conftest.py b/examples/jax/encoder/conftest.py new file mode 100644 index 0000000000..b1648892aa --- /dev/null +++ b/examples/jax/encoder/conftest.py @@ -0,0 +1,20 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""config for test_multiprocessing_encoder""" +import pytest + + +def pytest_addoption(parser): + """Pytest hook for test_multiprocessing_encoder""" + parser.addoption("--num-process", action="store", default=0) + parser.addoption("--process-id", action="store", default=0) + + +@pytest.fixture(autouse=True) +def multiprocessing_parses(request): + """Fixture for querying num-process and process-id""" + if request.cls: + request.cls.num_process = int(request.config.getoption("--num-process")) + request.cls.process_id = int(request.config.getoption("--process-id")) diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh new file mode 100644 index 0000000000..6a1dd96739 --- /dev/null +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -0,0 +1,17 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i & +done +wait + +for i in $(seq 0 $(($NUM_GPUS-1))) +do + pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i & +done +wait diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index bafd9bd2fb..228105d553 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training on multi-GPU with tesnor parallelism""" @@ -56,7 +56,6 @@ def __call__(self, x, mask, disable_dropout=False): self_attn_mask_type="padding", enable_relative_embedding=False, enable_sequence_parallel=self.enable_seq_paral, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) @@ -72,17 +71,15 @@ def __call__(self, x, mask, disable_dropout=False): features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16, )(x) x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16, )(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -91,7 +88,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -136,7 +133,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -239,7 +236,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding @@ -447,7 +444,7 @@ def test_te_fp8(self): """Test Transformer Engine with FP8""" self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_sp(self): @@ -462,7 +459,7 @@ def test_te_fp8_sp(self): self.args.enable_sp = True self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.785 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index a4a19b43c2..0dab636718 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training on multi-GPU with data parallelism""" @@ -51,17 +51,16 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -70,7 +69,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -115,7 +114,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -218,7 +217,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index f54deff69c..6522ed896a 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -1,12 +1,12 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training with multi-GPU, multiprocessing, and tensor parallelism""" import argparse -import multiprocessing as mp import os import unittest from functools import partial +import pytest import flax import jax @@ -21,10 +21,10 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding +from common import is_bf16_supported, is_fp8_supported import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from common import is_bf16_supported os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" @@ -57,7 +57,6 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) @@ -67,17 +66,15 @@ def __call__(self, x, mask, disable_dropout=False): features=256, kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS), bias_axes=(NAMED_TP_AXIS,), - dtype=jnp.bfloat16, )(x) x = te_flax.DenseGeneral( features=256, kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS), bias_axes=(NAMED_BROADCAST_AXIS,), - dtype=jnp.bfloat16, )(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -252,7 +249,6 @@ def eval_model( def data_preprocess(dataset, vocab, word_id, max_seq_len): """Convert tokens to numbers.""" - nltk.download("punkt_tab") dataset_size = len(dataset["sentence"]) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) @@ -321,7 +317,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding @@ -342,6 +338,9 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) + if args.process_id == 0: + nltk.download("punkt_tab") + train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) jax.distributed.initialize( @@ -551,69 +550,41 @@ def encoder_parser(args): return parser.parse_args(args) -def query_gpu(q): - """Query GPU info on the system""" - gpu_has_fp8, reason = te.fp8.is_fp8_available() - gpu_has_bf16 = is_bf16_supported() - num_gpu = len(jax.devices()) - q.put([num_gpu, gpu_has_fp8, gpu_has_bf16, reason]) - - -def unittest_query_gpu(): - r""" - It is only used by TestEncoder. - The `jax.distributed.initialize` must be called before any other JAX or Flax API, - otherwise `jax.local_devices` will be incorrect. - Thus, fork another process to query number of GPUs and FP8 capability. - """ - q = mp.Queue() - p = mp.Process(target=query_gpu, args=(q,)) - p.start() - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = q.get() - p.join() - return num_gpu, gpu_has_fp8, gpu_has_bf16, reason - - +@pytest.mark.usefixtures("multiprocessing_parses") class TestEncoder(unittest.TestCase): """Encoder unittests""" - num_gpu, gpu_has_fp8, gpu_has_bf16, reason = unittest_query_gpu() + gpu_has_fp8 = is_fp8_supported() + gpu_has_bf16 = is_bf16_supported() def exec(self, use_fp8): """Run 3 epochs for testing""" - num_gpu = self.num_gpu + args = encoder_parser([]) + + num_gpu = self.num_process tp_size = 2 if num_gpu > 1 and num_gpu % 2 == 0 else 1 dp_size = num_gpu // tp_size batch_size = 64 // dp_size - arg_list = [] - for i in range(num_gpu): - args = encoder_parser([]) - args.num_process = num_gpu - args.use_fp8 = use_fp8 - args.batch_size = batch_size - args.test_batch_size = batch_size - args.process_id = i - arg_list.append(args) - - with mp.Pool(self.num_gpu) as p: - results = p.map(train_and_evaluate, arg_list) + args.use_fp8 = use_fp8 + args.batch_size = batch_size + args.test_batch_size = batch_size + args.num_process = num_gpu + args.process_id = self.process_id - return results + return train_and_evaluate(args) @unittest.skipIf(not gpu_has_bf16, "Device compute capability 8.0+ is required for BF16") def test_te_bf16(self): """Test Transformer Engine with BF16""" - results = self.exec(False) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(False) + assert result[0] < 0.45 and result[1] > 0.79 - @unittest.skipIf(not gpu_has_fp8, reason) + @unittest.skipIf(not gpu_has_fp8, "Device compute capability 9.0+ is required for FP8") def test_te_fp8(self): """Test Transformer Engine with FP8""" - results = self.exec(True) - actual = results[0] - assert actual[0] < 0.45 and actual[1] > 0.79 + result = self.exec(True) + assert result[0] < 0.455 and result[1] > 0.79 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index ac71fe4c0e..cfbd30b767 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Encoder training on single GPU""" @@ -46,17 +46,16 @@ def __call__(self, x, mask, disable_dropout=False): layer_type=te_flax.TransformerLayerType.ENCODER, self_attn_mask_type="padding", enable_relative_embedding=False, - dtype=jnp.bfloat16, ) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = te_flax.DenseGeneral(features=256, dtype=jnp.bfloat16)(x) + x = te_flax.DenseGeneral(features=256)(x) - x = nn.Dense(features=2, dtype=jnp.bfloat16)(x) + x = nn.Dense(features=2)(x) return x @@ -66,7 +65,7 @@ def train_step(state, inputs, masks, labels, var_collect, rngs): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -112,7 +111,7 @@ def eval_step(state, inputs, masks, labels, var_collect): def loss_fn(var_collect, disable_dropout=False): logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) - one_hot = jax.nn.one_hot(labels, 2) + one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -217,6 +216,7 @@ def train_and_evaluate(args): with te.fp8_autocast(enabled=args.use_fp8): encoder = Net(num_embed) + # We use nn.Embed, thus inputs need to be in int inputs = jnp.zeros(input_shape, dtype=jnp.int32) masks = jnp.zeros(mask_shape, dtype=jnp.uint8) var_collect = encoder.init(init_rngs, inputs, masks) @@ -334,7 +334,7 @@ def test_te_fp8(self): """Test Transformer Engine with FP8""" self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.79 if __name__ == "__main__": diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index b251bb72ca..9d8f51cc16 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """MNIST training on single GPU""" @@ -36,6 +36,8 @@ def __call__(self, x, disable_dropout=False): nn_Dense = te_flax.DenseGeneral else: nn_Dense = nn.Dense + # dtype is used for param init in TE but computation in Linen.nn + dtype = jnp.float32 if self.use_te else jnp.bfloat16 x = nn.Conv(features=32, kernel_size=(3, 3), strides=1, dtype=jnp.bfloat16)(x) x = nn.relu(x) @@ -44,11 +46,13 @@ def __call__(self, x, disable_dropout=False): x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Dropout(rate=0.25)(x, deterministic=disable_dropout) x = x.reshape(x.shape[0], -1) - x = nn_Dense(features=128, dtype=jnp.bfloat16)(x) + assert x.dtype == jnp.bfloat16 + x = nn_Dense(features=128, dtype=dtype)(x) x = nn.relu(x) x = nn.Dropout(rate=0.5)(x, deterministic=disable_dropout) - x = nn_Dense(features=16, dtype=jnp.bfloat16)(x) - x = nn.Dense(features=10, dtype=jnp.bfloat16)(x) + x = nn_Dense(features=16, dtype=dtype)(x) + x = nn_Dense(features=10, dtype=dtype)(x) + assert x.dtype == jnp.bfloat16 return x diff --git a/examples/paddle/mnist/README.md b/examples/paddle/mnist/README.md deleted file mode 100644 index adb0144779..0000000000 --- a/examples/paddle/mnist/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Basic MNIST Example - -```bash -python test_single_gpu_mnist.py -python test_single_gpu_mnist.py --use-te # Linear layers from TransformerEngine -python test_single_gpu_mnist.py --use-te --use-fp8 # FP8 + TransformerEngine for Linear layers -``` diff --git a/examples/paddle/mnist/test_single_gpu_mnist.py b/examples/paddle/mnist/test_single_gpu_mnist.py deleted file mode 100644 index de5c9e9b6c..0000000000 --- a/examples/paddle/mnist/test_single_gpu_mnist.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""MNIST example of Transformer Engine Paddle""" - -import argparse -import os -import unittest - -import paddle -from paddle import nn -import paddle.nn.functional as F - -from paddle.vision.transforms import Normalize -from paddle.io import DataLoader -from paddle.vision.datasets import MNIST -from paddle.metric import Accuracy - -import transformer_engine.paddle as te -from transformer_engine.paddle.fp8 import is_fp8_available - - -class Net(nn.Layer): - """Simple network used to train on MNIST""" - - def __init__(self, use_te=False): - super().__init__() - self.conv1 = nn.Conv2D(1, 32, 3, 1) - self.conv2 = nn.Conv2D(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - if use_te: - self.fc1 = te.Linear(9216, 128) - self.fc2 = te.Linear(128, 16) - else: - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 16) - self.fc3 = nn.Linear(16, 10) - - def forward(self, x): - """FWD""" - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = paddle.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - x = self.fc3(x) - return x - - -def train(args, model, train_loader, optimizer, epoch, use_fp8): - """Training function.""" - model.train() - losses = [] - for batch_id, (data, labels) in enumerate(train_loader): - with paddle.amp.auto_cast( - dtype="bfloat16", level="O2" - ): # pylint: disable=not-context-manager - with te.fp8_autocast(enabled=use_fp8): - outputs = model(data) - loss = F.cross_entropy(outputs, labels) - losses.append(loss.item()) - - loss.backward() - optimizer.step() - optimizer.clear_gradients() - - if batch_id % args.log_interval == 0: - print( - f"Train Epoch: {epoch} " - f"[{batch_id * len(data)}/{len(train_loader.dataset)} " - f"({100. * batch_id / len(train_loader):.0f}%)]\t" - f"Loss: {loss.item():.6f}" - ) - if args.dry_run: - return loss.item() - avg_loss = sum(losses) / len(losses) - print(f"Train Epoch: {epoch}, Average Loss: {avg_loss}") - return avg_loss - - -def evaluate(model, test_loader, epoch, use_fp8): - """Testing function.""" - model.eval() - metric = Accuracy() - metric.reset() - - with paddle.no_grad(): - for data, labels in test_loader: - with paddle.amp.auto_cast( - dtype="bfloat16", level="O2" - ): # pylint: disable=not-context-manager - with te.fp8_autocast(enabled=use_fp8): - outputs = model(data) - acc = metric.compute(outputs, labels) - metric.update(acc) - print(f"Epoch[{epoch}] - accuracy: {metric.accumulate():.6f}") - return metric.accumulate() - - -def calibrate(model, test_loader): - """Calibration function.""" - model.eval() - - with paddle.no_grad(): - for data, _ in test_loader: - with paddle.amp.auto_cast( - dtype="bfloat16", level="O2" - ): # pylint: disable=not-context-manager - with te.fp8_autocast(enabled=False, calibrating=True): - _ = model(data) - - -def mnist_parser(args): - """Parse training settings""" - parser = argparse.ArgumentParser(description="Paddle MNIST Example") - parser.add_argument( - "--batch-size", - type=int, - default=64, - metavar="N", - help="input batch size for training (default: 64)", - ) - parser.add_argument( - "--test-batch-size", - type=int, - default=1000, - metavar="N", - help="input batch size for testing (default: 1000)", - ) - parser.add_argument( - "--epochs", - type=int, - default=14, - metavar="N", - help="number of epochs to train (default: 14)", - ) - parser.add_argument( - "--lr", - type=float, - default=0.001, - metavar="LR", - help="learning rate (default: 0.001)", - ) - parser.add_argument( - "--dry-run", - action="store_true", - default=False, - help="quickly check a single pass", - ) - parser.add_argument( - "--save-model", - action="store_true", - default=False, - help="For Saving the current Model", - ) - parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") - parser.add_argument( - "--log-interval", - type=int, - default=10, - metavar="N", - help="how many batches to wait before logging training status", - ) - parser.add_argument( - "--use-fp8", - action="store_true", - default=False, - help=( - "Use FP8 for inference and training without recalibration. " - "It also enables Transformer Engine implicitly." - ), - ) - parser.add_argument( - "--use-fp8-infer", - action="store_true", - default=False, - help=( - "Use FP8 for inference only. If not using FP8 for training, " - "calibration is performed for FP8 infernece." - ), - ) - parser.add_argument( - "--use-te", action="store_true", default=False, help="Use Transformer Engine" - ) - args = parser.parse_args(args) - return args - - -def train_and_evaluate(args): - """Execute model training and evaluation loop.""" - print(args) - - paddle.seed(args.seed) - - # Load MNIST dataset - transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW") - train_dataset = MNIST(mode="train", transform=transform) - val_dataset = MNIST(mode="test", transform=transform) - - # Define data loaders - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=args.test_batch_size) - - # Define model and optimizer - model = Net(use_te=args.use_te) - optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters()) - - # Cast model to BF16 - model = paddle.amp.decorate(models=model, level="O2", dtype="bfloat16") - - for epoch in range(1, args.epochs + 1): - loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8) - acc = evaluate(model, val_loader, epoch, args.use_fp8) - - if args.use_fp8_infer and not args.use_fp8: - calibrate(model, val_loader) - - if args.save_model or args.use_fp8_infer: - paddle.save(model.state_dict(), "mnist_cnn.pdparams") - print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8)) - weights = paddle.load("mnist_cnn.pdparams") - model.set_state_dict(weights) - acc = evaluate(model, val_loader, 0, args.use_fp8) - - return loss, acc - - -class TestMNIST(unittest.TestCase): - """MNIST unittests""" - - gpu_has_fp8, reason = is_fp8_available() - - @classmethod - def setUpClass(cls): - """Run MNIST without Transformer Engine""" - cls.args = mnist_parser(["--epochs", "5"]) - - @staticmethod - def verify(actual): - """Check If loss and accuracy match target""" - desired_traing_loss = 0.1 - desired_test_accuracy = 0.98 - assert actual[0] < desired_traing_loss - assert actual[1] > desired_test_accuracy - - @unittest.skipIf( - paddle.device.cuda.get_device_capability() < (8, 0), - "BF16 MNIST example requires Ampere+ GPU", - ) - def test_te_bf16(self): - """Test Transformer Engine with BF16""" - self.args.use_te = True - self.args.use_fp8 = False - self.args.save_model = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - @unittest.skipIf(not gpu_has_fp8, reason) - def test_te_fp8(self): - """Test Transformer Engine with FP8""" - self.args.use_te = True - self.args.use_fp8 = True - self.args.save_model = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - @unittest.skipIf(not gpu_has_fp8, reason) - def test_te_fp8_calibration(self): - """Test Transformer Engine with FP8 calibration""" - self.args.use_te = True - self.args.use_fp8 = False - self.args.use_fp8_infer = True - actual = train_and_evaluate(self.args) - if os.path.exists("mnist_cnn.pdparams"): - os.remove("mnist_cnn.pdparams") - self.verify(actual) - - -if __name__ == "__main__": - train_and_evaluate(mnist_parser(None)) diff --git a/examples/pytorch/comm_gemm_overlap/README.md b/examples/pytorch/comm_gemm_overlap/README.md index bb3ba209ed..fc8458844b 100644 --- a/examples/pytorch/comm_gemm_overlap/README.md +++ b/examples/pytorch/comm_gemm_overlap/README.md @@ -16,7 +16,7 @@ Forward and backward passes with layer weights distributed over all GPUs in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_with_overlap.py +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3, 4, 5, 6, 7] @@ -70,7 +70,7 @@ Uses `torch.nn.parallel.DistributedDataParallel` for replicatin the model across groups in a single node. ```bash -$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) ln_mlp_overlap.py --num-replicas 2 +$ torchrun --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) te_layer_with_overlap.py --num-replicas 2 # Sample output on 8x H100s: # [rank0:node0] |-- Created tensor-parallel group: [0, 1, 2, 3] diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index ab6b656be9..e510df1761 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None): help="Disable the comm+GEMM overlap.", ) parser.add_argument( - "--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." + "--num-replicas", + type=int, + default=1, + help="Number of data-parallel model replicas per node.", + ) + parser.add_argument( + "--use-global-replica-count", + action="store_true", + default=False, + help="Treat '--num-replicas' as the total number of replicas.", ) parser.add_argument( "--tcp-init", @@ -173,13 +182,12 @@ def _train(opts): opts.tcp_init = True opts.bind_to_device = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: # TORCHELASTIC, SLURM, etc... WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + NUM_NODES = WORLD_SIZE // LOCAL_SIZE # Initialize torch.distributed global process group and get DP/TP groups @@ -214,90 +222,24 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") - # Figure out process groups for tensor- and data-parallelism (if any) - if NUM_NODES > 1: - # Create a list of world ranks on this node - hostname = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", - os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")), - ) - - if ifname is not None: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - hostname = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] - ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err - - hostnames = [None for _ in range(WORLD_SIZE)] - dist.all_gather_object(hostnames, hostname) - unique_hosts = [] - for host in hostnames: - if host not in unique_hosts: - unique_hosts.append(host) - assert len(unique_hosts) == NUM_NODES - - ranks_per_node_list = [[] for _ in range(NUM_NODES)] - self_node_idx = -1 - for i, host in enumerate(hostnames): - node_idx = unique_hosts.index(host) - ranks_per_node_list[node_idx].append(i) - if host == hostname: - self_node_idx = node_idx - assert self_node_idx >= 0 - self_node_ranks = ranks_per_node_list[self_node_idx] - - if opts.num_replicas > 1: - # Split node ranks into multiple replicas - assert len(self_node_ranks) % opts.num_replicas == 0 - tp_size = len(self_node_ranks) // opts.num_replicas - ranks_per_replica_list = [] - for node_ranks in ranks_per_node_list: - for i in range(opts.num_replicas): - start = i * tp_size - end = start + tp_size - ranks_per_replica_list.append(node_ranks[start:end]) - - self_replica_idx = -1 - for i, replica_ranks in enumerate(ranks_per_replica_list): - if WORLD_RANK in replica_ranks: - self_replica_idx = i - break - assert self_replica_idx >= 0 + total_replicas = ( + opts.num_replicas if opts.use_global_replica_count else opts.num_replicas * NUM_NODES + ) + tp_size = WORLD_SIZE // total_replicas - else: - # The entire node is the tensor-parallel group - ranks_per_replica_list = ranks_per_node_list - self_replica_idx = self_node_idx + if total_replicas > 1: + ranks_per_replica_list = [ + [i * tp_size + t for t in range(tp_size)] for i in range(total_replicas) + ] tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) dp_group, _ = dist.new_subgroups_by_enumeration( ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" ) - else: - if opts.num_replicas > 1: - # Mixed data- and tensor-parallelism on a single node - # NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions - all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu") - ranks_per_replica_tensor = all_ranks.reshape( - (opts.num_replicas, LOCAL_SIZE // opts.num_replicas) - ) - tp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.tolist(), backend="nccl" - ) - dp_group, _ = dist.new_subgroups_by_enumeration( - ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" - ) - else: - dp_group = None - tp_group = nccl_world + dp_group = None + tp_group = nccl_world tp_rank = dist.get_rank(tp_group) tp_size = dist.get_world_size(tp_group) diff --git a/examples/pytorch/fsdp/README.md b/examples/pytorch/fsdp/README.md index 5ea1225fa1..d62f68bbda 100644 --- a/examples/pytorch/fsdp/README.md +++ b/examples/pytorch/fsdp/README.md @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index cf0a75c336..622228536c 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py index 2a003f0a0d..ff9e2f0785 100644 --- a/examples/pytorch/mnist/main.py +++ b/examples/pytorch/mnist/main.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/pylintrc b/pylintrc index b80679d72c..50f85fad9d 100644 --- a/pylintrc +++ b/pylintrc @@ -2,11 +2,8 @@ extension-pkg-whitelist=flash_attn_2_cuda, torch, transformer_engine_torch, - transformer_engine_paddle, transformer_engine_jax -extension-pkg-allow-list=transformer_engine.transformer_engine_jax - disable=too-many-locals, too-few-public-methods, too-many-public-methods, diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh old mode 100644 new mode 100755 index d68a3a0f41..57034900c6 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,7 +6,7 @@ set -e # Find TE : ${TE_PATH:=/opt/transformerengine} -TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2` +TE_LIB_PATH=`pip3 show transformer-engine | grep Location | cut -d ' ' -f 2` export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH cd $TE_PATH/tests/cpp diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh new file mode 100644 index 0000000000..3253861484 --- /dev/null +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -0,0 +1,34 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" + +: ${TE_PATH:=/opt/transformerengine} + +pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements" + +# Make encoder tests to have run-to-run deterministic to have the stable CI results +export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" +. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_jax_lint/test.sh b/qa/L0_jax_lint/test.sh index 7bc84eef51..dbc1ed0a1d 100644 --- a/qa/L0_jax_lint/test.sh +++ b/qa/L0_jax_lint/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,19 +6,19 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install cpplint==1.6.0 pylint==3.3.1 +pip3 install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH echo "Checking common API headers" - cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include + python3 -m cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include echo "Checking C++ files" - cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common - cpplint --recursive transformer_engine/jax + python3 -m cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common + python3 -m cpplint --recursive transformer_engine/jax fi if [ -z "${CPP_ONLY}" ] then cd $TE_PATH echo "Checking Python files" - pylint --recursive=y transformer_engine/common transformer_engine/jax + python3 -m pylint --recursive=y transformer_engine/common transformer_engine/jax fi diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index db3aa31951..1f7bb0ebc4 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -1,24 +1,42 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -set -xe +function error_exit() { + echo "Error: $1" + exit 1 +} -pip install "nltk>=3.8.2" -pip install pytest==8.2.1 +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" + +pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${TE_PATH:=/opt/transformerengine} -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py || test_fail "test_praxis_layers.py" # Test without custom calls -NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py +NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py" -pip install -r $TE_PATH/examples/jax/mnist/requirements.txt -pip install -r $TE_PATH/examples/jax/encoder/requirements.txt +pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" +pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py" # Make encoder tests to have run-to-run deterministic to have the stable CI results export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index 2c3b832933..e1400b10bd 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -1,35 +1,54 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : "${TE_PATH:=/opt/transformerengine}" -pip install wheel +pip3 install wheel || error_exit "Failed to install wheel" cd $TE_PATH -pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax +pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-jax" VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -wheel unpack dist/* +NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/* || error_exit "Failed to unpack dist/*" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -wheel pack ${WHL_BASE} -rm dist/*.whl -mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" +rm dist/*.whl || error_exit "Failed to remove dist/*.whl" +mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" cd transformer_engine/jax -NVTE_RELEASE_BUILD=1 python setup.py sdist +NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" -pip install dist/* +pip3 install dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip install dist/*.whl --no-deps +pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" + +python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py" -python $TE_PATH/tests/jax/test_sanity_import.py +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_license/copyright_checker.py b/qa/L0_license/copyright_checker.py index 46a3a6d4fe..a0e137d1ef 100644 --- a/qa/L0_license/copyright_checker.py +++ b/qa/L0_license/copyright_checker.py @@ -1,7 +1,7 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 # coding: utf-8 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,7 +12,7 @@ import datetime if len(sys.argv) < 2: - print("Usage: python copyright_checker.py ") + print("Usage: python3 copyright_checker.py ") path = sys.argv[1] diff --git a/qa/L0_license/test.sh b/qa/L0_license/test.sh index 8b9c86b39b..44b9469e55 100644 --- a/qa/L0_license/test.sh +++ b/qa/L0_license/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,4 +6,4 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -python $TE_PATH/qa/L0_license/copyright_checker.py $TE_PATH +python3 $TE_PATH/qa/L0_license/copyright_checker.py $TE_PATH diff --git a/qa/L0_paddle_lint/test.sh b/qa/L0_paddle_lint/test.sh deleted file mode 100644 index 5c5379554f..0000000000 --- a/qa/L0_paddle_lint/test.sh +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: "${TE_PATH:=/opt/transformerengine}" - -pip install cpplint==1.6.0 pylint==3.3.1 -if [ -z "${PYTHON_ONLY}" ] -then - cd $TE_PATH - echo "Checking common API headers" - cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include - echo "Checking C++ files" - cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common - cpplint --recursive transformer_engine/paddle -fi -if [ -z "${CPP_ONLY}" ] -then - cd $TE_PATH - echo "Checking Python files" - pylint --recursive=y transformer_engine/common transformer_engine/paddle -fi diff --git a/qa/L0_paddle_unittest/test.sh b/qa/L0_paddle_unittest/test.sh deleted file mode 100644 index 1038923b5a..0000000000 --- a/qa/L0_paddle_unittest/test.sh +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -xe - -pip install pytest==8.2.1 -: ${TE_PATH:=/opt/transformerengine} -pytest -Wignore -v $TE_PATH/tests/paddle -pytest -Wignore -v $TE_PATH/examples/paddle/mnist diff --git a/qa/L0_paddle_wheel/test.sh b/qa/L0_paddle_wheel/test.sh deleted file mode 100644 index 00653877b8..0000000000 --- a/qa/L0_paddle_wheel/test.sh +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: "${TE_PATH:=/opt/transformerengine}" - -# Install dependencies -# Note: Need to install wheel locally since PaddlePaddle container -# already contains APT install. -pip install pydantic -pip install --user wheel==0.44.0 - -cd $TE_PATH -pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle - -VERSION=`cat $TE_PATH/build_tools/VERSION.txt` -WHL_BASE="transformer_engine-${VERSION}" - -# Core wheel. -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -python -m wheel unpack dist/* -sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -python -m wheel pack ${WHL_BASE} -rm dist/*.whl -mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel -pip install dist/*.whl --no-deps - -cd transformer_engine/paddle -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -pip install dist/* - -python $TE_PATH/tests/paddle/test_sanity_import.py diff --git a/qa/L0_pytorch_lint/test.sh b/qa/L0_pytorch_lint/test.sh index ac517976c7..81d7822d7f 100644 --- a/qa/L0_pytorch_lint/test.sh +++ b/qa/L0_pytorch_lint/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,19 +6,19 @@ set -e : "${TE_PATH:=/opt/transformerengine}" -pip install cpplint==1.6.0 pylint==3.3.1 +pip3 install cpplint==1.6.0 pylint==3.3.1 if [ -z "${PYTHON_ONLY}" ] then cd $TE_PATH echo "Checking common API headers" - cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include + python3 -m cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include echo "Checking C++ files" - cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common - cpplint --recursive transformer_engine/pytorch + python3 -m cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine/common + python3 -m cpplint --recursive transformer_engine/pytorch fi if [ -z "${CPP_ONLY}" ] then cd $TE_PATH echo "Checking Python files" - pylint --recursive=y transformer_engine/common transformer_engine/pytorch + python3 -m pylint --recursive=y transformer_engine/common transformer_engine/pytorch fi diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 17307574a9..732f0a16d1 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -1,24 +1,47 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" + +set -x : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py -pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py -pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py -pytest -v -s $TE_PATH/tests/pytorch/test_jit.py -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py -pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py -pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_torch_save_load.py -pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py -pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py -pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py -pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" + +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +NVTE_CUDNN_MXFP8_NORM=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py" +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index fd8457c44b..ffd5ca2909 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -1,35 +1,54 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : "${TE_PATH:=/opt/transformerengine}" -pip install wheel +pip3 install wheel || error_exit "Failed to install wheel" cd $TE_PATH -pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch +pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch || error_exit "Failed to uninstall transformer-engine transformer-engine-cu12 transformer-engine-torch" VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel -wheel unpack dist/* +NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/* || error_exit "Failed to unpack dist/*" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" -mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" -wheel pack ${WHL_BASE} -rm dist/*.whl -mv *.whl dist/ -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel +mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" +wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" +rm dist/*.whl || error_exit "Failed to remove dist/*.whl" +mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" cd transformer_engine/pytorch -NVTE_RELEASE_BUILD=1 python setup.py sdist +NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" -pip install dist/* +pip3 install dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip install dist/*.whl --no-deps +pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" + +python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py" -python $TE_PATH/tests/pytorch/test_sanity_import.py +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index eb09df1a84..96c5949a99 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,10 +6,4 @@ set -xe : ${TE_PATH:=/opt/transformerengine} -# Skip ring attention tests since they need fixed environment vars -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn' - -# Test ring attention with and without scan loop -NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn -NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \ - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9a11ccc008..5776734c3b 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -1,14 +1,35 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py -pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" + +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" +python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" +# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential +python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L1_pytorch_mcore_integration/.gitignore b/qa/L1_pytorch_mcore_integration/.gitignore new file mode 100644 index 0000000000..46426003ca --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/.gitignore @@ -0,0 +1,2 @@ +Megatron-LM +vocab.json \ No newline at end of file diff --git a/qa/L1_pytorch_mcore_integration/merges.txt b/qa/L1_pytorch_mcore_integration/merges.txt new file mode 100644 index 0000000000..5e7f1fd949 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/merges.txt @@ -0,0 +1 @@ +#version: 0.2 diff --git a/qa/L1_pytorch_mcore_integration/test.sh b/qa/L1_pytorch_mcore_integration/test.sh new file mode 100644 index 0000000000..2200d11455 --- /dev/null +++ b/qa/L1_pytorch_mcore_integration/test.sh @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Paths +: ${TE_PATH:=/opt/transformerengine} +: ${MCORE_PATH:=${TE_PATH}/qa/L1_pytorch_mcore_integration/Megatron-LM} + +# Check whether FP8 is supported +DEVICE_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader | head -n 1 | sed 's/[^0-9]//g') +if [[ ${DEVICE_ARCH} -ge 89 ]]; then + WITH_FP8=1 +fi + +# Download Megatron-LM if needed +if [ ! -d "${MCORE_PATH}" ]; then + pushd $(dirname ${MCORE_PATH}) + git clone -b core_r0.9.0 https://github.com/NVIDIA/Megatron-LM.git Megatron-LM + popd +fi + +# Create mock vocab +VOCAB_FILE=${TE_PATH}/qa/L1_pytorch_mcore_integration/vocab.json +printf "" > ${VOCAB_FILE} +printf "{" >> ${VOCAB_FILE} +printf "\"<|endoftext|>\": 0" >> ${VOCAB_FILE} +seq 1 4095 | awk '{ printf(", \"%d\": %d", $1, $1) }' >> ${VOCAB_FILE} +printf "}" >> ${VOCAB_FILE} + +# Megatron-LM invocation +COMMAND=" +NVTE_TORCH_COMPILE=0 +NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 +NVTE_FLASH_ATTN=1 +NVTE_FWD_LAYERNORM_SM_MARGIN=0 +NVTE_BWD_LAYERNORM_SM_MARGIN=0 +CUDA_DEVICE_MAX_CONNECTIONS=1 +NVTE_BIAS_GELU_NVFUSION=0 +NVTE_BIAS_DROPOUT_FUSION=0 + +python3 +-m torch.distributed.launch +--use_env +--nnodes=1 +--nproc_per_node=1 + +${MCORE_PATH}/pretrain_gpt.py +--tensor-model-parallel-size 1 +--pipeline-model-parallel-size 1 +--use-cpu-initialization +--num-layers 2 +--hidden-size 128 +--num-attention-heads 8 +--seq-length 128 +--max-position-embeddings 128 +--micro-batch-size 1 +--global-batch-size 8 +--train-iters 10 +--eval-iters 10 +--lr 1e-4 +--mock-data +--vocab-file ${VOCAB_FILE} +--merge-file ${TE_PATH}/qa/L1_pytorch_mcore_integration/merges.txt +--transformer-impl transformer_engine +${WITH_FP8:+--fp8-format hybrid} +" +COMMAND=$(echo "${COMMAND}" | tr '\n' ' ') + +# Launch Megatron-LM +bash -c "${COMMAND}" diff --git a/qa/L1_pytorch_onnx_test/test.sh b/qa/L1_pytorch_onnx_test/test.sh deleted file mode 100644 index 5a01468064..0000000000 --- a/qa/L1_pytorch_onnx_test/test.sh +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: ${TE_PATH:=/opt/transformerengine} - -pip install pytest==8.2.1 onnxruntime==1.19.2 - -# Build custom ONNX Runtime operators -export CUSTOM_ORT_OPS_PATH=$TE_PATH/tests/pytorch/custom_ort_ops -bash $CUSTOM_ORT_OPS_PATH/build.sh - -# Run tests -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L1_pytorch_thunder_integration/test.sh b/qa/L1_pytorch_thunder_integration/test.sh new file mode 100644 index 0000000000..1737ca9ba1 --- /dev/null +++ b/qa/L1_pytorch_thunder_integration/test.sh @@ -0,0 +1,19 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -x + +: ${THUNDER_PATH:=/opt/pytorch/lightning-thunder} + +pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 +python3 -m pytest -v -s ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py + +# Check return code +# Note: Return code 5 is fine. Lightning tests are skipped on systems +# without FP8 support and Pytest returns 5 if no tests are run. +RC=$? +if [ ${RC} -eq 5 ]; then + RC=0 +fi +exit ${RC} diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 6c23e39a48..3e83ef7f52 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,28 +6,37 @@ set -e : ${TE_PATH:=/opt/transformerengine} -pip install pytest==8.2.1 +pip3 install pytest==8.2.1 # Limit parallel build jobs to avoid overwhelming system resources export MAX_JOBS=4 # Iterate over Flash Attention versions -FA_versions=(2.1.1 2.3.0 2.4.1 2.5.7 2.6.3 3.0.0b1) +sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` +if [ $sm_arch -gt 90 ] +then + FA_versions=(2.7.3) +else + FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1) +fi + for fa_version in "${FA_versions[@]}" do # Build Flash Attention if [ "${fa_version}" \< "3.0.0" ] then - pip install flash-attn==${fa_version} + pip3 install flash-attn==${fa_version} else - pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install python_path=`python -c "import site; print(site.getsitepackages()[0])"` - mkdir -p $python_path/flashattn_hopper - wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py + mkdir -p $python_path/flash_attn_3 + wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py + cd ../../ fi # Run tests - NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py done diff --git a/qa/L3_pytorch_convergence_test/test.sh b/qa/L3_pytorch_convergence_test/test.sh deleted file mode 100644 index fca621f279..0000000000 --- a/qa/L3_pytorch_convergence_test/test.sh +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -e - -: ${TE_PATH:=/opt/transformerengine} - -pip install prettytable -git clone https://github.com/NVIDIA/Megatron-LM.git -cd Megatron-LM -git checkout b3375a0e38c10e2300ef4be031f7dcabab52b448 -pytest -v -s $TE_PATH/tests/pytorch/distributed/test_convergence.py -python $TE_PATH/tests/pytorch/distributed/print_logs.py diff --git a/qa/format.sh b/qa/format.sh index d38b832263..86fd8f1981 100644 --- a/qa/format.sh +++ b/qa/format.sh @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,5 +11,5 @@ set -e cd $TE_PATH -pip install pre-commit -pre-commit run --all-files +pip3 install pre-commit +python3 -m pre_commit run --all-files diff --git a/setup.py b/setup.py index 3bb2fe6b95..13e8b6ee83 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,11 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Installation script.""" import os +import sys import time from pathlib import Path from typing import List, Tuple @@ -35,14 +36,13 @@ if "pytorch" in frameworks: from torch.utils.cpp_extension import BuildExtension -elif "paddle" in frameworks: - from paddle.utils.cpp_extension import BuildExtension elif "jax" in frameworks: install_and_import("pybind11[global]") from pybind11.setup_helpers import build_ext as BuildExtension CMakeBuildExtension = get_build_ext(BuildExtension) +archs = cuda_archs() class TimedBdist(bdist_wheel): @@ -57,13 +57,16 @@ def run(self): def setup_common_extension() -> CMakeExtension: """Setup CMake extension for common library""" - cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())] + cmake_flags = ["-DCMAKE_CUDA_ARCHITECTURES={}".format(archs)] if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))): assert ( os.getenv("MPI_HOME") is not None ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") + if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): + cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + # Project directory root root_path = Path(__file__).resolve().parent @@ -100,14 +103,14 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - install_reqs.extend(["torch"]) - test_reqs.extend(["numpy", "onnxruntime", "torchvision", "prettytable"]) + install_reqs.extend(["torch>=2.1"]) + # Blackwell is not supported as of Triton 3.2.0, need custom internal build + # install_reqs.append("triton") + test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) if "jax" in frameworks: install_reqs.extend(["jax", "flax>=0.7.1"]) - test_reqs.extend(["numpy", "praxis"]) - if "paddle" in frameworks: - install_reqs.append("paddlepaddle-gpu") - test_reqs.append("numpy") + # test_reqs.extend(["numpy", "praxis"]) + test_reqs.extend(["numpy"]) return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] @@ -132,7 +135,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: extras_require = { "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], - "paddle": [f"transformer_engine_paddle=={__version__}"], } else: setup_requires, install_requires, test_requires = setup_requirements() @@ -166,16 +168,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: current_file_path / "transformer_engine", ) ) - if "paddle" in frameworks: - from build_tools.paddle import setup_paddle_extension - - ext_modules.append( - setup_paddle_extension( - "transformer_engine/paddle/csrc", - current_file_path / "transformer_engine" / "paddle" / "csrc", - current_file_path / "transformer_engine", - ) - ) # Configure package setuptools.setup( diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 3bef457c43..afc80cba43 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -1,11 +1,15 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. cmake_minimum_required(VERSION 3.18) if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 70 80 90) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + endif() endif() @@ -22,7 +26,7 @@ enable_testing() include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) if(NOT DEFINED TE_LIB_PATH) - execute_process(COMMAND bash -c "pip show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" + execute_process(COMMAND bash -c "pip3 show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'" OUTPUT_VARIABLE TE_LIB_PATH) endif() diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 45806e7022..6785dbf6f4 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -1,26 +1,37 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. add_executable(test_operator + test_cast.cu + test_cast_current_scaling.cu + test_cast_dbias.cu + test_cast_dbias_dgelu.cu + test_cast_gated_swiglu.cu + test_cast_mxfp8_gated_swiglu.cu test_qdq.cu - test_cast_transpose.cu + test_cast_mxfp8.cu + test_dequantize_mxfp8.cu test_transpose.cu + test_cast_transpose.cu + test_cast_transpose_current_scaling.cu test_cast_transpose_dbias.cu test_cast_transpose_dbias_dgelu.cu test_cast_transpose_dgeglu.cu test_act.cu - test_layernorm.cu - test_rmsnorm.cu + test_normalization.cu + test_normalization_mxfp8.cu test_multi_cast_transpose.cu test_multi_padding.cu test_causal_softmax.cu + test_swizzle.cu ../test_common.cu) +find_package(OpenMP REQUIRED) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) -target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) -target_compile_options(test_operator PRIVATE -O2) +target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS} OpenMP::OpenMP_CXX) +target_compile_options(test_operator PRIVATE -O2 -fopenmp) include(GoogleTest) -gtest_discover_tests(test_operator) +gtest_discover_tests(test_operator DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index 7d03e41271..4224f199f4 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -21,58 +21,6 @@ using namespace transformer_engine; -namespace { - -// forward - -float gelu(const float x) { - return 0.5f * x * (1.0f + tanhf(0.79788456F * x * (1.0f + 0.044715f * x * x))); -} - -float silu(const float x) { - return x / (1 + expf(-x)); -} - -float relu(const float x) { - return x > 0 ? x : 0; -} - -float srelu(const float x) { - return x > 0 ? x * x : 0; -} - -float qgelu(const float x) { - return x / (1 + expf(-1.702f * x)); -} - -// backward - -float dgelu(const float x) { - const float tanh_out = tanhf(0.79788456f * x * (1.f + 0.044715f * x * x)); - return 0.5f * x * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + - 0.5f * (1.f + tanh_out); -} - -float dsilu(const float x) { - const float sigmoid = 1.f / (1 + expf(-x)); - return x * sigmoid * (1.f - sigmoid) + sigmoid; -} - -float drelu(const float x) { - return x > 0.f ? 1.f : 0.f; -} - -float dsrelu(const float x) { - return fmaxf(2.f * x, 0.f); -} - -float dqgelu(const float x) { - const float sigmoid = 1.f / (1 + expf(-1.702f * x)); - return 1.702f * x * sigmoid * (1.f - sigmoid) + sigmoid; -} - -} // namespace - template void compute_ref_act_cast(const IT *input_h, OT *output_h, @@ -82,6 +30,7 @@ void compute_ref_act_cast(const IT *input_h, const size_t H) { CT amax = 0.; + #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); @@ -101,6 +50,7 @@ void compute_ref_dact_cast(const IT *input_h, const size_t N, const size_t H) { using CT = float; + #pragma omp parallel for schedule(static) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT elt = static_cast(input_h[i * H + j]); @@ -118,6 +68,7 @@ void compute_ref_glu_act_cast(const IT *input_h, OT *output_h, const CT scale, C const int col = H * 2; + #pragma omp parallel for schedule(static) reduction(max: amax) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT gelu_elt = static_cast(input_h[i * col + j]); @@ -139,6 +90,7 @@ void compute_ref_dglu_act_cast(const IT *input_h, const IT *grad_h, OT *output_h const int col = H * 2; using CT = float; + #pragma omp parallel for schedule(static) proc_bind(spread) for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT grad = static_cast(grad_h[i * H + j]); @@ -164,10 +116,10 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output({ N, H }, otype); - Tensor igrad({ N, H }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype); + Tensor igrad("igrad", { N, H }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); @@ -179,7 +131,7 @@ void performTest(const size_t N, const size_t H) { nvte_act(input.data(), output.data(), 0); float ref_amax; - compute_ref_act_cast(input.cpu_dptr(), ref_output.get(), + compute_ref_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -195,7 +147,7 @@ void performTest(const size_t N, const size_t H) { nvte_dact(ograd.data(), input.data(), igrad.data(), 0); - compute_ref_dact_cast(input.cpu_dptr(), ograd.cpu_dptr(), + compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); @@ -219,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H * 2}, itype); - Tensor output({N, H}, otype); - Tensor igrad({ N, H * 2 }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H}, otype); + Tensor igrad("igrad", { N, H * 2 }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); @@ -234,7 +186,7 @@ void performTestGLU(const size_t N, const size_t H) { nvte_act(input.data(), output.data(), 0); float ref_amax; - compute_ref_glu_act_cast(input.cpu_dptr(), ref_output.get(), + compute_ref_glu_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), output.scale(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -242,15 +194,19 @@ void performTestGLU(const size_t N, const size_t H) { ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + auto [atol, rtol] = getTolerances(DType::kFloat32); + compareResults("amax", output.amax(), ref_amax, atol, rtol); + if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + const float ref_scale = 1.f / output.scale(); + compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale, atol, rtol); + } } auto [atol, rtol] = getTolerances(otype); compareResults("output_gelu", output, ref_output.get(), atol, rtol); nvte_dact(ograd.data(), input.data(), igrad.data(), 0); - compute_ref_dglu_act_cast(input.cpu_dptr(), ograd.cpu_dptr(), + compute_ref_dglu_act_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), ref_igrad.get(), N, H); cudaDeviceSynchronize(); diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu new file mode 100644 index 0000000000..81c975b0a8 --- /dev/null +++ b/tests/cpp/operator/test_cast.cu @@ -0,0 +1,134 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, + const size_t size, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i] = OutputType(scale * current); + } + *amax = current_max; +} + + +// delayed tensor scaling test +template +void performTest(const std::vector& shape) { + using namespace test; + + const size_t full_size = product(shape); + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input("input", shape, itype); + Tensor output_c("output_c", shape, otype); + + std::unique_ptr ref_output_c = std::make_unique(full_size); + + fillUniform(&input); + setRandomScale(&output_c); + + nvte_quantize(input.data(), output_c.data(), 0); + + float ref_amax; + + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + full_size, &ref_amax, output_c.scale()); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); +} + +std::vector> test_cases = { + {16}, + {16000}, + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; +} // namespace + +class CastTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastTestSuite, TestCast) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // delayed tensor scaling + performTest(size); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu new file mode 100644 index 0000000000..18325d6daf --- /dev/null +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -0,0 +1,214 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, + const size_t size, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i] = OutputType(scale * current); + } +} + + +template +void compute_amax_scale_ref(const InputType *data, + const size_t size, + float *amax_ptr, float *scale_ptr, float* scale_inv_ptr, + float max_fp8, float epsilon) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < size; ++i) { + compute_t current = static_cast(data[i]); + current_max = fmaxf(current_max, fabsf(current)); + } + *amax_ptr = current_max; + + // compute scale from amax + float clamp_amax = current_max; + if (current_max <= epsilon){ + clamp_amax = epsilon; + } + + float scale = 1.f; + float scale_inv = 1.f; + + if (isinf(clamp_amax) || clamp_amax == 0.f) { + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; + return; + } + + // use ieee_div in CPU + scale = max_fp8 / clamp_amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + scale_inv = 1.0f / scale; + + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; +} + +// current tensor scaling test +template +void performTest(const std::vector& shape) { + using namespace test; + + const size_t full_size = product(shape); + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + bool is_out_fp8 = isFp8Type(otype); + + // find out max fp8 value + float max_fp8; + if (is_out_fp8){ + switch (otype) { + case DType::kFloat8E5M2: { + max_fp8 = Quantized_Limits::max(); + } break; + case DType::kFloat8E4M3: { + max_fp8 = Quantized_Limits::max(); + } break; + default: + NVTE_ERROR("Invalid type."); + } + } + + Tensor input("input", shape, itype); + Tensor output_c("output_c", shape, otype, true, false); + + std::unique_ptr ref_output_c = std::make_unique(full_size); + + fillUniform(&input); + + // compute amax + float amax_to_check = 0.0f; + if (is_out_fp8){ + nvte_compute_amax(input.data(), output_c.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(output_c.data(), config, 0); + // avoid atomic amax update in cuda cast kernels because of current per-tensor scaling + amax_to_check = output_c.amax(); + output_c.set_tensor_amax_nullptr(); + } + nvte_quantize(input.data(), output_c.data(), 0); + + float ref_amax; + float ref_scale; + float ref_scale_inv; + if (is_out_fp8){ + compute_amax_scale_ref(input.rowwise_cpu_dptr(), + full_size, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f); + } + + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + full_size, nullptr, is_out_fp8 ? output_c.scale() : 1.0f ); + + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32); + compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32); + compareResults("scale", output_c.scale(), ref_scale, 0.0f, rtol_fp32); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, 0.0f, rtol); +} + +std::vector> test_cases = { + {16}, + {16000}, + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; +} // namespace + +class CastCSTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastCSTestSuite, TestCastCS) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // current tensor scaling + performTest(size); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastCSTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu new file mode 100644 index 0000000000..1f0a9305d8 --- /dev/null +++ b/tests/cpp/operator/test_cast_dbias.cu @@ -0,0 +1,181 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref_cast_dbias(const IT *input_h, + const CT scale, + OT *output_c_h, + CT *amax_h, + IT *dbias_h, + const size_t N, + const size_t H) { + CT amax = 0.; + + std::vector acc_dbias(H, 0.); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT elt = static_cast(input_h[i * H + j]); + + // update amax + amax = std::abs(elt) > amax ? std::abs(elt) : amax; + + output_c_h[i * H + j] = static_cast(scale * elt); + + // dbias + acc_dbias[j] += elt; + } + } + + *amax_h = amax; + + for (size_t i = 0; i < H; i++) { + dbias_h[i] = static_cast(acc_dbias[i]); + } +} + +template +void performTest(const std::vector& shape) { + using namespace test; + using CType = fp32; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); + + Tensor input("input", shape, itype); + + Tensor output_c("output_c", shape, otype); + // dbias has the same data type with "output grad" + Tensor dbias("dbias", {H}, itype); + + fillUniform(&input); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(N*H); + std::unique_ptr ref_output_dbias = std::make_unique(H); + + CType ref_amax; + compute_ref_cast_dbias(input.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + ref_output_dbias.get(), + N, H); + + Tensor workspace; + + nvte_quantize_dbias(input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); + + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + rtol_dbias *= 4; + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; + +} // namespace; + + +class CastDBiasTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastDBiasTestSuite, TestCastDBias) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastDBiasTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu new file mode 100644 index 0000000000..20ea5c31f1 --- /dev/null +++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu @@ -0,0 +1,196 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void compute_ref_cast_dbias_dgelu(const IT *input, + const IT *grad, + const CT scale, + OT *output_c, + CT *amax_h, + IT *dbias, + const size_t N, + const size_t H) { + CT amax = 0.; + + std::vector acc_dbias(H, 0.); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < H; j++) { + CT in_elt = static_cast(input[i * H + j]); + const CT in_grad = static_cast(grad[i * H + j]); + + const CT elt = in_grad * static_cast(dgelu(static_cast(in_elt))); + const CT elt_abs = std::abs(elt); + + // update amax + if (elt_abs > amax) { + amax = elt_abs; + } + + output_c[i * H + j] = static_cast(scale * elt); + + // dbias + acc_dbias[j] += elt; + } + } + + *amax_h = amax; + + for (size_t i = 0; i < H; i++) { + dbias[i] = static_cast(acc_dbias[i]); + } +} + +template +void performTest(const std::vector& shape) { + using namespace test; + using CType = fp32; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t N = first_dimension(shape); + const size_t H = last_dimension(shape); + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + + Tensor output_c("output_c", shape, otype); + // dbias has the same data type with "output grad" + Tensor dbias("dbias", {H}, itype); + + fillUniform(&input); + fillUniform(&grad); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(N*H); + std::unique_ptr ref_output_dbias = std::make_unique(H); + + CType ref_amax; + compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + ref_output_dbias.get(), + N, H); + + Tensor workspace; + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + dbias.data(), + workspace.data(), + 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); + + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + rtol_dbias *= 4; + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {65536, 160}, + {16384, 1616}, + {1, 128}, + {1, 1296}, + {1, 16}, + {5, 160}, + {5, 4, 3, 160}, + {217, 256}, +}; + +} // namespace; + + +class CastDBiasDGeluTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastDBiasDGeluTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu new file mode 100644 index 0000000000..35ae462106 --- /dev/null +++ b/tests/cpp/operator/test_cast_gated_swiglu.cu @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void compute_ref_cast_dgated_swiglu(const IType * const grad, + const IType * const input, + const float scale, + OType * const output, + float * const amax_ptr, + const size_t rows, + const size_t cols) { + float amax = 0; + const size_t stride = cols * 2; + + #pragma omp parallel for reduction(max: amax) proc_bind(spread) + for (size_t i = 0; i < rows; i++) { + for (size_t j = 0; j < cols; j++) { + float grad_elt = static_cast(grad[i * cols + j]); + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + float after_dgate = grad_elt * silu(silu_elt); + + if (abs(after_dsilu) > amax) { amax = abs(after_dsilu); } + if (abs(after_dgate) > amax) { amax = abs(after_dgate); } + + output[i * stride + j] = static_cast(scale * after_dsilu); + output[i * stride + cols + j] = static_cast(scale * after_dgate); + } + } + + *amax_ptr = amax; +} + +template +void performTest(const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + std::vector input_shape = shape; + input_shape[input_shape.size() - 1] *= 2; + + const size_t input_size = product(input_shape); + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + Tensor grad("grad", shape, itype); + Tensor input("input", input_shape, itype); + Tensor output_c("output_c", input_shape, otype); + + fillUniform(&grad); + fillUniform(&input); + setRandomScale(&output_c); + + std::unique_ptr ref_output_c = std::make_unique(input_size); + + nvte_dswiglu(grad.data(), input.data(), output_c.data(), 0); + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax; + compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + output_c.scale(), + ref_output_c.get(), + &ref_amax, + rows, + cols); + + if (isFp8Type(otype)) { + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output_c.scale(); + compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), true, atol, rtol); +} + +std::vector> test_cases = { + {128, 128}, + {256, 256}, + {768, 1024}, + {256, 65536}, + {2048, 12288}, + {65536, 128}, + {217, 256}, + {1296}, + {5, 4, 3, 160}, +}; + +} // namespace + +class CastSwiGLUTestSuite + : public ::testing::TestWithParam>> {}; + +TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) { + using namespace transformer_engine; + using namespace test; + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + if (size.back() % 32 != 0) { + GTEST_SKIP(); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL( + output_type, OutputType, performTest(size););); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, CastSwiGLUTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo &info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + return name; + }); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu new file mode 100644 index 0000000000..cb38a5a74a --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -0,0 +1,636 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ProcessingMethod { + CAST_ONLY, + CAST_DBIAS, + CAST_DBIAS_DACT, + CAST_DACT, + CAST_ACT +}; + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +template +void scale_block(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_c, + float* dbias, + fp8e8m0* output_scales, + const size_t scale_idx, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) { + float amax = 0.0f; + + // Find the absolute maximum value in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + dbias[j] += elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + amax = std::max(amax, std::abs(elt)); + } + } + + const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + output_scales[scale_idx] = biased_exponent; + + // Quantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + output_c[idx] = static_cast(elt * scale_reciprocal); + } + } +} + +template +void compute_ref_x1(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_c, + fp8e8m0* output_scales, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) +{ + std::vector output_dbias_fp32(cols, 0); + + const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; + const size_t blocks_X = (cols + block_size_X - 1) / block_size_X; + + for (size_t ii = 0; ii < blocks_Y; ++ii) { + const size_t i_min = ii * block_size_Y; + const size_t i_max = std::min((ii + 1) * block_size_Y, rows); + for (size_t jj = 0; jj < blocks_X; ++jj) { + const size_t j_min = jj * block_size_X; + const size_t j_max = std::min((jj + 1) * block_size_X, cols); + const size_t scale_idx = ii * scales_stride + jj; + scale_block( + processing_method, input, grad, output_c, output_dbias_fp32.data(), + output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); + } + } + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); + } +} + +template +void compute_ref_x2(const ProcessingMethod processing_method, + const InputType* input, + const InputType* grad, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { + compute_ref_x1( + processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, + rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1( + processing_method, input, grad, output_colwise, scales_colwise, output_dbias, + rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ + +template +void performTest_x1(const ProcessingMethod processing_method, + const std::vector& shape, + const bool rowwise, + const bool colwise, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + if (shape.size() < 2 && colwise) { + GTEST_SKIP(); + } + + const size_t block_size_rows = rowwise ? 1 : 32; + const size_t block_size_cols = colwise ? 1 : 32; + + const std::array scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows, + block_size_cols); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); + + std::unique_ptr ref_output_c = std::make_unique(rows * cols); + std::unique_ptr ref_output_dbias = std::make_unique(cols); + std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias(grad.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(grad.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DACT: { + nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + nvte_gelu(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x1(processing_method, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c.get(), + ref_output_scales.get(), + ref_output_dbias.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + + const uint8_t * const gpu_scales_ptr = rowwise + ? output_c.rowwise_cpu_scale_inv_ptr() + : output_c.columnwise_cpu_scale_inv_ptr(); + + compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + + if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + } +} + +/** + * Scaling along both dimensions (rows and columns) + * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * AND + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x2(const ProcessingMethod processing_method, + const std::vector& shape, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + if (shape.size() < 2) { + GTEST_SKIP(); + } + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); + + std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); + std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + std::unique_ptr ref_output_dbias = std::make_unique(cols); + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_DBIAS: { + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias(grad.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DBIAS_DACT: { + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); + break; + } + case ProcessingMethod::CAST_DACT: { + nvte_dgelu(grad.data(), input.data(), output.data(), 0); + break; + } + case ProcessingMethod::CAST_ACT: { + nvte_gelu(input.data(), output.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x2(processing_method, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); + compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise); + compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise); + + if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + auto [atol_dbias, rtol_dbias] = getTolerances(itype); + if (itype == DType::kFloat32) { + atol_dbias = 1e-4; + rtol_dbias *= sqrt(static_cast(rows)) ; + } else { + rtol_dbias *= 4; + } + compareResults("output_dbias", output_dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); + } +} + +std::vector> matrix_sizes = { + {1, 16}, + {16, 48}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {256, 65536}, + {2048, 6144}, + {16384, 128}, + {32768, 160}, + {4096, 1632}, + {1024}, + {8, 32, 1024}, + {16, 8, 4, 512}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + {32, 32}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, + // InputsFillCase::zeros, + // InputsFillCase::zero_to_minNorm, + // InputsFillCase::minNorm_to_maxNorm, + // InputsFillCase::maxNorm_to_inf +}; + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + ProcessingMethod::CAST_DBIAS, + ProcessingMethod::CAST_DBIAS_DACT, + ProcessingMethod::CAST_DACT, + ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + +} // namespace + +class FusedCastMXFP8TestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + InputsFillCase>> {}; + +#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ +switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ + case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \ + case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \ + case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \ +} + +#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ +switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ + case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \ + case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \ + case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \ + case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \ + case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \ +} + +TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const auto block_size = std::get<3>(GetParam()); + const DType input_type = std::get<4>(GetParam()); + const DType output_type = std::get<5>(GetParam()); + const InputsFillCase fill_case = std::get<6>(GetParam()); + + // Skips non Act tests if the Activation type is not an identity + if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + || processing_method == ProcessingMethod::CAST_DACT + || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + GTEST_SKIP(); + } + + const bool rowwise = block_size.second != 1; + const bool colwise = block_size.first != 1; + if (processing_method == ProcessingMethod::CAST_ACT) { + // Forward activations + ACT_FUNC_SWITCH(Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, matrix_size, + block_size.first, block_size.second, fill_case); + } + ); + ); + ); + } else { + DACT_FUNC_SWITCH(Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, matrix_size, + block_size.first, block_size.second, fill_case); + } + ); + ); + ); + } +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: return "CAST_ONLY"; + case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: return ""; + } +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "Identity"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(matrix_sizes), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)) + "X" + + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + std::to_string(std::get<3>(info.param).first) + + "X" + std::to_string(std::get<3>(info.param).second) + + "X" + test::typeName(std::get<4>(info.param)) + + "X" + test::typeName(std::get<5>(info.param)) + + "X" + test::caseName(std::get<6>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu new file mode 100644 index 0000000000..6acbdefeab --- /dev/null +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -0,0 +1,470 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void scale_block(const IType* grad, + const IType* input, + OType* output, + fp8e8m0* output_scales, + const size_t scale_idx, + const size_t scale_idx_gate, + float& thread_amax, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) { + + float block_amax = 0.0f; + float block_amax_gate = 0.0f; + const size_t stride = cols * 2; + + // Find the absolute maximum value in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + float gated_amax_act = 0; + float gated_amax_gate = 0; + + if constexpr (IS_DGATED) { + const float grad_elt = static_cast(grad[i * cols + j]); + const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + const float after_dgate = silu(silu_elt) * grad_elt; + gated_amax_act = abs(after_dsilu); + gated_amax_gate = abs(after_dgate); + } else { + const float after_silu = silu(silu_elt) * gate_elt; + gated_amax_act = abs(after_silu); + } + + if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } + if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } + } + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * + Quantized_Limits::max_reciprocal()); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + output_scales[scale_idx] = biased_exponent; + float scale_reciprocal_gate = 1; + if constexpr (IS_DGATED) { + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate * + Quantized_Limits::max_reciprocal()); + scale_reciprocal_gate = exp2f_rcp(biased_exponent); + output_scales[scale_idx_gate] = biased_exponent; + } + + + // Quantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + if constexpr (IS_DGATED) { + const float grad_elt = static_cast(grad[i * cols + j]); + const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; + const float after_dgate = silu(silu_elt) * grad_elt; + output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal); + output[i * stride + cols + j] = static_cast(after_dgate * + scale_reciprocal_gate); + } else { + const float after_silu = silu(silu_elt) * gate_elt; + output[i * cols + j] = static_cast(after_silu * scale_reciprocal); + } + + } + } + thread_amax = std::max(thread_amax, block_amax); + thread_amax = std::max(thread_amax, block_amax_gate); +} + +template +void compute_ref_x1(const IType* grad, + const IType* input, + OType* output, + fp8e8m0* output_scales, + float& ref_amax, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) { + const size_t tile_size_Y = std::max(32lu, block_size_Y); + const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; + const size_t blocks_per_tile_X = tile_size_X / block_size_X; + + float amax = 0; + #pragma omp parallel reduction(max: amax) proc_bind(spread) + { + float thread_amax = 0; + #pragma omp for schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { + const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; + const size_t block_offset_Y = ii * block_size_Y; + const size_t i_min = tile_offset_Y + block_offset_Y; + if (i_min >= rows) continue; + const size_t i_max = std::min(i_min + block_size_Y, rows); + + for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { + const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; + const size_t block_offset_X = jj * block_size_X; + const size_t j_min = tile_offset_X + block_offset_X; + if (j_min >= cols) continue; + const size_t j_max = std::min(j_min + block_size_X, cols); + + const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X; + const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X + + cols / block_size_X; + scale_block( + grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate, + thread_amax, i_min, i_max, j_min, j_max, cols); + } + } + } + if (thread_amax > amax) { + amax = thread_amax; + } + } + ref_amax = amax; +} + +template +void compute_ref_x2(const IType* grad, + const IType* input, + OType* output_rowwise, + OType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + float& ref_amax, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { + compute_ref_x1( + grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1( + grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +/** + * Scaling along single dimension (either rows or columns) + * Produces one set of output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * OR + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x1(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); + NVTE_CHECK(rowwise || colwise); + + // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl; + // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl; + // std::cout << "blocks_Y: " << blocks_Y << std::endl; + // std::cout << "blocks_X: " << blocks_X << std::endl; + // std::cout << "scales_stride: " << scales_stride << std::endl; + + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); + + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const std::array scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows, + block_size_cols); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + Tensor output("output", std::vector{ rows, output_cols }, otype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * output_cols); + std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); + + for (size_t i = 0; i < blocks_Y * blocks_X; ++i) { + ref_output_scales[i] = 0; + } + + // fillCase(&grad, fill_case); + if constexpr (IS_DGATED) { + fillUniform(&grad); + } + fillUniform(&input); + + if constexpr (IS_DGATED) { + nvte_dswiglu(grad.data(), input.data(), output.data(), 0); + } else { + nvte_swiglu(input.data(), output.data(), 0); + } + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax = 0; + compute_ref_x1(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_scales.get(), + ref_amax, + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol); + + const uint8_t * const gpu_scales_ptr = rowwise + ? output.rowwise_cpu_scale_inv_ptr() + : output.columnwise_cpu_scale_inv_ptr(); + if (rowwise) { + compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } else { + compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + } +} + +/** + * Scaling along both dimensions (rows and columns) + * Produces two sets of scaled output data and the corresponding data of the fused operation (dbias): + * 1) Scaled rows + row-wise scaling factors + * AND + * 2) Scaled columns + column-wise scaling factors + */ +template +void performTest_x2(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols, + InputsFillCase fill_case) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); + + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const std::array scale_dims_rowwise = get_scale_tensor_dims(rows, output_cols, 1, 32); + const std::array scale_dims_colwise = get_scale_tensor_dims(rows, output_cols, 32, 1); + + const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0]; + const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1]; + const size_t blocks_Y_rowwise = scale_dims_rowwise[2]; + const size_t blocks_X_rowwise = scale_dims_rowwise[3]; + const size_t scales_stride_rowwise = blocks_X_rowwise; + + const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0]; + const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1]; + const size_t blocks_Y_colwise = scale_dims_colwise[2]; + const size_t blocks_X_colwise = scale_dims_colwise[3]; + const size_t scales_stride_colwise = blocks_X_colwise; + + Tensor output("output", std::vector{ rows, output_cols }, otype, + true, true, NVTE_MXFP8_1D_SCALING); + + std::unique_ptr ref_output_rowwise = std::make_unique(rows * output_cols); + std::unique_ptr ref_output_colwise = std::make_unique(rows * output_cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_Y_rowwise * blocks_X_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_Y_colwise * blocks_X_colwise); + + for (size_t i = 0; i < blocks_Y_rowwise * blocks_X_rowwise; ++i) { + ref_scales_rowwise[i] = 0; + } + for (size_t i = 0; i < blocks_Y_colwise * blocks_X_colwise; ++i) { + ref_scales_colwise[i] = 0; + } + + // fillCase(&grad, fill_case); + if constexpr (IS_DGATED) { + fillUniform(&grad); + } + fillUniform(&input); + + if constexpr (IS_DGATED) { + nvte_dswiglu(grad.data(), input.data(), output.data(), 0); + } else { + nvte_swiglu(input.data(), output.data(), 0); + } + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + float ref_amax = 0; + compute_ref_x2(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_amax, + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); + compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise); + compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise); +} + +std::vector> matrix_sizes = { + {1, 32}, + {16, 64}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {768, 1024}, + {65536, 128}, + {16384, 1632}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + {32, 32}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, + // InputsFillCase::zeros, + // InputsFillCase::zero_to_minNorm, + // InputsFillCase::minNorm_to_maxNorm, + // InputsFillCase::maxNorm_to_inf +}; + +std::vector is_dgated_op = { + true, + false +}; + +} // namespace + +class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + InputsFillCase, + bool>> {}; + +TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const auto matrix_size = std::get<0>(GetParam()); + const auto block_size = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const InputsFillCase fill_case = std::get<4>(GetParam()); + const bool IS_DGATED = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, + if (block_size.first == 1 || block_size.second == 1) { + if (IS_DGATED) { + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } else { + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } + } else { + if (IS_DGATED) { + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } else { + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case); + } + } + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CastMXFP8_GatedActTestSuite, + ::testing::Combine( + ::testing::ValuesIn(matrix_sizes), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), + ::testing::ValuesIn(is_dgated_op)), + [](const testing::TestParamInfo& info) { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + std::to_string(std::get<1>(info.param).first) + "X" + + std::to_string(std::get<1>(info.param).second) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + test::caseName(std::get<4>(info.param)) + "X" + + (std::get<5>(info.param) ? "DGATED" : "GATED"); + return name; + }); diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 39a6614179..380ae96190 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -14,7 +14,7 @@ #include #include -#include +#include #include "../test_common.h" using namespace transformer_engine; @@ -38,6 +38,8 @@ void compute_ref(const InputType *data, OutputType *output_c, OutputType *output *amax = current_max; } + +// delayed tensor scaling test template void performTest(const size_t N, const size_t H) { using namespace test; @@ -45,38 +47,37 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output_c({ N, H }, otype); - Tensor output_t({ H, N }, otype); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype, true, true); std::unique_ptr ref_output_c = std::make_unique(N * H); std::unique_ptr ref_output_t = std::make_unique(N * H); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); - nvte_cast_transpose(input.data(), output_c.data(), output_t.data(), 0); + nvte_quantize(input.data(), output.data(), 0); float ref_amax; - compute_ref(input.cpu_dptr(), ref_output_c.get(), + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), ref_output_t.get(), N, H, &ref_amax, - output_c.scale()); + output.scale()); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } + std::vector> test_cases = {{2048, 12288}, {768, 1024}, {256, 65536}, @@ -103,6 +104,7 @@ TEST_P(CTTestSuite, TestCastTranspose) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // delayed tensor scaling performTest(size.first, size.second); ); ); diff --git a/tests/cpp/operator/test_cast_transpose_current_scaling.cu b/tests/cpp/operator/test_cast_transpose_current_scaling.cu new file mode 100644 index 0000000000..267970b34f --- /dev/null +++ b/tests/cpp/operator/test_cast_transpose_current_scaling.cu @@ -0,0 +1,210 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const InputType *data, OutputType *output_c, OutputType *output_t, + const size_t N, const size_t H, + float *amax, float scale) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + current_max = fmaxf(current_max, fabsf(current)); + output_c[i * H + j] = OutputType(scale * current); + output_t[j * N + i] = OutputType(scale * current); + } + } +} + +template +void compute_amax_scale_ref(const InputType *data, + const size_t N, const size_t H, + float *amax_ptr, float *scale_ptr, float* scale_inv_ptr, + float max_fp8, float epsilon) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + current_max = fmaxf(current_max, fabsf(current)); + } + } + *amax_ptr = current_max; + + // compute scale from amax + float clamp_amax = current_max; + if (current_max <= epsilon){ + clamp_amax = epsilon; + } + + float scale = 1.f; + float scale_inv = 1.f; + + if (isinf(clamp_amax) || clamp_amax == 0.f) { + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; + return; + } + + // use ieee_div in CPU + scale = max_fp8 / clamp_amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + scale_inv = 1.0f / scale; + + *scale_ptr = scale; + *scale_inv_ptr = scale_inv; +} + +// current tensor scaling test +template +void performTest(const size_t N, const size_t H) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + bool is_out_fp8 = isFp8Type(otype); + + // find out max fp8 value + float max_fp8; + if (is_out_fp8){ + switch (otype) { + case DType::kFloat8E5M2: { + max_fp8 = Quantized_Limits::max(); + } break; + case DType::kFloat8E4M3: { + max_fp8 = Quantized_Limits::max(); + } break; + default: + NVTE_ERROR("Invalid type."); + } + } + + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype, true, true); + + std::unique_ptr ref_output_c = std::make_unique(N * H); + std::unique_ptr ref_output_t = std::make_unique(N * H); + + fillUniform(&input); + + // compute amax + float amax_to_check = 0.0f; + if (is_out_fp8){ + nvte_compute_amax(input.data(), output.data(), 0); + QuantizationConfigWrapper config; + nvte_compute_scale_from_amax(output.data(), config, 0); + // avoid atomic amax update in cuda cast kernels because of current per-tensor scaling + amax_to_check = output.amax(); + output.set_tensor_amax_nullptr(); + } + nvte_quantize(input.data(), output.data(), 0); + + float ref_amax; + float ref_scale; + float ref_scale_inv; + if (is_out_fp8){ + compute_amax_scale_ref(input.rowwise_cpu_dptr(), + N, H, &ref_amax, &ref_scale, &ref_scale_inv, max_fp8, 0.0f); + } + + compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), + ref_output_t.get(), N, H, nullptr, + is_out_fp8 ? output.scale() : 1.0f ); + + cudaDeviceSynchronize(); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + if (isFp8Type(otype)) { + auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32); + compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32); + compareResults("scale", output.scale(), ref_scale, 0.0f, rtol_fp32); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, 0.0f, rtol_fp32); + compareResults("scale_inv_columnwise", output.columnwise_cpu_scale_inv_ptr()[0], ref_scale_inv, 0.0f, rtol_fp32); + } + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output, ref_output_c.get(), true, 0.0f, rtol); + compareResults("output_t", output, ref_output_t.get(), false, 0.0f, rtol); +} + +std::vector> test_cases = {{2048, 12288}, + {768, 1024}, + {256, 65536}, + {65536, 128}, + {256, 256}, + {120, 2080}, + {8, 8}, + {1, 3221}, // Prime 456 + {2333, 1}, // Prime 345 + {1481, 677}}; // Primes 234, 123 +} // namespace + +class CTCSTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(CTCSTestSuite, TestCastTransposeCS) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + // current tensor scaling + performTest(size.first, size.second); + ); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + CTCSTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::ValuesIn(test::all_fp_types), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)) + "X" + + std::to_string(std::get<2>(info.param).first) + "X" + + std::to_string(std::get<2>(info.param).second); + return name; + }); diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index 651508c871..53918e2699 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -15,7 +15,7 @@ #include #include -#include +#include #include "../test_common.h" using namespace transformer_engine; @@ -64,26 +64,23 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - DType ctype = TypeInfo::dtype; - Tensor input({N, H}, itype); + Tensor input("input", {N, H}, itype); - Tensor output_c({N, H}, otype); - Tensor output_t({ H, N}, otype); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); CType ref_amax; - compute_ref_cast_transpose_dbias(input.cpu_dptr(), - output_c.scale(), + compute_ref_cast_transpose_dbias(input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -92,22 +89,20 @@ void performTest(const size_t N, const size_t H) { Tensor workspace; - nvte_cast_transpose_dbias(input.data(), - output_c.data(), - output_t.data(), - dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias(input.data(), + output.data(), + dbias.data(), + workspace.data(), + 0); - workspace = Tensor(workspace.shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_cast_transpose_dbias(input.data(), - output_c.data(), - output_t.data(), - dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias(input.data(), + output.data(), + dbias.data(), + workspace.data(), + 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -115,17 +110,17 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); auto [atol_dbias, rtol_dbias] = getTolerances(itype); rtol_dbias *= 4; - compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias); + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } std::vector> test_cases = {{64, 400}, diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 38ac955bc9..15c7d8d665 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -75,29 +75,26 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - DType ctype = TypeInfo::dtype; - Tensor input({N, H}, itype); - Tensor gelu_input({N, H}, itype); + Tensor input("input", {N, H}, itype); + Tensor gelu_input("gelu_input", {N, H}, itype); - Tensor output_c({N, H}, otype); - Tensor output_t({ H, N}, otype); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); fillUniform(&gelu_input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); CType ref_amax; - compute_ref_cast_transpose_dbias_dgelu(input.cpu_dptr(), - gelu_input.cpu_dptr(), - output_c.scale(), + compute_ref_cast_transpose_dbias_dgelu(input.rowwise_cpu_dptr(), + gelu_input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -108,19 +105,17 @@ void performTest(const size_t N, const size_t H) { nvte_cast_transpose_dbias_dgelu(input.data(), gelu_input.data(), - output_c.data(), - output_t.data(), + output.data(), dbias.data(), workspace.data(), 0); - workspace = Tensor(workspace.shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_cast_transpose_dbias_dgelu(input.data(), gelu_input.data(), - output_c.data(), - output_t.data(), + output.data(), dbias.data(), workspace.data(), 0); @@ -131,18 +126,18 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); auto [atol_dbias, rtol_dbias] = getTolerances(itype); rtol_dbias *= 4; - compareResults("output_dbias", dbias, ref_output_dbias.get(), atol_dbias, rtol_dbias); + compareResults("output_dbias", dbias, ref_output_dbias.get(), true, atol_dbias, rtol_dbias); } std::vector> test_cases = {{64, 400}, diff --git a/tests/cpp/operator/test_cast_transpose_dgeglu.cu b/tests/cpp/operator/test_cast_transpose_dgeglu.cu index b1881b2a96..ae2da7bad2 100644 --- a/tests/cpp/operator/test_cast_transpose_dgeglu.cu +++ b/tests/cpp/operator/test_cast_transpose_dgeglu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -74,24 +74,22 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor grad({N, H}, itype); - Tensor input({N, H * 2}, itype); - Tensor output_c({N, H * 2}, otype); - Tensor output_t({H * 2, N}, otype); + Tensor grad("grad", {N, H}, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H * 2}, otype, true, true); fillUniform(&grad); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); std::unique_ptr ref_output_c = std::make_unique(N * H * 2); std::unique_ptr ref_output_t = std::make_unique(N * H * 2); - nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0); + nvte_dgeglu_cast_transpose(grad.data(), input.data(), output.data(), 0); CType ref_amax; - compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr(), input.cpu_dptr(), - output_c.scale(), ref_output_c.get(), ref_output_t.get(), + compute_ref_cast_transpose_dgated_gelu(grad.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), + output.scale(), ref_output_c.get(), ref_output_t.get(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -100,14 +98,14 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); - compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / output.scale(); + compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), atol, rtol); - compareResults("output_t", output_t, ref_output_t.get(), atol, rtol); + compareResults("output_c", output, ref_output_c.get(), true, atol, rtol); + compareResults("output_t", output, ref_output_t.get(), false, atol, rtol); } std::vector> test_cases = {{64, 400}, {4096, 2048}, {768, 2816}, diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu index 640434674b..2fdc0a524d 100644 --- a/tests/cpp/operator/test_causal_softmax.cu +++ b/tests/cpp/operator/test_causal_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -153,11 +153,11 @@ void performTest( DType itype = TypeInfo::dtype; - Tensor data_in({ batches, heads, rows, cols }, itype); - Tensor softmax_out({ batches, heads, rows, cols }, itype); - Tensor softmax_in({ batches, heads, rows, cols }, itype); - Tensor grads_in({ batches, heads, rows, cols }, itype); - Tensor grads_out({ batches, heads, rows, cols }, itype); + Tensor data_in("data_in", { batches, heads, rows, cols }, itype); + Tensor softmax_out("softmax_out", { batches, heads, rows, cols }, itype); + Tensor softmax_in("softmax_in", { batches, heads, rows, cols }, itype); + Tensor grads_in("grads_in", { batches, heads, rows, cols }, itype); + Tensor grads_out("grads_out", { batches, heads, rows, cols }, itype); const size_t elements_total = batches * heads * rows * cols; std::unique_ptr softmax_out_ref = std::make_unique(elements_total); @@ -175,9 +175,9 @@ void performTest( // Reference implementations - compute_fwd_ref(softmax_out_ref.get(), data_in.cpu_dptr(), + compute_fwd_ref(softmax_out_ref.get(), data_in.rowwise_cpu_dptr(), compute_buffer.get(), scaling_factor, batches, heads, rows, cols); - compute_bwd_ref(grads_out_ref.get(), grads_in.cpu_dptr(), softmax_in.cpu_dptr(), + compute_bwd_ref(grads_out_ref.get(), grads_in.rowwise_cpu_dptr(), softmax_in.rowwise_cpu_dptr(), compute_buffer.get(), scaling_factor, batches, heads, rows, cols); cudaDeviceSynchronize(); @@ -187,8 +187,8 @@ void performTest( if(itype == DType::kBFloat16) { atol = 1e-3; } - compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), atol, rtol); - compareResults("softmax_bwd", grads_out, grads_out_ref.get(), atol, rtol); + compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), true, atol, rtol); + compareResults("softmax_bwd", grads_out, grads_out_ref.get(), true, atol, rtol); } // [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor] diff --git a/tests/cpp/operator/test_dequantize_mxfp8.cu b/tests/cpp/operator/test_dequantize_mxfp8.cu new file mode 100644 index 0000000000..701deb38bb --- /dev/null +++ b/tests/cpp/operator/test_dequantize_mxfp8.cu @@ -0,0 +1,452 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +template +void dequantize_block(const InputType* input, + OutputType* output, + fp8e8m0* scales, + const size_t scale_idx, + const size_t i_min, + const size_t i_max, + const size_t j_min, + const size_t j_max, + const size_t cols) +{ + const fp8e8m0 biased_exponent = scales[scale_idx]; + const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float elem_scale = block_scale; + + // Dequantize elements in the block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float elt = static_cast(input[idx]); + output[idx] = static_cast(elt * elem_scale); + } + } +} + +template +void compute_ref_x1(const InputType* input, + OutputType* output, + fp8e8m0* scales, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride) +{ + const size_t blocks_Y = (rows + block_size_Y - 1) / block_size_Y; + const size_t blocks_X = (cols + block_size_X - 1) / block_size_X; + + for (size_t ii = 0; ii < blocks_Y; ++ii) { + const size_t i_min = ii * block_size_Y; + const size_t i_max = std::min((ii + 1) * block_size_Y, rows); + for (size_t jj = 0; jj < blocks_X; ++jj) { + const size_t j_min = jj * block_size_X; + const size_t j_max = std::min((jj + 1) * block_size_X, cols); + const size_t scale_idx = ii * scales_stride + jj; + dequantize_block( + input, output, scales, scale_idx, i_min, i_max, j_min, j_max, cols); + } + } +} + +template +void compute_ref_x2(const InputType* input, + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* scales_rowwise, + fp8e8m0* scales_colwise, + const size_t rows, + const size_t cols, + const size_t block_size_Y, + const size_t block_size_X, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) +{ + compute_ref_x1(input, output_rowwise, scales_rowwise, rows, cols, 1, block_size_X, scales_stride_rowwise); + compute_ref_x1(input, output_colwise, scales_colwise, rows, cols, block_size_Y, 1, scales_stride_colwise); +} + +void generate_scales(fp8e8m0 * const scales_ref, + fp8e8m0 * const scales, + const size_t blocks_num, + std::mt19937& gen, + std::uniform_int_distribution dis) +{ + for (size_t i = 0; i < blocks_num; ++i) { + const fp8e8m0 val = dis(gen); + scales_ref[i] = val; + scales[i] = val; + } +} + +template +void generate_data(InputType * const data, + const size_t rows, + const size_t cols, + std::mt19937& gen, + std::uniform_real_distribution<>& dis, + std::uniform_real_distribution<>& dis_sign) +{ + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const bool is_negative = (dis_sign(gen) < 0.0); + double val = dis(gen); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } + } +} + +template +void fill_tensor_data(Tensor& input, + fp8e8m0 * const scales_rowwise, + fp8e8m0 * const scales_colwise, + const bool is_rowwise_scaling, + const bool is_colwise_scaling, + const size_t rows, + const size_t cols, + const size_t blocks_num_rowwise, + const size_t blocks_num_colwise) +{ + const double minAbs = Numeric_Traits::minNorm; + const double maxAbs = Numeric_Traits::maxNorm; + static std::mt19937 gen(12345); + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + std::uniform_int_distribution int_dis(0, 255); + + if (is_rowwise_scaling) { + generate_scales(scales_rowwise, input.rowwise_cpu_scale_inv_ptr(), blocks_num_rowwise, gen, int_dis); + generate_data(input.rowwise_cpu_dptr(), rows, cols, gen, dis, dis_sign); + } + + if (is_colwise_scaling) { + generate_scales(scales_colwise, input.columnwise_cpu_scale_inv_ptr(), blocks_num_colwise, gen, int_dis); + generate_data(input.columnwise_cpu_dptr(), rows, cols, gen, dis, dis_sign); + } + + input.from_cpu(); +} + +// Dequantize along single dimension (either row- or columnwise) +template +void performTest_x1(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t block_size_rows = rowwise ? 1 : 32; + const size_t block_size_cols = colwise ? 1 : 32; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise; + const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise; + + Tensor input("input", { rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + // Output data are written to the rowwise ptr regardless of the scaling direction + Tensor output("output", { rows, cols }, otype, true, false); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr scales = std::make_unique(blocks_num); + + fill_tensor_data(input, scales.get(), scales.get(), rowwise, colwise, rows, cols, + blocks_num_rowwise, blocks_num_colwise); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + InputType * data_ptr = rowwise + ? input.rowwise_cpu_dptr() + : input.columnwise_cpu_dptr(); + + compute_ref_x1(data_ptr, + ref_output.get(), + scales.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), true, atol, rtol); +} + +// Dequantize along single dimension (either row- or columnwise) +template +void performTest_quantize_then_dequantize(const size_t rows, + const size_t cols, + const bool rowwise, + const bool colwise) +{ + using namespace test; + using EncodingType = fp32; + DType in_type = TypeInfo::dtype; + DType intermed_type = TypeInfo::dtype; + DType out_type = TypeInfo::dtype; + + std::unique_ptr input_cpu = std::make_unique(rows * cols); + std::unique_ptr quantized_cpu = std::make_unique(rows * cols); + std::unique_ptr output_cpu = std::make_unique(rows * cols); + + // input --> quantized --> output (dequantized) + // input == output + Tensor input("input", { rows, cols }, in_type); + Tensor quantized("quantized", { rows, cols }, intermed_type, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + + // Output data are written to the rowwise ptr regardless of the scaling direction + Tensor output("output", { rows, cols }, out_type, true, false); + + // fillCase(&input, InputsFillCase::minNorm_to_maxNorm); + fillCase(&input, InputsFillCase::uniform); + + const size_t copy_size = sizeof(InputType) * rows * cols; + cudaMemcpy(input_cpu.get(), input.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost); + + nvte_quantize(input.data(), quantized.data(), 0); + cudaDeviceSynchronize(); + + const size_t copy_size_quantized = sizeof(IntermediateType) * rows * cols; + if (rowwise) { + cudaMemcpy(quantized_cpu.get(), quantized.rowwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost); + } + if (colwise) { + cudaMemcpy(quantized_cpu.get(), quantized.columnwise_dptr(), copy_size_quantized, cudaMemcpyDeviceToHost); + } + + nvte_dequantize(quantized.data(), output.data(), 0); + cudaDeviceSynchronize(); + + cudaMemcpy(output_cpu.get(), output.rowwise_dptr(), copy_size, cudaMemcpyDeviceToHost); + + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol, rtol] = getTolerances(intermed_type); + compareResults("Quantize-Dequantize", input, output_cpu.get(), true, atol, rtol); +} + +// Dequantize along both dimensions (row- and columnwise) +template +void performTest_x2(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols) +{ + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t unpadded_blocks_Y_rowwise = rows; + const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols); + const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X_colwise = cols; + + const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise, + scale_tensor_alignment_Y_rowwise); + const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise, + scale_tensor_alignment_X_rowwise); + const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise, + scale_tensor_alignment_Y_colwise); + const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise, + scale_tensor_alignment_X_colwise); + + const size_t scales_stride_rowwise = blocks_X_rowwise; + const size_t scales_stride_colwise = blocks_X_colwise; + const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; + const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; + + Tensor input("input", { rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output("output", { rows, cols }, otype); + + std::unique_ptr ref_output_rowwise = std::make_unique(rows * cols); + std::unique_ptr ref_output_colwise = std::make_unique(rows * cols); + std::unique_ptr ref_scales_rowwise = std::make_unique(blocks_num_rowwise); + std::unique_ptr ref_scales_colwise = std::make_unique(blocks_num_colwise); + + constexpr bool rowwise = true; + constexpr bool colwise = true; + fill_tensor_data(input, ref_scales_rowwise.get(), ref_scales_colwise.get(), + rowwise, colwise, rows, cols, blocks_num_rowwise, blocks_num_colwise); + + nvte_dequantize(input.data(), output.data(), 0); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + compute_ref_x2(input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + rows, + cols, + block_size_rows, + block_size_cols, + scales_stride_rowwise, + scales_stride_colwise); + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); + compareResults("output_colwise", output, ref_output_colwise.get(), false, atol, rtol); +} + +std::vector> tensor_dims = { + {1, 16}, + {16, 48}, + {65, 96}, + {128, 128}, + {256, 256}, + {993, 512}, + {768, 1024}, + // {2048, 12288}, + // {65536, 128}, + // {16384, 1632}, + // {16384, 6144}, +}; + +std::vector> block_sizes = { + {1, 32}, + {32, 1}, + // {32, 32}, +}; + +} // namespace + +class DequantizeMXFP8TestSuite : public ::testing::TestWithParam + , + std::pair, + transformer_engine::DType, + transformer_engine::DType, + bool>> {}; + +TEST_P(DequantizeMXFP8TestSuite, TestDequantizeMXFP8) +{ + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const auto tensor_size = std::get<0>(GetParam()); + const auto block_size = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const bool quantize_then_dequantize = std::get<4>(GetParam()); + + const bool rowwise = block_size.second != 1; + const bool colwise = block_size.first != 1; + + // Skip tests for dequantization along both dimensions + if (rowwise && colwise) { + GTEST_SKIP(); + } + + // Skip cases with invalid alignment + if (rowwise && tensor_size.second % 32 != 0) { + GTEST_SKIP(); + } + if (colwise && tensor_size.first % 32 != 0) { + GTEST_SKIP(); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, + if (quantize_then_dequantize) { + // Mind the order of the Output/Input template parameters + performTest_quantize_then_dequantize( + tensor_size.first, tensor_size.second, rowwise, colwise); + } else { + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1(tensor_size.first, tensor_size.second, + rowwise, colwise); + } else { + performTest_x2(tensor_size.first, tensor_size.second, + block_size.first, block_size.second); + } + } + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DequantizeMXFP8TestSuite, + ::testing::Combine( + ::testing::ValuesIn(tensor_dims), + ::testing::ValuesIn(block_sizes), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(false)), + [](const testing::TestParamInfo& info) + { + std::string name = std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "X" + + std::to_string(std::get<1>(info.param).first) + "X" + + std::to_string(std::get<1>(info.param).second) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + (std::get<4>(info.param) ? "QD" : "D"); + return name; + } +); diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu deleted file mode 100644 index cdd8e7846c..0000000000 --- a/tests/cpp/operator/test_layernorm.cu +++ /dev/null @@ -1,302 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include "../test_common.h" - -using namespace transformer_engine; -using namespace test; - -namespace { - -template -void compute_ref_stats(const InputType *data, float *mu, float *rsigma, - const size_t N, const size_t H, const double epsilon) { - using compute_t = float; - for (size_t i = 0 ; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += current; - } - mu[i] = sum / H; - compute_t m = mu[i]; - sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += (current - m) * (current - m); - } - sum = sum / H; - compute_t rs = rsqrtf(sum + epsilon); - rsigma[i] = rs; - } -} - -template -void compute_ref_output(const InputType *data, const InputType *gamma, const InputType *beta, - OutputType *output, const float *mu, const float *rsigma, - const size_t N, const size_t H, - float *amax, float scale, const bool zero_centered_gamma) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0 ; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - compute_t tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - -template -void compute_ref_backward(const OutputType *output_grad, const InputType *data, - const float *mu, const float *rsigma, - const InputType *gamma, - InputType *data_grad, - InputType *gamma_grad, InputType *beta_grad, - const size_t N, const size_t H, - const bool zero_centered_gamma) { - using compute_t = float; - std::vector dgamma(H, 0.f); - std::vector dbeta(H, 0.f); - - for (size_t i = 0 ; i < N; ++i) { - // Reductions - compute_t mdy = 0, mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - mu[i]) * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - dbeta[j] += dz; - mdy += dy; - mdyy += dy * y; - } - mdy /= H; - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = (x - mu[i]) * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) { - gamma_grad[j] = static_cast(dgamma[j]); - beta_grad[j] = static_cast(dbeta[j]); - } -} - -template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) { - if (sizeof(InputType) < sizeof(OutputType)) { - GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; - return; - } - using WeightType = InputType; - DType itype = TypeInfo::dtype; - DType wtype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || - (itype == DType::kFloat16 && otype == DType::kBFloat16)) { - GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; - return; - } - - Tensor input({ N, H }, itype); - Tensor z({ N, H }, otype); - Tensor gamma({ H }, wtype); - Tensor beta({ H }, wtype); - Tensor mu({ N }, DType::kFloat32); - Tensor rsigma({ N }, DType::kFloat32); - Tensor dz({ N, H }, wtype); - Tensor dx({ N, H }, itype); - Tensor dgamma({ H }, wtype); - Tensor dbeta({ H }, wtype); - Tensor workspace, barrier, dgamma_part, dbeta_part; - - fillUniform(&input); - fillUniform(&gamma); - fillUniform(&beta); - setRandomScale(&z); - fillUniform(&dz); - - std::unique_ptr ref_output = std::make_unique(N * H); - std::unique_ptr ref_mu = std::make_unique(N); - std::unique_ptr ref_rsigma = std::make_unique(N); - std::unique_ptr ref_dx = std::make_unique(N * H); - std::unique_ptr ref_dgamma = std::make_unique(H); - std::unique_ptr ref_dbeta = std::make_unique(H); - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - - // Forward kernel - float epsilon = 1e-5; - auto fwd_function = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - fwd_function(input.data(), gamma.data(), beta.data(), epsilon, - z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - fwd_function(input.data(), gamma.data(), beta.data(), epsilon, - z.data(), mu.data(), rsigma.data(), 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - - // Backward kernel - auto bwd_function = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_function(dz.data(), input.data(), - mu.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), dbeta.data(), - dgamma_part.data(), dbeta_part.data(), - 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); - dbeta_part = Tensor(dbeta_part.shape(), dbeta_part.dtype()); - bwd_function(dz.data(), input.data(), - mu.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), dbeta.data(), - dgamma_part.data(), dbeta_part.data(), - 0, prop.multiProcessorCount, - workspace.data(), barrier.data()); - - // Reference implementations - // use the GPU stats to tighten the tolerances - mu.to_cpu(); - rsigma.to_cpu(); - float ref_amax; - compute_ref_stats(input.cpu_dptr(), ref_mu.get(), - ref_rsigma.get(), N, H, epsilon); - float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(input.cpu_dptr(), - gamma.cpu_dptr(), - beta.cpu_dptr(), - ref_output.get(), - mu.cpu_dptr(), - rsigma.cpu_dptr(), - N, H, - &ref_amax, - ref_scale, - zero_centered_gamma); - compute_ref_backward(dz.cpu_dptr(), input.cpu_dptr(), - mu.cpu_dptr(), rsigma.cpu_dptr(), - gamma.cpu_dptr(), - ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), - N, H, zero_centered_gamma); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - if (isFp8Type(otype)) { - compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); - } - - auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); - rtol_stats = 5e-5; - compareResults("mu", mu, ref_mu.get(), atol_stats, rtol_stats); - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); - - auto [atol, rtol] = getTolerances(otype); - if (otype == DType::kFloat32) { - atol = 5e-7; - } - compareResults("output", z, ref_output.get(), atol, rtol); - - double atol_bwd = 1e-4; - double rtol_bwd = 1e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); - compareResults("dbeta", dbeta, ref_dbeta.get(), atol_bwd, rtol_bwd); -} - -std::vector> test_cases = {{2048, 12288}, - {768, 1024}, - {256, 65536}, - {128, 6144}, - {64, 2304}, - {229, 541}, // Primes 50, 100 - {71, 3571}, // Primes 20, 500 - {29, 17389}}; // Primes 10, 2000 - -} // namespace - -class LNTestSuite : public ::testing::TestWithParam, - bool>> {}; - -TEST_P(LNTestSuite, TestLN) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - const bool zero_centered_gamma = std::get<3>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma); - ); - ); -} - -INSTANTIATE_TEST_SUITE_P( - OperatorTest, - LNTestSuite, - ::testing::Combine( - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo& info) { - std::string name = test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second) + "X" + - std::to_string(std::get<3>(info.param)); - return name; - }); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index e7fb183217..f07138caca 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -69,7 +69,7 @@ void performTest() { const size_t num_tensors = tensor_dims.size(); // Buffers for Transformer Engine implementation - std::vector input_list, output_c_list, output_t_list; + std::vector input_list, output_list; // Buffers for reference implementation std::vector> ref_input_list; @@ -81,25 +81,23 @@ void performTest() { for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_c_list.emplace_back(Tensor({ height, width }, otype)); - output_t_list.emplace_back(Tensor({ width, height }, otype)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), + { height, width }, otype, true, true)); auto& input = input_list.back(); - auto& output_c = output_c_list.back(); - auto& output_t = output_t_list.back(); + auto& output = output_list.back(); fillUniform(&input); - setRandomScale(&output_c); - output_t.shareFP8Meta(output_c); + setRandomScale(&output); ref_input_list.emplace_back(height*width); ref_output_c_list.emplace_back(height*width); ref_output_t_list.emplace_back(width*height); - std::copy(input.cpu_dptr(), - input.cpu_dptr() + height * width, + std::copy(input.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr() + height * width, ref_input_list.back().begin()); - ref_scale_list[tensor_id] = output_c.scale(); + ref_scale_list[tensor_id] = output.scale(); ref_height_list[tensor_id] = height; ref_width_list[tensor_id] = width; } @@ -115,8 +113,7 @@ void performTest() { }; nvte_multi_cast_transpose(num_tensors, make_nvte_vector(input_list).data(), - make_nvte_vector(output_c_list).data(), - make_nvte_vector(output_t_list).data(), + make_nvte_vector(output_list).data(), 0); // Reference implementation @@ -136,23 +133,23 @@ void performTest() { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", - output_c_list[tensor_id].amax(), + output_list[tensor_id].amax(), ref_amax_list[tensor_id], atol_amax, rtol_amax); compareResults("scale_inv", - output_c_list[tensor_id].scale_inv(), - 1.f / output_c_list[tensor_id].scale(), + output_list[tensor_id].rowwise_scale_inv(), + 1.f / output_list[tensor_id].scale(), atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); compareResults("output_c", - output_c_list[tensor_id], + output_list[tensor_id], ref_output_c_list[tensor_id].data(), - atol, rtol); + true, atol, rtol); compareResults("output_t", - output_t_list[tensor_id], + output_list[tensor_id], ref_output_t_list[tensor_id].data(), - atol, rtol); + false, atol, rtol); } } diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu index e9e42725fe..b8475fe561 100644 --- a/tests/cpp/operator/test_multi_padding.cu +++ b/tests/cpp/operator/test_multi_padding.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -84,8 +85,8 @@ void performTest() { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; const size_t padded_height = (height + align - 1) / align * align; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_list.emplace_back(Tensor({ padded_height, width }, otype)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), { padded_height, width }, otype)); auto& input = input_list.back(); auto& output = output_list.back(); @@ -95,8 +96,8 @@ void performTest() { ref_input_list.emplace_back(height*width); ref_output_list.emplace_back(padded_height*width); - std::copy(input.cpu_dptr(), - input.cpu_dptr() + height * width, + std::copy(input.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr() + height * width, ref_input_list.back().begin()); ref_height_list[tensor_id] = height; ref_width_list[tensor_id] = width; @@ -134,6 +135,7 @@ void performTest() { compareResults("output", output_list[tensor_id], ref_output_list[tensor_id].data(), + true, atol, rtol); } } diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu new file mode 100644 index 0000000000..0004c2ce74 --- /dev/null +++ b/tests/cpp/operator/test_normalization.cu @@ -0,0 +1,385 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RmsNorm"} +}; + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + compute_t current, m; + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +// For now, cudnn does static_cast(gamma + static_cast(1.0)) +// This will be changed in the future release +template +inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const bool use_cudnn){ + + using compute_t = float; + if constexpr (std::is_same_v || std::is_same_v){ + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } else { + if (use_cudnn){ + compute_t g = static_cast(0.f); + InputType gi = gamma; + if (zero_centered_gamma) { + gi = gi + static_cast(1.f); + } + g = static_cast(gi); + return g; + } else { + compute_t g = static_cast(gamma); + if (zero_centered_gamma) { + g += static_cast(1.f); + } + return g; + } + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + OutputType* output, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + float *amax, float scale, const bool zero_centered_gamma, const bool use_cudnn) { + using compute_t = float; + compute_t current_max = -1e100; + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = static_cast(tmp * scale); + current_max = fmaxf(current_max, fabsf(tmp)); + } + } + *amax = current_max; +} + + +template +void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, + const float *mu, const float *rsigma, + const InputType *gamma, + InputType *data_grad, + InputType *gamma_grad, InputType *beta_grad, + const size_t N, const size_t H, + const bool zero_centered_gamma, const bool use_cudnn) { + using compute_t = float; + std::vector dgamma(H, 0.f); + std::vector dbeta(H, 0.f); + + for (size_t i = 0 ; i < N; ++i) { + // Reductions + auto local_mu = (norm_type == LayerNorm) ? mu[i] : 0.; + compute_t mdy = 0, mdyy = 0; + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + dgamma[j] += y * dz; + if (norm_type == LayerNorm) { + dbeta[j] += dz; + mdy += dy; + } + mdyy += dy * y; + } + mdy /= H; + mdyy /= H; + + // Input grads + for (size_t j = 0; j < H; ++j) { + const compute_t x = static_cast(data[i * H + j]); + const compute_t y = (x - local_mu) * rsigma[i]; + compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn); + const compute_t dz = static_cast(output_grad[i * H + j]); + const compute_t dy = g * dz; + const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); + data_grad[i * H + j] = static_cast(dx); + } + } + + // Weight grads + for (size_t j = 0; j < H; ++j) gamma_grad[j] = static_cast(dgamma[j]); + if (norm_type == LayerNorm) for (size_t j = 0; j < H; ++j) beta_grad[j] = static_cast(dbeta[j]); +} + +template +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, + NormType norm_type, bool use_cudnn) { + if (sizeof(InputType) < sizeof(OutputType)) { + GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; + return; + } + + if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) { + GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!"; + } + + using WeightType = InputType; + DType itype = TypeInfo::dtype; + DType wtype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || + (itype == DType::kFloat16 && otype == DType::kBFloat16)) { + GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; + return; + } + + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); + Tensor dz("dz", { N, H }, wtype); + Tensor dx("dx", { N, H }, itype); + Tensor dgamma("dgamma", { H }, wtype); + Tensor dbeta("dbeta", { H }, wtype); + Tensor workspace_fwd, workspace_bwd; + + fillUniform(&input); + fillUniform(&gamma); + fillUniform(&beta); + setRandomScale(&z); + fillUniform(&dz); + + std::unique_ptr ref_output = std::make_unique(N * H); + std::unique_ptr ref_mu = std::make_unique(N); + std::unique_ptr ref_rsigma = std::make_unique(N); + std::unique_ptr ref_dx = std::make_unique(N * H); + std::unique_ptr ref_dgamma = std::make_unique(H); + std::unique_ptr ref_dbeta = std::make_unique(H); + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + + if (use_cudnn){ + nvte_enable_cudnn_norm_fwd(true); + nvte_enable_cudnn_norm_bwd(true); + } + + // Forward kernel + float epsilon = 1e-5; + if (norm_type == LayerNorm){ + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + + nvte_layernorm_bwd(dz.data(), input.data(), + mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + nvte_layernorm_bwd(dz.data(), input.data(), + mu.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), dbeta.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace_fwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), + workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } + + if (use_cudnn){ + nvte_enable_cudnn_norm_fwd(false); + nvte_enable_cudnn_norm_bwd(false); + } + + // Reference implementations + // use the GPU stats to tighten the tolerances + mu.to_cpu(); + rsigma.to_cpu(); + float ref_amax; + compute_ref_stats(norm_type, input.rowwise_cpu_dptr(), ref_mu.get(), + ref_rsigma.get(), N, H, epsilon); + float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; + compute_ref_output(norm_type, input.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), + beta.rowwise_cpu_dptr(), + ref_output.get(), + mu.rowwise_cpu_dptr(), + rsigma.rowwise_cpu_dptr(), + N, H, + &ref_amax, + ref_scale, + zero_centered_gamma, + use_cudnn); + compute_ref_backward(norm_type, dz.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + mu.rowwise_cpu_dptr(), rsigma.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), + ref_dx.get(), ref_dgamma.get(), ref_dbeta.get(), + N, H, zero_centered_gamma, + use_cudnn); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + if (isFp8Type(otype)) { + compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); + float ref_scale_inv = 1.f / z.scale(); + compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); + } + + auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); + rtol_stats = 5e-5; + compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats); + + auto [atol, rtol] = getTolerances(otype); + if (otype == DType::kFloat32) { + atol = 5e-7; + } + compareResults("output", z, ref_output.get(), true, atol, rtol); + + double atol_bwd = 5e-4; + double rtol_bwd = 5e-4; + compareResults("dx", dx, ref_dx.get(), true, atol_bwd, rtol_bwd); + compareResults("dgamma", dgamma, ref_dgamma.get(), true, atol_bwd, rtol_bwd); + compareResults("dbeta", dbeta, ref_dbeta.get(), true, atol_bwd, rtol_bwd); +} + +std::vector> test_cases = { + {71, 229}, + {29, 541}, + {768, 6144}, + {2048, 12288}, +}; + +} // namespace + +class NormTestSuite : public ::testing::TestWithParam, + bool>> {}; + +TEST_P(NormTestSuite, TestNorm) { + using namespace transformer_engine; + using namespace test; + + const bool use_cudnn = std::get<0>(GetParam()); + const NormType norm_type = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const auto size = std::get<4>(GetParam()); + const bool zero_centered_gamma = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + NormTestSuite, + ::testing::Combine( + ::testing::Values(true, false), + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; + std::string name = + backend + + normToString.at(std::get<1>(info.param)) + "_" + + test::typeName(std::get<2>(info.param)) + "X" + + test::typeName(std::get<3>(info.param)) + "X" + + std::to_string(std::get<4>(info.param).first) + "X" + + std::to_string(std::get<4>(info.param).second) + "X" + + std::to_string(std::get<5>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu new file mode 100644 index 0000000000..191c62835b --- /dev/null +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -0,0 +1,340 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include "../test_common.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +using fp8e8m0 = byte; + +enum NormType { + LayerNorm, + RMSNorm +}; + +std::map normToString = { + {NormType::LayerNorm, "LayerNorm"}, + {NormType::RMSNorm, "RMSNorm"} +}; + +template +void dequantize_1x_kernel(InputType* input_ptr, ScaleType* scale_ptr, OutputType* output_ptr, + size_t rows, size_t cols, size_t scaling_mode_x, size_t scaling_mode_y){ + + const size_t block_size_Y = scaling_mode_x; // mind the mapping Y <-- x + const size_t block_size_X = scaling_mode_y; // and X <-- y + const size_t tile_size_Y = std::max(32lu, block_size_Y); + const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; + const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; + const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; + const size_t blocks_per_tile_X = tile_size_X / block_size_X; + const size_t blocks_per_row = (cols + block_size_X - 1) / block_size_X; + + #pragma omp parallel for proc_bind(spread) schedule(static) + for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { + const size_t tile_Y = t / tiles_num_X; + const size_t tile_X = t % tiles_num_X; + const size_t tile_offset_Y = tile_Y * tile_size_Y; + const size_t tile_offset_X = tile_X * tile_size_X; + + for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { + const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; + const size_t block_offset_Y = ii * block_size_Y; + const size_t i_min = tile_offset_Y + block_offset_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + + for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { + const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; + const size_t block_offset_X = jj * block_size_X; + const size_t j_min = tile_offset_X + block_offset_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + const size_t mx_scale_idx = block_idx_Y * blocks_per_row + block_idx_X; + + // TODO: padded SFs i.e. (4,128) + const float scale_inv = exp2f(static_cast(scale_ptr[mx_scale_idx]) - FP32_EXPONENT_BIAS); + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float elem = static_cast(input_ptr[idx]); + output_ptr[idx] = static_cast(elem * scale_inv); + } + } + } + } + } +} + +template +void dequantize_2x(Tensor& input, Tensor& output, bool is_training) +{ + input.to_cpu(); + auto scaling_mode = input.scaling_mode(); + assert(input.rowwise_shape().ndim == 2); + + if (is_training) { + assert(input.columnwise_shape().ndim == 2); + } + + dequantize_1x_kernel(input.rowwise_cpu_dptr(), + input.rowwise_cpu_scale_inv_ptr(), + output.rowwise_cpu_dptr(), + input.rowwise_shape().data[0], input.rowwise_shape().data[1], + 1, 32); + if (is_training) + dequantize_1x_kernel(input.columnwise_cpu_dptr(), + input.columnwise_cpu_scale_inv_ptr(), + output.columnwise_cpu_dptr(), + input.columnwise_shape().data[0], input.columnwise_shape().data[1], + 32, 1); +} + +template +void compute_ref_stats(NormType norm_type, + const InputType *data, float *mu, float *rsigma, + const size_t N, const size_t H, const double epsilon){ + using compute_t = float; + + #pragma omp parallel for proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + compute_t sum = 0; + for (size_t j = 0; j < H; ++j) { + sum += static_cast(data[i * H + j]); + } + compute_t m; + if (norm_type == LayerNorm){ + mu[i] = sum / H; + m = mu[i]; + } else { m = 0;} + + compute_t sum_sq = 0; + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + sum_sq += (current - m) * (current - m); + } + rsigma[i] = rsqrtf((sum_sq / H) + epsilon); + } +} + +template +void compute_ref_output(NormType norm_type, + const InputType *data, const InputType *gamma, const InputType *beta, + const float *mu, const float *rsigma, + const size_t N, const size_t H, + OutputType* output, + const bool zero_centered_gamma){ + using compute_t = float; + + #pragma omp parallel for proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + for (size_t j = 0; j < H; ++j) { + compute_t current = static_cast(data[i * H + j]); + compute_t g = static_cast(gamma[j]); + if (zero_centered_gamma) { + g += 1.0; + } + + compute_t tmp; + if (norm_type == LayerNorm) { + tmp = (current - mu[i]) * rsigma[i] * g + static_cast(beta[j]); + } else { // RMSNorm + tmp = current * rsigma[i] * g; + } + + output[i * H + j] = tmp; + } + } +} + +template +void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, NormType norm_type, bool is_training) { + + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using WeightType = InputType; + DType itype = TypeInfo::dtype; + DType wtype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); + Tensor workspace; + + + fillUniform(&input); + fillUniform(&gamma); + fillUniform(&beta); + + // Forward kernel + float epsilon = 1e-5; + if (norm_type == NormType::LayerNorm){ + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, + z.data(), mu.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + } else { + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); + nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, + z.data(), rsigma.data(), workspace.data(), + prop.multiProcessorCount, zero_centered_gamma, + 0); + } + + Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); + + dequantize_2x(z, dequantized_output, is_training); + + // Reference implementations + std::unique_ptr ref_mu = std::make_unique(N); + std::unique_ptr ref_rsigma = std::make_unique(N); + std::unique_ptr ref_output = std::make_unique(N * H); + + + compute_ref_stats(norm_type, input.rowwise_cpu_dptr(), ref_mu.get(), + ref_rsigma.get(), N, H, epsilon); + // use the GPU stats to tighten the tolerances + float *ref_mu_ptr, *ref_rsigma_ptr; + if (is_training){ + mu.to_cpu(); + rsigma.to_cpu(); + ref_mu_ptr = mu.rowwise_cpu_dptr(); + ref_rsigma_ptr = rsigma.rowwise_cpu_dptr(); + } else { + ref_mu_ptr = ref_mu.get(); + ref_rsigma_ptr = ref_rsigma.get(); + } + compute_ref_output(norm_type, input.rowwise_cpu_dptr(), + gamma.rowwise_cpu_dptr(), + beta.rowwise_cpu_dptr(), + ref_mu_ptr, + ref_rsigma_ptr, + N, H, + ref_output.get(), + zero_centered_gamma); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); + rtol_stats = 5e-5; + if (is_training){ + compareResults("mu", mu, ref_mu.get(), true, atol_stats, rtol_stats); + compareResults("rsigma", rsigma, ref_rsigma.get(), true, atol_stats, rtol_stats); + } + + float atol, rtol; + if (otype == DType::kFloat8E5M2){ + atol = 1.25e-1; + rtol = 1.25e-1; + } else if (otype == DType::kFloat8E4M3){ + if (itype == DType::kBFloat16){ + atol = 7e-2; + rtol = 7e-2; + } else { + atol = 6.25e-2; + rtol = 6.25e-2; + } + } + compareResults("output_rowwise", dequantized_output, ref_output.get(), true, atol, rtol, false); + if (is_training) + compareResults("output_colwise", dequantized_output, ref_output.get(), false, atol, rtol, false); +} + +std::vector> test_cases = { + {32, 32}, + {768, 2304}, + {2048, 12288}, +}; + +std::vector norms = { + NormType::LayerNorm, + NormType::RMSNorm +}; + +} // namespace + +class MxNormTestSuite : public ::testing::TestWithParam< std::tuple, + bool, bool>> {}; + +TEST_P(MxNormTestSuite, TestMxNorm) { + using namespace transformer_engine; + using namespace test; + + const NormType norm_type = std::get<0>(GetParam()); + const DType input_type = std::get<1>(GetParam()); + const DType output_type = std::get<2>(GetParam()); + const auto size = std::get<3>(GetParam()); + const bool zero_centered_gamma = std::get<4>(GetParam()); + const bool is_training = std::get<5>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + performTest(size.first, size.second, zero_centered_gamma, norm_type, is_training); + ); + ); +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + MxNormTestSuite, + ::testing::Combine( + ::testing::Values(NormType::LayerNorm, NormType::RMSNorm), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), + ::testing::ValuesIn(test_cases), + ::testing::Values(true, false), + ::testing::Values(true, false)), + [](const testing::TestParamInfo& info) { + std::string name = normToString.at(std::get<0>(info.param)) + "_" + + test::typeName(std::get<1>(info.param)) + "X" + + test::typeName(std::get<2>(info.param)) + "X" + + std::to_string(std::get<3>(info.param).first) + "X" + + std::to_string(std::get<3>(info.param).second) + "X" + + std::to_string(std::get<4>(info.param)) + "out" + + std::to_string(int(std::get<5>(info.param)) + 1) + "x"; + return name; + }); diff --git a/tests/cpp/operator/test_qdq.cu b/tests/cpp/operator/test_qdq.cu index 565e3986e6..3c12cef865 100644 --- a/tests/cpp/operator/test_qdq.cu +++ b/tests/cpp/operator/test_qdq.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -58,18 +58,18 @@ void performTestQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); fillUniform(&input); setRandomScale(&output); - nvte_fp8_quantize(input.data(), output.data(), 0); + nvte_quantize(input.data(), output.data(), 0); float ref_amax; - compute_ref_q(input.cpu_dptr(), ref_output.get(), + compute_ref_q(input.rowwise_cpu_dptr(), ref_output.get(), N, &ref_amax, output.scale()); cudaDeviceSynchronize(); @@ -79,7 +79,7 @@ void performTestQ(const size_t N) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); auto [atol, rtol] = getTolerances(otype); - compareResults("output_q", output, ref_output.get(), atol, rtol); + compareResults("output_q", output, ref_output.get(), true, atol, rtol); } template @@ -89,24 +89,24 @@ void performTestDQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); fillUniform(&input); - nvte_fp8_dequantize(input.data(), output.data(), 0); + nvte_dequantize(input.data(), output.data(), 0); - compute_ref_dq(input.cpu_dptr(), ref_output.get(), - N, input.scale_inv()); + compute_ref_dq(input.rowwise_cpu_dptr(), ref_output.get(), + N, input.rowwise_scale_inv()); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); auto [atol, rtol] = getTolerances(otype); - compareResults("output_dq", output, ref_output.get(), atol, rtol); + compareResults("output_dq", output, ref_output.get(), true, atol, rtol); } std::vector qdq_test_cases = {2048* 12288, diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu deleted file mode 100644 index 0ec3a877e5..0000000000 --- a/tests/cpp/operator/test_rmsnorm.cu +++ /dev/null @@ -1,249 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include "../test_common.h" - -using namespace transformer_engine; -using namespace test; - -namespace { - -template -void compute_ref_stats(const InputType *data, float *rsigma, const size_t N, const size_t H, - const double epsilon) { - using compute_t = float; - for (size_t i = 0; i < N; ++i) { - compute_t sum = 0; - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - sum += (current) * (current); - } - sum = sum / H; - compute_t rs = rsqrtf(sum + epsilon); - rsigma[i] = rs; - } -} - -template -void compute_ref_output(const InputType *data, const InputType *gamma, OutputType *output, - const float *rsigma, const size_t N, const size_t H, float *amax, - float scale, const bool zero_centered_gamma) { - using compute_t = float; - compute_t current_max = -1e100; - for (size_t i = 0; i < N; ++i) { - for (size_t j = 0; j < H; ++j) { - compute_t current = static_cast(data[i * H + j]); - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - compute_t tmp = current * rsigma[i] * g; - output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); - } - } - *amax = current_max; -} - -template -void compute_ref_backward(const OutputType *output_grad, const InputType *data, const float *rsigma, - const InputType *gamma, InputType *data_grad, InputType *gamma_grad, - const size_t N, const size_t H, const bool zero_centered_gamma) { - using compute_t = float; - std::vector dgamma(H, 0.f); - - for (size_t i = 0; i < N; ++i) { - // Reductions - compute_t mdyy = 0; - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = x * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - dgamma[j] += y * dz; - mdyy += dy * y; - } - mdyy /= H; - - // Input grads - for (size_t j = 0; j < H; ++j) { - const compute_t x = static_cast(data[i * H + j]); - const compute_t y = x * rsigma[i]; - compute_t g = static_cast(gamma[j]); - if (zero_centered_gamma) { - g += 1; - } - const compute_t dz = static_cast(output_grad[i * H + j]); - const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y); - data_grad[i * H + j] = static_cast(dx); - } - } - - // Weight grads - for (size_t j = 0; j < H; ++j) { - gamma_grad[j] = static_cast(dgamma[j]); - } -} - -template -void performTest(const size_t N, const size_t H, const bool zero_centered_gamma) { - if (sizeof(InputType) < sizeof(OutputType)) { - GTEST_SKIP() << "RMSNorm kernel does not support OutputType > InputType"; - return; - } - using WeightType = InputType; - DType itype = TypeInfo::dtype; - DType wtype = TypeInfo::dtype; - DType otype = TypeInfo::dtype; - - if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || - (itype == DType::kFloat16 && otype == DType::kBFloat16)) { - GTEST_SKIP() << "RMSNorm kernel does not support mixing Float16 and BFloat16"; - return; - } - - Tensor input({N, H}, itype); - Tensor z({N, H}, otype); - Tensor gamma({H}, wtype); - Tensor rsigma({N}, DType::kFloat32); - Tensor dz({N, H}, wtype); - Tensor dx({N, H}, itype); - Tensor dgamma({H}, wtype); - Tensor workspace, barrier, dgamma_part; - - fillUniform(&input); - fillUniform(&gamma); - fillUniform(&dz); - setRandomScale(&z); - - std::unique_ptr ref_output = std::make_unique(N * H); - std::unique_ptr ref_rsigma = std::make_unique(N); - std::unique_ptr ref_dx = std::make_unique(N * H); - std::unique_ptr ref_dgamma = std::make_unique(H); - - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - - // Forward kernel - float epsilon = 1e-5; - auto fwd_function = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; - fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, - prop.multiProcessorCount, workspace.data(), barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - fwd_function(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), 0, - prop.multiProcessorCount, workspace.data(), barrier.data()); - - // Backward kernel - auto bwd_function = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; - bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), - dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), - barrier.data()); - workspace = Tensor(workspace.shape(), workspace.dtype()); - barrier = Tensor(barrier.shape(), barrier.dtype()); - dgamma_part = Tensor(dgamma_part.shape(), dgamma_part.dtype()); - bwd_function(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), - dgamma_part.data(), 0, prop.multiProcessorCount, workspace.data(), - barrier.data()); - - // Reference implementations - // use the GPU stats to tighten the tolerances - rsigma.to_cpu(); - float ref_amax; - compute_ref_stats(input.cpu_dptr(), ref_rsigma.get(), N, H, epsilon); - float ref_scale = isFp8Type(otype) ? z.scale() : 1.f; - compute_ref_output(input.cpu_dptr(), gamma.cpu_dptr(), ref_output.get(), - rsigma.cpu_dptr(), N, H, &ref_amax, ref_scale, - zero_centered_gamma); - compute_ref_backward(dz.cpu_dptr(), input.cpu_dptr(), - rsigma.cpu_dptr(), gamma.cpu_dptr(), ref_dx.get(), - ref_dgamma.get(), N, H, zero_centered_gamma); - - cudaDeviceSynchronize(); - auto err = cudaGetLastError(); - ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - if (isFp8Type(otype)) { - compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / z.scale(); - compareResults("scale_inv", z.scale_inv(), ref_scale_inv, atol_amax, rtol_amax); - } - - auto [atol_stats, rtol_stats] = getTolerances(DType::kFloat32); - rtol_stats = 5e-5; - compareResults("rsigma", rsigma, ref_rsigma.get(), atol_stats, rtol_stats); - - auto [atol, rtol] = getTolerances(otype); - atol = 1e-8; - compareResults("output", z, ref_output.get(), atol, rtol); - - double atol_bwd = 5e-6; - double rtol_bwd = 1e-4; - compareResults("dx", dx, ref_dx.get(), atol_bwd, rtol_bwd); - compareResults("dgamma", dgamma, ref_dgamma.get(), atol_bwd, rtol_bwd); -} - -std::vector> test_cases = { - {2048, 4096}, {768, 2048}, {256, 1024}, {128, 768}, {64, 512}, {173, 409}, // Primes 40, 80 - {71, 3571}, // Primes 20, 500 - {29, 17389}}; // Primes 10, 2000 - -} // namespace - -class RMSNormTestSuite : public ::testing::TestWithParam, - bool>> {}; - -TEST_P(RMSNormTestSuite, TestRMSNorm) { - using namespace transformer_engine; - using namespace test; - - const DType input_type = std::get<0>(GetParam()); - const DType output_type = std::get<1>(GetParam()); - const auto size = std::get<2>(GetParam()); - const bool zero_centered_gamma = std::get<3>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma););); -} - -INSTANTIATE_TEST_SUITE_P(OperatorTest, RMSNormTestSuite, - ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, - DType::kFloat16), - ::testing::Values(DType::kFloat32, DType::kBFloat16, - DType::kFloat16, DType::kFloat8E4M3), - ::testing::ValuesIn(test_cases), - ::testing::Values(false, true)), - [](const testing::TestParamInfo &info) { - std::string name = - test::typeName(std::get<0>(info.param)) + "X" + - test::typeName(std::get<1>(info.param)) + "X" + - std::to_string(std::get<2>(info.param).first) + "X" + - std::to_string(std::get<2>(info.param).second) + "X" + - std::to_string(std::get<3>(info.param)); - return name; - }); diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu new file mode 100644 index 0000000000..f6e0da057a --- /dev/null +++ b/tests/cpp/operator/test_swizzle.cu @@ -0,0 +1,165 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; + +constexpr int MAT_TILE_DIM_M = 128; +constexpr int MAT_TILE_DIM_K = 128; + +template +void compute_ref_swizzle(const uint8_t *h_input, uint8_t *h_output, + const size_t M, const size_t K) { + + constexpr int NEW_SF_TILE_DIM_M = SF_TILE_DIM_M / 4; + constexpr int NEW_SF_TILE_DIM_K = SF_TILE_DIM_K * 4; + constexpr int SF_TILE_SIZE = SF_TILE_DIM_M * SF_TILE_DIM_K; + + for (int m = 0; m < M; m++) { + for (int k = 0; k < K; k++) { + + int tile_id_m = m / SF_TILE_DIM_M; + int tile_id_k = k / SF_TILE_DIM_K; + int m_in_tile = m % SF_TILE_DIM_M; + int k_in_tile = k % SF_TILE_DIM_K; + + int row_in_new_tile = m_in_tile % NEW_SF_TILE_DIM_M; + int col_in_new_tile = m_in_tile / NEW_SF_TILE_DIM_M * SF_TILE_DIM_K + k_in_tile; + + int tile_output_ptr = tile_id_m * SF_TILE_DIM_M * K + tile_id_k * SF_TILE_SIZE; + int out_index = tile_output_ptr + row_in_new_tile * NEW_SF_TILE_DIM_K + col_in_new_tile; + if constexpr(row_scaling) + h_output[out_index] = h_input[k + m * K]; + else + h_output[out_index] = h_input[k * M + m]; + } + } +} + +void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool rowwise, bool columnwise, const bool transa) { + using namespace test; + + int SF_MODE_X, SF_MODE_Y; + if (rowwise) { + SF_MODE_X = 1; + SF_MODE_Y = 32; + } + if (columnwise) { + SF_MODE_X = 32; + SF_MODE_Y = 1; + } + + if ((rowwise && columnwise) || !(rowwise || columnwise)){ + GTEST_SKIP() << "TEST SKIPPED, The scaling mode " + std::to_string(SF_MODE_X) + "x" + + std::to_string(SF_MODE_Y) + "is not implemented."; + } + + DType dtype = DType::kFloat8E4M3; + + const size_t M = num_tiles_M * MAT_TILE_DIM_M; + const size_t K = num_tiles_K * MAT_TILE_DIM_K; + const auto data_shape = transa ? std::vector{M, K} : std::vector{K, M}; + + const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; + + std::vector scaling_mode = {SF_MODE_X, SF_MODE_Y, 0}; + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + + fillUniform(&input); + + std::unique_ptr ref_output = std::make_unique(scale_shape[0] * scale_shape[1]); + + nvte_swizzle_scaling_factors(input.data(), output.data(), 0); + + if (rowwise) + compute_ref_swizzle<128, 4, true>(input.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0], scale_shape[1]); + else + compute_ref_swizzle<128, 4, false>(input.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[1], scale_shape[0]); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + output.to_cpu(); + if (rowwise) { + compareResults("output_swizzle", output.rowwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } else { + compareResults("output_swizzle", output.columnwise_cpu_scale_inv_ptr(), ref_output.get(), scale_shape[0] * scale_shape[1]); + } +} + +class SwizzleTestSuite : public ::testing::TestWithParam, std::pair, bool>> {}; + + +TEST_P(SwizzleTestSuite, TestSwizzle) { + using namespace transformer_engine; + using namespace test; + + const auto num_tiles = std::get<0>(GetParam()); + const auto scaling_mode = std::get<1>(GetParam()); + const auto transa = std::get<2>(GetParam()); + + performTestSwizzle1D(num_tiles.first, num_tiles.second, + scaling_mode.first, scaling_mode.second, + transa); +} + +namespace { + +std::vector> num_tiles = { + {1, 1}, + {1, 132}, + {132, 1}, + {65, 256}, + {65, 257}, + {65, 258}, + {65, 259}, +}; + +std::vector> scaling_mode = { + {true, false}, + {false, true} +}; + +std::vector transa = {true, false}; + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwizzleTestSuite, + ::testing::Combine( + ::testing::ValuesIn(num_tiles), + ::testing::ValuesIn(scaling_mode), + ::testing::ValuesIn(transa) + ), + [](const testing::TestParamInfo& info) { + std::string name = "ntiles" + + std::to_string(std::get<0>(info.param).first) + "X" + + std::to_string(std::get<0>(info.param).second) + "smode" + + std::to_string(std::get<1>(info.param).first) + "X"+ + std::to_string(std::get<1>(info.param).second) + "trans" + + std::to_string(std::get<2>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_transpose.cu b/tests/cpp/operator/test_transpose.cu index 844f6801f1..00dd241c92 100644 --- a/tests/cpp/operator/test_transpose.cu +++ b/tests/cpp/operator/test_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) { DType dtype = TypeInfo::dtype; - Tensor input({ N, H }, dtype); - Tensor output({ H, N }, dtype); + Tensor input("input", { N, H }, dtype); + Tensor output("output", { H, N }, dtype); std::unique_ptr ref_output = std::make_unique(N * H); @@ -46,13 +46,13 @@ void performTest(const size_t N, const size_t H) { nvte_transpose(input.data(), output.data(), 0); - compute_ref(input.cpu_dptr(), ref_output.get(), N, H); + compute_ref(input.rowwise_cpu_dptr(), ref_output.get(), N, H); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); auto [atol, rtol] = getTolerances(dtype); - compareResults("output", output, ref_output.get(), atol, rtol); + compareResults("output", output, ref_output.get(), true, atol, rtol); } std::vector> test_cases = {{2048, 12288}, diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index b90ea183cb..8565e5d5c6 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,14 +10,24 @@ #include #include #include +#include +#include +#include #include +#include #include #include "util/logging.h" namespace test { +size_t create_seed_from_tensor_name(const std::string& tensor_name) { + auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) + + "/" + tensor_name; + return std::hash{}(full_name); +} + std::vector all_fp_types = {DType::kFloat32, DType::kFloat16, DType::kBFloat16, @@ -50,102 +60,386 @@ const std::string &typeName(DType type) { {DType::kFloat16, "float16"}, {DType::kBFloat16, "bfloat16"}, {DType::kFloat8E4M3, "float8e4m3"}, - {DType::kFloat8E5M2, "float8e5m2"}}; + {DType::kFloat8E5M2, "float8e5m2"}, + {DType::kFloat8E8M0, "float8e8m0"}}; return name_map.at(type); } -size_t product(const NVTEShape &shape) { +const std::string& caseName(InputsFillCase type) { + static const std::unordered_map name_map = { + {InputsFillCase::uniform, "uniform"}, + {InputsFillCase::zeros, "zeros"}, + {InputsFillCase::zero_to_minNorm, "zero_to_minNorm"}, + {InputsFillCase::minNorm_to_maxNorm, "minNorm_to_maxNorm"}, + {InputsFillCase::maxNorm_to_inf, "maxNorm_to_inf"}}; + return name_map.at(type); +} + +size_t product(const NVTEShape &shape, size_t begin, size_t end) { size_t ret = 1; - for (size_t i = 0; i < shape.ndim; ++i) { + NVTE_CHECK(end <= shape.ndim); + for (size_t i = begin; i < end; ++i) { ret *= shape.data[i]; } return ret; } -Tensor::Tensor(const NVTEShape &shape, const DType type) { - size_t s = typeToSize(type); - size_t total_size = product(shape) * s; - void *dptr = nullptr; - cpu_data_ = nullptr; - amax_cpu_data_ = nullptr; - scale_cpu_data_ = nullptr; - scale_inv_cpu_data_ = nullptr; - float *amax = nullptr, *scale = nullptr, *scale_inv = nullptr; - if (total_size != 0) { - cudaMalloc((void**)&dptr, total_size); // NOLINT(*) - cudaMemset(dptr, 0, total_size); - cpu_data_ = std::make_unique(total_size); - for (size_t i = 0; i < total_size; ++i) { - cpu_data_[i] = 0; - } +size_t product(const NVTEShape &shape) { + return product(shape, 0, shape.ndim); +} + +size_t product(const std::vector shape, size_t begin, size_t end) { + size_t ret = 1; + NVTE_CHECK(end <= shape.size()); + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} + +size_t product(const std::vector& shape) { + return product(shape, 0, shape.size()); +} + +size_t DIVUP(const size_t &x, const size_t &y){ + return (((x) + ((y)-1)) / (y)); +} + +struct scale_inv_meta { + std::vector shape; + DType type; + size_t type_size; +}; + +NVTEShape convertShape(const std::vector& shape) { + return {shape.data(), shape.size()}; +} + +std::pair get_scales(const NVTEShape& shape, + const NVTEScalingMode scaling_mode) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + scale_inv_meta ret; + ret.shape = {1}; + ret.type = DType::kFloat32; + ret.type_size = sizeof(float); + return {ret, ret}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + auto block_alignment = std::vector{128ul,4ul}; + { + auto alignment = block_alignment[0]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, + static_cast(1)), + alignment) * alignment; + alignment = block_alignment[1]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, + static_cast(32)), + alignment) * alignment; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto alignment = block_alignment[1]; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, + static_cast(32)), + alignment) * alignment; + alignment = block_alignment[0]; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, + static_cast(1)), + alignment) * alignment; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat8E8M0; + ret_colwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size = sizeof(uint8_t); + ret_colwise.type_size = sizeof(uint8_t); + + return {ret_rowwise, ret_colwise}; + } + + NVTE_ERROR("Invalid scaling mode!"); +} + +Tensor::Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise, const bool columnwise, + const NVTEScalingMode &scaling_mode) { + name_ = name; + const size_t seed = create_seed_from_tensor_name(name); + gen_.seed(seed); + rowwise_ = rowwise; + columnwise_ = columnwise; + size_t s = typeToSize(type); + size_t total_size = product(shape) * s; + void *dptr_rowwise = nullptr; + void *dptr_columnwise = nullptr; + cpu_data_rowwise_ = nullptr; + cpu_data_columnwise_ = nullptr; + amax_cpu_data_ = nullptr; + scale_cpu_data_ = nullptr; + rowwise_scale_inv_cpu_data_ = nullptr; + columnwise_scale_inv_cpu_data_ = nullptr; + float *amax = nullptr, *scale = nullptr; + float *rowwise_scale_inv = nullptr, *columnwise_scale_inv = nullptr; + if (columnwise) { + NVTE_CHECK(shape.ndim >= 2); + } + std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), + shape.data[shape.ndim - 1]}; + NVTEShape normalized_shape = convertShape(normalized_shape_v); + NVTEShape columnwise_shape{nullptr, 0}; + + std::vector columnwise_shape_vec; + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + // Transpose when tensor scaling + columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); + for (size_t i = 0; i < shape.ndim - 1; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); } - if (isFp8Type(type)) { + } else { + // Same shape for MX + for (size_t i = 0; i < shape.ndim; ++i) { + columnwise_shape_vec.emplace_back(shape.data[i]); + } + } + + if (columnwise) { + columnwise_shape.data = columnwise_shape_vec.data(); + columnwise_shape.ndim = columnwise_shape_vec.size(); + } + + tensor_ = TensorWrapper(scaling_mode); + + if (total_size != 0) { + if (rowwise) { + cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) + cudaMemset(dptr_rowwise, 0, total_size); + cpu_data_rowwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_rowwise_.get(), total_size, 0); + } + if (columnwise) { + cudaMalloc((void**)&dptr_columnwise, total_size); // NOLINT(*) + cudaMemset(dptr_columnwise, 0, total_size); + cpu_data_columnwise_ = std::make_unique(total_size); + std::fill_n(cpu_data_columnwise_.get(), total_size, 0); + } + } + tensor_.set_rowwise_data(dptr_rowwise, type, shape); + tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); + + if (isFp8Type(type)) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMemset(amax, 0, sizeof(float)); cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) cudaMemset(scale, 0, sizeof(float)); - cudaMalloc((void**)&scale_inv, sizeof(float)); // NOLINT(*) - cudaMemset(scale_inv, 0, sizeof(float)); - amax_cpu_data_ = std::make_shared(); - *amax_cpu_data_ = 0; - scale_cpu_data_ = std::make_shared(); - *scale_cpu_data_ = 0; - scale_inv_cpu_data_ = std::make_shared(); - *scale_inv_cpu_data_ = 0; + amax_cpu_data_ = std::make_shared(0); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + cudaMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) + if (rowwise) { + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + if (columnwise) { + tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, + std::vector{1}); + columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + } + } else { + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, + tensor_.scaling_mode()); + auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + auto scale_shape = rowwise_scale_meta.shape; + auto columnwise_scale_shape = colwise_scale_meta.shape; + if (rowwise) { + cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); + rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); + std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); + } + if (columnwise) { + cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) + cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); + columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); + std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); + } } - tensor_ = TensorWrapper(dptr, shape, type, amax, scale, scale_inv); + } } void Tensor::to_cpu() const { const NVTEShape s = tensor_.shape(); const size_t size = product(s) * typeToSize(tensor_.dtype()); - cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost); + if (rowwise_) { + cudaMemcpy(cpu_data_rowwise_.get(), + tensor_.get_rowwise_data().data_ptr, + size, + cudaMemcpyDeviceToHost); + } + if (columnwise_) { + cudaMemcpy(cpu_data_columnwise_.get(), + tensor_.get_columnwise_data().data_ptr, + size, + cudaMemcpyDeviceToHost); + } if (isFp8Type(dtype())) { - cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float), - cudaMemcpyDeviceToHost); - cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float), - cudaMemcpyDeviceToHost); - cudaMemcpy(scale_inv_cpu_data_.get(), tensor_.scale_inv(), sizeof(float), - cudaMemcpyDeviceToHost); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.amax() != nullptr){ + cudaMemcpy(amax_cpu_data_.get(), + tensor_.amax(), + sizeof(float), + cudaMemcpyDeviceToHost); + } + cudaMemcpy(scale_cpu_data_.get(), + tensor_.scale(), + sizeof(float), + cudaMemcpyDeviceToHost); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), + tensor_.get_rowwise_scale_inv().data_ptr, + scale_size, + cudaMemcpyDeviceToHost); + } + if (columnwise_) { + auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), + tensor_.get_columnwise_scale_inv().data_ptr, + scale_size, + cudaMemcpyDeviceToHost); + } } } void Tensor::from_cpu() const { const NVTEShape s = tensor_.shape(); const size_t size = product(s) * typeToSize(tensor_.dtype()); - cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice); + if (rowwise_) { + cudaMemcpy(tensor_.get_rowwise_data().data_ptr, + cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice); + } + if (columnwise_) { + cudaMemcpy(tensor_.get_columnwise_data().data_ptr, + cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); + } if (isFp8Type(dtype())) { - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); - cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); - cudaMemcpy(tensor_.scale_inv(), scale_inv_cpu_data_.get(), sizeof(float), - cudaMemcpyHostToDevice); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (tensor_.amax() != nullptr){ + cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + } + cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), + cudaMemcpyHostToDevice); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + if (rowwise_) { + auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; + cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, + rowwise_scale_inv_cpu_data_.get(), scale_size, + cudaMemcpyHostToDevice); + } + if (columnwise_) { + auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; + cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, + columnwise_scale_inv_cpu_data_.get(), scale_size, + cudaMemcpyHostToDevice); + } } } void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - *scale_cpu_data_ = scale; - from_cpu(); + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + *scale_cpu_data_ = scale; + from_cpu(); + } } } void Tensor::set_scale_inv(float scale_inv) { if (isFp8Type(dtype())) { - NVTE_CHECK(scale_inv_cpu_data_); - *scale_inv_cpu_data_ = scale_inv; + if (rowwise_) { + NVTE_CHECK(rowwise_scale_inv_cpu_data_); + } + if (columnwise_) { + NVTE_CHECK(columnwise_scale_inv_cpu_data_); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + if (rowwise_) { + auto num_scales = product(rowwise_scale_meta.shape); + if (num_scales == 1){ + rowwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else{ + std::uniform_int_distribution dis(0, 127); + auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++){ + scale_inv_ptr[i] = dis(gen_); + } + } + } + if (columnwise_) { + auto num_scales = product(colwise_scale_meta.shape); + if (num_scales == 1){ + columnwise_cpu_scale_inv_ptr()[0] = scale_inv; + } else{ + std::uniform_int_distribution dis(0, 127); + auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++){ + scale_inv_ptr[i] = dis(gen_); + } + } + } from_cpu(); } } void Tensor::shareFP8Meta(const Tensor &other) { if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { - tensor_ = TensorWrapper(dptr(), shape(), dtype(), - other.tensor_.amax(), - other.tensor_.scale(), - other.tensor_.scale_inv()); + auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); + auto my_rowwise_data = tensor_.get_rowwise_data(); + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, + static_cast(my_rowwise_data.dtype), + my_rowwise_data.shape); + auto my_columnwise_data = tensor_.get_columnwise_data(); + new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, + static_cast(my_columnwise_data.dtype), + my_columnwise_data.shape); + auto other_amax = other.tensor_.get_amax(); + new_tensor.set_amax(other_amax.data_ptr, + static_cast(other_amax.dtype), + other_amax.shape); + auto other_scale = other.tensor_.get_scale(); + new_tensor.set_scale(other_scale.data_ptr, + static_cast(other_scale.dtype), + other_scale.shape); + auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); + new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, + static_cast(other_row_scale_inv.dtype), + other_row_scale_inv.shape); + auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv(); + new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr, + static_cast(other_col_scale_inv.dtype), + other_col_scale_inv.shape); + tensor_ = std::move(new_tensor); to_cpu(); } } @@ -177,12 +471,14 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { return ret; } -void compareResults(const std::string &name, const Tensor &test, const void *ref, - double atol, double rtol) { - test.to_cpu(); - const size_t N = product(test.shape()); +void compareResults_sequential(const std::string &name, const Tensor &test, + const void *ref, const bool rowwise, + double atol, double rtol, bool if_on_gpus) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, - const T *test_data = test.cpu_dptr(); + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); for (size_t i = 0; i < N; ++i) { double t = static_cast(test_data[i]); @@ -200,14 +496,84 @@ void compareResults(const std::string &name, const Tensor &test, const void *ref const double cast_mean_m = static_cast(static_cast(mean_m)); assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); } - ASSERT_FALSE(assertion) << "Error in tensor " << name << std::endl - << "Mismatch at place " << to_string(unravel(i, test.shape())) + std::string direction = rowwise ? "rowwise" : "columnwise"; + ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) << " (" << std::to_string(i) << "): " << t << " vs " << r; + } + ); +} +template +static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, + const size_t N, const double atol, const double rtol) { + int first_mismatch_idx = N; + + bool is_mismatch_found = false; + #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ + reduction(min: first_mismatch_idx) proc_bind(spread) + for (size_t i = 0; i < N; ++i) { + if (is_mismatch_found) { // early escape of the omp thread + continue; + } + + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion && i < first_mismatch_idx) { + first_mismatch_idx = i; + is_mismatch_found = true; + } + } + return first_mismatch_idx; +} + +void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus) { + if (if_on_gpus) test.to_cpu(); + const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); + const size_t N = product(shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, + const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); + const T *ref_data = reinterpret_cast(ref); + + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol); + if (i != N) { + const double t = static_cast(test_data[i]); + const double r = static_cast(ref_data[i]); + std::string direction = rowwise ? "rowwise" : "columnwise"; + ASSERT_FALSE(true) << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; } ); } +void compareResults(const std::string &name, const Tensor &test, const void *ref, + const bool rowwise, double atol, double rtol, bool if_on_gpus) { + constexpr bool sequential = false; + if constexpr (sequential) { + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); + } else { + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); + } +} + void compareResults(const std::string &name, const float test, const float ref, double atol, double rtol) { double t = static_cast(test); @@ -218,6 +584,51 @@ void compareResults(const std::string &name, const float test, const float ref, } + +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol) { + size_t max_mismatches = std::ceil(N * mismatch_rate_tol); + size_t n_mismatches = 0; + std::vector mismatch_indices; + for (int i = 0; i < N; i++){ + bool mismatch = test[i] != ref[i]; + if (mismatch){ + n_mismatches++; + mismatch_indices.push_back(i); + } + if (n_mismatches > max_mismatches){ + std::cout << "Error in " << name << std::endl; + for (auto &index : mismatch_indices) + std::cout << "Mismatch at (" << index << "):" << static_cast(test[i]) << " vs " + << static_cast(ref[i]) << std::endl; + GTEST_FAIL() << n_mismatches << " mismatche(s) which is more than mismatch tol."; + } + } +} + +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride) +{ + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int idx = i * stride + j; + ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl + << "Mismatch: " << static_cast(test[idx]) << " vs " + << static_cast(ref[idx]) << " at index " << idx; + } + } +} + +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N) +{ + for (int i = 0; i < N; i++) { + ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl + << "Mismatch: " << static_cast(test[i]) << " vs " + << static_cast(ref[i]) << " at index " << i; + } +} + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: @@ -228,6 +639,7 @@ std::pair getTolerances(const DType type) { return {1e-5, 1e-2}; case DType::kFloat8E4M3: case DType::kFloat8E5M2: + case DType::kFloat8E8M0: return {1e-2, 1e-2}; default: NVTE_CHECK("Invalid type!"); @@ -235,29 +647,158 @@ std::pair getTolerances(const DType type) { return {0, 0}; } +template +void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { + #pragma omp parallel proc_bind(spread) + { + std::mt19937 gen_local = *gen; + gen_local.discard(omp_get_thread_num() * 599); + std::uniform_real_distribution<> dis(-2.0, 1.0); + #pragma omp for schedule(static) + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(dis(gen_local)); + } + } + gen->discard(size); +} + void fillUniform(Tensor *t) { - const size_t size = product(t->shape()); - static std::mt19937 gen(12345); + if (t->rowwise()) { + const size_t size = product(t->rowwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { + T *data = t->rowwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); + } + ); + } else { + const size_t size = product(t->columnwise_shape()); + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, + { + T *data = t->columnwise_cpu_dptr(); + generate_data_uniformly(data, size, &(t->gen())); + } + ); + } std::uniform_real_distribution<> dis(-2.0, 1.0); - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { - T *data = t->cpu_dptr(); + t->set_scale_inv(dis(t->gen())); + t->from_cpu(); +} + +template +void fillCase_special(Tensor *t) { + const size_t size = product(t->rowwise_shape()); + const size_t rows = t->rowwise_shape().data[0]; + const size_t cols = t->rowwise_shape().data[1]; + + if constexpr (Case == InputsFillCase::zeros) { + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { + InputType *data = t->rowwise_cpu_dptr(); for (size_t i = 0; i < size; ++i) { - data[i] = T(dis(gen)); + data[i] = static_cast(0); } - }); - t->set_scale_inv(dis(gen)); + }); + } else { + double minAbs = -2.0; + double maxAbs = 1.0; + if constexpr (Case != InputsFillCase::uniform) { + minAbs = Quantized_Limits::ranges[Case]; + maxAbs = Quantized_Limits::ranges[Case + 1]; + } + std::uniform_real_distribution<> dis(minAbs, maxAbs); + std::uniform_real_distribution<> dis_sign(-1.0, 1.0); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { + InputType *data = t->rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const bool is_negative = (dis_sign(t->gen()) < 0.0); + double val = dis(t->gen()); + if (is_negative) { + val = -val; + } + data[idx] = static_cast(val); + } + } + }); + } + t->set_scale_inv(1.0); t->from_cpu(); } +template +void fillCase(Tensor *t, const InputsFillCase fill_case) { + switch (fill_case) { + case InputsFillCase::uniform: + fillCase_special(t); break; + case InputsFillCase::zeros: + fillCase_special(t); break; + case InputsFillCase::zero_to_minNorm: + fillCase_special(t); break; + case InputsFillCase::minNorm_to_maxNorm: + fillCase_special(t); break; + case InputsFillCase::maxNorm_to_inf: + fillCase_special(t); break; + } +} + +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); + void setRandomScale(Tensor *t) { - static std::mt19937 gen(12345); std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale = dis(gen); + const float scale = dis(t->gen()); t->set_scale(scale); } +void setRandomScaleInv(Tensor *t) { + std::uniform_real_distribution<> dis(-2.0, 1.0); + const float scale_inv = dis(t->gen()); + t->set_scale_inv(scale_inv); +} + bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; +} + +int32_t getDeviceComputeCapability() +{ + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; +} + +size_t first_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + if (shape.size() == 1) return 1; + return product(shape, 0, shape.size() - 1); +} + +size_t last_dimension(const std::vector &shape) { + if (shape.size() == 0) return 1; + return shape[shape.size() - 1]; +} + +std::array get_scale_tensor_dims(const size_t rows, + const size_t cols, + const size_t block_size_rows, + const size_t block_size_cols) { + const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + + const size_t alignment_Y = is_rowwise + ? scale_tensor_alignment_Y_rowwise + : scale_tensor_alignment_Y_colwise; + const size_t alignment_X = is_rowwise + ? scale_tensor_alignment_X_rowwise + : scale_tensor_alignment_X_colwise; + + const size_t unpadded_blocks_Y = divide_round_up(rows, block_size_rows); + const size_t unpadded_blocks_X = divide_round_up(cols, block_size_cols); + + const size_t blocks_Y = round_up_to_nearest_multiple(unpadded_blocks_Y, alignment_Y); + const size_t blocks_X = round_up_to_nearest_multiple(unpadded_blocks_X, alignment_X); + return {unpadded_blocks_Y, unpadded_blocks_X, blocks_Y, blocks_X}; } } // namespace test diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index a6181256d9..4352056ddb 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -1,14 +1,15 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #pragma once -#include #include #include +#include +#include #include #include @@ -52,6 +53,7 @@ using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +using fp8e8m0 = uint8_t; template struct TypeInfo{ @@ -62,7 +64,8 @@ struct TypeInfo{ fp16, bf16, fp8e4m3, - fp8e5m2>; + fp8e5m2, + fp8e8m0>; template struct Helper { @@ -94,10 +97,19 @@ struct TypeInfo{ class Tensor { public: - Tensor(const NVTEShape &shape, const DType type); - - Tensor(const std::vector &shape, const DType type) : - Tensor(NVTEShape{shape.data(), shape.size()}, type) {} + Tensor(const std::string& name, + const NVTEShape &shape, const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + + Tensor(const std::string& name, + const std::vector &shape, + const DType type, + const bool rowwise = true, + const bool columnwise = false, + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} @@ -108,30 +120,82 @@ class Tensor { Tensor& operator=(Tensor &&other) = default; ~Tensor() { - if (tensor_.dptr() != nullptr) { - cudaFree(tensor_.dptr()); + void *data_ptr = tensor_.dptr(); + void *scale_inv = tensor_.scale_inv(); + void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr; + void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr; + if (columnwise_data_ptr == data_ptr) { + columnwise_data_ptr = nullptr; + } + if (columnwise_scale_inv == scale_inv) { + columnwise_scale_inv = nullptr; + } + if (data_ptr != nullptr) { + cudaFree(data_ptr); + } + if (scale_inv != nullptr) { + cudaFree(scale_inv); + } + if (columnwise_data_ptr != nullptr){ + cudaFree(columnwise_data_ptr); + } + if (columnwise_scale_inv != nullptr){ + cudaFree(columnwise_scale_inv); } } + NVTETensor data() const noexcept { return tensor_.data(); } - const NVTEShape shape() const noexcept { - return tensor_.shape(); + NVTEShape rowwise_shape() const noexcept { + return tensor_.get_rowwise_data().shape; + } + + NVTEShape columnwise_shape() const noexcept { + return tensor_.get_columnwise_data().shape; + } + + NVTEShape rowwise_scale_inv_shape() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_scale_inv().shape; + } + + NVTEShape columnwise_scale_inv_shape() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_scale_inv().shape; + } + + NVTEScalingMode scaling_mode() const noexcept { + return tensor_.scaling_mode(); } DType dtype() const noexcept { return tensor_.dtype(); } - void *dptr() const noexcept { - return tensor_.dptr(); + void *rowwise_dptr() const { + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return tensor_.get_rowwise_data().data_ptr; + } + + void *columnwise_dptr() const { + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return tensor_.get_columnwise_data().data_ptr; } template - T *cpu_dptr() const { + T *rowwise_cpu_dptr() const { NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); - return reinterpret_cast(cpu_data_.get()); + NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); + return reinterpret_cast(cpu_data_rowwise_.get()); + } + + template + T *columnwise_cpu_dptr() const { + NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); + return reinterpret_cast(cpu_data_columnwise_.get()); } float amax() const { @@ -145,6 +209,7 @@ class Tensor { float scale() const { if(scale_cpu_data_) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -152,52 +217,250 @@ class Tensor { } } - float scale_inv() const { - if(scale_inv_cpu_data_) { - to_cpu(); - return *scale_inv_cpu_data_; + template + T *rowwise_cpu_scale_inv_ptr(){ + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(rowwise_scale_inv_cpu_data_.get()); + } + + template + T *columnwise_cpu_scale_inv_ptr(){ + if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else { + NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); + } + to_cpu(); + return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); + } + + float rowwise_scale_inv(){ + if(rowwise_scale_inv_cpu_data_) { + float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; + return scale_inv; } else { return 1; } } + bool rowwise() const { + return rowwise_; + } + + bool columnwise() const { + return columnwise_; + } + + void set_tensor_amax_nullptr(){ + tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); void set_scale_inv(float scale_inv); void shareFP8Meta(const Tensor &other); + std::mt19937& gen() { return gen_; } + private: TensorWrapper tensor_; - std::unique_ptr cpu_data_; + std::unique_ptr cpu_data_rowwise_; + std::unique_ptr cpu_data_columnwise_; std::shared_ptr amax_cpu_data_; std::shared_ptr scale_cpu_data_; - std::shared_ptr scale_inv_cpu_data_; + std::unique_ptr rowwise_scale_inv_cpu_data_; + std::unique_ptr columnwise_scale_inv_cpu_data_; + bool rowwise_; + bool columnwise_; + std::string name_; + std::mt19937 gen_; +}; + +constexpr uint32_t FP32_EXPONENT_BIAS = 127; +constexpr uint32_t FP32_MANTISSA_BITS = 23; + +// [128,4] rowwise and [4,128] colwise alignment requirement +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + +inline size_t divide_round_up(const size_t N, const size_t M) { + return (N - 1 + M) / M; +} + +inline size_t round_up_to_nearest_multiple(const size_t N, const size_t M) { + return divide_round_up(N, M) * M; +} + +template +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0; + static constexpr double maxSubnorm = 1.0; + static constexpr double minNorm = 1.0; + static constexpr double maxNorm = 1.0; + static constexpr double artifInf = 1.0; + static constexpr int maxBiasedExponent = 1; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 9); // std::pow(2.0, -9.0); + static constexpr double maxSubnorm = 0.875 / static_cast(1 << 6); // std::pow(2.0, -6.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 6); // std::pow(2.0, -6.0); + static constexpr double maxNorm = 448.0; + static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 8; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = 1.0 / static_cast(1 << 16); // std::pow(2.0, -16.0); + static constexpr double maxSubnorm = 0.75 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double minNorm = 1.0 / static_cast(1 << 14); // std::pow(2.0, -14.0); + static constexpr double maxNorm = 57344.0; + static constexpr double artifInf = 10.0 * maxNorm; // artificial Infinity + static constexpr int maxBiasedExponentAsFP32 = 15 + FP32_EXPONENT_BIAS; + static constexpr int maxUnbiasedExponentAsFP32 = 15; + static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32; +}; + +template <> +struct Numeric_Traits { + static constexpr double minSubnorm = std::numeric_limits::denorm_min(); // std::pow(2.0, -149.0); + static constexpr double maxSubnorm = std::numeric_limits::min() + - std::numeric_limits::denorm_min(); // minNormalized - minDenormalized + static constexpr double minNorm = std::numeric_limits::min(); // std::pow(2.0, -126.0); + static constexpr double maxNorm = std::numeric_limits::max(); // (1 - pow(2, -24)) * pow(2, 128) + static constexpr double artifInf = std::numeric_limits::infinity(); + static constexpr int maxBiasedExponentAsFP32 = 255; + static constexpr int maxUnbiasedExponentAsFP32 = 128; +}; + +template +struct Quantized_Limits { + static constexpr double ranges[] = { + 0.0, + Numeric_Traits::minNorm, + Numeric_Traits::maxNorm, + Numeric_Traits::artifInf + }; + static constexpr inline fp32 max() { return static_cast(Numeric_Traits::maxNorm); } + static constexpr inline fp32 max_reciprocal() { return static_cast(1.0 / max()); } + static constexpr inline fp32 emax() { return static_cast(Numeric_Traits::maxExpNorm); } + static constexpr inline fp32 emax_reciprocal() { return static_cast(1.0 / emax()); } + static constexpr inline int max_norm_biased_exponent() { return Numeric_Traits::maxBiasedExponentAsFP32; } + static constexpr inline int max_norm_unbiased_exponent() { return Numeric_Traits::maxUnbiasedExponentAsFP32; } +}; + +// Input data filling cases +// Considering normal and subnormal magnitudes of E4M3 and E5M2 formats +// with nearest to even rounding per OFP8 specification +enum InputsFillCase { + zero_to_minNorm = 0, // [0, min_normal) + minNorm_to_maxNorm = 1, // [min_normal, max_normal) + maxNorm_to_inf = 2, // [max_normal, inf) + zeros = 3, // {0} + uniform = 4, // std::uniform_real_distribution<> dis(-2.0, 1.0) }; +inline fp8e8m0 float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (std::isnan(val)) { + return 0xFF; + } + if (std::isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + fp8e8m0 exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +} + +inline float exp2f_rcp(fp8e8m0 biased_exp) { + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + +inline float identity(const float x) { return x; } +inline float gelu(const float x) { return x * (0.5f + 0.5f * tanhf(x * (0.79788456f + 0.03567741f * x * x))); } +inline float dgelu(const float x) { + const float tanh_out = tanhf(0.79788456f * x * (1 + 0.044715f * x * x)); + return 0.5f * x * ((1 - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * x * x)) + + 0.5f * (1 + tanh_out); +} +inline float sigmoid(const float x) { return 1 / (1 + expf(-x)); } +inline float dsigmoid(const float x) { return sigmoid(x) * (1 - sigmoid(x)); } +inline float qgelu(const float x) { return x * sigmoid(1.702f * x); } +inline float dqgelu(const float x) { return 1.702f * x * dsigmoid(1.702f * x) + sigmoid(1.702f * x); } +inline float relu(const float x) { return fmaxf(0, x); } +inline float drelu(const float x) { return x > 0 ? 1 : 0; } +inline float silu(const float x) { return x * sigmoid(x); } +inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } +inline float srelu(const float x) { return x > 0 ? x * x : 0; } +inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } + size_t typeToSize(DType type); size_t product(const NVTEShape &shape); +size_t product(const std::vector &shape); + +size_t first_dimension(const std::vector &shape); +size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, - double atol = 1e-5, double rtol = 1e-8); + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); +void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, + size_t N, float mismatch_rate_tol = 0.); +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride); +void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t N); + +std::array get_scale_tensor_dims(const size_t rows, const size_t cols, + const size_t block_size_rows, const size_t block_size_cols); std::pair getTolerances(const DType type); void fillUniform(Tensor *t); + +template +void fillCase(Tensor *t, const InputsFillCase fill_case); + void setRandomScale(Tensor *t); +void setRandomScaleInv(Tensor *t); constexpr int THREADS_PER_WARP = 32; const std::string &typeName(DType type); +const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +int32_t getDeviceComputeCapability(); +constexpr int32_t blackwellComputeCapability = 100; + } // namespace test #define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \ @@ -254,3 +517,47 @@ bool isFp8Type(DType type); default: \ NVTE_ERROR("Invalid type."); \ } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat8E4M3: \ + { \ + using type = fp8e4m3; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat8E5M2: \ + { \ + using type = fp8e5m2; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: \ + { \ + using type = float; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kFloat16: \ + { \ + using type = fp16; \ + {__VA_ARGS__} \ + } \ + break; \ + case DType::kBFloat16: \ + { \ + using type = bf16; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } diff --git a/tests/cpp/util/CMakeLists.txt b/tests/cpp/util/CMakeLists.txt index d93be956b0..7540687089 100644 --- a/tests/cpp/util/CMakeLists.txt +++ b/tests/cpp/util/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,8 +8,9 @@ add_executable(test_util ../test_common.cu) -target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) -target_compile_options(test_util PRIVATE -O2) +find_package(OpenMP REQUIRED) +target_link_libraries(test_util PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX) +target_compile_options(test_util PRIVATE -O2 -fopenmp) include(GoogleTest) -gtest_discover_tests(test_util) +gtest_discover_tests(test_util DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/util/test_nvrtc.cpp b/tests/cpp/util/test_nvrtc.cpp index 03982deb73..e885140ce1 100644 --- a/tests/cpp/util/test_nvrtc.cpp +++ b/tests/cpp/util/test_nvrtc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/tests/cpp/util/test_string.cpp b/tests/cpp/util/test_string.cpp index 531994aff8..a2e8bc1410 100644 --- a/tests/cpp/util/test_string.cpp +++ b/tests/cpp/util/test_string.cpp @@ -1,10 +1,11 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include +#include #include @@ -57,6 +58,12 @@ TEST(UtilTest, ToStringLike) { // to_string_like EXPECT_EQ(std::stof(to_string_like(-2.5f)), -2.5f); EXPECT_EQ(std::stod(to_string_like(2.25)), 2.25); EXPECT_EQ(std::stod(to_string_like(-4.5)), -4.5); + + // Container types + EXPECT_EQ(to_string_like(std::vector{-3,1,-4}), "(-3,1,-4)"); + EXPECT_EQ(to_string_like(std::vector{"Accept", "no", "substitutes", ".", + "Buy", "N", "V", "IDIA"}), + "(Accept,no,substitutes,.,Buy,N,V,IDIA)"); } TEST(UtilTest, ConcatStringsTest) { // concat_strings @@ -88,6 +95,9 @@ TEST(UtilTest, ConcatStringsTest) { // concat_strings EXPECT_EQ(std::stof(concat_strings(6.5f)), 6.5f); EXPECT_EQ(std::stod(concat_strings("-", 4.25)), -4.25); EXPECT_EQ(std::stod(concat_strings(8.5)), 8.5); + + // Container types + EXPECT_EQ(concat_strings("vector ", std::vector{1,-2,3}), "vector (1,-2,3)"); } TEST(UtilTest, RegexReplaceTest) { // regex_replace diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index ccb6690a87..663a954184 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """conftest for tests/jax""" @@ -6,7 +6,9 @@ import jax import pytest -from transformer_engine.transformer_engine_jax import get_device_compute_capability + +import transformer_engine.jax +from transformer_engine_jax import get_device_compute_capability @pytest.fixture(autouse=True, scope="function") @@ -20,16 +22,13 @@ def clear_live_arrays(): @pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): +def enable_fused_attn_after_hopper(): """ Enable fused attn for hopper+ arch. Fused attn kernels on pre-hopper arch are not deterministic. """ if get_device_compute_capability(0) >= 90: os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" yield if "NVTE_FUSED_ATTN" in os.environ: del os.environ["NVTE_FUSED_ATTN"] - if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ: - del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index bbd54ecce5..d0ace8263f 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import operator @@ -18,14 +18,22 @@ def generate_configs(): configs = [] if is_devices_enough(2): - configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")]) - configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")]) + configs.append( + pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") + ) + configs.append( + pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2") + ) if is_devices_enough(4): - TP_size = 2 - DP_size = 2 configs.append( - [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")] + pytest.param( + 4, + (2, 2), + ("dp", "tp"), + MeshResource(dp_resource="dp", tp_resource="tp"), + id=f"n4_dp2_tp2", + ) ) return configs @@ -33,7 +41,8 @@ def generate_configs(): def generate_context_parallel_configs(): configs = [] - + mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") + axes = ("dp", "cp", "tp") DP_sizes = (1, 2) CP_sizes = (1, 2, 4, 8) TP_sizes = (1, 2) @@ -41,13 +50,7 @@ def generate_context_parallel_configs(): ndev = cp * tp * dp if is_devices_enough(ndev): configs.append( - pytest.param( - ndev, - (dp, cp, tp), - ("dp", "cp", "tp"), - MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp"), - id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}", - ) + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") ) return configs diff --git a/tests/jax/pytest.ini b/tests/jax/pytest.ini index 2cbbe2ac67..1e835b2187 100644 --- a/tests/jax/pytest.ini +++ b/tests/jax/pytest.ini @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -24,3 +24,4 @@ filterwarnings= ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning ignore:The host_callback APIs are deprecated .*:DeprecationWarning + ignore:Scan loop is disabled for fused ring attention.*:UserWarning diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 20b16c2809..4e4be7569f 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 7ef0d68474..2abcb28dec 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -1,48 +1,33 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import os import pytest -from functools import partial - import jax import jax.numpy as jnp import numpy as np -from flax.linen import dot_product_attention from jax import random -from jax.sharding import Mesh, NamedSharding, PartitionSpec from distributed_test_base import ( generate_configs, generate_context_parallel_configs, generate_collectives_count, - compare_ops, -) -from utils import ( - make_causal_mask, - make_self_mask, - assert_tree_like_allclose, - assert_allclose, - print_debug_tensor_stats, ) -from transformer_engine.jax import fp8_autocast +from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, - fused_attn, AttnBiasType, AttnMaskType, QKVLayout, QKVFormat, - get_qkv_format, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, CPStrategy, + ReorderStrategy, ) -from transformer_engine.jax.sharding import MeshResource -# We will use the golden reference model from our non distributed attention test fixture. -from test_fused_attn import general_dot_product_attention, make_mask -DTYPES = [jnp.float16, jnp.bfloat16] +DTYPES = [jnp.bfloat16] class TestDistributedSelfAttn: @@ -51,7 +36,7 @@ def generate_collectives_count_ref( self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) - _, seqlen, _, heads, _ = shape + _, seqlen, heads, _ = shape is_dp_enabled = mesh_resource.dp_resource is not None tp_size = 1 if mesh_resource.tp_resource is not None: @@ -64,45 +49,28 @@ def generate_collectives_count_ref( # for loss and dbias return generate_collectives_count(allreduce=allreduce_total_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, with_bias, attn_mask_type, dtype): - batch, seqlen, _, heads, _ = shape - - qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - - bias = ( - random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) - if with_bias - else None - ) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - qkv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - bias_pspec = ( - PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]]) @pytest.mark.parametrize( - "attn_bias_type", - [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS], + "data_shape", + [ + pytest.param((32, 512, 12, 64), id="32-512-12-64"), + pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), + ], ) @pytest.mark.parametrize( - "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) + @pytest.mark.parametrize( + "attn_mask_type", + [ + pytest.param(AttnMaskType.PADDING_MASK, id="PADDING_MASK"), + pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), + ], ) @pytest.mark.parametrize("dtype", DTYPES) def test_self_attn( @@ -113,14 +81,14 @@ def test_self_attn( mesh_resource, data_shape, attn_bias_type, + bias_shape, attn_mask_type, dtype, ): dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, _, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -138,74 +106,37 @@ def test_self_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(qkv, bias, mask): - return jnp.mean( - fused_attn( - (qkv,), - bias, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BS3HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ) - ) - - def ref_func(qkv, bias, mask): - query, key, value = jnp.split(qkv, [1, 2], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=bias, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output).astype(dtype) - - with_bias = attn_bias_type != AttnBiasType.NO_BIAS - (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, with_bias, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref( + mesh_shape, + mesh_axes, + mesh_resource, + attn_bias_type != AttnBiasType.NO_BIAS, + data_shape, + dtype, ) - collective_count_ref = self.generate_collectives_count_ref( - mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BS3HD, + bias_shape, + None, + SeqDescFormat.Seqlens, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec)) - bias_ = ( - jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias - ) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - grad_args = (0, 1) if with_bias else (0,) - out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,) - - compare_ops( - target_func, - ref_func, - [qkv_, bias_, mask_], - collective_count_ref, - grad_args=grad_args, - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(qkv_pspec, bias_pspec, mask_pspec), - out_shardings=(None, out_grad_shardings), - ) + runner.test_backward() class TestDistributedCrossAttn: @@ -215,31 +146,6 @@ def generate_collectives_count_ref(self): all_reduce_loss_bytes = 4 # 1 * FP32 return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) - def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype): - batch, seqlen, heads, hidden = shape - - q = random.normal(random.PRNGKey(1124), shape, dtype=dtype) - kv = random.normal(random.PRNGKey(1125), (batch, seqlen, 2, heads, hidden), dtype=dtype) - - mask = None - if attn_mask_type == AttnMaskType.PADDING_MASK: - mask = make_causal_mask(batch, seqlen) - elif attn_mask_type == AttnMaskType.CAUSAL_MASK: - mask = make_self_mask(batch, seqlen) - - q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None) - - kv_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None - ) - mask_pspec = ( - PartitionSpec(mesh_resource.dp_resource, None, None, None) - if attn_mask_type != AttnMaskType.NO_MASK - else None - ) - - return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) @pytest.mark.parametrize( @@ -250,11 +156,11 @@ def test_cross_attn( self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True - scaling_factor = 1.0 - _, seqlen, num_head, hidden = data_shape + batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( dtype, @@ -272,67 +178,30 @@ def test_cross_attn( ): pytest.skip(f"No FusedAttn backend found") - def target_func(q, kv, mask): - return jnp.mean( - fused_attn( - (q, kv), - None, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=QKVLayout.BSHD_BS2HD, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - ), - dtype=jnp.float32, - ) - - def ref_func(query, kv, mask): - key, value = jnp.split(kv, [1], axis=-3) - query = jnp.squeeze(query) - key = jnp.squeeze(key) - value = jnp.squeeze(value) - - output = dot_product_attention( - query, - key, - value, - bias=None, - mask=mask, - deterministic=is_training, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - - return jnp.mean(output, dtype=jnp.float32) - - (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs( - data_shape, mesh_resource, attn_mask_type, dtype + col_ref = self.generate_collectives_count_ref() + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_head, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + QKVLayout.BSHD_BS2HD, + bias_shape, + None, + SeqDescFormat.Seqlens, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + coll_count_ref=col_ref, ) - collective_count_ref = self.generate_collectives_count_ref() - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - q_ = jax.device_put(q, NamedSharding(mesh, q_pspec)) - kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec)) - mask_ = ( - jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask - ) - - compare_ops( - target_func, - ref_func, - [q_, kv_, mask_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(q_pspec, kv_pspec, mask_pspec), - out_shardings=(None, (q_pspec, kv_pspec)), - ) + runner.test_backward() @pytest.mark.parametrize( @@ -341,68 +210,34 @@ def ref_func(query, kv, mask): @pytest.mark.parametrize( "data_shape", [ - pytest.param([2, 512, 12, 128], id="2-512-12-128"), - pytest.param([4, 1024, 16, 64], id="4-1024-16-64"), - ], -) -@pytest.mark.parametrize("kv_groups", [1, 4, 8, 12, 16]) -@pytest.mark.parametrize( - "attn_mask_type", - [ - pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL_MASK"), - pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"), + # Sequence lengths will be scaled by CP so that we don't run with tiny sizes. + pytest.param([2, 128, 8, 128], id="2-128xCP-8-128"), + pytest.param([4, 256, 16, 64], id="4-256xCP-16-64"), ], ) -@pytest.mark.parametrize("dtype", [jnp.bfloat16]) +@pytest.mark.parametrize("kv_groups", [1, 8]) +@pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @pytest.mark.parametrize( - "qkv_layout", + "qkv_layout, attn_mask_type", [ - pytest.param(QKVLayout.BSHD_BS2HD, id="COMBINED_KV"), - pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.CAUSAL_MASK, id="BSHD_KVPACKED-CAUSAL"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.CAUSAL_MASK, id="BSHD_SEPARATE-CAUSAL"), + pytest.param(QKVLayout.BSHD_BS2HD, AttnMaskType.NO_MASK, id="HD_KVPACKED-NO_MASK"), + pytest.param(QKVLayout.BSHD_BSHD_BSHD, AttnMaskType.NO_MASK, id="BSHD_SEPARATE-NO_MASK"), + pytest.param( + QKVLayout.THD_THD_THD, + AttnMaskType.PADDING_CAUSAL_MASK, + id="THD_SEPARATE-PADDING_CAUSAL", + ), ], ) @pytest.mark.parametrize( "load_balanced", - [pytest.param(False, id="UNBALANCED"), pytest.param(True, id="BALANCED")], + [pytest.param(True, id="BALANCED"), pytest.param(False, id="UNBALANCED")], ) class TestDistributedContextParallelSelfAttn: - def generate_inputs(self, shape, kv_groups: int, attn_mask_type: AttnMaskType, dtype): - batch, seqlen, heads, hidden = shape - kv_shape = (batch, seqlen, heads // kv_groups, hidden) - qkey, kkey, vkey = random.split(random.PRNGKey(1124), 3) - q = random.normal(qkey, shape, dtype=dtype) - k = random.normal(kkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - v = random.normal(vkey, (batch, seqlen, heads // kv_groups, hidden), dtype=dtype) - - def gen_valid(bs, max_seqlen, pad_ratio): - pad_len = int(max_seqlen * pad_ratio) - valid_len = max_seqlen - pad_len - tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1) - return tokens, jnp.logical_not(tokens) - - from test_fused_attn import make_mask - - q_idx, _ = gen_valid(batch, seqlen, 0.0) - kv_idx, _ = gen_valid(batch, seqlen, 0.0) - mask = make_mask(q_idx, kv_idx, None, None, attn_mask_type) - - return q, k, v, mask - - def qkv_to_layout(self, q, k, v, qkv_layout): - qkv_args = () - match qkv_layout: - case QKVLayout.BSHD_BS2HD: - k, v = map(partial(jnp.expand_dims, axis=-3), [k, v]) - kv = jnp.concatenate((k, v), axis=-3) - qkv_args = (q, kv) - case QKVLayout.BSHD_BSHD_BSHD: - qkv_args = (q, k, v) - case _: - raise ValueError(f"Unsupported {qkv_layout=}") - return qkv_args - - def impl_test_contex_parallel_attn( + def impl_test_context_parallel_attn( self, device_count, mesh_shape, @@ -417,14 +252,44 @@ def impl_test_contex_parallel_attn( cp_strategy, ): attn_bias_type = AttnBiasType.NO_BIAS + bias_shape = None dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape - qkv_format = get_qkv_format(qkv_layout) + qkv_format = qkv_layout.get_qkv_format() batch, seqlen, num_head, hidden = data_shape + + # Scale the sequence length by 2*CP so its never too small as we scale up test. + # 2*CP is used since we split into two CP groups for load balancing. + seqlen = seqlen * cp_size * 2 + data_shape = batch, seqlen, num_head, hidden + num_kv_heads = num_head // kv_groups - scaling_factor = 1.0 / np.sqrt(num_head) + + runner = FusedAttnRunner( + batch, + seqlen, + seqlen, + num_head, + num_kv_heads, + hidden, + attn_bias_type, + attn_mask_type, + dropout_prob, + dtype, + is_training, + qkv_layout, + bias_shape, + None, + SeqDescFormat.SegmentIDs, + number_of_devices=device_count, + mesh_shape=mesh_shape, + mesh_axes=mesh_axes, + mesh_resource=mesh_resource, + cp_strategy=cp_strategy, + cp_load_balanced=load_balanced, + ) def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( @@ -432,7 +297,7 @@ def check_has_backend_for_mask(mask_type): dtype, qkv_layout, attn_bias_type, - attn_mask_type, + mask_type, dropout_prob, num_head, num_kv_heads, @@ -460,125 +325,9 @@ def check_has_backend_for_mask(mask_type): if num_head % kv_groups != 0 or (num_head // kv_groups) % tp_size != 0: pytest.skip(f"Skipping {kv_groups=} not multiple of {data_shape=} or {tp_size=}") - def target_func(q, k, v, mask): - return fused_attn( - self.qkv_to_layout(q, k, v, qkv_layout), - None, # bias - mask, - None, # seed - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - qkv_layout=qkv_layout, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training, - context_parallel_strategy=cp_strategy, - context_parallel_causal_load_balanced=load_balanced, - context_parallel_axis="cp", - ).astype(dtype) - - def ref_func(q, k, v, mask): - output = general_dot_product_attention( - q, - k, - v, - bias=None, - mask=mask, - deterministic=not is_training, - scale_factor=scaling_factor, - dropout_rate=dropout_prob, - dropout_rng=None, - dtype=jnp.float32, - ) - return output.astype(dtype) - - def grad_func(func, *args, **kwargs): - # Gradient is small, use a gradient multiplier to amplify the gradient - _, max_seq_len, num_heads, _ = data_shape - gradient_multiplier = max_seq_len * num_heads - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: - gradient_multiplier /= 10 - ret_valid = func(*args, **kwargs) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) - - q, k, v, mask = self.generate_inputs(data_shape, kv_groups, attn_mask_type, dtype) - - diff_argnums = (0, 1, 2) - - # Single GPU (reference) - ref_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(ref_func, q, k, v, mask), argnums=diff_argnums - ) - ) - ref_fwd, ref_grads = ref_func_jit(q, k, v, mask) - - # Multi GPU (function under test) - devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) - mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource, enabled=False): - qkv_ps = PartitionSpec( - mesh_resource.dp_resource, - mesh_resource.cp_resource, - mesh_resource.tp_resource, - None, - ) - qkv_sharding = NamedSharding(mesh, qkv_ps) - - mask_ps = PartitionSpec( - mesh_resource.dp_resource, None, mesh_resource.cp_resource, None - ) - mask_sharding = NamedSharding(mesh, mask_ps) - - reorder = partial( - reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - inverse_reorder = partial( - inverse_reorder_causal_load_balancing, cp_size=cp_size, tensor_format=qkv_format - ) - - if load_balanced: - q, k, v = jax.tree.map(reorder, (q, k, v)) - - q_, k_, v_ = map(partial(jax.device_put, device=qkv_sharding), [q, k, v]) - mask_ = jax.device_put(mask, device=mask_sharding) - - target_func_jit = jax.jit( - jax.value_and_grad( - lambda q, k, v, mask: grad_func(target_func, q, k, v, mask), - argnums=diff_argnums, - ), - in_shardings=[qkv_sharding, qkv_sharding, qkv_sharding, mask_sharding], - out_shardings=(None, (qkv_sharding, qkv_sharding, qkv_sharding)), - ) - - target_fwd, target_grads = target_func_jit(q_, k_, v_, mask_) - - if load_balanced: - target_dq, target_dk, target_dv = jax.tree.map(inverse_reorder, target_grads[0:3]) - target_grads = (target_dq, target_dk, target_dv, *target_grads[3:]) - - has_diffs = False - - print_debug_tensor_stats("target", target_fwd) - print_debug_tensor_stats("ref", ref_fwd) - print_debug_tensor_stats("diff", jnp.abs(target_fwd - ref_fwd)) - assert_allclose(target_fwd, ref_fwd, dtype=dtype) - - for i in range(len(target_grads)): - if ref_grads[i] is None or target_grads[i] is None: - # expect both none if one is - assert target_grads[i] is None and ref_grads[i] is None - else: - print_debug_tensor_stats(f"target_grad[{i}]", target_grads[i]) - print_debug_tensor_stats(f"ref_grad[{i}]", ref_grads[i]) - print_debug_tensor_stats( - f"diff_grad[{i}]", jnp.abs(target_grads[i] - ref_grads[i]) - ) - - assert_allclose(target_grads[i], ref_grads[i], dtype=dtype) - - def test_contex_parallel_allgather_attn( + runner.test_backward() + + def test_context_parallel_allgather_attn( self, device_count, mesh_shape, @@ -591,7 +340,9 @@ def test_contex_parallel_allgather_attn( qkv_layout, load_balanced, ): - return self.impl_test_contex_parallel_attn( + if qkv_layout.is_thd(): + pytest.skip("THD doesn't support all gather context parallelism.") + return self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -605,6 +356,10 @@ def test_contex_parallel_allgather_attn( CPStrategy.ALL_GATHER, ) + @pytest.mark.parametrize( + "use_scan", + [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], + ) def test_context_parallel_ring_attn( self, device_count, @@ -617,8 +372,17 @@ def test_context_parallel_ring_attn( dtype, qkv_layout, load_balanced, + use_scan, ): - return self.impl_test_contex_parallel_attn( + if use_scan: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" + else: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" + + if qkv_layout.is_thd() and not load_balanced: + pytest.skip("THD + ring doesn't support unbalanced context parallelism.") + + return self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -631,6 +395,7 @@ def test_context_parallel_ring_attn( load_balanced, CPStrategy.RING, ) + del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] class TestReorderCausalLoadBalancing: @@ -644,17 +409,26 @@ class TestReorderCausalLoadBalancing: ], ) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) - def test(self, cp_size, shape, qkv_format): + @pytest.mark.parametrize( + "reorder_strategy", + [ + pytest.param(ReorderStrategy.DualChunkSwap, id="DualChunkSwap"), + pytest.param(ReorderStrategy.Striped, id="Striped"), + ], + ) + def test(self, cp_size, shape, qkv_format, reorder_strategy): tensor = random.normal(random.PRNGKey(1124), shape, dtype=jnp.bfloat16) + seq_dim = 1 if qkv_format == QKVFormat.SBHD: tensor = tensor.swapaxes(0, 1) + seq_dim = 0 ref = tensor.copy() - reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2]) - inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2]) + reorder = jax.jit(reorder_causal_load_balancing, static_argnums=[1, 2, 3]) + inverse = jax.jit(inverse_reorder_causal_load_balancing, static_argnums=[1, 2, 3]) - reordered = reorder(tensor, cp_size, qkv_format) - inversed = inverse(reordered, cp_size, qkv_format) + reordered = reorder(tensor, reorder_strategy, cp_size, seq_dim) + inversed = inverse(reordered, reorder_strategy, cp_size, seq_dim) assert jnp.array_equal(inversed, ref) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index f0dd56feaa..cc59ecfb34 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 38f7ec0d49..77b299e5bf 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import pytest @@ -271,7 +271,6 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, - dtype=dtype, use_bias=use_bias, ) params_single = ln_mlp_single.init(init_rngs, x) @@ -289,7 +288,6 @@ def _test_layernorm_mlp( transpose_batch_sequence=False, intermediate_dim=INTERMEDIATE, activations=activation_type, - dtype=dtype, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0ed6b84fd5..8f48bc77dd 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_functions.py b/tests/jax/test_functions.py index d6da307fd3..48a2fb4f88 100644 --- a/tests/jax/test_functions.py +++ b/tests/jax/test_functions.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index af05538ef5..745f1cc633 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1,12 +1,12 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tests for fused attention""" -from enum import Enum -from dataclasses import dataclass +from enum import Enum, auto +from dataclasses import dataclass, field from functools import partial from math import sqrt -from typing import Tuple, Optional +from typing import Tuple, Optional, Dict import random import jax @@ -19,25 +19,32 @@ from flax.linen.dtypes import promote_dtype from jax import Array from jax import value_and_grad, jit +from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, AttnMaskType, QKVLayout, QKVFormat, + reorder_causal_load_balancing, + inverse_reorder_causal_load_balancing, fused_attn, - fused_attn_thd, - get_qkv_format, make_swa_mask, + SequenceDescriptor, + CPStrategy, + ReorderStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper -from transformer_engine.transformer_engine_jax import ( +from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, ) -from utils import assert_allclose +from distributed_test_base import assert_equal_collectives +from utils import assert_allclose, print_debug_tensor_stats @pytest.fixture(autouse=True, scope="module") @@ -50,6 +57,7 @@ def init(): yield +@partial(jax.jit, static_argnums=(5, 6, 7, 9)) def general_dot_product_attention( query: ArrayLike, key: ArrayLike, @@ -102,29 +110,36 @@ def general_dot_product_attention( return context -def is_causal_mask(mask: AttnMaskType): - """ - Check if the mask is a causal mask - """ - return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] - - -def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: +@jax.jit +def make_causal_mask( + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike = None, + segment_pos_kv: ArrayLike = None, +) -> Array: """ Create inverse padded causal mask where `True` means allowing the corresponding position to participate in attention and `False` means masking out that position. + If segment_pos is not provided, aragne of the segment_ids will be applied. """ - q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) - kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) - inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal) return inv_causal_mask +@partial(jax.jit, static_argnums=(4, 5)) def make_mask( - q_token: ArrayLike, - kv_token: ArrayLike, - segment_pad_q: ArrayLike, - segment_pad_kv: ArrayLike, + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike, + segment_pos_kv: ArrayLike, attn_mask_type: AttnMaskType, window_size: Optional[Tuple[int, int]] = None, ) -> Array: @@ -132,32 +147,45 @@ def make_mask( Create attention mask based on mask type. A `True` value in the mask means masking out the corresponding position and a `False` value means allowing that position to participate in attention. + + - segment_ids should start with 1, and using 0s for the paddings. + Expected that each segment starts without paddings. + - segment_pos marks the token position in the segments. + + A example pair of segments_ids and segment_pos: + segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5] + segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] """ + # segment masks inv_mask = make_attention_mask( - q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) + segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) - if is_causal_mask(attn_mask_type): - inv_causal_mask = make_causal_mask(q_token, kv_token) - inv_mask = combine_masks(inv_causal_mask, inv_mask) - if segment_pad_q is not None and segment_pad_kv is not None: - inv_pad_mask = make_attention_mask( - segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) + + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape ) - inv_mask = combine_masks(inv_pad_mask, inv_mask) - if window_size is not None: - max_seqlen_q = inv_mask.shape[-2] - max_seqlen_kv = inv_mask.shape[-1] - inv_swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, window_size, attn_mask_type) - inv_swa_mask = jnp.broadcast_to(inv_swa_mask, inv_mask.shape) - # In inv_swa_mask and inv_mask 0 is masked out - inv_mask = jnp.where(inv_mask != 0, inv_swa_mask, inv_mask) + # causal mask + if attn_mask_type.is_causal(): + inv_causal_mask = make_attention_mask( + segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) + ) + inv_mask = combine_masks(inv_causal_mask, inv_mask) + # sliding window mask + inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_) + inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask -def get_seqlens_and_offsets(segment_ids, segment_pad): +@jax.jit +def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) @@ -165,21 +193,17 @@ def get_seqlens_and_offsets(segment_ids, segment_pad): def _find_offsets(x): same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) - first_column = jnp.ones((x.shape[0], 1), dtype=bool) + first_column = x[..., :1] != 0 same_as_previous = jnp.hstack((first_column, same_as_previous)) return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))( same_as_previous ).squeeze(-1) offsets = _find_offsets(segment_ids) - offsets = jnp.insert(offsets, -1, values=-1, axis=-1) - if segment_pad is not None: - segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids) - padding_aware_seqlen = bincount_vmap(segment_id_with_paddings) - output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1) - else: - output = jnp.insert(seqlens, -1, values=0, axis=-1) - return output, offsets + offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1) + seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1) + seqlens = jnp.where(seqlens, seqlens, -1) + return seqlens, offsets @jax.jit @@ -200,8 +224,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): query, key, value, - bias=bias, - mask=mask, + bias, + mask, deterministic=not kwargs["is_training"], scale_factor=kwargs["scaling_factor"], dropout_rate=kwargs["dropout_probability"], @@ -216,11 +240,7 @@ def customcall_fused_dpa( key, value, bias, - mask, - seqlens_q, - seqlens_kv, - offsets_q, - offsets_kv, + sequence_descriptor, dropout_rng, **kwargs, ): @@ -228,7 +248,6 @@ def customcall_fused_dpa( TE customcall dot product attention implementation """ qkv_layout = kwargs["qkv_layout"] - is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD match qkv_layout: case QKVLayout.BS3HD | QKVLayout.T3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) @@ -242,19 +261,9 @@ def customcall_fused_dpa( qkv_args = (query, key, value) case _: raise ValueError(f"Unsupported {qkv_layout=}") - if not is_thd: - kwargs.pop("max_segments_per_seq") - return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) - return fused_attn_thd( - qkv_args, - bias, - seqlens_q, - seqlens_kv, - offsets_q, - offsets_kv, - dropout_rng, - **kwargs, - ).astype(query.dtype) + return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype( + query.dtype + ) class BiasShape(Enum): @@ -262,10 +271,16 @@ class BiasShape(Enum): Enum class to represent the different bias shapes used in the fused attention. """ - BIAS_1HSS = "1HSS" - BIAS_B1SS = "B1SS" - BIAS_BHSS = "BHSS" - BIAS_11SS = "11SS" + _1HSS = "1HSS" + _B1SS = "B1SS" + _BHSS = "BHSS" + _11SS = "11SS" + + +class SeqDescFormat(Enum): + Mask = auto() + Seqlens = auto() + SegmentIDs = auto() @dataclass @@ -287,31 +302,42 @@ class FusedAttnRunner: is_training: bool qkv_layout: QKVLayout bias_shape: BiasShape - window_size: Optional[Tuple[int, int]] = None + window_size: Tuple[int, int] + seq_desc_format: SeqDescFormat + + # Specifies sharding resources for distributed tests + number_of_devices: int = 1 + mesh_shape: tuple[int, ...] = (1, 1, 1) + mesh_axes: tuple[str, ...] = ("dp", "cp", "tp") + mesh_resource: MeshResource = field(default_factory=partial(MeshResource, "dp", "cp", "tp")) + + # Context parallel aux arguments + cp_strategy: CPStrategy = CPStrategy.DEFAULT + cp_load_balanced: bool = True + + # dictionary of expected collective comm bytes + coll_count_ref: Optional[Dict[str, int]] = None # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): - if 90400 <= get_cudnn_version() < 90500: - return self.num_segments_per_seq + if self.qkv_layout.is_thd(): + if 90400 <= get_cudnn_version() < 90500: + return self.num_segments_per_seq + else: + # +1 for testing runtime_segments < max_segments + return self.num_segments_per_seq + 1 else: - # +1 for testing runtime_segments < max_segments - return self.num_segments_per_seq + 1 + return 1 def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available - if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ - AttnMaskType.PADDING_MASK, - AttnMaskType.PADDING_CAUSAL_MASK, - ]: + if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") - qkv_format = get_qkv_format(self.qkv_layout) - if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD: + if self.qkv_layout.is_qkvpacked(): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") - - if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD: if self.num_heads_q != self.num_heads_kv: pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") @@ -323,9 +349,9 @@ def _check_configs(self): self.backend = FusedAttnHelper( self.dtype, self.dtype, - self.qkv_layout.value, - self.attn_bias_type.value, - self.attn_mask_type.value, + self.qkv_layout, + self.attn_bias_type, + self.attn_mask_type, self.dropout_prob, self.num_heads_q, self.num_heads_kv, @@ -339,15 +365,11 @@ def _check_configs(self): if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS - and self.bias_shape != BiasShape.BIAS_1HSS + and self.bias_shape != BiasShape._1HSS ): - if self.attn_mask_type not in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - ]: + if self.attn_mask_type.is_padding(): pytest.skip( - "B1SS, BHSS and 11SS bias shapes are only supported for " - "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." + "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask" ) elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip( @@ -357,6 +379,14 @@ def _check_configs(self): def _setup_inputs(self): self._check_configs() + + # Create a mesh for distributed tests + self.devices = np.asarray(jax.devices()[: self.number_of_devices]).reshape(*self.mesh_shape) + self.mesh = Mesh(self.devices, self.mesh_axes) + self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) + self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) + self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1) + key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) @@ -370,18 +400,18 @@ def _setup_inputs(self): if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None - elif self.bias_shape == BiasShape.BIAS_1HSS: + elif self.bias_shape == BiasShape._1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_B1SS: + elif self.bias_shape == BiasShape._B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_BHSS: + elif self.bias_shape == BiasShape._BHSS: bias_shape = ( self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv, ) - elif self.bias_shape == BiasShape.BIAS_11SS: + elif self.bias_shape == BiasShape._11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") @@ -391,7 +421,7 @@ def _setup_inputs(self): self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) if self.attn_bias_type != AttnBiasType.NO_BIAS: - if self.bias_shape == BiasShape.BIAS_1HSS: + if self.bias_shape == BiasShape._1HSS: self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0) else: # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for @@ -408,10 +438,10 @@ def _setup_inputs(self): else: self.bias = None - if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: - pad_ratio = 0.0 - else: + if self.attn_mask_type.is_padding(): pad_ratio = 0.3 + else: + pad_ratio = 0.0 def gen_valid(bs, max_seqlen, pad_ratio): pad_len = int(max_seqlen * pad_ratio) @@ -420,13 +450,20 @@ def gen_valid(bs, max_seqlen, pad_ratio): return tokens, jnp.logical_not(tokens) def generate_random_segment_ids( - batch_size, sequence_length, num_segments, seed, with_segment_pad=True + batch_size, + sequence_length, + num_segments, + seed, + with_segment_pad=True, + min_segment_len=None, ): rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad - segment_ids = np.zeros((batch_size, sequence_length), dtype=int) + segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32) + segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32) + # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0] # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad - segment_pad = np.zeros((batch_size, sequence_length), dtype=int) + segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32) # Not include paddings max_segment_size = sequence_length // num_segments @@ -434,69 +471,210 @@ def generate_random_segment_ids( current_pos = 0 segment_id = 1 - for _ in range(num_segments): - segment_size = rng.integers(1, max_segment_size + 1) + for seg_id in range(num_segments): + # min_segment_len is to force kv_len >= q_len because cuDNN kernels failed + # TODO(rewang): Remove this constrain after cuDNN supports + min_segment_size = 1 + if min_segment_len is not None: + min_segment_size = min_segment_len[i][seg_id] + segment_size = rng.integers(min_segment_size, max_segment_size + 1) if current_pos + segment_size > sequence_length: break segment_end = current_pos + segment_size segment_ids[i, current_pos:segment_end] = segment_id + segment_pos[i, current_pos:segment_end] = np.arange(segment_size) if with_segment_pad: - num_valid = rng.integers(1, segment_size + 1) + num_valid = rng.integers(min_segment_size, segment_size + 1) segment_pad[i, current_pos + num_valid : segment_end] = 1 current_pos = segment_end segment_id += 1 segment_pad[i, current_pos:sequence_length] = 1 - return segment_ids, segment_pad - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: + segment_ids, segment_pos, segment_pad = map( + jnp.asarray, [segment_ids, segment_pos, segment_pad] + ) + segment_ids = jnp.where(segment_pad, 0, segment_ids) + return segment_ids, segment_pos, segment_pad + + if self.qkv_layout.is_thd(): self.num_segments_per_seq = 2 - self.token_q, self.segment_pad_q = generate_random_segment_ids( + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - # TODO(rewang): Check if qkvpacked supported different q/kv - # TODO(rewang): Causal with different q/kv segment_id fails - if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type): - self.token_kv = self.token_q - self.segment_pad_kv = self.segment_pad_q + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + # TODO(rewang): record only self attention and find the reason of cross attention + if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q else: - self.token_kv, self.segment_pad_kv = generate_random_segment_ids( + # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support + min_segment_len = None if self.window_size is None else self.seqlens_q + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024, + min_segment_len=min_segment_len, ) - self.pad_q = self.segment_pad_q - self.pad_kv = self.segment_pad_kv + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 - self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) - self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) - self.segment_pad_q = self.segment_pad_kv = None + self.segment_ids_q, self.pad_q = gen_valid( + self.batch_size, self.max_seqlen_q, pad_ratio + ) + self.segment_ids_kv, self.pad_kv = gen_valid( + self.batch_size, self.max_seqlen_kv, pad_ratio + ) + self.segment_pos_q = self.segment_pos_kv = None + self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None + # For reference code self.mask = make_mask( - self.token_q, - self.token_kv, - self.segment_pad_q, - self.segment_pad_kv, + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, self.attn_mask_type, self.window_size, ) - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets( - self.token_q, self.segment_pad_q + if self.cp_size > 1 and self.cp_load_balanced: + if self.qkv_layout.is_thd(): + reorder_strategy = ReorderStrategy.Striped + else: + reorder_strategy = ReorderStrategy.DualChunkSwap + + seq_dim = 0 if self.qkv_layout.get_qkv_format() == QKVFormat.SBHD else 1 + self.cp_reorder_fn = partial( + reorder_causal_load_balancing, + strategy=reorder_strategy, + cp_size=self.cp_size, + seq_dim=seq_dim, ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets( - self.token_kv, self.segment_pad_kv + self.cp_inverse_reorder_fn = partial( + inverse_reorder_causal_load_balancing, + strategy=reorder_strategy, + cp_size=self.cp_size, + seq_dim=seq_dim, ) - self.mask_for_customcall = None # THD format doesn't support mask else: - self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None - self.mask_for_customcall = self.mask + # no-ops for non cp or non load balanced + self.cp_reorder_fn = lambda x: x + self.cp_inverse_reorder_fn = lambda x: x + + # Test different input formats + if self.qkv_layout.is_thd(): + match self.seq_desc_format: + case SeqDescFormat.Mask: + pytest.skip("THD doesn't support mask input") + case SeqDescFormat.Seqlens: + self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets( + (self.seqlens_q, self.seqlens_kv), + (self.offsets_q, self.offsets_kv), + ) + case SeqDescFormat.SegmentIDs: + self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( + ( + self.cp_reorder_fn(self.segment_ids_q), + self.cp_reorder_fn(self.segment_ids_kv), + ), + ( + self.cp_reorder_fn(self.segment_pos_q), + self.cp_reorder_fn(self.segment_pos_kv), + ), + ) + case _: + raise ValueError(f"Unknown {self.seq_desc_format=}") + else: + match self.seq_desc_format: + case SeqDescFormat.Mask: + if self.attn_mask_type == AttnMaskType.NO_MASK: + self.sequence_desciptor = None + else: + self.sequence_desciptor = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) + case SeqDescFormat.Seqlens: + self.sequence_desciptor = SequenceDescriptor.from_seqlens( + ( + self.segment_ids_q.sum(axis=-1).astype(jnp.int32), + self.segment_ids_kv.sum(axis=-1).astype(jnp.int32), + ), + ) + case SeqDescFormat.SegmentIDs: + self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( + (self.segment_ids_q, self.segment_ids_kv), + None, + ) + case _: + raise ValueError(f"Unknown {self.seq_desc_format=}") self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) + # Setup distributed sharding specs + # Setup shardings for distributed tests + self.qkvo_psec = PartitionSpec( + self.mesh_resource.dp_resource, + self.mesh_resource.cp_resource, + self.mesh_resource.tp_resource, + None, + ) + self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) + + mask_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + self.mask_sharding = NamedSharding(self.mesh, mask_pspec) + + match self.seq_desc_format: + case SeqDescFormat.Mask: + self.seq_desc_sharding = self.mask_sharding + case _: + + def to_dp_shardings(x): + if x.ndim == 1: + pspec = PartitionSpec(self.mesh_resource.dp_resource) + else: + pspec = PartitionSpec( + self.mesh_resource.dp_resource, self.mesh_resource.cp_resource + ) + return NamedSharding(self.mesh, pspec) + + self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor) + + if self.bias_shape == BiasShape._1HSS: + self.bias_pspec = PartitionSpec( + None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._B1SS: + self.bias_pspec = PartitionSpec( + self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None + ) + elif self.bias_shape == BiasShape._11SS: + self.bias_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + else: + self.bias_pspec = PartitionSpec() + self.bias_sharding = NamedSharding(self.mesh, self.bias_pspec) + + self.dropout_rng_pspec = PartitionSpec( + None, + ) + self.dropout_rng_sharding = NamedSharding(self.mesh, self.dropout_rng_pspec) + + self.logit_scale_pspec = PartitionSpec(None, None, self.mesh_resource.cp_resource, None) + self.logit_scale_sharding = NamedSharding(self.mesh, self.logit_scale_pspec) + + # [batch][max_segments_per_batch] + # TODO(mgoldfarb-nvidia): Will need to handle CP cases of replicated or distributed length/offset. + self.seq_length_offset_pspec = PartitionSpec(self.mesh_resource.dp_resource, None) + self.seq_length_offset_sharding = NamedSharding(self.mesh, self.seq_length_offset_pspec) + def test_forward(self): """ Test forward without JIT @@ -504,17 +682,17 @@ def test_forward(self): self._setup_inputs() args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] + customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # Put test data onto each GPU for distributed. + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -525,10 +703,27 @@ def test_forward(self): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } - # Convert the outputs to float32 for the elementwise comparison - primitive_out = customcall_fused_dpa(*customcall_args, **kwargs) + customcall_fused_dpa_jit = jit( + partial(customcall_fused_dpa, **kwargs), + static_argnames=kwargs.keys(), + in_shardings=[ + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.seq_desc_sharding, + self.dropout_rng_sharding, + ], + ) + + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out = customcall_fused_dpa_jit(*customcall_args) + primitive_out = self.cp_inverse_reorder_fn(primitive_out) + reference_out = jax_dpa(*args, **kwargs) if self.is_training and self.dropout_prob > 0.0: @@ -541,38 +736,55 @@ def test_forward(self): assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = ( + customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() + ) + assert_equal_collectives(target_hlo, self.coll_count_ref) + def test_backward(self): """ - Test value_and_grad with JIT, which includes both forward and backward + Test value_and_grad with JIT, which includes both forward and backward. + + If coll_count_ref is not None then the HLO of the backwrds function + HLO will be examined for the expected comms. """ self._setup_inputs() - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS: - pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.") - def grad_func(func, *args, **kwargs): + def grad_func(func, *args, cp_reverse_out=False, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.max_seqlen_q * self.num_heads_q - if is_causal_mask(self.attn_mask_type): + if self.attn_mask_type.is_causal(): gradient_multiplier /= 10 # Keep only valid result for the gradient - ret_valid = jnp.where( - self.pad_q[..., jnp.newaxis, jnp.newaxis], 0, func(*args, **kwargs) - ) - return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(self.dtype) + if not cp_reverse_out: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + func(*args, **kwargs), + ) + else: + ret_valid = jnp.where( + self.pad_q[..., jnp.newaxis, jnp.newaxis], + 0, + self.cp_inverse_reorder_fn(func(*args, **kwargs)), + ) + return ( + jnp.mean(ret_valid.astype(jnp.float32), dtype=jnp.float32) * gradient_multiplier + ).astype(self.dtype) args = [self.q, self.k, self.v, self.bias, self.mask, self.dropout_rng] customcall_args = [ - self.q, - self.k, - self.v, - self.bias, - self.mask_for_customcall, - self.seqlens_q, - self.seqlens_kv, - self.offsets_q, - self.offsets_kv, - self.dropout_rng, + # TODO(mgoldfarb-nvidia): We will need to add reordering for bias, mas and + # THD params once we support those features on CP. + jax.device_put(self.cp_reorder_fn(self.q), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), + jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), + jax.device_put(self.bias, self.bias_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), + jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { "attn_bias_type": self.attn_bias_type, @@ -583,19 +795,40 @@ def grad_func(func, *args, **kwargs): "qkv_layout": self.qkv_layout, "max_segments_per_seq": self._get_max_segments_per_sequence(), "window_size": self.window_size, + "context_parallel_strategy": self.cp_strategy, + "context_parallel_causal_load_balanced": self.cp_load_balanced, } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2) + if self.bias_shape == BiasShape._1HSS: + arg_nums = (0, 1, 2, 3) + grad_shardings = ( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + ) + else: + arg_nums = (0, 1, 2) + grad_shardings = (self.qkvo_sharding, self.qkvo_sharding, self.qkvo_sharding) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( value_and_grad( lambda q, k, v, bias, *args: grad_func( - customcall_fused_dpa, q, k, v, bias, *args, **kwargs + customcall_fused_dpa, q, k, v, bias, *args, cp_reverse_out=True, **kwargs ), arg_nums, - ) + ), + in_shardings=( + self.qkvo_sharding, + self.qkvo_sharding, + self.qkvo_sharding, + self.bias_sharding, + self.seq_desc_sharding, + self.dropout_rng_sharding, + ), + out_shardings=(None, grad_shardings), ) jitted_reference = jit( value_and_grad( @@ -604,20 +837,31 @@ def grad_func(func, *args, **kwargs): ) ) - primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) + reference_out, reference_dgrad = jitted_reference(*args) # Skip elementwise comparison when dropout enabled if self.dropout_prob > 0.0: return + print_debug_tensor_stats(f"primitive_out", primitive_out) + print_debug_tensor_stats(f"reference_grad_valid", reference_out) + print_debug_tensor_stats(f"diff_grad", jnp.abs(primitive_out - reference_out)) assert_allclose(primitive_out, reference_out, dtype=self.dtype) - def check_dqkv(primitive, reference, pad): + def check_dqkv(primitive, reference, pad, idx): primitive_valid, primitive_invalid, reference_valid, reference_invalid = ( _split_valid_and_invalid(primitive, reference, pad) ) + print_debug_tensor_stats(f"primitive_grad_valid[{idx}]", primitive_valid[idx]) + print_debug_tensor_stats(f"reference_grad_valid[{idx}]", reference_valid[idx]) + print_debug_tensor_stats( + f"diff_grad[{idx}]", jnp.abs(primitive_valid[idx] - reference_valid[idx]) + ) + assert_allclose(primitive_invalid, jnp.zeros_like(primitive_invalid), dtype=self.dtype) assert_allclose(primitive_invalid, reference_invalid, dtype=self.dtype) assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) @@ -625,11 +869,17 @@ def check_dqkv(primitive, reference, pad): primitive_dq, primitive_dk, primitive_dv = primitive_dgrad[:3] reference_dq, reference_dk, reference_dv = reference_dgrad[:3] - check_dqkv(primitive_dq, reference_dq, self.pad_q) - check_dqkv(primitive_dk, reference_dk, self.pad_kv) - check_dqkv(primitive_dv, reference_dv, self.pad_kv) + primitive_dq = self.cp_inverse_reorder_fn(primitive_dq) + primitive_dk = self.cp_inverse_reorder_fn(primitive_dk) + primitive_dv = self.cp_inverse_reorder_fn(primitive_dv) + + check_dqkv(primitive_dq, reference_dq, self.pad_q, 0) + check_dqkv(primitive_dk, reference_dk, self.pad_kv, 1) + check_dqkv(primitive_dv, reference_dv, self.pad_kv, 2) + + if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: + # TODO(mgoldfarb-nvidia): Inverse reorder bias once supported by a CP implementation. - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -657,17 +907,12 @@ def check_dqkv(primitive, reference, pad): dtype=self.dtype, ) + if self.coll_count_ref is not None: + with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() + assert_equal_collectives(target_hlo, self.coll_count_ref) + -@pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"), - ], -) @pytest.mark.parametrize( "attn_mask_type", [ @@ -691,10 +936,7 @@ def check_dqkv(primitive, reference, pad): @pytest.mark.parametrize( "b, s_q, s_kv, h_q, h_kv, d, dtype", [ - pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"), - pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), - pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"), pytest.param( 2, 2048, @@ -705,8 +947,8 @@ def check_dqkv(primitive, reference, pad): jnp.bfloat16, id="2-2048-1024-12-12-64-BF16-CROSS", ), - pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), + pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), ], ) @pytest.mark.parametrize( @@ -723,6 +965,14 @@ def check_dqkv(primitive, reference, pad): pytest.param(True, id="SWA"), ], ) +@pytest.mark.parametrize( + "seq_desc_format", + [ + pytest.param(SeqDescFormat.Mask, id="Mask"), + pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), + pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"), + ], +) class TestFusedAttn: """ Fused attention tester @@ -736,6 +986,16 @@ class TestFusedAttn: pytest.param(False, id="INFERENCE"), ], ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"), + ], + ) def _test_forward( b, s_q, @@ -751,6 +1011,7 @@ def _test_forward( qkv_layout, bias_shape, swa, + seq_desc_format, ): """ Test forward with parameterized configs @@ -775,10 +1036,18 @@ def _test_forward( qkv_layout, bias_shape, window_size, + seq_desc_format, ) runner.test_forward() @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) def test_backward( b, s_q, @@ -793,6 +1062,7 @@ def test_backward( qkv_layout, bias_shape, swa, + seq_desc_format, ): """ Test backward with parameterized configs @@ -815,5 +1085,6 @@ def test_backward( qkv_layout, bias_shape, window_size, + seq_desc_format, ) runner.test_backward() diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 3a0add0a38..e906a37414 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 55c09b4562..ed15913f38 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -1,17 +1,22 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Test transformer_engine.jax.flax.TransformerLayer""" import os from functools import partial -from typing import Dict, Tuple +from typing import Dict, Tuple, Optional import flax import jax import jax.numpy as jnp import pytest -from utils import assert_allclose, assert_tree_like_allclose, sync_params_values +from utils import ( + assert_allclose, + assert_tree_like_allclose, + dtype_tols, + sync_params_values, +) from utils import DecoderLayer as RefDecoderLayer from utils import EncoderLayer as RefEncoderLayer @@ -250,12 +255,18 @@ def _sync_params(self, ref, target): target = sync_params_values(target, ref, self.transformations) return ref, target - def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + def test_forward( + self, + data_shape: Tuple[int], + dtype: jnp.dtype, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> None: """Test only the forward""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) - ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) - layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + ref_layer_cls = partial(self.reference_layer, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) @@ -264,14 +275,21 @@ def test_forward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): ref_out = self._loss_fn(inputs, ref_masks, ref_params, ref_others, ref_layer) test_out = self._loss_fn(inputs, test_masks, test_params, test_others, test_layer) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) + tols = dtype_tols(dtype, rtol=rtol, atol=atol) + assert_allclose(ref_out, test_out, **tols) - def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): + def test_backward( + self, + data_shape: Tuple[int], + dtype: jnp.dtype, + rtol: Optional[float] = None, + atol: Optional[float] = None, + ) -> None: """Test forward and backward through value_and_grad()""" inputs, (ref_masks, test_masks) = self.generate_inputs(data_shape, dtype) - ref_layer_cls = partial(self.reference_layer, dtype=dtype, **self.attrs) - layer_cls = partial(TransformerLayer, layer_type=self.layer_type, dtype=dtype, **self.attrs) + ref_layer_cls = partial(self.reference_layer, **self.attrs) + layer_cls = partial(TransformerLayer, layer_type=self.layer_type, **self.attrs) ref_layer, ref_params, ref_others = self._generate_layer(ref_layer_cls, inputs, ref_masks) test_layer, test_params, test_others = self._generate_layer(layer_cls, inputs, test_masks) @@ -302,11 +320,12 @@ def test_backward(self, data_shape, dtype, rtol=1e-05, atol=1e-08): inputs, test_masks, test_params, test_others, test_layer ) - assert_allclose(ref_out, test_out, rtol=rtol, atol=atol) - assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol) + tols = dtype_tols(dtype, rtol=rtol, atol=atol) + assert_allclose(ref_out, test_out, **tols) + assert_tree_like_allclose(ref_dgrads, test_dgrads, **tols) _, restructed_ref_wgrads = self._sync_params(ref_wgrads, test_wgrads) - assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, rtol=rtol, atol=atol) + assert_tree_like_allclose(restructed_ref_wgrads, test_wgrads, **tols) class EncoderRunner(BaseRunner): @@ -418,12 +437,12 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" FP8Helper.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5) + self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" FP8Helper.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5) + self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_format", FP8_FORMATS) diff --git a/tests/jax/test_misc.py b/tests/jax/test_misc.py index 67145daf63..6db492921d 100644 --- a/tests/jax/test_misc.py +++ b/tests/jax/test_misc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 8ac8ecbe79..935eb290e4 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_sanity_import.py b/tests/jax/test_sanity_import.py index f47c2eb411..5e1bca2c9c 100644 --- a/tests/jax/test_sanity_import.py +++ b/tests/jax/test_sanity_import.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py index 4581cdc39e..0d50b73451 100644 --- a/tests/jax/test_sharding.py +++ b/tests/jax/test_sharding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py index 49e32e503c..8cc8448979 100644 --- a/tests/jax/test_softmax.py +++ b/tests/jax/test_softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tests for the softmax primitives""" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 78a6225e1f..dba7cb64fc 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Utility for the TE layer tests""" @@ -19,7 +19,11 @@ from jax import nn as jax_nn from jax import random as jax_random -from transformer_engine.jax.attention import AttnMaskType, make_swa_mask +from transformer_engine.jax.attention import ( + AttnMaskType, + canonicalize_attn_mask_type, + make_swa_mask, +) from transformer_engine.jax.fp8 import DType as TEDType PRNGKey = Any @@ -106,7 +110,7 @@ class DotProductAttention(nn.Module): Args: dropout_rate: dropout rate - dtype: the dtype of the computation (default: float32) + dtype: the data type used to allocate the initial parameters (default: float32). float32_logits: bool, if True then compute logits in float32 to avoid numerical issues with bfloat16. """ @@ -191,6 +195,7 @@ def __call__( attn_weights = attn_weights * multiplier attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) + attn_weights = attn_weights.astype(value.dtype) # Take the linear combination of `value`. if self.transpose_batch_sequence: @@ -205,7 +210,7 @@ class DenseGeneral(nn.Module): Attributes: features: tuple with numbers of output features. axis: tuple with axes to apply the transformation on. - dtype: the dtype of the computation (default: float32). + dtype: the data type used to allocate the initial parameters (default: float32). kernel_init: initializer function for the weight matrix. use_bias: whether to add a bias to the output (default: False). bias_init: initializer function for the bias vector. @@ -222,7 +227,9 @@ class DenseGeneral(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -235,6 +242,7 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ + input_dtype = inputs.dtype features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) @@ -244,23 +252,24 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes ) - kernel = jnp.asarray(kernel, self.dtype) + kernel = jnp.asarray(kernel, input_dtype) kernel = jnp.reshape(kernel, kernel_shape) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes + "bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) else: bias = None contract_ind = tuple(range(0, len(axis))) y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ()))) + y = y.astype(input_dtype) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) @@ -277,7 +286,7 @@ class MlpBlock(nn.Module): kernel_init: Kernel function, passed to the dense layers. deterministic: Whether the dropout layers should be deterministic. intermediate_dropout_rate: Dropout rate used after the intermediate layers. - dtype: Type for the dense layer. + dtype: the data type used to allocate the initial parameters (default: float32). """ transpose_batch_sequence: bool @@ -292,7 +301,9 @@ class MlpBlock(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -354,6 +365,9 @@ def __call__(self, inputs, deterministic: bool = False): bias_axes="embed", name="wo", )(x) + assert ( + output.dtype == inputs.dtype + ), f"input.dtype={input.dtype}, output.dtype={output.dtype}" return output @@ -425,7 +439,7 @@ class MultiHeadAttention(nn.Module): should be divisible by the number of heads. num_gqa_groups: number of kv attention heads head_dim: dimension of each head. - dtype: the dtype of the computation. + dtype: the data type used to allocate the initial parameters (default: float32). dropout_rate: dropout rate kernel_init: initializer for the kernel of the Dense layers. float32_logits: bool, if True then compute logits in float32 to avoid @@ -449,7 +463,9 @@ class MultiHeadAttention(nn.Module): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", dtype=self.dtype + ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @@ -734,6 +750,9 @@ def qkv_init(key, shape, dtype): dtype=self.dtype, name="out", )(x) + assert ( + inputs_q.dtype == inputs_kv.dtype == out.dtype + ), f"q.dtype={inputs_q.dtype}, kv.dtype={inputs_kv.dtype}, out.dtype={out.dtype}" return out @@ -759,13 +778,13 @@ def __post_init__(self): def __call__(self, x: jnp.ndarray) -> jnp.ndarray: """Applies layer normalization on the input.""" - x = jnp.asarray(x, jnp.float32) + input_dtype = x.dtype features = x.shape[-1] scale = nn_partitioning.param_with_axes( - "scale", self.scale_init, (features,), jnp.float32, axes=("embed",) + "scale", self.scale_init, (features,), self.dtype, axes=("embed",) ) - scale = jnp.asarray(scale, self.dtype) + scale = jnp.asarray(scale, input_dtype) if self.layernorm_type == "layernorm": mean = jnp.mean(x, axis=-1, keepdims=True) @@ -773,9 +792,9 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: y = (x - mean) * lax.rsqrt(var + self.epsilon) bias = nn_partitioning.param_with_axes( - "ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",) + "ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",) ) - bias = jnp.asarray(bias, self.dtype) + bias = jnp.asarray(bias, input_dtype) if not self.zero_centered_gamma: z = y * scale + bias @@ -788,7 +807,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: y = x * lax.rsqrt(mean2 + self.epsilon) z = y * scale - return jnp.asarray(z, self.dtype) + assert z.dtype == x.dtype, f"output_dtype={z.dtype}, input_dtype={x.dtype}" + return z class RelativePositionBiases(nn.Module): @@ -801,7 +821,7 @@ class RelativePositionBiases(nn.Module): distance bucket. num_heads: Number of heads in the attention layer. Each head will get a different relative position weighting. - dtype: Type of arrays through this module. + dtype: the data type used to allocate the initial parameters (default: float32). embedding_init: initializer for relative embedding table. """ @@ -913,24 +933,16 @@ def apply_swa_mask( window_size: Tuple[int, int] = (-1, -1), ) -> Array: """Apply the sliding window mask to a given mask""" - mask_map = { - "no_mask": AttnMaskType.NO_MASK, - "padding": AttnMaskType.PADDING_MASK, - "causal": AttnMaskType.CAUSAL_MASK, - "padding_causal": AttnMaskType.PADDING_CAUSAL_MASK, - "causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - "padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - } - _attn_mask_type = mask_map.get(attn_mask_type, None) + _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type) assert _attn_mask_type is not None + batch = original_mask.shape[0] max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] - swa_mask = make_swa_mask( - max_seqlen_q, max_seqlen_kv, window_size, _attn_mask_type, dtype=original_mask.dtype - ) + pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q)) + pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv)) + swa_mask = make_swa_mask(pos_q, pos_kv, window_size, original_mask.dtype) # In swa_mask and original_mask 0 is masked out - swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) - new_mask = jnp.where(original_mask == 1, swa_mask_bcast, original_mask) + new_mask = jnp.where(original_mask == 1, swa_mask, original_mask) return new_mask @@ -1091,6 +1103,7 @@ def __call__(self, inputs, encoder_mask=None, deterministic=False): dtype=self.dtype, name="output_layernorm", )(y) + assert y.dtype == inputs.dtype, f"output_dtype={y.dtype}, input_dtype={inputs.dtype}" return y @@ -1297,6 +1310,7 @@ def __call__( name="output_layernorm", )(z) + assert z.dtype == inputs.dtype, f"output_dtype={z.dtype}, input_dtype={inputs.dtype}" return z @@ -1391,18 +1405,26 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08): def dtype_tols( dtype: Union[DType, TEDType, np.dtype], reference_value: float = 1.0, + rtol: Optional[float] = None, + atol: Optional[float] = None, ) -> Dict[str, float]: """Expected numerical tolerance for a data type. Args: dtype: data type. reference_value: reference value (default: 1). + rtol: override for relative tolerance estimate + atol: override for absolute tolerance estimate Returns: Dictionary with "rtol" and "atol" as keys """ + # Return immediately if tolerances are fully specified + if rtol is not None and atol is not None: + return {"rtol": rtol, "atol": atol} + # Convert to JAX dtype if needed if isinstance(dtype, TEDType): dtype = { @@ -1420,7 +1442,11 @@ def dtype_tols( # Expect bit-wise accuracy for integer dtypes if not jnp.issubdtype(dtype, jnp.floating): - return dict(rtol=0, atol=0) + if rtol is None: + rtol = 0.0 + if atol is None: + atol = 0.0 + return {"rtol": rtol, "atol": atol} # Estimate floating-point error finfo = jnp.finfo(dtype) @@ -1433,10 +1459,11 @@ def dtype_tols( spacing_high = jnp.nextafter(reference_value, finfo.max) - reference_value spacing_low = reference_value - jnp.nextafter(reference_value, finfo.min) ulp = max(spacing_high.item(), spacing_low.item()) - return dict( - rtol=eps_relaxed, - atol=max(ulp, eps_relaxed), - ) + if rtol is None: + rtol = eps_relaxed + if atol is None: + atol = max(ulp, eps_relaxed) + return {"rtol": rtol, "atol": atol} def sync_params_values(dst, src, transformations, sep="/"): diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py deleted file mode 100644 index 8c417b1930..0000000000 --- a/tests/paddle/dist_launcher.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Helper functions to launch distributed tests""" - -import copy -import os -from pathlib import Path -import subprocess -import time -import unittest - -try: - from paddle.base import core -except ImportError: - from paddle.fluid import core -from paddle.distributed.utils.launch_utils import ( - TrainerProc, - find_free_ports, - get_cluster, - watch_local_trainers, -) - -__all__ = ["TestDistributed"] - - -def get_cluster_from_args(selected_gpus): - """Get node information from selected GPUs""" - cluster_node_ips = "127.0.0.1" - node_ip = "127.0.0.1" - - node_ips = [x.strip() for x in cluster_node_ips.split(",")] - - node_ips.index(node_ip) - - free_ports = None - - free_ports = find_free_ports(len(selected_gpus)) - if free_ports is not None: - free_ports = list(free_ports) - - trainer_endpoints = [] - for ip in node_ips: - trainer_endpoints.append([f"{ip}:{port}" for port in free_ports]) - return get_cluster(node_ips, node_ip, trainer_endpoints, selected_gpus) - - -def get_gpus(selected_gpus): - """Get selected GPU string""" - selected_gpus = [x.strip() for x in selected_gpus.split(",")] - return selected_gpus - - -def start_local_trainers( - cluster, - pod, - training_script, - training_script_args, - allocator_strategy="auto_growth", -): - """Launch trainers""" - current_env = copy.copy(os.environ.copy()) - # paddle broadcast ncclUniqueId use socket, and - # proxy maybe make trainers unreachable, so delete them. - # if we set them to "", grpc will log error message "bad uri" - # so just delete them. - current_env.pop("http_proxy", None) - current_env.pop("https_proxy", None) - - procs = [] - for t in pod.trainers: - proc_env = { - "FLAGS_selected_gpus": ",".join([str(g) for g in t.gpus]), - "PADDLE_TRAINER_ID": f"{t.rank}", - "PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}", - "PADDLE_TRAINERS_NUM": f"{cluster.trainers_nranks()}", - "PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints()), - "PYTHONPATH": str(Path(__file__).resolve().parent), - } - - proc_env["FLAGS_allocator_strategy"] = allocator_strategy - if allocator_strategy == "auto_growth": - proc_env["FLAGS_fraction_of_gpu_memory_to_use"] = "0.1" - - current_env.update(proc_env) - - print(f"trainer proc env:{current_env}") - - if os.getenv("WITH_COVERAGE", "OFF") == "ON": - cmd = "python -m coverage run --branch -p " + training_script - else: - cmd = "python -u " + training_script - - print(f"start trainer proc:{cmd} env:{proc_env}") - - fn = None - - proc = subprocess.Popen( - cmd.split(" ") + training_script_args, env=current_env - ) # pylint: disable=consider-using-with - - tp = TrainerProc() - tp.proc = proc - tp.rank = t.rank - tp.log_fn = fn - tp.cmd = cmd - - procs.append(tp) - - return procs - - -class TestDistributed(unittest.TestCase): - """Base class for distributed test""" - - @staticmethod - def run_2gpu( - target_file_name, - allocator_strategy="auto_growth", - ): - """Run target file in subprocesses""" - if not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0: - return - - selected_gpus = get_gpus("0,1") - cluster = None - pod = None - - cluster, pod = get_cluster_from_args(selected_gpus) - - procs = start_local_trainers( - cluster, - pod, - allocator_strategy=allocator_strategy, - training_script=target_file_name, - training_script_args=[], - ) - - while True: - alive = watch_local_trainers(procs, cluster.trainers_endpoints()) - - if not alive: - print(f"Local procs complete, POD info:{pod}") - break - time.sleep(3) diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py deleted file mode 100644 index c4605f121e..0000000000 --- a/tests/paddle/parallel_tests/amax_reduction.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -def assert_allclose_across_ranks(tensor, group=None): - """Assert tensor is identical in all ranks""" - gathered_list = [] - paddle.distributed.all_gather(gathered_list, tensor, group=group) - assert len(gathered_list) > 1 - for gathered_tensor in gathered_list: - assert_allclose(tensor, gathered_tensor) - - -class TestAmaxReduction(unittest.TestCase): - """Tests Amax reduction""" - - def setUp(self): - self.data_parallel_size = 2 - self.init_dist_env() - self.global_dtype = "bfloat16" - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": self.data_parallel_size, - "mp_degree": 1, - "pp_degree": 1, - } - fleet.init(is_collective=True, strategy=strategy) - - def test_amax_reduction(self): - """Tests column parallel linear""" - set_random_seed(1024) - layer1 = te.Linear(16, 16) - layer2 = te.Linear(16, 16) - model = paddle.nn.Sequential(layer1, layer2) - model = fleet.distributed_model(model) - - rank_id = paddle.distributed.get_rank() - set_random_seed(rank_id) - - optimizer = paddle.optimizer.SGD(learning_rate=10.0, parameters=model.parameters()) - optimizer = fleet.distributed_optimizer(optimizer) - - def train_one_step(layer, inp, optimizer): - inp = paddle.to_tensor(inp) - inp.stop_gradient = False - out = layer(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([16, 16], self.global_dtype) - with te.fp8_autocast(enabled=True): - train_one_step(model, inp, optimizer) - - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].amax_history[-1]) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_fwd"].scale_inv) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].amax_history[-1]) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_fwd"].scale_inv) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].amax_history[-1]) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale) - assert_allclose_across_ranks(layer1.fp8_meta["scaling_bwd"].scale_inv) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].amax_history[-1]) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale) - assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/attention_tp.py b/tests/paddle/parallel_tests/attention_tp.py deleted file mode 100644 index e145f20b39..0000000000 --- a/tests/paddle/parallel_tests/attention_tp.py +++ /dev/null @@ -1,234 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Transformer layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks -import transformer_engine.paddle as te - - -class TestAttentionTp(unittest.TestCase): - """Tests MultiHeadAttention layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-3 - self.atol = 5e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False): - inp, mask = inp_list - if sequence_parallel: - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - with te.fp8_autocast(enabled=fp8_enabled): - out = layer(input_parallel, mask) - if sequence_parallel: - total_out = mp_ops._c_concat(out, group=self.tp_group) - total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss, total_out - - def test_parallel_layer(self): - """Tests parallel Transformer""" - set_random_seed(1024) - common_args = ( - self.hidden_size, - self.num_heads, - ) - common_kwargs = { - "layernorm_epsilon": self.eps, - "attention_dropout": 0.0, - "attn_mask_type": self.mask_type, - "attention_type": "self", - "tp_group": self.tp_group, - "input_layernorm": True, - } - - layer_tp = te.MultiHeadAttention( - *common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False) - - def _get_total_weight(local_weight, tp_group, axis, interleave=False): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - if interleave: - # Due to the interleaved qkv layout, need to concat on num_head - # dimension for column parallel linear in MultiHeadAttention layer - assert axis == 0 - assert [ - 3 * self.hidden_size // self.world_size, - self.hidden_size, - ] == partial_weight.shape - local_num_head = self.num_heads // self.world_size - for idx, _ in enumerate(total_weight): - total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size] - ) - total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) - else: - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - def _get_weight(obj, weight_names): - for name in weight_names: - obj = getattr(obj, name) - return obj - - def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False): - weight_src = _get_weight(layer_src, weight_names) - weight_dst = _get_weight(layer_dst, weight_names) - if partition_mode is None: - total_weight = weight_src - elif partition_mode == "column": - total_weight = _get_total_weight( - weight_src, tp_group=self.tp_group, axis=0, interleave=interleave - ) - elif partition_mode == "row": - total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) - else: - raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert ( - weight_dst.shape == total_weight.shape - ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." - weight_dst.copy_(total_weight, True) - - copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"]) - copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True) - copy_weight(layer_tp, layer_single, "row", ["proj", "weight"]) - - if self.sequence_parallel: - register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) - - optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD( - learning_rate=0.01, parameters=layer_single.parameters() - ) - - layer_tp = fleet.distributed_model(layer_tp) - optimizer_tp = fleet.distributed_optimizer(optimizer_tp) - - for _ in range(5): - inp = paddle.uniform( - [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype - ) - mask = paddle.zeros( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" - ) - loss_tp, out_tp = self._train_one_step( - layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel - ) - loss_single, out_single = self._train_one_step( - layer_single, [inp, mask], optimizer_single, self.fp8 - ) - assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) - assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) - - -class TestAttentionTpFp8(TestAttentionTp): - """Tests MultiHeadAttention layer with model parallel in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestAttentionSp(TestAttentionTp): - """Tests MultiHeadAttention layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-3 - self.atol = 5e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestAttentionSpFp8(TestAttentionTp): - """Tests MultiHeadAttention layer with sequence parallel in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 1e-1 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py deleted file mode 100644 index 11060be38e..0000000000 --- a/tests/paddle/parallel_tests/group_sharding.py +++ /dev/null @@ -1,188 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for group sharding""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import ( - DygraphShardingOptimizer, -) - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -class TestGroupSharding(unittest.TestCase): - """Tests group sharding""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def set_attr(self): - """Set test configs""" - self.sharding_degree = 2 - self.global_dtype = "float32" - self.rtol = 1e-5 - self.atol = 1e-5 - self.batch_size = 16 - self.in_channels = 16 - self.out_channels = 32 - self.fp8 = False - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": 1, - "sharding_degree": self.sharding_degree, - } - self.strategy = strategy - fleet.init(is_collective=True, strategy=strategy) - - def _get_model_and_optimizer(self, model, stage): - if stage == 1: - optimizer = DygraphShardingOptimizer( - paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()), - fleet.get_hybrid_communicate_group(), - ) - model = fleet.distributed_model(model) - optimizer = fleet.distributed_optimizer(optimizer) - elif stage in [2, 3]: - optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters()) - group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group() - - class ShardingLevel: # pylint: disable=too-few-public-methods, - """Paddle sharding options""" - - kStage1 = "os" - kStage2 = "os_g" - kStage3 = "p_g_os" - - level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2 - model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel( - model=model, - optimizer=optimizer, - level=level, - group=group, - segment_size=256, - ) - else: - raise ValueError(f"Stage {stage} not supported") - return model, optimizer - - def test_group_sharding_stage1(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=1) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=1) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - assert ( - len(optimizer_te.state_dict()) == 4 - ), "Expect each rank to hold 4 optimizer state entries." - - def test_group_sharding_stage2(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=2) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=2) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - # Check gradients are split to different trainers - if rank_id == 0: - assert model.bias.grad is None and model.weight.grad is not None - elif rank_id == 1: - assert model.weight.grad is None and model.bias.grad is not None - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - assert ( - len(optimizer_te.state_dict()) == 4 - ), "Expect each rank to hold 4 optimizer state entries." - - def test_group_sharding_stage3(self): - """Tests group sharding training""" - set_random_seed(1024) - model_te = te.Linear(self.in_channels, self.out_channels) - model_pd = paddle.nn.Linear(self.in_channels, self.out_channels) - model_pd.weight.copy_(model_te.weight.T, True) - model_pd.bias.copy_(model_te.bias, True) - - model_te, optimizer_te = self._get_model_and_optimizer(model_te, stage=3) - model_pd, optimizer_pd = self._get_model_and_optimizer(model_pd, stage=3) - - rank_id = paddle.distributed.get_rank() - paddle.seed(rank_id) - - def train_one_step(model, inp, optimizer): - out = model(inp) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_channels], self.global_dtype) - with te.fp8_autocast(enabled=False): - loss_te = train_one_step(model_te, inp, optimizer_te) - loss_pd = train_one_step(model_pd, inp, optimizer_pd) - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - for name, value in optimizer_te.state_dict().items(): - if name.endswith("w_0_moment1_0"): - assert ( - value.numel() == self.in_channels * self.out_channels // self.sharding_degree - ), "Expect optimizer state to be sharded across trainers." - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py deleted file mode 100644 index 02295a71da..0000000000 --- a/tests/paddle/parallel_tests/layernorm_linear_tp.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for LayerNormLinear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLayerNormLinearTp(unittest.TestCase): - """Tests LayerNormLinear layer with column/row parallelism in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - total_out = mp_ops._c_concat(out, group=self.tp_group) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_column_parallel_layer(self): - """Tests column parallel LayerNormLinear""" - set_random_seed(1024) - layer_te = te.LayerNormLinear( - self.in_features, - self.out_features, - eps=self.eps, - parallel_mode="column", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.LayerNormLinear( - self.in_features, - self.out_features, - eps=self.eps, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=0) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] - ) - assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=True, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLayerNormLinearTpFp8(TestLayerNormLinearTp): - """Tests LayernormLinear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestLayerNormLinearSp(TestLayerNormLinearTp): - """Tests LayernormLinear layer with sequence parallelism""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLayerNormLinearSpFp8(TestLayerNormLinearTp): - """Tests LayernormLinear layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py deleted file mode 100644 index f23cfb9e3f..0000000000 --- a/tests/paddle/parallel_tests/layernorm_mlp_tp.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for LayerNormMLP layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLayerNormMLPTp(unittest.TestCase): - """Tests LayerNormMLP layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - # Need to concat on the first dim, while _c_concat concats on the last dim - total_out = mp_ops._c_concat(out.T, group=self.tp_group).T - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_parallel_layer(self): - """Tests parallel LayerNormMLP""" - set_random_seed(1024) - layer_te = te.LayerNormMLP( - hidden_size=self.hidden_size, - ffn_hidden_size=self.ffn_hidden_size, - eps=self.eps, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.LayerNormMLP( - hidden_size=self.hidden_size, - ffn_hidden_size=self.ffn_hidden_size, - eps=self.eps, - set_parallel_mode=False, - backend="paddle", - ) - - def _get_total_weight(local_weight, tp_group, axis): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - # Get total weight - total_fc1_weight = _get_total_weight(layer_te.fc1_weight, tp_group=self.tp_group, axis=0) - total_fc2_weight = _get_total_weight(layer_te.fc2_weight, tp_group=self.tp_group, axis=1) - layer_pd.fc1_weight.copy_(total_fc1_weight.T, True) - layer_pd.fc2_weight.copy_(total_fc2_weight.T, True) - - assert_shape( - layer_te.fc1_weight, - [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size], - ) - assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size]) - assert_shape( - layer_te.fc2_weight, - [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size], - ) - assert_shape(layer_te.fc2_bias, [self.hidden_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.hidden_size], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=self.sequence_parallel, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLayerNormMLPTpFp8(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with tensor parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestLayerNormMLPSp(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLayerNormMLPSpFp8(TestLayerNormMLPTp): - """Tests LayerNormMLP layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 32 - self.ffn_hidden_size = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py deleted file mode 100644 index 0e7e90611e..0000000000 --- a/tests/paddle/parallel_tests/linear_pp.py +++ /dev/null @@ -1,235 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in pipeline parallel""" - -import unittest - -import numpy as np - -import paddle -from paddle.distributed import fleet - -from paddle.distributed.fleet.meta_parallel import ( - LayerDesc, - PipelineLayer, -) - -from utils import assert_allclose, set_random_seed -import transformer_engine.paddle as te - - -class TELinear(te.Linear): - """To pass is_first_microbatch""" - - def __init__(self, *args, **kwargs): - assert "accumulate_steps" in kwargs - self.accumulate_steps = kwargs["accumulate_steps"] - del kwargs["accumulate_steps"] - self._micro_batch_id = 0 - super().__init__(*args, **kwargs) - - def forward(self, *args, **kwargs): - kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0 - if paddle.is_grad_enabled() and self.training: - self._micro_batch_id += 1 - return super().forward(*args, **kwargs) - - -class TEPipelineModel(PipelineLayer): - """Model for pipeline parallel test""" - - def __init__( - self, - in_features, - hidden_features, - weight_attrs, - use_te=True, - use_fp8=False, - accumulate_steps=1, - **kwargs, - ): - self.in_features = in_features - self.hidden_features = hidden_features - self.fp8 = use_fp8 - hcg = fleet.get_hybrid_communicate_group() - self.dp_group = hcg.get_data_parallel_group() - - Linear = TELinear if use_te else paddle.nn.Linear - extra_kwargs = {} - if use_te: - extra_kwargs["accumulate_steps"] = accumulate_steps - - model_desc = [ - LayerDesc( - Linear, - self.in_features, - self.hidden_features, - weight_attr=weight_attrs[0], - **extra_kwargs, - ), - LayerDesc( - Linear, - self.hidden_features, - self.in_features, - weight_attr=weight_attrs[1], - **extra_kwargs, - ), - ] - super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs) - - def forward(self, *args, **kwargs): - with te.fp8_autocast(enabled=self.fp8, fp8_group=self.dp_group): - return super().forward(*args, **kwargs) - - -class StandaloneModel(paddle.nn.Layer): - """Model for pipeline parallel test""" - - def __init__(self, in_features, hidden_features, weight_attrs): - super().__init__() - self.in_features = in_features - self.hidden_features = hidden_features - Linear = paddle.nn.Linear - self.layer = paddle.nn.Sequential( - Linear(self.in_features, self.hidden_features, weight_attr=weight_attrs[0]), - Linear(self.hidden_features, self.in_features, weight_attr=weight_attrs[1]), - ) - self.loss = paddle.nn.CrossEntropyLoss() - - def forward(self, inp): - out = self.layer(inp[0]) - loss = self.loss(out, inp[1]) - return loss - - -class TestLinearPipelineParallel(unittest.TestCase): - """Tests Linear layer with pipeline parallel""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.pipeline_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": 1, - "pp_degree": self.pipeline_parallel_size, - } - self.accumulate_steps = self.batch_size // self.micro_batch_size - strategy.pipeline_configs = { - "accumulate_steps": self.accumulate_steps, - "micro_batch_size": self.micro_batch_size, - } - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - - def set_attr(self): - """Set test configs""" - self.batch_size = 32 - self.micro_batch_size = 16 - self.in_features = 32 - self.hidden_features = 64 - self.global_dtype = "float32" - self.rtol = 1e-5 - self.atol = 1e-5 - self.iter = 10 - self.fp8 = False - - def test_pipeline_train(self): - """Test pipeline parallel training""" - set_random_seed(1024) - np.random.seed(1024) - - weight1_np = np.random.normal(size=[self.in_features, self.hidden_features]) - weight2_np = np.random.normal(size=[self.hidden_features, self.in_features]) - weight_attrs = [ - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np)), - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np)), - ] - weight_attrs_transposed = [ - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight1_np.T)), - paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(weight2_np.T)), - ] - - pipe_model = TEPipelineModel( - self.in_features, - self.hidden_features, - weight_attrs_transposed, - use_te=True, - use_fp8=self.fp8, - seg_method="layer:Linear", - num_stages=self.pipeline_parallel_size, - accumulate_steps=self.accumulate_steps, - ) - - # Check if model is split across ranks as expected - for name, sublayer in pipe_model.named_sublayers(): - if name in ("_loss_fn", "shared_layers"): - continue - if self.rank == 0: - assert tuple(sublayer.weight.shape) == weight1_np.T.shape, ( - f"Shape does not match, expect: {weight1_np.T.shape}, " - f"actual: {tuple(sublayer.weight.shape)}" - ) - elif self.rank == 1: - assert tuple(sublayer.weight.shape) == weight2_np.T.shape, ( - f"Shape does not match, expect: {weight2_np.T.shape}, " - f"actual: {tuple(sublayer.weight.shape)}" - ) - - standalone_model = StandaloneModel( - self.in_features, - self.hidden_features, - weight_attrs, - ) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters()) - optimizer_pd = paddle.optimizer.SGD( - learning_rate=0.1, parameters=standalone_model.parameters() - ) - - pipe_model = fleet.distributed_model(pipe_model) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - def train_one_step(layer, inp, optimizer): - loss = layer(inp) - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss - - for i in range(self.iter): - inp = paddle.to_tensor( - np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype - ) - label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1])) - loss_te = pipe_model.train_batch([inp, label], optimizer_te) - loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd) - print(f"Iter: {i}, loss_te: {loss_te.item()}, loss_pd: {loss_pd.item()}") - assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol) - - -class TestLinearPipelineParallelFP8(TestLinearPipelineParallel): - """Tests Linear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 32 - self.micro_batch_size = 16 - self.in_features = 32 - self.hidden_features = 64 - self.global_dtype = "float32" - self.rtol = 5e-2 - self.atol = 5e-2 - self.iter = 10 - self.fp8 = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py deleted file mode 100644 index 4a49474a37..0000000000 --- a/tests/paddle/parallel_tests/linear_tp.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Linear layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, assert_shape, set_random_seed -import transformer_engine.paddle as te - - -class TestLinearTp(unittest.TestCase): - """Tests Linear layer with column/row parallelism in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False): - inp = paddle.to_tensor(inp, stop_gradient=True) - assert split_input in ["none", "column", "row"] - if split_input == "column": - split_size = inp.shape[1] // self.world_size - input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)] - elif split_input == "row": - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - input_parallel.stop_gradient = False - out = layer(input_parallel) - if gather_output: - total_out = mp_ops._c_concat(out, group=self.tp_group) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - if split_input != "none": - grad_input = [] - paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group) - if split_input == "column": - grad_input = paddle.concat(grad_input, axis=1) - elif split_input == "row": - grad_input = paddle.concat(grad_input, axis=0) - else: - grad_input = input_parallel.grad - return loss, grad_input - - def test_column_parallel_layer(self): - """Tests column parallel linear""" - set_random_seed(1024) - layer_te = te.Linear( - self.in_features, - self.out_features, - parallel_mode="column", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.Linear( - self.in_features, - self.out_features, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=0) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features] - ) - assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="row" if self.sequence_parallel else "none", - gather_output=True, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - def test_row_parallel_layer(self): - """Tests row parallel linear""" - set_random_seed(1024) - layer_te = te.Linear( - self.in_features, - self.out_features, - parallel_mode="row", - sequence_parallel=self.sequence_parallel, - ) - layer_pd = te.Linear( - self.in_features, - self.out_features, - backend="paddle", - ) - # Get total weight - total_weight = [] - partial_weight = layer_te.weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=self.tp_group) - total_weight = paddle.concat(total_weight, axis=1) - layer_pd.weight.copy_(total_weight.T, True) - - assert_shape( - layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size] - ) - assert_shape(layer_te.bias, [self.out_features]) - - optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters()) - optimizer_pd = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_pd.parameters()) - - layer_te = fleet.distributed_model(layer_te) - optimizer_te = fleet.distributed_optimizer(optimizer_te) - - for _ in range(5): - inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype) - with te.fp8_autocast(enabled=self.fp8): - loss_tp, grad_input = self._train_one_step( - layer_te, - inp, - optimizer_te, - split_input="column", - gather_output=self.sequence_parallel, - ) - loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd) - assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol) - assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol) - - -class TestLinearTpFP8(TestLinearTp): - """Tests Linear layer with column/row parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.fp8 = True - self.sequence_parallel = False - - -class TestLinearSp(TestLinearTp): - """Tests Linear layer with sequence parallelism""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-3 - self.atol = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestLinearSpFP8(TestLinearTp): - """Tests Linear layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.in_features = 32 - self.out_features = 64 - self.global_dtype = "bfloat16" - self.rtol = 1e-2 - self.atol = 1e-2 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py deleted file mode 100644 index 5506be042f..0000000000 --- a/tests/paddle/parallel_tests/transformer_tp.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Unittest for Transformer layer in tensor parallel""" - -import unittest - -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.layers.mpu import mp_ops - -from utils import assert_allclose, set_random_seed, register_sequence_parallel_allreduce_hooks -import transformer_engine.paddle as te - - -class TestTransformerTp(unittest.TestCase): - """Tests Transformer layer with model parallel in BF16""" - - def setUp(self): - self.set_attr() - self.init_dist_env() - paddle.set_default_dtype(self.global_dtype) - - def init_dist_env(self): - """Init Paddle Fleet environment""" - strategy = fleet.DistributedStrategy() - self.model_parallel_size = 2 - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": self.model_parallel_size, - "pp_degree": 1, - } - strategy.hybrid_configs["mp_configs"].need_broadcast_data = False - fleet.init(is_collective=True, strategy=strategy) - self.rank = fleet.worker_index() - self.hcg = fleet.get_hybrid_communicate_group() - self.tp_group = self.hcg.get_model_parallel_group() - self.world_size = self.hcg.get_model_parallel_world_size() - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = False - - def _train_one_step(self, layer, inp_list, optimizer, fp8_enabled, sequence_parallel=False): - inp, mask = inp_list - if sequence_parallel: - split_size = inp.shape[0] // self.world_size - input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :] - else: - input_parallel = inp - with te.fp8_autocast(enabled=fp8_enabled): - out = layer(input_parallel, mask) - if sequence_parallel: - total_out = mp_ops._c_concat(out, group=self.tp_group) - total_out = paddle.concat(paddle.split(total_out, self.world_size, axis=-1), axis=0) - else: - total_out = out - loss = total_out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - return loss, total_out - - def test_parallel_layer(self): - """Tests parallel Transformer""" - set_random_seed(1024) - common_args = [ - self.hidden_size, - self.ffn_hidden_size, - self.num_heads, - ] - common_kwargs = { - "layernorm_epsilon": self.eps, - "hidden_dropout": 0.0, - "attention_dropout": 0.0, - "self_attn_mask_type": self.mask_type, - "layer_type": self.layer_type, - } - layer_tp = te.TransformerLayer( - *common_args, - **common_kwargs, - set_parallel_mode=True, - sequence_parallel=self.sequence_parallel, - ) - layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False) - - def _get_total_weight(local_weight, tp_group, axis, interleave=False): - total_weight = [] - partial_weight = local_weight.clone().detach() - paddle.distributed.all_gather(total_weight, partial_weight, group=tp_group) - if interleave: - # Due to the interleaved qkv layout, need to concat on num_head - # dimension for column parallel linear in MultiHeadAttention layer - assert axis == 0 - assert [ - 3 * self.hidden_size // self.world_size, - self.hidden_size, - ] == partial_weight.shape - local_num_head = self.num_heads // self.world_size - for idx, _ in enumerate(total_weight): - total_weight[idx] = total_weight[idx].reshape( - [3, local_num_head, -1, self.hidden_size] - ) - total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size]) - else: - total_weight = paddle.concat(total_weight, axis=axis) - return total_weight - - def _get_weight(obj, weight_names): - for name in weight_names: - obj = getattr(obj, name) - return obj - - def copy_weight(layer_src, layer_dst, partition_mode, weight_names, interleave=False): - weight_src = _get_weight(layer_src, weight_names) - weight_dst = _get_weight(layer_dst, weight_names) - if partition_mode is None: - total_weight = weight_src - elif partition_mode == "column": - total_weight = _get_total_weight( - weight_src, tp_group=self.tp_group, axis=0, interleave=interleave - ) - elif partition_mode == "row": - total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1) - else: - raise ValueError(f"Partition Mode {partition_mode} is not supported.") - assert ( - weight_dst.shape == total_weight.shape - ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match." - weight_dst.copy_(total_weight, True) - - copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"]) - copy_weight( - layer_tp, - layer_single, - "column", - ["self_attention", "layernorm_qkv", "weight"], - interleave=True, - ) - copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"]) - copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"]) - copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"]) - copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"]) - - if self.sequence_parallel: - register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1) - - optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters()) - optimizer_single = paddle.optimizer.SGD( - learning_rate=0.01, parameters=layer_single.parameters() - ) - - layer_tp = fleet.distributed_model(layer_tp) - optimizer_tp = fleet.distributed_optimizer(optimizer_tp) - - for _ in range(5): - inp = paddle.uniform( - [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype - ) - mask = paddle.zeros( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool" - ) - loss_tp, out_tp = self._train_one_step( - layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel - ) - loss_single, out_single = self._train_one_step( - layer_single, [inp, mask], optimizer_single, self.fp8 - ) - assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol) - assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol) - - -class TestTransformerTpFp8(TestTransformerTp): - """Tests Transformer layer with tensor parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 0.5 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = False - - -class TestTransformerSp(TestTransformerTp): - """Tests Transformer layer with sequence parallel in BF16""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 5e-2 - self.eps = 1e-3 - self.fp8 = False - self.sequence_parallel = True - - -class TestTransformerSpFp8(TestTransformerSp): - """Tests Transformer layer with sequence parallelism in FP8""" - - def set_attr(self): - """Set test configs""" - self.batch_size = 16 - self.hidden_size = 1024 - self.num_heads = 16 - self.ffn_hidden_size = 4096 - self.q_seqlen = 128 - self.kv_seqlen = 128 - self.mask_type = "padding" - self.layer_type = "encoder" - self.global_dtype = "bfloat16" - self.rtol = 5e-2 - self.atol = 0.5 - self.eps = 1e-3 - self.fp8 = True - self.sequence_parallel = True - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/recompute_tests/recompute_transformer_encoder.py b/tests/paddle/recompute_tests/recompute_transformer_encoder.py deleted file mode 100644 index 56d0c24535..0000000000 --- a/tests/paddle/recompute_tests/recompute_transformer_encoder.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TransformerLayer encoder recompute""" - -import sys -import paddle -import transformer_engine.paddle as te - - -class Net(paddle.nn.Layer): - """Network use for recompute testing""" - - def __init__(self, layers): - super().__init__() - self.layers = layers - - def forward(self, inp, mask, enable_recompute, use_reentrant): - for layer in self.layers: - if enable_recompute: - out = te.recompute(layer, inp, mask, use_reentrant=use_reentrant) - else: - out = layer(inp, mask) - return out - - -def main(): - """Main function""" - paddle.seed(10) - batch_size = 16 - hidden_size = 4096 - num_heads = 32 - ffn_hidden_size = 16384 - q_seqlen = 512 - kv_seqlen = 512 - num_layers = 4 - enable_recompute = int(sys.argv[1]) - use_reentrant = int(sys.argv[2]) - - layers = paddle.nn.LayerList( - [ - te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layer_type="encoder", - ) - for _ in range(num_layers) - ] - ) - model = Net(layers) - - optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters()) - - for _ in range(10): - inp = paddle.uniform([batch_size, q_seqlen, hidden_size]) - inp.stop_gradient = False - mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool") - with te.fp8_autocast(enabled=True): - out = model(inp, mask, enable_recompute, use_reentrant) - loss = out.mean() - loss.backward() - optimizer.step() - optimizer.clear_grad() - - print("Loss: ", float(loss)) - print("Peak memory: ", paddle.device.cuda.max_memory_allocated(0)) - - -if __name__ == "__main__": - main() diff --git a/tests/paddle/test_install.py b/tests/paddle/test_install.py deleted file mode 100644 index 686771ec09..0000000000 --- a/tests/paddle/test_install.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test basic installation of Paddle extensions""" - - -def test_import(): - """ - Test if Paddle extension can be imported normally - """ - import transformer_engine.paddle # pylint: disable=unused-import diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py deleted file mode 100644 index b519fc0a0f..0000000000 --- a/tests/paddle/test_layers.py +++ /dev/null @@ -1,1663 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Layer-level APIs""" - -import os -from utils import assert_allclose, is_fused_attention_supported - -import paddle -import pytest - -from transformer_engine.common.recipe import DelayedScaling -import transformer_engine.paddle as te -from transformer_engine.paddle.fp8 import is_fp8_available, fp8_autocast - -is_fp8_supported, reason = is_fp8_available() -LINEAR_CASES = [(16, 16, 32), (32, 32, 64)] -NORM_CASES = [(16, 32), (256, 1024)] - - -@pytest.fixture(autouse=True) -def setup(): - """Setup random seed before each test""" - paddle.seed(10) - yield - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("use_fp8", [True, False]) -def test_checkpoint(use_fp8): - """Test checkpoint save / load""" - bs = 16 - in_features = 16 - out_features = 32 - file_name = "model.pdparams" - input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32") - model = te.Linear(in_features, out_features) - model_loaded = te.Linear(in_features, out_features) - # Populate amax_history - with fp8_autocast(enabled=False, calibrating=True): - _ = model(input_tensor) - # Save model - paddle.save(model.state_dict(), file_name) - # Get ref output - with fp8_autocast(enabled=use_fp8): - out_ref = model(input_tensor) - # Load model - model_loaded.set_state_dict(paddle.load(file_name)) - if os.path.exists(file_name): - os.remove(file_name) - # Get actual output - with fp8_autocast(enabled=use_fp8): - out = model_loaded(input_tensor) - - assert_allclose(out, out_ref) - - -def calc_output_and_grad(layer, x, dy): - """ - Calculate forward and backward pass - """ - inp = paddle.to_tensor(x) - inp.stop_gradient = x.stop_gradient - y = layer(inp) - y.backward(dy) - - return y, inp.grad if not inp.stop_gradient else None - - -@staticmethod -def calc_output_and_grad_ln_out(layer, x, dy, return_ln_out=False): - """ - Calculate forward and backward pass for layernorm - """ - inp = paddle.to_tensor(x) - inp.stop_gradient = x.stop_gradient - outputs = layer(inp) - ln_out = None - if return_ln_out: - y, ln_out = outputs - else: - y = outputs - y.backward(dy) - - return y, ln_out, inp.grad if not inp.stop_gradient else None - - -class TestLinear: - """ - Tests for Linear layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - def test_linear_bf16( - bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype - ): - """ - Test BF16 Linear - """ - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False) - layer_pd = te.Linear( - in_features, out_features, bias_attr=None if has_bias else False, backend="paddle" - ) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) - out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - def test_linear_fp8( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - activation_dtype, - ): - """ - Test FP8 Linear - """ - rtol = 0.1 - atol = 0.5 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.Linear( - in_features=in_features, - out_features=out_features, - bias_attr=None if has_bias else False, - ) - layer_pd = te.Linear( - in_features=in_features, - out_features=out_features, - bias_attr=None if has_bias else False, - backend="paddle", - ) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out) - out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch): - """ - Test FP8 Linear - """ - rtol = 0.1 - atol = 0.1 - - recipe = DelayedScaling() - - paddle.set_default_dtype(activation_dtype) - layer_cached = te.Linear( - in_features=in_features, - out_features=out_features, - ) - layer_normal = te.Linear( - in_features=in_features, - out_features=out_features, - ) - layer_cached.weight.copy_(layer_normal.weight, True) - layer_cached.bias.copy_(layer_normal.bias, True) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol - ) - - -@pytest.mark.parametrize("bs,hidden_size", NORM_CASES) -@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) -@pytest.mark.parametrize("no_dgrad", [True, False]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) -def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype): - """ - Test BF16 LayerNorm - """ - eps = 1e-3 - rtol = 1e-2 - atol = 1e-2 - - x = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - x.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - - paddle.set_default_dtype(activation_dtype) - layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False) - layer_pd = te.LayerNorm( - hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle" - ) - layer_pd.weight.copy_(layer_te.weight, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, grad_input_ref = calc_output_and_grad(layer_pd, x, grad_out) - out, grad_input = calc_output_and_grad(layer_te, x, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad, rtol=rtol, atol=atol) - if has_bias and not no_dbias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - - -class TestLayerNormLinear: - """ - Tests for LayerNormLinear layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - def test_layernorm_linear_bf16( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - return_ln_out, - activation_dtype, - normalization, - ): - """ - Test BF16 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - layer_te = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - def test_layernorm_linear_fp8( - bs, - in_features, - out_features, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - return_ln_out, - activation_dtype, - normalization, - ): - """ - Test FP8 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 0.1 - atol = 0.75 - - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - layer_te = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - normalization=normalization, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.weight.copy_(layer_te.weight.T, True) - if has_bias: - layer_pd.bias.copy_(layer_te.bias, True) - - layer_te.weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.bias.stop_gradient = no_dbias - layer_pd.bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.weight.grad, layer_pd.weight.grad.T, rtol=rtol, atol=atol) - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose(layer_te.bias.grad, layer_pd.bias.grad, rtol=rtol, atol=atol) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_layernorm_linear_fp8_microbatch( - bs, in_features, out_features, activation_dtype, num_microbatch - ): - """ - Test FP8 LayerNormLinear Layer - """ - paddle.set_default_dtype(activation_dtype) - eps = 1e-3 - rtol = 0.5 - atol = 0.5 - - recipe = DelayedScaling() - - layer_cached = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - ) - - layer_normal = te.LayerNormLinear( - in_features=in_features, - out_features=out_features, - eps=eps, - ) - - layer_cached.ln_weight.copy_(layer_normal.ln_weight, True) - layer_cached.ln_bias.copy_(layer_normal.ln_bias, True) - layer_cached.weight.copy_(layer_normal.weight, True) - layer_cached.bias.copy_(layer_normal.bias, True) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, in_features), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol - ) - - -class TestLayerNormMLP: - """ - Test LayerNormMLP Layer - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), - reason="BF16 Linear requires Ampere+ GPU", - ) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) - def test_layernorm_mlp_bf16( - bs, - hidden_size, - ffn_hidden_size, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - return_ln_out, - activation_dtype, - normalization, - activation, - ): - """ - Tests for TestLayerNormMLP layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 5e-2 - atol = 5e-2 - - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - layer_te = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - layer_pd = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) - layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) - if has_bias: - layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True) - layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True) - - layer_te.fc1_weight.stop_gradient = no_wgrad - layer_te.fc2_weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.fc1_weight.stop_gradient = no_wgrad - layer_pd.fc2_weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.fc1_bias.stop_gradient = no_dbias - layer_te.fc2_bias.stop_gradient = no_dbias - layer_pd.fc1_bias.stop_gradient = no_dbias - layer_pd.fc2_bias.stop_gradient = no_dbias - - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose( - layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol - ) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose( - layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol - ) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]]) - @pytest.mark.parametrize("no_dgrad", [True, False]) - @pytest.mark.parametrize("no_wgrad", [True, False]) - @pytest.mark.parametrize("fp8_wgrad", [True, False]) - @pytest.mark.parametrize("do_calibration", [True, False]) - @pytest.mark.parametrize("return_ln_out", [True, False]) - @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"]) - @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) - @pytest.mark.parametrize("activation", ["gelu", "swiglu"]) - def test_layernorm_mlp_fp8( - bs, - hidden_size, - ffn_hidden_size, - has_bias, - no_dbias, - no_dgrad, - no_wgrad, - fp8_wgrad, - do_calibration, - return_ln_out, - activation_dtype, - normalization, - activation, - ): - """ - Test FP8 LayerNormMLP Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 0.1 - atol = 0.75 - - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - input_tensor.stop_gradient = no_dgrad - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad)) - - layer_te = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - ) - - layer_pd = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - normalization=normalization, - activation=activation, - bias_attr=None if has_bias else False, - return_layernorm_output=return_ln_out, - backend="paddle", - ) - layer_pd.ln_weight.copy_(layer_te.ln_weight, True) - if has_ln_bias: - layer_pd.ln_bias.copy_(layer_te.ln_bias, True) - layer_pd.fc1_weight.copy_(layer_te.fc1_weight.T, True) - layer_pd.fc2_weight.copy_(layer_te.fc2_weight.T, True) - if has_bias: - layer_pd.fc1_bias.copy_(layer_te.fc1_bias, True) - layer_pd.fc2_bias.copy_(layer_te.fc2_bias, True) - - layer_te.fc1_weight.stop_gradient = no_wgrad - layer_te.fc2_weight.stop_gradient = no_wgrad - layer_te.ln_weight.stop_gradient = no_wgrad - layer_pd.fc1_weight.stop_gradient = no_wgrad - layer_pd.fc2_weight.stop_gradient = no_wgrad - layer_pd.ln_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_te.ln_bias.stop_gradient = no_dbias - layer_pd.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_te.fc1_bias.stop_gradient = no_dbias - layer_te.fc2_bias.stop_gradient = no_dbias - layer_pd.fc1_bias.stop_gradient = no_dbias - layer_pd.fc2_bias.stop_gradient = no_dbias - - with fp8_autocast( - enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe - ): - out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out( - layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out - ) - out, ln_out, grad_input = calc_output_and_grad_ln_out( - layer_te, input_tensor, grad_out, return_ln_out=return_ln_out - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - if not no_dgrad: - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol) - assert_allclose( - layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol - ) - if not no_dbias: - if has_ln_bias: - assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol) - if has_bias: - assert_allclose( - layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol - ) - if return_ln_out: - assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol) - - if do_calibration: - assert paddle.count_nonzero(layer_te.fp8_meta["scaling_fwd"].amax_history).item() > 0 - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES) - @pytest.mark.parametrize("activation_dtype", ["bfloat16"]) - @pytest.mark.parametrize("num_microbatch", [8]) - def test_layernorm_mlp_fp8_microbatch( - bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch - ): - """ - Test FP8 LayerNormMLP Layer - """ - paddle.set_default_dtype(activation_dtype) - rtol = 1e-5 - atol = 1e-5 - eps = 1e-3 - - recipe = DelayedScaling() - - layer_cached = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - ) - - layer_normal = te.LayerNormMLP( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - eps=eps, - ) - layer_normal.ln_weight.copy_(layer_cached.ln_weight, True) - layer_normal.ln_bias.copy_(layer_cached.ln_bias, True) - layer_normal.fc1_weight.copy_(layer_cached.fc1_weight, True) - layer_normal.fc2_weight.copy_(layer_cached.fc2_weight, True) - layer_normal.fc1_bias.copy_(layer_cached.fc1_bias, True) - layer_normal.fc2_bias.copy_(layer_cached.fc2_bias, True) - - # Calibration to make sure weight scale is the same - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_cached(input_tensor) - - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_normal(input_tensor) - - for iteration in range(num_microbatch): - input_tensor = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(input_tensor, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(input_tensor) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol - ) - assert_allclose( - layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol - ) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("attn_type", ["self", "cross"]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("deterministic", [True, False]) -def test_dot_product_attention( - bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype, deterministic -): - """ - Test DotProductAttention Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 1e-4 - atol = 2e-2 - head_size = hidden_size // num_heads - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=head_size, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - attn_q_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size) - ).astype(math_dtype) - attn_k_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) - ).astype(math_dtype) - attn_v_input = paddle.normal( - mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size) - ).astype(math_dtype) - - q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32") - kv_actual_seqlen = ( - paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32") - if attn_type == "cross" - else q_actual_seqlen - ) - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - head_size = hidden_size // num_heads - - if deterministic: - os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" - - layer_te = te.DotProductAttention( - num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend="transformer_engine", - ) - layer_pd = te.DotProductAttention( - num_heads, - head_size, - attention_dropout=0.0, - attn_mask_type=mask_type, - attention_type=attn_type, - backend="paddle", - ) - - def calc_attn_output_and_grad(layer, q, k, v, mask, dout): - _q = paddle.to_tensor(q, stop_gradient=False) - _k = paddle.to_tensor(k, stop_gradient=False) - _v = paddle.to_tensor(v, stop_gradient=False) - - out = layer(_q, _k, _v, mask) - out.backward(dout) - return out, _q.grad, _k.grad, _v.grad - - out, q_grad, k_grad, v_grad = calc_attn_output_and_grad( - layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad( - layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - valid_out_ref = paddle.full_like(out_ref, 0) - for i in range(0, bs): - valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :] - - valid_q_grad_ref = paddle.full_like(q_grad_ref, 0) - valid_k_grad_ref = paddle.full_like(k_grad_ref, 0) - valid_v_grad_ref = paddle.full_like(v_grad_ref, 0) - for i in range(0, bs): - valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[ - i, 0 : q_actual_seqlen[i], :, : - ] - valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[ - i, 0 : kv_actual_seqlen[i], :, : - ] - valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[ - i, 0 : kv_actual_seqlen[i], :, : - ] - - assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol) - assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol) - assert_allclose(k_grad, valid_k_grad_ref, rtol=rtol, atol=atol) - assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol) - if deterministic: - out2, q_grad2, k_grad2, v_grad2 = calc_attn_output_and_grad( - layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out - ) - assert_allclose(out, out2, rtol=1e-12, atol=1e-12) - assert_allclose(q_grad, q_grad2, rtol=1e-12, atol=1e-12) - assert_allclose(k_grad, k_grad2, rtol=1e-12, atol=1e-12) - assert_allclose(v_grad, v_grad2, rtol=1e-12, atol=1e-12) - os.environ.pop("NVTE_ALLOW_NONDETERMINISTIC_ALGO", None) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("output_layernorm", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) -def test_transformer_encoder_layer( - bs, - hidden_size, - num_heads, - num_gqa_groups, - ffn_hidden_size, - has_bias, - no_dbias, - no_wgrad, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - output_layernorm, - return_layernorm_output, - normalization, -): - """ - Test Transformer Encoder Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 5e-2 - atol = 5e-2 - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="encoder", - normalization=normalization, - backend="transformer_engine", - ) - layer_pd = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="encoder", - normalization=normalization, - backend="paddle", - ) - - # MultiHeadAttention params - if output_layernorm: - layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) - layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) - layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.qkv.bias.stop_gradient = no_dbias - else: - layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True - ) - layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True - ) - layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True - ) - layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True - ) - layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - - layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) - layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad - layer_te.self_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) - layer_pd.self_attention.proj.bias.stop_gradient = no_dbias - layer_te.self_attention.proj.bias.stop_gradient = no_dbias - - # LayerNorm MLP params - layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) - layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) - layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) - layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) - layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) - layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) - layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - - if output_layernorm: - layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) - layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) - layer_pd.layernorm.weight.stop_gradient = no_wgrad - layer_pd.layernorm.bias.stop_gradient = no_dbias - layer_te.layernorm.weight.stop_gradient = no_wgrad - layer_te.layernorm.bias.stop_gradient = no_dbias - - def calc_transformer_output_and_grad(layer, encoder_input, mask, dout): - _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) - out = layer(_encoder_input, mask) - out.backward(dout) - return out, _encoder_input.grad - - out_ref, grad_input_ref = calc_transformer_output_and_grad( - layer_pd, encoder_input, attn_mask, grad_out - ) - out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol) - if not no_wgrad: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - if not no_dbias: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - - -@pytest.mark.parametrize("bs", [1, 2]) -@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]]) -@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]]) -@pytest.mark.parametrize("no_wgrad", [True, False]) -@pytest.mark.parametrize("mask_type", ["causal", "padding"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"]) -@pytest.mark.parametrize("output_layernorm", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize("recompute_core_attention", [True, False]) -@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"]) -def test_transformer_decoder_layer( - bs, - hidden_size, - num_heads, - num_gqa_groups, - ffn_hidden_size, - has_bias, - no_dbias, - no_wgrad, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - output_layernorm, - return_layernorm_output, - recompute_core_attention, - normalization, -): - """ - Test Transformer Decoder Layer - """ - paddle.set_default_dtype(math_dtype) - rtol = 5e-2 - atol = 6e-2 - eps = 1e-3 - has_ln_bias = normalization == "LayerNorm" - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype( - math_dtype - ) - encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype( - math_dtype - ) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - - # rounding to avoid numerical issues - encoder_input = paddle.round(encoder_input * 1000) / 1000 - encoder_output = paddle.round(encoder_output * 1000) / 1000 - grad_out = paddle.round(grad_out * 1000) / 1000 - - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - layer_te = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="decoder", - normalization=normalization, - backend="transformer_engine", - ) - layer_pd = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - num_gqa_groups=num_gqa_groups, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None if has_bias else False, - self_attn_mask_type=mask_type, - apply_residual_connection_post_layernorm=return_layernorm_output, - output_layernorm=output_layernorm, - layer_type="decoder", - normalization=normalization, - backend="paddle", - ) - - # MultiHeadAttention params - self attn - if output_layernorm: - layer_pd.self_attention.qkv.weight.copy_(layer_te.self_attention.qkv.weight.T, True) - layer_pd.self_attention.qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.qkv.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.qkv.bias.copy_(layer_te.self_attention.qkv.bias, True) - layer_pd.self_attention.qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.qkv.bias.stop_gradient = no_dbias - else: - layer_pd.self_attention.layernorm_qkv.ln_weight.copy_( - layer_te.self_attention.layernorm_qkv.ln_weight, True - ) - layer_pd.self_attention.layernorm_qkv.weight.copy_( - layer_te.self_attention.layernorm_qkv.weight.T, True - ) - layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad - layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.self_attention.layernorm_qkv.ln_bias.copy_( - layer_te.self_attention.layernorm_qkv.ln_bias, True - ) - layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.self_attention.layernorm_qkv.bias.copy_( - layer_te.self_attention.layernorm_qkv.bias, True - ) - layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias - - layer_pd.self_attention.proj.weight.copy_(layer_te.self_attention.proj.weight.T, True) - layer_pd.self_attention.proj.weight.stop_gradient = no_wgrad - layer_te.self_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.self_attention.proj.bias.copy_(layer_te.self_attention.proj.bias, True) - layer_pd.self_attention.proj.bias.stop_gradient = no_dbias - layer_te.self_attention.proj.bias.stop_gradient = no_dbias - - # MultiHeadAttention params - cross attn - layer_pd.inter_attention.layernorm_query.ln_weight.copy_( - layer_te.inter_attention.layernorm_query.ln_weight, True - ) - layer_pd.inter_attention.layernorm_query.weight.copy_( - layer_te.inter_attention.layernorm_query.weight.T, True - ) - layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad - layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad - layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad - layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.inter_attention.layernorm_query.ln_bias.copy_( - layer_te.inter_attention.layernorm_query.ln_bias, True - ) - layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias - layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.inter_attention.layernorm_query.bias.copy_( - layer_te.inter_attention.layernorm_query.bias, True - ) - layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias - layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias - - layer_pd.inter_attention.key_value.weight.copy_( - layer_te.inter_attention.key_value.weight.T, True - ) - layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad - layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad - layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True) - layer_pd.inter_attention.proj.weight.stop_gradient = no_wgrad - layer_te.inter_attention.proj.weight.stop_gradient = no_wgrad - if has_bias: - layer_pd.inter_attention.key_value.bias.copy_(layer_te.inter_attention.key_value.bias, True) - layer_pd.inter_attention.key_value.bias.stop_gradient = no_dbias - layer_te.inter_attention.key_value.bias.stop_gradient = no_dbias - layer_pd.inter_attention.proj.bias.copy_(layer_te.inter_attention.proj.bias, True) - layer_pd.inter_attention.proj.bias.stop_gradient = no_dbias - layer_te.inter_attention.proj.bias.stop_gradient = no_dbias - - # LayerNorm MLP params - layer_pd.layernorm_mlp.ln_weight.copy_(layer_te.layernorm_mlp.ln_weight, True) - layer_pd.layernorm_mlp.fc1_weight.copy_(layer_te.layernorm_mlp.fc1_weight.T, True) - layer_pd.layernorm_mlp.fc2_weight.copy_(layer_te.layernorm_mlp.fc2_weight.T, True) - layer_pd.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_pd.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.ln_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc1_weight.stop_gradient = no_wgrad - layer_te.layernorm_mlp.fc2_weight.stop_gradient = no_wgrad - if has_ln_bias: - layer_pd.layernorm_mlp.ln_bias.copy_(layer_te.layernorm_mlp.ln_bias, True) - layer_pd.layernorm_mlp.ln_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.ln_bias.stop_gradient = no_dbias - if has_bias: - layer_pd.layernorm_mlp.fc1_bias.copy_(layer_te.layernorm_mlp.fc1_bias, True) - layer_pd.layernorm_mlp.fc2_bias.copy_(layer_te.layernorm_mlp.fc2_bias, True) - layer_pd.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_pd.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc1_bias.stop_gradient = no_dbias - layer_te.layernorm_mlp.fc2_bias.stop_gradient = no_dbias - - if output_layernorm: - layer_pd.layernorm.weight.copy_(layer_te.layernorm.weight, True) - layer_pd.layernorm.bias.copy_(layer_te.layernorm.bias, True) - layer_pd.layernorm.weight.stop_gradient = no_wgrad - layer_pd.layernorm.bias.stop_gradient = no_dbias - layer_te.layernorm.weight.stop_gradient = no_wgrad - layer_te.layernorm.bias.stop_gradient = no_dbias - - def calc_transformer_output_and_grad( - layer, - encoder_input, - mask, - encoder_output, - enc_dec_attn_mask, - dout, - recompute_core_attention=False, - ): - _encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False) - _encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False) - out = layer( - _encoder_input, - mask, - _encoder_output, - enc_dec_attn_mask, - recompute_core_attention=recompute_core_attention, - ) - out.backward(dout) - return out, _encoder_input.grad, _encoder_output.grad - - out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad( - layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out - ) - out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad( - layer_te, - encoder_input, - attn_mask, - encoder_output, - attn_mask, - grad_out, - recompute_core_attention=recompute_core_attention, - ) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol) - assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol) - if not no_wgrad: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.weight.grad, - layer_pd.self_attention.qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.weight.grad, - layer_pd.self_attention.layernorm_qkv.weight.grad.T, - rtol=rtol, - atol=atol, - ) - assert_allclose( - layer_te.inter_attention.layernorm_query.weight.grad, - layer_pd.inter_attention.layernorm_query.weight.grad.T, - rtol=rtol, - atol=atol, - ) - if not no_dbias: - if output_layernorm: - assert_allclose( - layer_te.self_attention.qkv.bias.grad, - layer_pd.self_attention.qkv.bias.grad, - rtol=0.5, - atol=0.6, - ) - else: - assert_allclose( - layer_te.self_attention.layernorm_qkv.bias.grad, - layer_pd.self_attention.layernorm_qkv.bias.grad, - rtol=0.01, - atol=0.5, - ) - assert_allclose( - layer_te.inter_attention.layernorm_query.bias.grad, - layer_pd.inter_attention.layernorm_query.bias.grad, - rtol=rtol, - atol=atol, - ) - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("bs", [8]) -@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]]) -@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]]) -@pytest.mark.parametrize("mask_type", ["causal"]) -@pytest.mark.parametrize("math_dtype", ["bfloat16"]) -@pytest.mark.parametrize("num_microbatch", [8]) -def test_transformer_encoder_layer_microbatch( - bs, - hidden_size, - num_heads, - ffn_hidden_size, - q_seqlen, - kv_seqlen, - mask_type, - math_dtype, - num_microbatch, -): - """ - Test Transformer Encoder Layer with FP8 weight caching - """ - paddle.set_default_dtype(math_dtype) - rtol = 1e-5 - atol = 1e-5 - eps = 1e-3 - - # Skip if cuDNN fused attention is not supported - if not is_fused_attention_supported( - num_heads=num_heads, - num_gqa_groups=num_heads, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=hidden_size // num_heads, - dtype=math_dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type=mask_type, - ): - pytest.skip("cuDNN fused attention is not supported") - - layer_cached = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type="encoder", - ) - layer_normal = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_heads, - layernorm_epsilon=eps, - hidden_dropout=0.0, - attention_dropout=0.0, - weight_attr=None, - bias_attr=None, - self_attn_mask_type=mask_type, - layer_type="encoder", - ) - - layer_normal.self_attention.layernorm_qkv.ln_weight.copy_( - layer_cached.self_attention.layernorm_qkv.ln_weight, True - ) - layer_normal.self_attention.layernorm_qkv.ln_bias.copy_( - layer_cached.self_attention.layernorm_qkv.ln_bias, True - ) - layer_normal.self_attention.layernorm_qkv.weight.copy_( - layer_cached.self_attention.layernorm_qkv.weight, True - ) - layer_normal.self_attention.layernorm_qkv.bias.copy_( - layer_cached.self_attention.layernorm_qkv.bias, True - ) - - layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True) - layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True) - - # LayerNorm MLP params - layer_normal.layernorm_mlp.ln_weight.copy_(layer_cached.layernorm_mlp.ln_weight, True) - layer_normal.layernorm_mlp.ln_bias.copy_(layer_cached.layernorm_mlp.ln_bias, True) - layer_normal.layernorm_mlp.fc1_weight.copy_(layer_cached.layernorm_mlp.fc1_weight, True) - layer_normal.layernorm_mlp.fc2_weight.copy_(layer_cached.layernorm_mlp.fc2_weight, True) - layer_normal.layernorm_mlp.fc1_bias.copy_(layer_cached.layernorm_mlp.fc1_bias, True) - layer_normal.layernorm_mlp.fc2_bias.copy_(layer_cached.layernorm_mlp.fc2_bias, True) - - recipe = DelayedScaling() - - def generate_input(): - encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype) - - q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen - kv_actual_seqlen = q_actual_seqlen - attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool") - - grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype( - "float32" - ) - for i in range(0, bs): - grad_out[i, q_actual_seqlen[i] :, :] = 0 - grad_out = grad_out.astype(math_dtype) - - for i in range(0, bs): - attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False - - return encoder_input, attn_mask, grad_out - - # Calibration to make sure weight scale is the same - encoder_input, mask, _ = generate_input() - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_cached(encoder_input, mask) - - with fp8_autocast(enabled=False, calibrating=True, fp8_recipe=recipe): - _ = layer_normal(encoder_input, mask) - - for iteration in range(num_microbatch): - encoder_input, mask, grad_out = generate_input() - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out = layer_cached(encoder_input, mask, is_first_microbatch=(iteration == 0)) - out.backward(grad_out) - - with fp8_autocast(enabled=True, fp8_recipe=recipe): - out_ref = layer_normal(encoder_input, mask) - out_ref.backward(grad_out) - - assert_allclose(out, out_ref, rtol=rtol, atol=atol) - assert_allclose( - layer_cached.self_attention.layernorm_qkv.weight.grad, - layer_normal.self_attention.layernorm_qkv.weight.grad, - rtol=rtol, - atol=atol, - ) diff --git a/tests/paddle/test_master_grad.py b/tests/paddle/test_master_grad.py deleted file mode 100644 index 4e029cf8dd..0000000000 --- a/tests/paddle/test_master_grad.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TransformerLayer encoder main_grad""" - -import numpy as np -import pytest - -import paddle -from paddle.distributed.fleet.utils import mix_precision_utils - -import transformer_engine.paddle as te -from transformer_engine.paddle.fp8 import is_fp8_available - -is_fp8_supported, reason = is_fp8_available() - - -def create_optimizer(model, use_pure_bf16, use_main_grad): - """Create optimizer""" - if use_main_grad: - assert use_pure_bf16 - model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") - optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), - learning_rate=0.0001, - multi_precision=use_pure_bf16, - ) - if use_main_grad: - optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) - - return optimizer - - -class Net(paddle.nn.Layer): - """Network use for main_grad testing""" - - def __init__(self, fuse_wgrad_accumulation): - super().__init__() - self.layer = te.TransformerLayer( - 4096, - 16384, - 32, - layer_type="encoder", - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - ) - - def forward(self, inp): - out = self.layer(inp) - return out - - -def train(enable_master_grad, fuse_wgrad_accumulation=False): - """Train function""" - paddle.seed(10) - - accumulate_steps = 4 - - if fuse_wgrad_accumulation: - assert enable_master_grad, "fuse_wgrad_accumulation requires enable_master_grad" - - model = Net(fuse_wgrad_accumulation) - - optimizer = create_optimizer(model, use_pure_bf16=True, use_main_grad=enable_master_grad) - - loss_list = [] - for step_id in range(16): - inp = paddle.uniform([2, 1024, 4096], dtype="float32") - inp.stop_gradient = False - with te.fp8_autocast(enabled=True): - out = model(inp) - loss = out.mean() - loss_list.append(loss) - loss.backward() - - # gradient accumulation - if (step_id + 1) % accumulate_steps == 0: - optimizer.step() - optimizer.clear_grad() - - return loss_list - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -def test_master_grad(): - """Test main_grad""" - paddle.set_default_dtype("float32") - loss1 = train(enable_master_grad=False) - loss2 = train(enable_master_grad=True) - loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True) - - np.testing.assert_allclose(loss1, loss2, rtol=1e-5, atol=1e-5) - np.testing.assert_allclose(loss1, loss3, rtol=1e-5, atol=1e-5) diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py deleted file mode 100644 index b3b8560775..0000000000 --- a/tests/paddle/test_operators.py +++ /dev/null @@ -1,1201 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE operators""" - -import struct - -import numpy as np -import paddle -import paddle.nn.functional as F -import pytest - -from utils import ( - assert_allclose, - create_fp8_meta, - get_fused_attention_backend, - is_fused_attention_supported, -) - -from transformer_engine import transformer_engine_paddle as tex -from transformer_engine.paddle.cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, - gemm, - fp8_gemm, - transpose, - cast_transpose, - cast_transpose_bgrad, - te_gelu, - gelu_fp8, - swiglu, - swiglu_fp8, - swiglu_pd, - dswiglu, - dgelu_cast_transpose_bgrad_fp8, - layernorm_fwd_fp8, - layernorm_fwd, - layernorm_bwd, - rmsnorm_fwd_fp8, - rmsnorm_fwd, - rmsnorm_bwd, - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, - fused_attn_fwd, - fused_attn_bwd, - scaled_softmax_forward, - scaled_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_backward, - scaled_upper_triang_masked_softmax_forward, - scaled_upper_triang_masked_softmax_backward, -) -from transformer_engine.paddle.fp8 import is_fp8_available -from transformer_engine.paddle.constants import FP8FwdTensors -from transformer_engine.common.recipe import DelayedScaling - -GEMM_CASES = [ - (256, 256, 512), - (32, 32, 32), - (16384, 1024, 2816), - (16384, 2816, 1024), - (16384, 1024, 1024), -] -is_fp8_supported, reason = is_fp8_available() - -SELF_ATTN_CASES = [(2, 512, 12, 64)] -CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)] -FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)] -ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16] - - -@pytest.fixture(autouse=True) -def setup(): - """Setup random seed before each test""" - np.random.seed(10) - paddle.seed(11) - yield - - -@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("inplace", [True, False]) -def test_quantize_dequantize(fp8_dtype, inplace): - """ - Test cast_to_fp8 and cast_from_fp8 - """ - a = paddle.rand(shape=(32, 32), dtype="float32") - # Init fp8_meta - fp8_meta = create_fp8_meta() - a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None - a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_OUTPUT, otype=fp8_dtype, out=a_fp8) - b = cast_from_fp8( - a_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_OUTPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - assert_allclose(a, b, rtol=5e-2, atol=5e-2) - - -def copy_bits_from_float_to_uint16(f): - """ - Copy bits - """ - return struct.unpack("> 16 - - -def convert_float_to_uint16(float_list): - """ - convert float to uint16 - """ - new_output = [] - for x in np.nditer(float_list): - new_output.append(np.uint16(copy_bits_from_float_to_uint16(x))) - new_output = np.reshape(new_output, float_list.shape).view(np.uint16) - - return new_output - - -class TestTranspose: - """ - Test transpose operators - """ - - @staticmethod - def test_transpose_bf16(): - """ - Test BF16 transpose - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") - a_transposed = transpose(a, otype=tex.DType.kBFloat16) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_transpose_fp8(fp8_dtype): - """ - Test FP8 transpose - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype) - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - @pytest.mark.parametrize("inplace", [True, False]) - def test_cast_transpose(fp8_dtype, inplace): - """ - Test cast_transpose - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - a_fp8_casted, a_fp8_transposed = None, None - if inplace: - a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8) - a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8) - a_fp8_casted, a_fp8_transposed = cast_transpose( - a, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - otype=fp8_dtype, - cast_out=a_fp8_casted, - transpose_out=a_fp8_transposed, - ) - - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - a_casted = cast_from_fp8( - a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(a_casted, a) - assert_allclose(a_transposed, a.T) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_cast_transpose_bgrad(fp8_dtype): - """ - Test cast_transpose_bgrad - """ - min_val = -8 - max_val = 8 - a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32") - fp8_meta = create_fp8_meta() - bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad( - a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype - ) - - a_transposed = cast_from_fp8( - a_fp8_transposed, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - a_casted = cast_from_fp8( - a_fp8_casted, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(a_casted, a) - assert_allclose(a_transposed, a.T) - assert_allclose(bgrad, a.sum(axis=0)) - - -class TestActivation: - """ - Test activation operators - """ - - @staticmethod - def test_gelu_bf16(): - """ - Test BF16 GELU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - gelu_out = te_gelu(a, otype=tex.DType.kBFloat16) - gelu_ref = paddle.nn.GELU()(a) - - assert_allclose(gelu_out, gelu_ref, rtol=1e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_gelu_fp8(fp8_dtype): - """ - Test FP8 GELU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - fp8_meta = create_fp8_meta() - - gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - - gelu_out = cast_from_fp8( - gelu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - gelu_ref = paddle.nn.GELU()(a) - - assert_allclose(gelu_out, gelu_ref, rtol=0.1, atol=0.01) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_gelu_bwd_fp8(fp8_dtype): - """ - Test FP8 GELU Backward - """ - # y = GELU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - x.stop_gradient = False - y = paddle.nn.GELU()(x) - y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - paddle.autograd.backward([y], [y_grad], True) - # calculate fp8 - fp8_meta = create_fp8_meta() - x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8( - y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype - ) - - x_grad = cast_from_fp8( - x_grad_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - x_grad_t = cast_from_fp8( - x_grad_t_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) - assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01) - assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01) - - @staticmethod - def test_swiglu_bf16(): - """ - Test BF16 SwiGLU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - swiglu_out = swiglu(a, otype=tex.DType.kBFloat16) - swiglu_ref = swiglu_pd(a) - - assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) - def test_swiglu_fp8(fp8_dtype): - """ - Test FP8 SwiGLU Forward - """ - a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1 - fp8_meta = create_fp8_meta() - - swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - - swiglu_out = cast_from_fp8( - swiglu_out_fp8, - fp8_meta, - FP8FwdTensors.GEMM1_INPUT, - itype=fp8_dtype, - otype=tex.DType.kFloat32, - ) - - swiglu_ref = swiglu_pd(a) - - assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01) - - @staticmethod - def test_swiglu_bwd(): - """ - Test SwiGLU Backward - """ - # y = SwiGLU(x), calculate ref - x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1 - x.stop_gradient = False - y = swiglu_pd(x) - y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1 - paddle.autograd.backward([y], [y_grad], True) - # calculate fp8 - x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16) - - assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01) - - -class TestGemm: - """ - Tests for gemm(cuBLASLt) operator - """ - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" - ) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_bf16(m, n, k): - """ - Test "TN" BF16 GEMM - """ - a = paddle.rand(shape=(m, k), dtype="bfloat16") - b = paddle.rand(shape=(n, k), dtype="bfloat16") - - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = paddle.matmul(a, b.T) - # CublasLt inside tex.te_gemm assumes inputs are column major. - # Mathematically, A@B=C is equivalent to B^T@A^T=C^T, where X^T is the - # transpose of X. - # Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T, - # which is equivalent to a@b^T = C in row major. - actual_out, _, _ = gemm( - b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False - ) - - assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5) - - @staticmethod - @pytest.mark.skipif( - paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU" - ) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_bf16_inplace(m, n, k): - """ - Test "TN" BF16 GEMM, with accumulate=True - """ - min_val = -16 - max_val = 16 - a = paddle.rand(shape=(m, k), dtype="bfloat16") - b = paddle.rand(shape=(n, k), dtype="bfloat16") - c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16") - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = c + paddle.matmul(a, b.T) - - actual_out = paddle.clone(c) - _, _, _ = gemm( - b, - a, - paddle.bfloat16, - workspace, - False, - None, - False, - True, - "TN", - actual_out, - None, - False, - ) - - assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2) - - @staticmethod - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize("m,n,k", GEMM_CASES) - def test_fp8_randint(m, n, k): - """ - Test "TN" FP8 GEMM - """ - min_val = -4 - max_val = 4 - fp8_dtype = tex.DType.kFloat8E4M3 - out_dtype = paddle.float32 - fp8_meta = create_fp8_meta(num_gemms=1) - - a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32") - - a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype) - b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32") - b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype) - workspace = paddle.zeros(shape=[33_554_432], dtype="uint8") - - ref_out = paddle.matmul(a, b.T) - actual_out, _ = fp8_gemm( - b_casted, - fp8_meta.scale_inv, - FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, - a_casted, - fp8_meta.scale_inv, - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - out_dtype, - workspace, - ) - - assert_allclose(actual_out, ref_out) - - -class TestLayerNorm: - """ - Test layernorm operators - """ - - @staticmethod - def calc_fwd_ref(x, eps, gamma, beta): - """ - Calculate reference using paddle layer_norm op - """ - y = paddle.nn.functional.layer_norm( - x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps - ) - mean = paddle.mean(x, axis=-1) - var = paddle.var(x, axis=-1) - inv_var = paddle.sqrt(1.0 / var) - return y, mean, inv_var - - @staticmethod - def calc_bwd_ref(x, eps, gamma, beta, dy): - """ - Calculate reference using paddle layer_norm op - """ - x.stop_gradient = False - gamma.stop_gradient = False - beta.stop_gradient = False - - y = paddle.nn.functional.layer_norm( - x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps - ) - - paddle.autograd.backward([y], [dy], True) - - return x.grad, gamma.grad, beta.grad - - def test_layernorm_fwd(self): - """ - Test BF16 LayerNorm Forward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - beta = paddle.uniform(shape=(H,), dtype="bfloat16") - - y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) - - y_ref, mu_ref, rsigma_ref = self.calc_fwd_ref(x, eps, gamma, beta) - - assert_allclose(y, y_ref, rtol=1e-4, atol=1e-4) - assert_allclose(mu, mu_ref, rtol=1e-3, atol=1e-3) - assert_allclose(rsigma, rsigma_ref, rtol=5e-2, atol=5e-2) - - @staticmethod - def test_layernorm_fwd_fp8(): - """ - Test FP8 LayerNorm Forward - """ - fp8_dtype = tex.DType.kFloat8E4M3 - N, H = (16, 32) - eps = 1e-3 - - x = paddle.uniform(shape=(N, H), dtype="float32") - gamma = paddle.uniform(shape=(H,), dtype="float32") - beta = paddle.uniform(shape=(H,), dtype="float32") - - fp8_tensor = FP8FwdTensors.GEMM1_INPUT - fp8_meta = create_fp8_meta() - - y_ref, mu_ref, rsigma_ref = layernorm_fwd(x, gamma, beta, eps, tex.DType.kFloat32) - - y_fp8, mu, rsigma = layernorm_fwd_fp8(x, gamma, beta, eps, fp8_meta, fp8_tensor, fp8_dtype) - - y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32) - - assert_allclose(y, y_ref, rtol=0.1, atol=0.01) - assert_allclose(mu, mu_ref) - assert_allclose(rsigma, rsigma_ref) - - def test_layernorm_bwd(self): - """ - Test BF16 LayerNorm Backward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - dy = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - beta = paddle.uniform(shape=(H,), dtype="bfloat16") - - dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy) - - _, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16) - dx, dgamma, dbeta = layernorm_bwd(dy, x, mu, rsigma, gamma) - - assert_allclose(dx, dx_ref, rtol=1e-5, atol=1e-5) - assert_allclose(dgamma, dgamma_ref, rtol=1e-5, atol=1e-5) - assert_allclose(dbeta, dbeta_ref, rtol=1e-5, atol=1e-5) - - -class TestRMSNorm: - """ - Test rmsnorm operators - """ - - @staticmethod - def calc_fwd_ref(x, eps, gamma): - """ - Calculate rmsnorm reference using paddle op - """ - - norm = paddle.rsqrt(paddle.mean(x**2, axis=-1, keepdim=True) + eps) - y = x * norm * gamma - - return y - - def calc_bwd_ref(self, x, eps, gamma, dy): - """ - Calculate rmsnorm bwd reference using paddle op - """ - x.stop_gradient = False - gamma.stop_gradient = False - - y = self.calc_fwd_ref(x, eps, gamma) - - paddle.autograd.backward([y], [dy], True) - - return x.grad, gamma.grad - - def test_rmsnorm_fwd(self): - """ - Test BF16 RMSNorm Forward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - - y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) - - y_ref = self.calc_fwd_ref(x, eps, gamma) - - assert_allclose(y, y_ref, rtol=1e-2, atol=1e-2) - - @staticmethod - def test_rmsnorm_fwd_fp8(): - """ - Test FP8 RMSNorm Forward - """ - fp8_dtype = tex.DType.kFloat8E4M3 - N, H = (16, 32) - eps = 1e-3 - - x = paddle.uniform(shape=(N, H), dtype="float32") - gamma = paddle.uniform(shape=(H,), dtype="float32") - - fp8_tensor = FP8FwdTensors.GEMM1_INPUT - fp8_meta = create_fp8_meta() - - y_ref, rsigma_ref = rmsnorm_fwd(x, gamma, eps, tex.DType.kFloat32) - - y_fp8, rsigma = rmsnorm_fwd_fp8(x, gamma, eps, fp8_meta, fp8_tensor, fp8_dtype) - - y = cast_from_fp8(y_fp8, fp8_meta, fp8_tensor, itype=fp8_dtype, otype=tex.DType.kFloat32) - - assert_allclose(y, y_ref, rtol=0.1, atol=0.01) - assert_allclose(rsigma, rsigma_ref) - - def test_rmsnorm_bwd(self): - """ - Test BF16 RMSNorm Backward - """ - N, H = (16, 32) - eps = 1e-3 - x = paddle.uniform(shape=(N, H), dtype="bfloat16") - dy = paddle.uniform(shape=(N, H), dtype="bfloat16") - gamma = paddle.uniform(shape=(H,), dtype="bfloat16") - - dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy) - - _, rsigma = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16) - dx, dgamma = rmsnorm_bwd(dy, x, rsigma, gamma) - - assert_allclose(dx, dx_ref, rtol=1e-2, atol=1e-2) - assert_allclose(dgamma, dgamma_ref, rtol=1e-2, atol=5e-2) - - -class TestFusedAttn: - """ - Test fused attention operators - """ - - def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False): - """ - set test input - """ - - def _random(shape): - if self.dtype == "bfloat16": - data = np.random.normal(loc=0.0, scale=0.02, size=shape).astype("float32") - return convert_float_to_uint16(data) - return np.random.normal(loc=0.0, scale=0.02, size=shape).astype(self.dtype) - - self.batch_size = b - self.q_seqlen = s_q - self.kv_seqlen = s_kv - self.num_heads = h - self.head_size = d - self.dropout_prob = 0.0 - self.scaling_factor = 1.0 / np.sqrt(d) - self.q_shape = (b, s_q, h, d) - self.kv_shape = (b, s_kv, h, d) - self.fuse_qkv_shape = (b, s_q, 3, h, d) - self.fuse_kv_shape = (b, s_kv, 2, h, d) - self.bias_shape = (1, h, s_q, s_kv) - self.attn_mode = attn_mode - self.dtype = dtype - self.is_causal_masking = is_causal_masking - - self.q = _random(self.q_shape) - if self.attn_mode == "self_attn": - assert self.q_seqlen == self.kv_seqlen, "self attention requires q_seqlen == kv_seqlen" - self.kv = self.q - else: - self.kv = _random(self.kv_shape) - - self.q_actual_seqlen = None - if self.is_causal_masking: - self.q_actual_seqlen = np.full( - self.batch_size, - self.q_seqlen, - dtype=np.int32, - ) - else: - self.q_actual_seqlen = np.random.randint( - low=20, - high=self.q_seqlen, - size=(self.batch_size,), - dtype=np.int32, - ) - self.kv_actual_seqlen = self.q_actual_seqlen - - self.q_cu_seqlen = np.cumsum(self.q_actual_seqlen) - self.q_cu_seqlen = np.insert(self.q_cu_seqlen, 0, 0) - self.kv_cu_seqlen = np.cumsum(self.kv_actual_seqlen) - self.kv_cu_seqlen = np.insert(self.kv_cu_seqlen, 0, 0) - self.attn_mask = np.ones( - shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), - dtype=np.int32, - ) - if self.is_causal_masking: - assert attn_mode == "self_attn", "only support causal masking for self attention" - for i in range(0, self.batch_size): - for j in range(self.q_actual_seqlen[i]): - self.attn_mask[i, :, j, : j + 1] = 0 - else: - for i in range(0, self.batch_size): - self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0 - - dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size)) - self.dout = paddle.to_tensor(dout, dtype=self.dtype) - - def _get_reference_out(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - - q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d] - - qk_out = paddle.matmul( - x=q_out * self.scaling_factor, - y=k_out, - transpose_x=False, - transpose_y=True, - ) - - attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool") - attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype) - attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out) - attn_mask_out = paddle.cast(attn_mask_out, "float32") - softmax_out = F.softmax(attn_mask_out) - softmax_out = paddle.cast(softmax_out, self.dtype) - - if self.dropout_prob: - dropout_out = F.dropout( - softmax_out, - self.dropout_prob, - training=self.training, - mode="upscale_in_train", - ) - qkv_out = paddle.matmul(dropout_out, v_out) - else: - qkv_out = paddle.matmul(softmax_out, v_out) - - out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d] - - paddle.autograd.backward( - [out], - [self.dout], - retain_graph=True, - ) - return out, q_tensor.grad, k_tensor.grad, v_tensor.grad - - def _get_fused_attention_out(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - - if self.attn_mode == "self_attn": - qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d] - qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False) - else: - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d] - kv_tensor = paddle.to_tensor(kv, stop_gradient=False) - - q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) - kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) - - qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd" - fused_attention_backend = get_fused_attention_backend( - num_heads=self.num_heads, - num_gqa_groups=self.num_heads, - q_seqlen=self.q_seqlen, - kv_seqlen=self.kv_seqlen, - head_size=self.head_size, - dtype=self.dtype, - dropout=self.dropout_prob, - qkv_layout=qkv_layout, - bias_type="no_bias", - mask_type="causal" if self.is_causal_masking else "padding", - ) - - qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 - out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None - if self.attn_mode == "self_attn": - out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked( - qkv_tensor, - q_cu_seqlen_tensor, - is_training=True, - max_seqlen=self.q_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - dqkv, _ = fused_attn_bwd_qkvpacked( - qkv_tensor, - q_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - max_seqlen=self.q_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - q_grad = dqkv[:, :, 0, :, :] - k_grad = dqkv[:, :, 1, :, :] - v_grad = dqkv[:, :, 2, :, :] - else: # attn_mode == 'cross_attn' - out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked( - q_tensor, - kv_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - is_training=True, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - ) - dq, dkv, _ = fused_attn_bwd_kvpacked( - q_tensor, - kv_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - fused_attention_backend=fused_attention_backend, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - ) - q_grad = dq - k_grad = dkv[:, :, 0, :, :] - v_grad = dkv[:, :, 1, :, :] - - return out, q_grad, k_grad, v_grad - - def _get_fused_attention_with_separate_qkv(self): - paddle.disable_static(place=paddle.CUDAPlace(0)) - - q_tensor = paddle.to_tensor(self.q, stop_gradient=False) - k_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - v_tensor = paddle.to_tensor(self.kv, stop_gradient=False) - - q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True) - kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True) - - qkv_layout = "bshd_bshd_bshd" - fused_attention_backend = get_fused_attention_backend( - num_heads=self.num_heads, - num_gqa_groups=self.num_heads, - q_seqlen=self.q_seqlen, - kv_seqlen=self.kv_seqlen, - head_size=self.head_size, - dtype=self.dtype, - dropout=self.dropout_prob, - qkv_layout=qkv_layout, - bias_type="no_bias", - mask_type="causal" if self.is_causal_masking else "padding", - ) - - qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16 - out, softmax_aux_tensor, rng_state = fused_attn_fwd( - q_tensor, - k_tensor, - v_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - is_training=True, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - fused_attention_backend=fused_attention_backend, - Bias=None, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - dq, dk, dv, _ = fused_attn_bwd( - q_tensor, - k_tensor, - v_tensor, - q_cu_seqlen_tensor, - kv_cu_seqlen_tensor, - rng_state, - out, - self.dout, - softmax_aux_tensor, - fused_attention_backend=fused_attention_backend, - max_seqlen_q=self.q_seqlen, - max_seqlen_kv=self.kv_seqlen, - qkv_dtype=qkv_dtype, - attn_scale=self.scaling_factor, - dropout=self.dropout_prob, - set_zero=False, - qkv_layout=qkv_layout, - attn_mask_type="causal" if self.is_causal_masking else "padding", - ) - - return out, dq, dk, dv - - @pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [True, False]) - def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): - """ - test self attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype): - """ - test cross attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s_q, - kv_seqlen=s_kv, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bs2hd", - bias_type="no_bias", - mask_type="padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn") - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [True]) - def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking): - """ - test flash attention forward + backward - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bs3hd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_out() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES) - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - @pytest.mark.parametrize("is_causal_masking", [False, True]) - def test_fused_attn_with_separate_qkv_forward_backward( - self, b, s, h, d, dtype, is_causal_masking - ): - """ - test flash attention forward + backward with separate qkv inputs - """ - if not is_fused_attention_supported( - num_heads=h, - num_gqa_groups=h, - q_seqlen=s, - kv_seqlen=s, - head_size=d, - dtype=dtype, - dropout=0.0, - qkv_layout="bshd_bshd_bshd", - bias_type="no_bias", - mask_type="causal" if is_causal_masking else "padding", - ): - pytest.skip("cuDNN fused attention is not supported") - self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking) - reference_out, q_grad_ref, k_grad_ref, v_grad_ref = self._get_reference_out() - fused_attention_out, q_grad, k_grad, v_grad = self._get_fused_attention_with_separate_qkv() - assert_allclose(reference_out, fused_attention_out, rtol=1e-3, atol=1e-2) - assert_allclose(q_grad_ref, q_grad, rtol=1e-3, atol=1e-2) - assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2) - assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2) - - -class TestSoftmax: - """ - Test softmax operators - """ - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_softmax_fwd_bwd(dtype): - """test scaled softmax""" - B, H, S = (16, 4, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - - y_ref = F.softmax(scale * x) - y = scaled_softmax_forward(x, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_masked_softmax_fwd_bwd(dtype): - """test scaled masked softmax""" - B, H, S = (16, 4, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, H, S, S), dtype=dtype) - mask = paddle.reshape(x[0, 0] > 0.3, shape=(1, 1, S, S)) - mask_flipped = x[0, 0] <= 0.3 - mask_ref = (mask_flipped.astype(dtype) - 1.0) * 1e4 - - y_ref = F.softmax(scale * x + mask_ref) - y = scaled_masked_softmax_forward(x, mask, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_masked_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=1e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3) - - @staticmethod - @pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) - def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype): - """test scaled upper triang masked softmax""" - B, S = (16, 32) - scale = 0.8 - - x = paddle.uniform(shape=(B, S, S), dtype=dtype) - x.stop_gradient = False - dy = paddle.uniform(shape=(B, S, S), dtype=dtype) - - mask = paddle.ones((S, S), dtype="int32") - col_beg, col_end = 1, S - for row in range(0, S): - mask[row, col_beg:col_end] = 0 - col_beg += 1 - - mask_ref = (mask.astype(dtype) - 1.0) * 1e4 - - y_ref = F.softmax(scale * x + mask_ref) - y = scaled_upper_triang_masked_softmax_forward(x, scale) - - paddle.autograd.backward([y_ref], [dy], True) - dx_ref = x.grad - dx = scaled_upper_triang_masked_softmax_backward(dy, y, scale) - - assert_allclose(y_ref, y, rtol=1e-4, atol=5e-3) - assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3) - - -@pytest.mark.parametrize("update_weight_scale_inv", [True, False]) -def test_amax_and_scale_update(update_weight_scale_inv): - """Test update_scale""" - num_gemm = 6 - history_len = 1024 - recipe = DelayedScaling() - fp8_dtype = tex.DType.kFloat8E4M3 - fp8_max = recipe.fp8_format.value.max_fwd - non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2)) - - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") - rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0) - rolled_history_ref[0] = 0.0 - amax_tensor = paddle.max(amax_history_tensor, axis=0) - scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32") - - def calc_ref(amax, scale, fp8_max, margin=0): - """Calculate reference scale""" - sf = (fp8_max / amax) / (2**margin) - sf = paddle.where(amax > 0.0, sf, scale) - sf = paddle.where(paddle.isfinite(amax), sf, scale) - return sf - - scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0) - if update_weight_scale_inv: - scale_inv_ref = 1.0 / scale_ref - else: - scale_inv_ref = paddle.zeros_like(scale_tensor) - scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref) - - # Placeholder - scale_actual = paddle.zeros_like(scale_tensor) - scale_inv_actual = paddle.zeros_like(scale_tensor) - - if update_weight_scale_inv: - non_weight_mask = paddle.empty([0]) - tex.amax_and_scale_update_inplace( - _amax_history=amax_history_tensor, - _scale=scale_actual, - _scale_inv=scale_inv_actual, - non_weight_mask=non_weight_mask, - fp8_dtype=int(fp8_dtype), - margin=0.0, - amax_compute="max", - ) - - assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7) - assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7) - assert_allclose(amax_history_tensor, rolled_history_ref, rtol=1e-7, atol=1e-7) - - -def test_update_latest_history(): - """Test update_latest_history""" - num_gemm = 6 - history_len = 1024 - - amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32") - amax = paddle.rand(shape=[num_gemm], dtype="float32") - - tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax) - - assert_allclose(amax_history_tensor[0], amax, rtol=1e-7, atol=1e-7) diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py deleted file mode 100644 index f07d56d44b..0000000000 --- a/tests/paddle/test_parallel.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Parallel""" - -from pathlib import Path -import unittest - -from dist_launcher import TestDistributed -from utils import is_devices_enough - -from transformer_engine.paddle.fp8 import is_fp8_available - -test_root = Path(__file__).resolve().parent -gpu_has_fp8, reason = is_fp8_available() - - -class TestParallelLinear(TestDistributed): - """Test Linear in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLinear needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_linear_tp(self): - """Tests linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py")) - - -class TestParallelLayerNormLinear(TestDistributed): - """Test LayerNormLinear in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormLinear needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_layernorm_linear_tp(self): - """Tests layernorm_linear with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py")) - - -class TestParallelLayerNormMLP(TestDistributed): - """Test LayerNormMLP in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelLayerNormMLP needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_layernorm_mlp_tp(self): - """Tests layernorm_mlp with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py")) - - -class TestAmaxReduction(TestDistributed): - """Test amax reduction in dp mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestAmaxReduction needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_amax_reduction(self): - """Tests amax reduction""" - self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py")) - - -class TestPipelineParallel(TestDistributed): - """Test pipeline parallel""" - - @unittest.skipIf(not is_devices_enough(2), "TestPipelineParallel needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_pipeline_parallel(self): - """Tests pipeline parallel""" - self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py")) - - -class TestGroupSharding(TestDistributed): - """Test group sharding""" - - @unittest.skipIf(not is_devices_enough(2), "TestGroupSharding needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_group_sharding(self): - """Tests group sharding""" - self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py")) - - -class TestParallelAttention(TestDistributed): - """Test MultiHeadAttention Layer in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelAttention needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_attention_tp(self): - """Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py")) - - -class TestParallelTransformerLayer(TestDistributed): - """Test Transformer Layer in Parallel mode""" - - @unittest.skipIf(not is_devices_enough(2), "TestParallelTransformerLayer needs 2 GPUs") - @unittest.skipIf(not gpu_has_fp8, reason) - def test_transformer_tp(self): - """Tests Transformer Layer with tensor parallel in BF16""" - self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py")) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/paddle/test_recompute.py b/tests/paddle/test_recompute.py deleted file mode 100644 index 02dddad210..0000000000 --- a/tests/paddle/test_recompute.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Test TE Paddle Recompute""" - -from pathlib import Path -import re -import subprocess - -import numpy as np -import pytest - -from transformer_engine.paddle.fp8 import is_fp8_available - -test_root = Path(__file__).resolve().parent -is_fp8_supported, reason = is_fp8_available() - - -@pytest.mark.skipif(not is_fp8_supported, reason=reason) -@pytest.mark.parametrize("use_reentrant", [False, True]) -def test_transformer_encoder_recompute(use_reentrant): - """ - Test TransformerLayer encoder recompute - """ - rtol = 1e-5 - atol = 1e-5 - - def launch_subprocess_and_check_output(enable_recompute): - """Launch training in subprocess and check output""" - try: - cmd = [ - "python", - str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"), - str(int(enable_recompute)), - str(int(use_reentrant)), - ] - result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True) - - print(result) - - loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result) - memory_match = re.search(r"Peak memory:\s+(\d+)", result) - - loss_value = float(loss_match.group(1)) - memory_value = int(memory_match.group(1)) - - return loss_value, memory_value - - except subprocess.CalledProcessError as e: - raise ValueError(f"Subprocess failed with error: {e}") from e - - loss_recompute, peak_memory_recompute = launch_subprocess_and_check_output(True) - loss_ref, peak_memory_ref = launch_subprocess_and_check_output(False) - - assert peak_memory_recompute < peak_memory_ref - np.testing.assert_allclose(loss_recompute, loss_ref, rtol=rtol, atol=atol) diff --git a/tests/paddle/test_sanity_import.py b/tests/paddle/test_sanity_import.py deleted file mode 100644 index 9b38d543da..0000000000 --- a/tests/paddle/test_sanity_import.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import transformer_engine.paddle - -print("OK") diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py deleted file mode 100644 index 572af66ff9..0000000000 --- a/tests/paddle/utils.py +++ /dev/null @@ -1,221 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utils for testing""" - -import random -from typing import Union - -import numpy as np -import paddle -from paddle.distributed import fleet -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker - -import transformer_engine # pylint: disable=unused-import -from transformer_engine.paddle.constants import ( - TE_DType, - AttnBiasType, - AttnMaskType, - FusedAttnBackend, -) -from transformer_engine.paddle.fp8 import FP8TensorMeta -from transformer_engine import ( - transformer_engine_paddle as tex, -) # pylint: disable=wrong-import-order - - -def create_fp8_meta(num_gemms=1, amax_history_len=10): - """ - Create and initialize FP8TensorMeta - """ - fp8_meta = FP8TensorMeta(is_forward=True) - fp8_meta.prepare(num_gemms, amax_history_len) - return fp8_meta - - -def assert_allclose( - actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True -): - """Compare two input paddle tensors""" - if isinstance(actual, paddle.Tensor): - actual = paddle.cast(actual, "float32") - if isinstance(desired, paddle.Tensor): - desired = paddle.cast(desired, "float32") - if len(actual.shape) == 0: - actual = actual.item() - desired = desired.item() - else: - actual = actual.numpy() - desired = desired.numpy() - np.testing.assert_allclose(actual, desired, rtol, atol, equal_nan, err_msg, verbose) - - -def assert_shape(inp, expected_shape): - """Assert the shape of input tensor equals to expected shape""" - assert ( - inp.shape == expected_shape - ), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}" - - -def is_devices_enough(required): - """If the number of device is enough""" - return paddle.device.cuda.device_count() >= required - - -def set_random_seed(seed): - """Set random seed for reproducability.""" - fleet.meta_parallel.model_parallel_random_seed(seed) - - hcg = fleet.get_hybrid_communicate_group() - if paddle.distributed.get_world_size() > 1: - # obtain rank message of hybrid parallel - - mp_rank = hcg.get_model_parallel_rank() - mp_size = hcg.get_model_parallel_world_size() - - pp_rank = hcg.get_stage_id() - pp_size = hcg.get_pipe_parallel_world_size() - - dp_rank = hcg.get_data_parallel_rank() - dp_size = hcg.get_data_parallel_world_size() - - sharding_rank = hcg.get_sharding_parallel_rank() - else: - mp_rank, mp_size = 0, 1 - pp_rank, pp_size = 0, 1 - dp_rank, dp_size = 0, 1 - sharding_rank, _ = 0, 1 - - random.seed(seed + 100 * pp_rank) - np.random.seed(seed + 100 * pp_rank) - - seed_offset = seed + 1024 + paddle.distributed.get_world_size() - global_seed = ( - seed_offset - + pp_rank * (mp_size) - + dp_rank * (mp_size * pp_size) - + sharding_rank * (mp_size * pp_size * dp_size) - ) - - seed_offset += paddle.distributed.get_world_size() - local_seed = ( - seed_offset - + mp_rank - + pp_rank * (mp_size) - + dp_rank * (mp_size * pp_size) - + sharding_rank * (mp_size * pp_size * dp_size) - ) - - tracker = get_rng_state_tracker() - # tracker.reset() - if "global_seed" not in tracker.states_: - tracker.add("global_seed", global_seed) - if "local_seed" not in tracker.states_: - tracker.add("local_seed", local_seed) - - paddle.seed(global_seed) - - -def get_fused_attention_backend( - num_heads: int, - num_gqa_groups: int, - q_seqlen: int, - kv_seqlen: int, - head_size: int, - dtype: Union[paddle.dtype, str], - dropout: float, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - mask_type: str = "causal", -) -> tex.NVTE_Fused_Attn_Backend: - """Get cuDNN fused attention backend for attention config""" - if isinstance(dtype, str): - dtype = dict( - float32=paddle.float32, - bfloat16=paddle.bfloat16, - float16=paddle.float16, - )[dtype] - return tex.get_fused_attn_backend( - TE_DType[dtype], - TE_DType[dtype], - tex.get_nvte_qkv_layout(qkv_layout), - AttnBiasType[bias_type], - AttnMaskType[mask_type], - dropout, - num_heads, - num_gqa_groups, - q_seqlen, - kv_seqlen, - head_size, - ) - - -def is_fused_attention_supported( - num_heads: int, - num_gqa_groups: int, - q_seqlen: int, - kv_seqlen: int, - head_size: int, - dtype: Union[paddle.dtype, str], - dropout: float, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - mask_type: str = "causal", -) -> bool: - """Check if cuDNN fused attention is supported for attention config""" - backend = get_fused_attention_backend( - num_heads=num_heads, - num_gqa_groups=num_gqa_groups, - q_seqlen=q_seqlen, - kv_seqlen=kv_seqlen, - head_size=head_size, - dtype=dtype, - dropout=dropout, - qkv_layout=qkv_layout, - bias_type=bias_type, - mask_type=mask_type, - ) - return backend != FusedAttnBackend["No_Backend"] - - -def register_sequence_parallel_allreduce_hooks(model, accumulation_steps) -> None: - """Register allreduce hooks for sequence parallel tensors""" - - def is_sequence_parallel_parameter(parameter): - """If input tensor is marked as sequence parallel tensor""" - out = getattr(parameter, "sequence_parallel", False) - return out - - def create_allreduce_gradient_hook(param, accumulation_steps): - """Create allreduce gradient hook""" - hcg = fleet.get_hybrid_communicate_group() - pg = hcg.get_model_parallel_group().process_group - step = [0] - - @paddle.autograd.no_grad() - def __impl__(): - step[0] += 1 - if (step[0] % accumulation_steps) == 0: - if hasattr(param, "main_grad"): - pg.allreduce(param.main_grad).wait() - else: - pg.allreduce(param.grad).wait() - - return __impl__ - - if accumulation_steps <= 0 or not paddle.distributed.is_initialized(): - return - - hcg = fleet.get_hybrid_communicate_group() - mp_group = hcg.get_model_parallel_group() - if mp_group.nranks <= 1: - return - - params = [] - for p in model.parameters(): - if is_sequence_parallel_parameter(p): - params.append(p) - - for p in params: - hook = create_allreduce_gradient_hook(p, accumulation_steps) - p._register_backward_hook(hook) diff --git a/tests/pytorch/custom_ort_ops/.gitignore b/tests/pytorch/custom_ort_ops/.gitignore deleted file mode 100644 index d491fb774c..0000000000 --- a/tests/pytorch/custom_ort_ops/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -build -onnxruntime -libcustom_ort_ops.so diff --git a/tests/pytorch/custom_ort_ops/CMakeLists.txt b/tests/pytorch/custom_ort_ops/CMakeLists.txt deleted file mode 100644 index 90fb3624c1..0000000000 --- a/tests/pytorch/custom_ort_ops/CMakeLists.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -cmake_minimum_required(VERSION 3.21) -project(custom_ort_ops LANGUAGES CXX) - -# Dependencies -find_package(CUDAToolkit REQUIRED) -set(ONNX_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/onnxruntime/include) -if(NOT EXISTS "${ONNX_INCLUDE_DIR}") - message(FATAL_ERROR - "Could not find ONNX Runtime headers. " - "Please clone https://github.com/microsoft/onnxruntime " - "into TransformerEngine/tests/pytorch/onnx.") -endif() -include_directories(${ONNX_INCLUDE_DIR}) - -# Configure library -add_library(custom_ort_ops SHARED custom_op_library.cc) -target_link_libraries(custom_ort_ops PUBLIC CUDA::cudart) -target_include_directories(custom_ort_ops PUBLIC - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -target_include_directories(custom_ort_ops PRIVATE - ${ONNX_INCLUDE_DIR}/onnxruntime - ${ONNX_INCLUDE_DIR}/onnxruntime/core/session) - -# Install library -install(TARGETS custom_ort_ops DESTINATION .) diff --git a/tests/pytorch/custom_ort_ops/README.md b/tests/pytorch/custom_ort_ops/README.md deleted file mode 100644 index ca392805be..0000000000 --- a/tests/pytorch/custom_ort_ops/README.md +++ /dev/null @@ -1,22 +0,0 @@ -# Custom ONNX Runtime operators for Transformer Engine tests - -This directory contains code that builds custom ONNX operators for use -in Transformer Engine tests. It includes basic, non-performant -implementations of the FP8 quantization and dequantization operators -that are used when exporting Transformer Engine models to ONNX. - -For more information, see [the ONNX Runtime reference for custom -operators](https://onnxruntime.ai/docs/reference/operators/add-custom-op.html). -Much of the code has been adapted from [an ONNX Runtime -test](https://github.com/microsoft/onnxruntime/blob/de93f40240459953a6e3bbb86b6ad83eaeab681f/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc). - -## Usage - -* Build the custom operators: -```bash -$ bash TransformerEngine/tests/pytorch/custom_ort_ops/build.sh -``` -* Run the ONNX export tests with pytest: -```bash -$ python -m pytest TransformerEngine/tests/pytorch/test_onnx_export.py -``` \ No newline at end of file diff --git a/tests/pytorch/custom_ort_ops/build.sh b/tests/pytorch/custom_ort_ops/build.sh deleted file mode 100644 index 989da2f4ef..0000000000 --- a/tests/pytorch/custom_ort_ops/build.sh +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -set -ex - -: ${CUSTOM_ORT_OPS_PATH=$(dirname $(realpath $0))} -cd ${CUSTOM_ORT_OPS_PATH} - -# Download ONNX Runtime source -git clone --depth=1 -b rel-1.19.2 --single-branch https://github.com/microsoft/onnxruntime.git || true - -# Configure and build with CMake -mkdir -p build -cmake -S . -B build -DCMAKE_INSTALL_PREFIX=. -cmake --build build --verbose -cmake --install build --verbose diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.cc b/tests/pytorch/custom_ort_ops/custom_op_library.cc deleted file mode 100755 index f46e897152..0000000000 --- a/tests/pytorch/custom_ort_ops/custom_op_library.cc +++ /dev/null @@ -1,102 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "custom_op_library.h" - -#define ORT_API_MANUAL_INIT -#include "onnxruntime_c_api.h" -#include "onnxruntime_cxx_api.h" -#undef ORT_API_MANUAL_INIT - -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/session/onnxruntime_lite_custom_op.h" -#include - -namespace { - -template -void Quantize(OrtKernelContext* context, - const Ort::Custom::Tensor& input, - const Ort::Custom::Tensor& scale_inv, - Ort::Custom::Tensor& output) { - auto raw_input = input.Data(); - auto raw_scale_inv = scale_inv.Data(); - auto raw_output = reinterpret_cast(output.Allocate(input.Shape())); - const auto rs = static_cast(raw_scale_inv[0]); - const size_t N = input.NumberOfElement(); - for (size_t i = 0; i < N; ++i) { - const auto x = static_cast(raw_input[i]); - raw_output[i] = static_cast(x / rs); - } -} - -template -void Dequantize(OrtKernelContext* context, - const Ort::Custom::Tensor& input, - const Ort::Custom::Tensor& scale_inv, - Ort::Custom::Tensor& output) { - auto raw_input = reinterpret_cast(input.Data()); - auto raw_scale_inv = scale_inv.Data(); - auto raw_output = output.Allocate(input.Shape()); - const auto rs = static_cast(raw_scale_inv[0]); - const size_t N = input.NumberOfElement(); - for (size_t i = 0; i < N; ++i) { - const auto x = rs * static_cast(raw_input[i]); - raw_output[i] = static_cast(x); - } -} - -static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { - static std::vector ort_custom_op_domain_container; - static std::mutex ort_custom_op_domain_mutex; - std::lock_guard lock(ort_custom_op_domain_mutex); - ort_custom_op_domain_container.push_back(std::move(domain)); -} - -} // namespace - -OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { - Ort::Global::api_ = api->GetApi(ORT_API_VERSION); - - // Namespace for custom ops - static const char* c_OpDomain = "trt"; - - // Construct custom ops - static const std::unique_ptr c_Quantize{ - Ort::Custom::CreateLiteCustomOp("TRT_FP8QuantizeLinear", - "CPUExecutionProvider", - Quantize) - }; - static const std::unique_ptr c_Dequantize{ - Ort::Custom::CreateLiteCustomOp("TRT_FP8DequantizeLinear", - "CPUExecutionProvider", - Dequantize<__nv_fp8_e4m3, float, float>) - }; - - // Register custom ops - OrtStatus* result = nullptr; - ORT_TRY { - Ort::CustomOpDomain domain{c_OpDomain}; - domain.Add(c_Quantize.get()); - domain.Add(c_Dequantize.get()); - Ort::UnownedSessionOptions session_options(options); - session_options.Add(domain); - AddOrtCustomOpDomainToContainer(std::move(domain)); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - Ort::Status status{e}; - result = status.release(); - }); - } - return result; -} diff --git a/tests/pytorch/custom_ort_ops/custom_op_library.h b/tests/pytorch/custom_ort_ops/custom_op_library.h deleted file mode 100755 index 7e4b8256bc..0000000000 --- a/tests/pytorch/custom_ort_ops/custom_op_library.h +++ /dev/null @@ -1,18 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#pragma once -#include "onnxruntime/core/session/onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api); - -#ifdef __cplusplus -} -#endif diff --git a/tests/pytorch/distributed/print_logs.py b/tests/pytorch/distributed/print_logs.py deleted file mode 100644 index 6c25db4945..0000000000 --- a/tests/pytorch/distributed/print_logs.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import os -import re -import glob -import datetime -from prettytable import PrettyTable -from matplotlib import pyplot as plt - -NUM_MOST_RECENT_RUNS = 100 - - -te_path = os.getenv("TE_PATH", "/opt/transformerengine") -mlm_log_dir = os.path.join(te_path, "ci_logs") -te_ci_log_dir = "/data/transformer_engine_ci_logs" -te_ci_plot_dir = os.path.join(te_ci_log_dir, "plots") - - -convergence_pattern = ( - "validation loss at iteration \d* on validation set | lm loss" - " value: ([\d.]*)E\+(\d*) | lm loss PPL: ([\d.]*)E\+(\d*)" -) - - -perf_pattern = "elapsed time per iteration \(ms\): ([\d.]*)" - - -def get_output_file(): - now = datetime.datetime.now() - default_fname = f"unknown_pipeline_id_{now.month}_{now.day}_{now.year}_{now.hour}_{now.minute}" - fname = f"{os.getenv('CI_PIPELINE_ID', default_fname)}.txt" - return os.path.join(te_ci_log_dir, fname) - - -def get_run_metrics(filename): - """Return the loss, perplexity, and step time for a given megatron-LM logfile.""" - - with open(filename, "r") as f: - data = f.read() - - # Loss and PPL - convergence_matches = re.findall(convergence_pattern, data) - loss = round(float(convergence_matches[1][0]) * (10 ** int(convergence_matches[1][1])), 2) - ppl = round(float(convergence_matches[2][2]) * (10 ** int(convergence_matches[2][3])), 2) - - step_times_str = re.findall(perf_pattern, data) - step_times = [float(x) for x in step_times_str] - avg_step_time = round(sum(step_times) / len(step_times), 2) - return loss, ppl, avg_step_time - - -def print_run_logs(): - tables = [] - raw_logs = [] - for model_config in os.listdir(mlm_log_dir): - model_config_dir = os.path.join(mlm_log_dir, model_config) - table = PrettyTable() - table.title = model_config - table.field_names = ["Config", "Loss", "Perplexity", "Avg time per step (ms)"] - for exp in os.listdir(model_config_dir): - filename = os.path.join(model_config_dir, exp) - loss, ppl, time_per_step = get_run_metrics(filename) - exp_name = exp[:-4] - table.add_row([exp_name, loss, ppl, time_per_step]) - raw_logs.append(f"{model_config} {exp_name} {loss} {ppl} {time_per_step}\n") - tables.append(table) - - with open(get_output_file(), "w") as f: - for raw_log in raw_logs: - f.write(raw_log) - for table in tables: - print(table) - - -def save_plot(title, legend, data, filename, ylabel): - x = list(range(1, len(data[0]) + 1)) - plt.figure() - for label, y in zip(legend, data): - plt.plot(x, y, "-o", label=label) - plt.title(title) - plt.legend() - plt.xlabel(f"Last {NUM_MOST_RECENT_RUNS} runs") - plt.ylabel(ylabel) - plt.savefig(os.path.join(te_ci_plot_dir, filename)) - - -def perf_and_loss_plots(): - files = glob.glob(os.path.join(te_ci_log_dir, "*.txt")) - files.sort(key=os.path.getctime) - files = files[-NUM_MOST_RECENT_RUNS:] - data = {} - for filename in files: - with open(filename) as file: - for line in file: - line = line.strip() - model_config, exp_name, loss, _, time_per_step = line.split(" ") - if model_config not in data: - data[model_config] = {} - if exp_name not in data[model_config]: - data[model_config][exp_name] = {"loss": [], "perf": []} - data[model_config][exp_name]["loss"].append(float(loss)) - data[model_config][exp_name]["perf"].append(float(time_per_step)) - - for model_config, experiments in data.items(): - lm_loss_data = [] - lm_perf_data = [] - legend = [] - for exp_name, lm_data in experiments.items(): - legend.append(exp_name) - lm_loss_data.append(lm_data["loss"]) - lm_perf_data.append(lm_data["perf"]) - save_plot( - model_config + " loss", - legend, - lm_loss_data, - model_config + "_loss.png", - "LM-Loss", - ) - save_plot( - model_config + " perf", - legend, - lm_perf_data, - model_config + "_perf.png", - "Time per step (ms)", - ) - - -if __name__ == "__main__": - print_run_logs() - perf_and_loss_plots() diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py new file mode 100644 index 0000000000..e32f64cf1c --- /dev/null +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -0,0 +1,181 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import sys +import argparse + +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn, optim +from torch.distributed import DeviceMesh +from torch.distributed._composable.fsdp import fully_shard +from torch.distributed.device_mesh import init_device_mesh +from contextlib import nullcontext + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(SimpleNet, self).__init__() + self.fc1 = te.Linear(input_size, hidden_size) + self.fc2 = te.Linear(hidden_size, output_size) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items()} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +def _parse_args(argv=None, namespace=None): + parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") + parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") + parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") + parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument( + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + ) + parser.add_argument( + "--iter", type=int, default=10, help="Number of iterations for forward pass" + ) + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") + # Adding hsdp_dim as a list argument, comma-separated + parser.add_argument( + "--sharding-dims", + type=int, + nargs="+", + help='FSDP/HSDP sharding dimensions ("replicate", "shard")', + ) + args = parser.parse_args(argv, namespace) + if args.sharding_dims: + assert len(args.sharding_dims) <= 2 + return args + + +sub_modules_to_wrap = [te.Linear] + + +def _train(args): + assert "TORCHELASTIC_RUN_ID" in os.environ + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + assert LOCAL_SIZE == WORLD_SIZE + + # Set device and initialize RNG states + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + # Initialize torch.distributed global process group and get DP/TP groups + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + } + assert dist.is_nccl_available() + dist.init_process_group(**dist_init_kwargs) + nccl_world = dist.new_group(backend="nccl") + device = torch.device(f"cuda:{LOCAL_RANK}") + + # FP8 Configuration + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + + if not args.fp8_init: + # Build model context (FP8 init) + build_model_context = nullcontext + build_model_context_args = {} + + from transformer_engine.pytorch import fp8_model_init + + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + + # Build the model with the specified context + with build_model_context(**build_model_context_args): + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + else: + model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + # Move the model to the correct device + + model.to(device) + + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") + # Creating a DeviceMesh for fully_shard + world_size = int(WORLD_SIZE) + device_ids = list(range(world_size)) + if LOCAL_RANK == 0: + print(f"sharding-dims:{args.sharding_dims}") + # Setup the sharding mesh for FSDP/HSDP + if args.sharding_dims == None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 1: + assert args.sharding_dims[0] == device_ids[-1] + 1 + mesh = DeviceMesh("cuda", device_ids) + elif len(args.sharding_dims) == 2: # HSDP + assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 + mesh = init_device_mesh( + "cuda", + (args.sharding_dims[0], args.sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + + # Apply FSDP/HSDP + custom_attrs = save_custom_attrs(model) + for sub_module in model.modules(): + if any( + isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, mesh=mesh) + fully_shard(model, mesh=mesh) + restore_custom_attrs(model, custom_attrs) + + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + for iteration in range(args.iter): + # Zero the parameter gradients + optimizer.zero_grad() + input_data = torch.randn(args.batch_size, args.input_size).to(device) + output = model(input_data) + target = torch.randn(args.batch_size, args.output_size).to(device) + loss = F.mse_loss(output, target) + loss.backward() + optimizer.step() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + + dist.destroy_process_group() + if LOCAL_RANK == 0: + print(f"Rank {LOCAL_RANK}: Done...") + return 0 + + +if __name__ == "__main__": + sys.exit(_train(_parse_args())) diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index b00b8cc042..4bbdd23fd6 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -19,8 +19,8 @@ import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.common.recipe import Format -from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.module.base import get_cublas_workspace_size_bytes warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) @@ -47,14 +47,14 @@ def _mapped_argtype(opt, typemap): def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Test comm+GEMM overlap with Userbuffers.") parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=512, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") parser.add_argument( - "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + "-n", "--num-heads", type=int, default=16, help="Number of attention heads." ) parser.add_argument( - "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." ) - parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") + parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." ) @@ -180,15 +180,22 @@ def _main(opts): LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) opts.tcp_init = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: # TORCHELASTIC, SLURM, etc... WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") - assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node - assert LOCAL_SIZE <= torch.cuda.device_count() + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + + result = subprocess.run( + "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'", + capture_output=True, + text=True, + shell=True, + ) + + if result.stdout == "0": # Extra checks for non-MNNVL platforms + assert WORLD_SIZE == LOCAL_SIZE + assert LOCAL_SIZE <= torch.cuda.device_count() # Fix clock speed torch.cuda.set_device(LOCAL_RANK) @@ -288,33 +295,6 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None else tex.CommOverlapHelper(bootstrap_pg) ) - if opts.comm_type == tex.CommOverlapType.RS: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_RS - elif opts.p2p: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - ) - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ) - elif opts.comm_type == tex.CommOverlapType.AG: - if opts.bulk_overlap: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - else: - ub_algo = ( - tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - if opts.atomic - else tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - ) - else: - raise TypeError("Invalid comm+GEMM overlap type!") - # Initialize userbuffers with (M, N) buffer # M = sequence * batch # N = hidden size @@ -322,11 +302,7 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp_shape = (opts.seq_length, opts.batch_size, hidden_size) outer_size = reduce(operator.mul, inp_shape[:-1], 1) buffer_dtype = torch.bfloat16 - if ( - opts.fp8 - and not opts.bulk_overlap - and (opts.comm_type == tex.CommOverlapType.AG or opts.fp8_output) - ): + if opts.fp8 and not opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.AG: buffer_dtype = torch.uint8 ub_obj = ( tex.CommOverlapP2P( @@ -421,6 +397,10 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None std=opts.std, ) + # Allocate cuBLAS workspace + workspace_size = 3 * get_cublas_workspace_size_bytes() + workspace = torch.empty(workspace_size, dtype=torch.uint8, device="cuda") + # Gather global tensors and calculate reference result (need these first for Fp8 scales) if opts.bulk_overlap: ker_g = torch.transpose(kernel_t, 0, 1) @@ -467,120 +447,123 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None inp2_g = torch.nn.functional.gelu(ref_g) # pylint: disable=not-callable ref2_g = torch.matmul(inp2_g, ker2_g) + inp_quantizer = None + ker_quantizer = None + out_quantizer = None + bulk_inp_quantizer = None + inp2_quantizer = None + ker2_quantizer = None + out2_quantizer = None if opts.fp8: - fp8_formats = { - tex.DType.kFloat8E4M3: Format.E4M3, - tex.DType.kFloat8E5M2: Format.E5M2, - } - # Structure to maintain amax and scale/scale_inv information for the kernel and input - fp8_dtype = tex.DType.kFloat8E4M3 - fp8_meta = tex.FP8TensorMeta() num_gemms = 6 if ub_obj2 is not None else 3 - fp8_meta.amax_history = torch.zeros((2, num_gemms), dtype=torch.float, device="cuda") - fp8_meta.scale = torch.ones(num_gemms, dtype=torch.float, device="cuda") - fp8_meta.scale_inv = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_dtype = tex.DType.kFloat8E4M3 + fp8_scales = torch.ones(num_gemms, dtype=torch.float, device="cuda") + fp8_amaxes = torch.zeros(num_gemms, dtype=torch.float, device="cuda") # Compute initial amaxes and scales inp_amax = torch.max(torch.abs(inp_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_INPUT].copy_(inp_amax) + fp8_amaxes[0].copy_(inp_amax) ker_amax = torch.max(torch.abs(ker_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_WEIGHT].copy_(ker_amax) + fp8_amaxes[1].copy_(ker_amax) ref_amax = torch.max(torch.abs(ref_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM1_OUTPUT].copy_(ref_amax) + fp8_amaxes[2].copy_(ref_amax) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: bulk_amax = torch.max(torch.abs(bulk_inp)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(bulk_amax) + fp8_amaxes[5].copy_(bulk_amax) elif ub_obj2 is not None: inp2_amax = torch.max(torch.abs(inp2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_INPUT].copy_(inp2_amax) + fp8_amaxes[3].copy_(inp2_amax) ker2_amax = torch.max(torch.abs(ker2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_WEIGHT].copy_(ker2_amax) + fp8_amaxes[4].copy_(ker2_amax) ref2_amax = torch.max(torch.abs(ref2_g)) - fp8_meta.amax_history[1][tex.FP8FwdTensors.GEMM2_OUTPUT].copy_(ref2_amax) - fp8_meta.scale = _default_sf_compute( - fp8_meta.amax_history[1], fp8_meta.scale, fp8_formats[fp8_dtype].value.max_fwd, 1 - ) - fp8_meta.scale_inv = torch.reciprocal(fp8_meta.scale) + fp8_amaxes[5].copy_(ref2_amax) - # Cast input to Float8Tensor - inp_fp8 = tex.cast_to_fp8(inp, fp8_meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype) + inp_quantizer = Float8Quantizer(fp8_scales[0].clone(), fp8_amaxes[0].clone(), fp8_dtype) + ker_quantizer = Float8Quantizer(fp8_scales[1].clone(), fp8_amaxes[1].clone(), fp8_dtype) + if opts.fp8_output: + out_quantizer = Float8Quantizer(fp8_scales[2].clone(), fp8_amaxes[2].clone(), fp8_dtype) - # Cast kernel to Float8Tensor - kernel_t_fp8 = tex.cast_to_fp8( - kernel_t, fp8_meta, tex.FP8FwdTensors.GEMM1_WEIGHT, fp8_dtype - ) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: - bulk_inp_fp8 = tex.cast_to_fp8( - bulk_inp, fp8_meta, tex.FP8Tensors.GEMM2_OUTPUT, fp8_dtype + bulk_inp_quantizer = Float8Quantizer( + fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype ) elif ub_obj2 is not None: - kernel2_t_fp8 = tex.cast_to_fp8( - kernel2_t, fp8_meta, tex.FP8FwdTensors.GEMM2_WEIGHT, fp8_dtype + inp2_quantizer = Float8Quantizer( + fp8_scales[3].clone(), fp8_amaxes[3].clone(), fp8_dtype + ) + ker2_quantizer = Float8Quantizer( + fp8_scales[4].clone(), fp8_amaxes[4].clone(), fp8_dtype ) + if opts.fp8_output: + out2_quantizer = Float8Quantizer( + fp8_scales[5].clone(), fp8_amaxes[5].clone(), fp8_dtype + ) + + # Cast input to Float8Tensor + inp_fp8 = inp_quantizer(inp) + + # Cast kernel to Float8Tensor + kernel_t_fp8 = ker_quantizer(kernel_t) + if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: + bulk_inp_fp8 = bulk_inp_quantizer(bulk_inp) + elif ub_obj2 is not None: + kernel2_t_fp8 = ker2_quantizer(kernel2_t) # Make sure the inputs are cast correctly if opts.check_numerics: torch.allclose( inp.to(dtype=torch.float32), - inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT], + inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) torch.allclose( kernel_t.to(dtype=torch.float32), - kernel_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_WEIGHT], + kernel_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) if opts.bulk_overlap and opts.comm_type == tex.CommOverlapType.RS: torch.allclose( bulk_inp.to(dtype=torch.float32), - bulk_inp_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT], + bulk_inp_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) elif ub_obj2 is not None: torch.allclose( kernel2_t.to(dtype=torch.float32), - kernel2_t_fp8 * fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_WEIGHT], + kernel2_t_fp8.dequantize(dtype=torch.float32), rtol=0.125, atol=0.0675, ) - # Set Fp8 scales for userbuffers - if opts.comm_type == tex.CommOverlapType.AG: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_INPUT]) - if ub_obj2 is not None: - ub_obj2.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - elif opts.bulk_overlap: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM2_OUTPUT]) - else: - ub_obj.set_ubuf_scale_inv(fp8_meta.scale_inv[tex.FP8FwdTensors.GEMM1_OUTPUT]) - # Set up comm/compute buffers - ubuf_out2 = None + rs_out = None rs_out2 = None if opts.comm_type == tex.CommOverlapType.AG: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp, 1) + ub_obj.copy_into_buffer(bulk_inp, bulk_inp_quantizer, True) gemm_inp = inp else: - ub_obj.copy_input_to_ubuf(inp_fp8 if opts.fp8 else inp, 1) - gemm_inp = ub_obj.get_ubuf_output(1) - ubuf_out = None - rs_out = None + ub_obj.copy_into_buffer(inp_fp8 if opts.fp8 else inp, inp_quantizer, True) + gemm_inp = ub_obj.get_buffer(inp_quantizer, False, inp_g.size()) if ub_obj2 is not None: - ubuf_out2 = ub_obj2.get_ubuf_output(1) + if opts.fp8 and opts.fp8_output: + ub_obj2.set_buffer_params(out_quantizer) rs_out2 = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" ) else: if opts.bulk_overlap: - ub_obj.copy_input_to_ubuf(bulk_inp_fp8 if opts.fp8 else bulk_inp, 0) - ubuf_out = None - else: - ubuf_out = ub_obj.get_ubuf_output(1) + ub_obj.copy_into_buffer( + bulk_inp_fp8 if opts.fp8 else bulk_inp, bulk_inp_quantizer, False + ) + if opts.fp8: + ub_obj.set_buffer_params(bulk_inp_quantizer) + elif opts.fp8 and opts.fp8_output: + ub_obj.set_buffer_params(out_quantizer) gemm_inp = inp_fp8 if opts.fp8 else inp rs_out = torch.empty( (outer_size // tp_size, hidden_size), dtype=torch.bfloat16, device="cuda" @@ -588,88 +571,47 @@ def dist_print(msg, src=None, info=False, error=False, section=False, group=None # Wrap GEMM ops in condensed functions to make CUDA Graphs easier to use def _fp8_gemm(): - return tex.fp8_gemm( + return tex.general_gemm( kernel_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype, gemm_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=ub_algo, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM1_OUTPUT if opts.fp8_output else None, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap, ) def _fp8_gemm2(gemm1_out): gemm2_inp = tex.gelu( - ( - tex.cast_from_fp8( - gemm1_out, - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) - if opts.fp8_output - else gemm1_out - ), - fp8_meta, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, + (gemm1_out.dequantize() if opts.fp8_output else gemm1_out), + inp2_quantizer, ) - return tex.fp8_gemm( + return tex.general_gemm( kernel2_t_fp8, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_WEIGHT, - fp8_dtype, gemm2_inp, - fp8_meta.scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype, - torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, + workspace, + out_dtype=torch.float8_e4m3fn if opts.fp8_output else torch.bfloat16, + quantization_params=out2_quantizer, use_split_accumulator=te.module.base._2X_ACC_FPROP, - ub_algo=( - tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - if opts.atomic_rs_p2p - else tex.CommOverlapAlgo.ATOMIC_GEMM_RS - ), ub=ub_obj2, - extra_output_tensor=rs_out2, - out=ubuf_out2, - D_dtype=fp8_dtype if opts.fp8_output else None, - fp8_meta_tensor=fp8_meta if opts.fp8_output else None, - out_index=tex.FP8FwdTensors.GEMM2_OUTPUT if opts.fp8_output else None, + ub_type=tex.CommOverlapType.AG, + extra_output=rs_out2, ) def _gemm(): - return tex.gemm( + return tex.general_gemm( kernel_t, gemm_inp, - torch.bfloat16, - te.module.base.get_workspace(), - bias=None, - use_bias=False, - gelu=False, - ub_algo=ub_algo, + workspace, + out_dtype=torch.bfloat16, + use_split_accumulator=te.module.base._2X_ACC_FPROP, ub=ub_obj, - extra_output_tensor=rs_out, - out=ubuf_out, + ub_type=opts.comm_type, + extra_output=rs_out, + bulk_overlap=opts.bulk_overlap, ) # Trigger GEMM @@ -746,10 +688,10 @@ def _gemm(): output_info = "" if opts.comm_type == tex.CommOverlapType.AG: # Bulk overlap AG output is already gathered - test_out = ub_obj.get_ubuf_output(1) + test_out = ub_obj.get_buffer(bulk_inp_quantizer, False) else: # Bulk overlap RS output needs to be gathered - out_local = ub_obj.get_ubuf_output(0) + out_local = ub_obj.get_buffer(bulk_inp_quantizer, True) output_info += f"rs_output: {list(out_local.shape)} | " test_out = te.distributed.gather_along_first_dim(out_local, tp_group)[0] @@ -775,17 +717,7 @@ def _gemm(): test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] else: # AG Output: (M, K/P) -> T -> (K/P, M) -> gather -> (K, M) -> T -> (M, K) - output = ( - tex.cast_from_fp8( - all_outputs[0], - fp8_meta, - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype, - tex.DType.kFloat32, - ) - if opts.fp8_output - else all_outputs[0] - ) + output = all_outputs[0].dequantize() if opts.fp8_output else all_outputs[0] test_out = torch.transpose( te.distributed.gather_along_first_dim( torch.transpose(output, 0, 1), tp_group @@ -798,25 +730,6 @@ def _gemm(): output = rs_out.to(dtype=torch.float32) test_out = te.distributed.gather_along_first_dim(output, tp_group)[0] - if opts.fp8: - dist_print("GEMM1 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][:3].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][:3].tolist()}\n" - + f"scale = {fp8_meta.scale[:3].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[:3].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) - if ub_obj2 is not None: - dist_print("GEMM2 FP8 metas = [INPUT, WEIGHT, OUTPUT]", src=0, section=True) - fp8_meta_info = ( - f"amax_reference = {fp8_meta.amax_history[1][3:].tolist()}\n" - + f"amax_history = {fp8_meta.amax_history[0][3:].tolist()}\n" - + f"scale = {fp8_meta.scale[3:].tolist()}\n" - + f"scale_inv = {fp8_meta.scale_inv[3:].tolist()}" - ) - dist_print(fp8_meta_info, src=0, group=tp_group) - ref_out = ref2_g if ub_obj2 is not None else ref_g test_nonzeros = torch.count_nonzero(test_out) ref_nonzeros = torch.count_nonzero(ref_out) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index e32a7ccb12..526876edf3 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -1,26 +1,41 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import os import sys import socket +import subprocess import argparse import warnings +import pprint +import yaml import torch import torch.distributed as dist import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import Format, DelayedScaling, Float8CurrentScaling warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) +class multi_module_model(torch.nn.Module): + def __init__(self, module, num_layers, *args, **kwargs): + super().__init__() + self.num_layers = num_layers + self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)]) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + def _te_layer_argtype(name): te_layers = [ te.Linear, @@ -37,49 +52,66 @@ def _te_layer_argtype(name): return layer_map[name.lower()] -def _get_layer_args(config, tp_group, tp_size, reference=False): +def _get_layer_args(config, tp_group, tp_size, num_layers, reference=False): hidden_size = config.num_heads * config.head_dim + ffn_hidden_size = 4 * hidden_size + qkv_size = 3 * hidden_size + if num_layers > 1 and config.layer_type != te.TransformerLayer: + raise ValueError("Stacked layers are only supported for te.TransformerLayer!") input_shape = [config.seq_length, config.batch_size, hidden_size] args = [hidden_size] kwargs = { - "params_dtype": torch.float32, + "params_dtype": torch.float32 if not config.use_bf16_params else torch.bfloat16, "device": "cuda", "tp_group": tp_group, "tp_size": tp_size, "sequence_parallel": True, + "ub_overlap_ag": not reference, + "ub_overlap_rs": not reference, } - kwargs["ub_overlap_ag"] = not reference - - if config.layer_type is te.Linear: - input_shape[2] = hidden_size // tp_size - args.append(hidden_size) - kwargs["parallel_mode"] = "row" - kwargs["ub_overlap_rs"] = not reference - kwargs["ub_name"] = "proj" + + if config.layer_type in [te.Linear, te.LayerNormLinear]: + if config.linear_parallel_mode == "row": + input_shape[-1] = ffn_hidden_size // tp_size + args = [ffn_hidden_size, hidden_size] + if config.in_features is not None: + input_shape[-1] = config.in_features // tp_size + args = [config.in_features, hidden_size] + kwargs["ub_name"] = "proj" if config.layer_type == te.Linear else "fc2" + kwargs["ub_name"] = kwargs["ub_name"] if config.ub_name is None else config.ub_name + elif config.linear_parallel_mode == "column": + input_shape[0] = config.seq_length // tp_size + if config.out_features is not None: + args.append(config.out_features) + else: + args.append(qkv_size) + kwargs["ub_name"] = "qkv" if config.ub_name is None else config.ub_name + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["parallel_mode"] = config.linear_parallel_mode else: input_shape[0] = config.seq_length // tp_size - kwargs["ub_bulk_wgrad"] = not reference - kwargs["ub_bulk_dgrad"] = not reference - if config.layer_type is te.LayerNormLinear: - args.append(3 * hidden_size) - kwargs["parallel_mode"] = "column" - kwargs["ub_name"] = "qkv" - else: - kwargs["set_parallel_mode"] = True - kwargs["ub_overlap_rs"] = not reference - if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: - args.append(4 * hidden_size) - kwargs["seq_length"] = config.seq_length - if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: - args.append(config.num_heads) - kwargs["attention_dropout"] = 0.0 - kwargs["fuse_qkv_params"] = True - if config.layer_type is te.MultiheadAttention: - kwargs["input_layernorm"] = True - else: - kwargs["ub_tp_comm_overlap"] = not reference - kwargs["hidden_dropout"] = 0.0 - + if config.layer_type in [te.LayerNormMLP, te.TransformerLayer]: + args.append(ffn_hidden_size) + kwargs["seq_length"] = config.seq_length + if config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args.append(config.num_heads) + kwargs["attention_dropout"] = 0.0 + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + else: + kwargs["ub_tp_comm_overlap"] = not reference + kwargs["hidden_dropout"] = 0.0 + kwargs["set_parallel_mode"] = True + kwargs["ub_overlap_rs_dgrad"] = config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_dgrad"] = not config.overlap_rs_dgrad and not reference + kwargs["ub_bulk_wgrad"] = not config.overlap_rs_dgrad and not reference + + if config.ub_cfg is not None and isinstance(config.ub_cfg, str): + with open(config.ub_cfg, "r") as stream: + config.ub_cfg = yaml.safe_load(stream) return args, kwargs, input_shape @@ -88,18 +120,52 @@ def _parse_args(argv=None, namespace=None): description="Test a Transformer Engine layer with GEMM+comm overlap via Userbuffers." ) parser.add_argument("-l", "--layer-type", type=_te_layer_argtype, default=te.LayerNormMLP) + parser.add_argument( + "--num-layers", type=int, default=1, help="Number of identical layers to stack." + ) parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.") - parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.") + parser.add_argument("-s", "--seq-length", type=int, default=1024, help="Input sequence length.") + parser.add_argument( + "-n", "--num-heads", type=int, default=16, help="Number of attention heads." + ) + parser.add_argument( + "-d", "--head-dim", type=int, default=48, help="Dimension of each attention head." + ) + parser.add_argument( + "--in-features", + type=int, + default=None, + help="Optional input feature size for weight. Only used for Linear layer.", + ) parser.add_argument( - "-n", "--num-heads", type=int, default=12, help="Number of attention heads." + "--out-features", + type=int, + default=None, + help="Optional output feature size for weight. Only used for LayerNormLinear layer.", ) parser.add_argument( - "-d", "--head-dim", type=int, default=64, help="Dimension of each attention head." + "--tp", + type=int, + default=None, + help="Optional tensor_model_parallel_size used to initialize UB.", + ) + parser.add_argument( + "--use-bf16-params", + action="store_true", + default=False, + help="Use BF16 params instead of FP32.", ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." ) + parser.add_argument( + "--quantization", + type=str.lower, + default="none", + choices=["none", "fp8_delayed_scaling", "fp8_current_scaling"], + help="Quantization recipe", + ) parser.add_argument( "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." ) @@ -125,6 +191,41 @@ def _parse_args(argv=None, namespace=None): parser.add_argument( "--use-cuda-graphs", action="store_true", default=False, help="Use CUDA Graphs." ) + parser.add_argument( + "--ub-cfg", type=str, default=None, help="Optional TP config yaml file input." + ) + parser.add_argument("--ub-name", type=str, default=None, help="Optional TP layer name.") + parser.add_argument( + "--skip-verify", + action="store_true", + default=False, + help="Skip numerics check.", + ) + parser.add_argument( + "--benchmark", + action="store_true", + default=False, + help="Benchmark comm-gemm overlap perf.", + ) + parser.add_argument( + "--benchmark-iter", + type=int, + default=100, + help="Number of iterations for benchmarking perf.", + ) + parser.add_argument( + "--linear-parallel-mode", + type=str.lower, + default="row", + choices=["row", "column"], + help="Parallel mode for te.Linear.", + ) + parser.add_argument( + "--overlap-rs-dgrad", + action="store_true", + default=False, + help="Replace bulk DGRAD/WGRAD overlaps with DGRAD+RS in the backward pass for AG+GEMM.", + ) parser.add_argument( "--debug", action="store_true", @@ -154,7 +255,7 @@ def _compare_tensors(name, test, ref, rtol, atol): ) return 1, numerics_info - diff = torch.abs(test - ref).flatten() + diff = torch.abs(test.flatten() - ref.flatten()) m = torch.argmax(diff) abs_err = diff[m].item() rel_err = abs_err / max(abs(ref.flatten()[m].item()), 1e-5) @@ -190,14 +291,48 @@ def _train(opts): opts.tcp_init = True opts.bind_to_device = True opts.bootstrap_backend = "mpi" - elif "TORCHELASTIC_RUN_ID" in os.environ: + else: WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count()))) + + result = subprocess.run( + "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'", + capture_output=True, + text=True, + shell=True, + ) + + if result.stdout == "0" and opts.tp is None: # Extra checks for non-MNNVL platforms + assert WORLD_SIZE == LOCAL_SIZE + + # Initialize torch.distributed tp process group + new_group_kwargs = { + "backend": "nccl", + } + if opts.tp is not None: + LOCAL_SIZE = opts.tp + tp_base_rank = (WORLD_RANK // LOCAL_SIZE) * LOCAL_SIZE + tp_rank_list = list(range(tp_base_rank, tp_base_rank + LOCAL_SIZE)) + new_group_kwargs = { + "backend": "nccl", + "ranks": tp_rank_list, + } else: - raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") - assert LOCAL_SIZE == WORLD_SIZE + opts.tp = WORLD_SIZE + + # Tensor dim overrides for tensors that do not require TP communication + if opts.in_features is not None: + assert opts.layer_type is te.Linear and opts.linear_parallel_mode == "row", ( + "--in-features is only used to configure row-tensor-parallel Linear layers. Use" + " --num-heads or --head-dim for other cases." + ) + if opts.out_features is not None: + assert opts.layer_type is te.LayerNormLinear and opts.linear_parallel_mode == "column", ( + "--out-features is only used to configure column-tensor-parallel LayerNormLinear" + " layers. Use --num-heads or --head-dim for other cases." + ) def dist_print(msg, src=None, end="\n", debug=False, error=False): if debug and not opts.debug: @@ -208,7 +343,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist.barrier() # Set device and initialize RNG states - torch.cuda.set_device(WORLD_RANK) + torch.cuda.set_device(LOCAL_RANK) torch.manual_seed(opts.seed) torch.cuda.manual_seed(opts.seed) @@ -226,28 +361,43 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") assert dist.is_nccl_available() dist.init_process_group(**dist_init_kwargs) - nccl_world = dist.new_group(backend="nccl") + nccl_world = dist.new_group(**new_group_kwargs) dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") + # Initialize the Transformer Engine layer with overlap + args, kwargs, input_shape = _get_layer_args( + opts, nccl_world, opts.tp, num_layers=opts.num_layers + ) # Intialize userbuffers + ub_cfgs = None + if opts.overlap_rs_dgrad: + ub_cfgs = { + "qkv_dgrad": {"method": "ring_exchange"}, + "fc1_dgrad": {"method": "ring_exchange"}, + } te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], - WORLD_SIZE, + opts.tp, use_fp8=opts.fp8, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, + ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, ) - # Initialize the Transformer Engine layer with overlap - args, kwargs, input_shape = _get_layer_args(opts, nccl_world, WORLD_SIZE) with te.fp8_model_init(enabled=opts.fp8_init): - test_model = opts.layer_type(*args, **kwargs) + test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs) dist_print("Initialized test model...", debug=True) + if WORLD_RANK == 0: + pprint.pprint(kwargs) + sys.stdout.write("\n") + dist.barrier() # Initialize the reference model and copy all parameters - ref_args, ref_kwargs, _ = _get_layer_args(opts, nccl_world, WORLD_SIZE, reference=True) + ref_args, ref_kwargs, _ = _get_layer_args( + opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True + ) with te.fp8_model_init(enabled=opts.fp8_init): - ref_model = opts.layer_type(*ref_args, **ref_kwargs) + ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs) dist_print("Initialized reference model...", debug=True) for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()): with torch.no_grad(): @@ -257,7 +407,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): # Fp8 recipe setup fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + fp8_recipe = None + if opts.quantization == "fp8_delayed_scaling": + fp8_recipe = DelayedScaling( + fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max" + ) + elif opts.quantization == "fp8_current_scaling": + fp8_recipe = Float8CurrentScaling(fp8_format=fp8_format) # Prepare random input tensors test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) @@ -277,18 +433,19 @@ def run_fwd_bwd(model, x): out, *_ = y else: out = y - loss = out.sum() - loss.backward() + loss = out.sum() + loss.backward() return out torch_rng_state = torch.get_rng_state() - cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}")) + cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{LOCAL_RANK}")) if opts.use_cuda_graphs: test_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(test_graph): test_out = run_fwd_bwd(test_model, test_x) test_graph.replay() - del test_graph + if not opts.benchmark: + del test_graph else: test_out = run_fwd_bwd(test_model, test_x) test_grads = [test_out, test_x.grad] @@ -299,7 +456,7 @@ def run_fwd_bwd(model, x): names.append(test_name + ".grad") torch.set_rng_state(torch_rng_state) - torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}")) + torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{LOCAL_RANK}")) if opts.use_cuda_graphs: ref_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(ref_graph): @@ -313,28 +470,46 @@ def run_fwd_bwd(model, x): if ref_param.requires_grad and "layer_norm" not in ref_name: ref_grads.append(ref_param.grad) - # Make sure we have the same number of gradients numerics_failed = torch.tensor([0], dtype=torch.uint8, device="cuda") - if len(test_grads) != len(ref_grads): - numerics_failed[0] = 1 - numerics_info = ( - "NUMERICAL CHECK FAILED: Incorrect number of gradients, " - + f"expected {len(ref_grads)} but got {len(test_grads)}." - ) - dist_print(numerics_info, src=WORLD_RANK, error=True) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - - # Now validate accuracy - if not bool(numerics_failed.item()): - for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): - rtol = 0.125 if opts.fp8 else 0.025 - atol = 0.0625 if opts.fp8 else 0.00125 - grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) - dist_print(grad_info, src=WORLD_RANK, error=grad_failed) - numerics_failed[0] = int(grad_failed) - dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) - if bool(numerics_failed.item()): - break + if not opts.skip_verify: + # Make sure we have the same number of gradients + if len(test_grads) != len(ref_grads): + numerics_failed[0] = 1 + numerics_info = ( + "NUMERICAL CHECK FAILED: Incorrect number of gradients, " + + f"expected {len(ref_grads)} but got {len(test_grads)}." + ) + dist_print(numerics_info, src=WORLD_RANK, error=True) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + + # Now validate accuracy + if not bool(numerics_failed.item()): + for i, (test_g, ref_g) in enumerate(zip(test_grads, ref_grads)): + rtol = 0.125 if opts.fp8 else 0.025 + atol = 0.0625 if opts.fp8 else 0.00125 + grad_failed, grad_info = _compare_tensors(names[i], test_g, ref_g, rtol, atol) + dist_print(grad_info, src=WORLD_RANK, error=grad_failed) + numerics_failed[0] = int(grad_failed) + dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, nccl_world) + if bool(numerics_failed.item()) and not opts.debug: + break + + if opts.benchmark: + # Warmup to not profile CPU overhead + for _ in range(100): + if opts.use_cuda_graphs: + test_graph.replay() + else: + test_out = run_fwd_bwd(test_model, test_x) + torch.cuda.cudart().cudaProfilerStart() + for _ in range(opts.benchmark_iter): + if opts.use_cuda_graphs: + test_graph.replay() + else: + test_out = run_fwd_bwd(test_model, test_x) + torch.cuda.cudart().cudaProfilerStop() + if opts.use_cuda_graphs: + del test_graph te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) diff --git a/tests/pytorch/distributed/run_megatron_lm_gpt.sh b/tests/pytorch/distributed/run_megatron_lm_gpt.sh deleted file mode 100755 index 855f0c3030..0000000000 --- a/tests/pytorch/distributed/run_megatron_lm_gpt.sh +++ /dev/null @@ -1,120 +0,0 @@ -#!/bin/bash - -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -# This script allows flexibly running various sizes of -# GPT3 models with named hyperparameters. - -# Trick to get kwargs. -for ARGUMENT in "$@" -do - KEY=$(echo $ARGUMENT | cut -f1 -d=) - - KEY_LENGTH=${#KEY} - VALUE="${ARGUMENT:$KEY_LENGTH+1}" - - export "$KEY"="$VALUE" -done - -# Set defaults for all arguments. -: ${DP_SIZE:="1"} -: ${TP_SIZE:="1"} -: ${PP_SIZE:="1"} -: ${NUM_LAYERS:="12"} -: ${HIDDEN_SIZE:="768"} -: ${NHEADS:="12"} -: ${SEQLEN:="2048"} -: ${MAX_POSITION_EMBEDDINGS:="2048"} -: ${MBS:="8"} -: ${GBS:="32"} -: ${STEPS:="400"} -: ${LR:="6.0e-4"} -: ${MIN_LR:="6.0e-5"} -: ${SAVE_INTERVAL:="1000"} -: ${SPLIT:="98,2,0"} -: ${CLIP_GRAD:="1.0"} -: ${WEIGHT_DECAY:="0.1"} -: ${ADAM_BETA1:="0.9"} -: ${ADAM_BETA2:="0.95"} -: ${INIT_METHOD_STD:="0.023"} -: ${SP:="False"} -: ${DTYPE:="bf16"} -: ${WGRAD_FUSION:="True"} -: ${FP8:="False"} -: ${FP8_AMAX_HISTORY_LEN:="32"} -: ${TRANSFORMER_IMPL:="transformer_engine"} -: ${FILENAME:="log.txt"} - -# Logging. -DIR=`pwd` -TENSORBOARD_DIR="${DIR}/tensorboard" -CHECKPOINT_DIR="${DIR}/checkpoints" -mkdir -p ${TENSORBOARD_DIR} -mkdir -p ${CHECKPOINT_DIR} - -# Dataset. -. /data/gpt3/pile-cc1-cc2-shuf/gpt3_blend.sh - -# Set GP3 options. -options=" \ - --exit-duration-in-mins 230 \ - --tensor-model-parallel-size ${TP_SIZE} \ - --pipeline-model-parallel-size ${PP_SIZE} \ - --num-layers ${NUM_LAYERS} \ - --hidden-size ${HIDDEN_SIZE} \ - --num-attention-heads ${NHEADS} \ - --seq-length ${SEQLEN} \ - --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ - --micro-batch-size ${MBS} \ - --global-batch-size ${GBS} \ - --train-iters ${STEPS} \ - --lr ${LR} \ - --min-lr ${MIN_LR} \ - --lr-decay-style cosine \ - --log-interval 1 \ - --eval-iters 50 \ - --eval-interval 2000 \ - --data-path ${DATA_BLEND} \ - --vocab-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-vocab.json \ - --merge-file /data/gpt3/pile-cc1-cc2-shuf/bpe/gpt2-merges.txt \ - --save-interval ${SAVE_INTERVAL} \ - --save ${CHECKPOINT_DIR} \ - --split ${SPLIT} \ - --clip-grad ${CLIP_GRAD} \ - --weight-decay ${WEIGHT_DECAY} \ - --adam-beta1 ${ADAM_BETA1} \ - --adam-beta2 ${ADAM_BETA2} \ - --init-method-std ${INIT_METHOD_STD} \ - --log-params-norm \ - --log-num-zeros-in-grad \ - --transformer-impl ${TRANSFORMER_IMPL} \ - --tensorboard-dir ${TENSORBOARD_DIR} \ - --fp8-margin 0 \ - --fp8-interval 1 \ - --fp8-amax-history-len ${FP8_AMAX_HISTORY_LEN} \ - --fp8-amax-compute-algo max" - -if [[ "$SP" == "True" ]]; then - options+=" --sequence-parallel" -fi - -if [[ "$WGRAD_FUSION" == "False" ]]; then - options+=" --no-gradient-accumulation-fusion" -fi - -if [[ "$FP8" != "False" ]]; then - options+=" --fp8-format ${FP8}" -fi - -if [[ "$DTYPE" != "fp32" ]]; then - options+=" --${DTYPE}" -fi - -# Run GPT3. -NUM_GPUS=$((${DP_SIZE}*${TP_SIZE}*${PP_SIZE})) -NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FLASH_ATTN=1 NVTE_FWD_LAYERNORM_SM_MARGIN=0 NVTE_BWD_LAYERNORM_SM_MARGIN=0 CUDA_DEVICE_MAX_CONNECTIONS=1 NVTE_BIAS_GELU_NVFUSION=0 NVTE_BIAS_DROPOUT_FUSION=0 python -m torch.distributed.launch --use_env --nnodes=1 --nproc_per_node=${NUM_GPUS} ${DIR}/pretrain_gpt.py ${options} 2>&1 | tee $FILENAME - -# Remove checkpoints. -rm -rf ${CHECKPOINT_DIR}/* diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 5d2828454c..e2e78b72b1 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -1,20 +1,28 @@ #!/usr/bin/python3 -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -import sys -import os import argparse +import datetime +import os +import sys from functools import wraps import transformer_engine.pytorch as te import torch from torch import nn import torch.distributed as dist - -from transformer_engine.common.recipe import Format, DelayedScaling +import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, + DelayedScaling, + Float8CurrentScaling, + Format, + Recipe, +) +from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer from run_layer_with_overlap import _compare_tensors SEQ_LEN, BATCH_SIZE = 16, 16 @@ -23,15 +31,29 @@ WORLD_RANK, WORLD_SIZE = None, None NCCL_WORLD = None LOSS_FN = nn.MSELoss() -FP8 = False +QUANTIZATION = None + -# Fp8 recipe setup -fp8_format = Format.HYBRID -fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") +# Disable TF32 +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +# Quantization recipe setup +def quantization_recipe() -> Recipe: + if QUANTIZATION == "fp8": + return DelayedScaling( + fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max" + ) + if QUANTIZATION == "mxfp8": + return MXFP8BlockScaling() + if QUANTIZATION == "fp8_cs": + return Float8CurrentScaling() + return te.fp8.get_default_fp8_recipe() def main(argv=None, namespace=None): - global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8 + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) @@ -44,6 +66,7 @@ def main(argv=None, namespace=None): "backend": "nccl", "rank": WORLD_RANK, "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), } dist_init_kwargs["init_method"] = "env://" dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") @@ -57,10 +80,19 @@ def main(argv=None, namespace=None): parser = argparse.ArgumentParser() parser.add_argument("-l", "--layer-type", type=str) - parser.add_argument("--fp8", action="store_true", default=False) + parser.add_argument("--quantization", type=str, default=None) args = parser.parse_args(argv, namespace) + # Quantization scheme + QUANTIZATION = args.quantization + if QUANTIZATION in ("fp8", "mxfp8"): + global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE + SEQ_LEN = 32 + BATCH_SIZE = 32 + HIDDEN_SIZE = 128 + test_dict = [ + test_quantizer, test_linear, test_layernorm, test_layernorm_linear, @@ -68,8 +100,6 @@ def main(argv=None, namespace=None): test_transformer_layer, ] - FP8 = args.fp8 - for test in test_dict: test() dist.destroy_process_group() @@ -124,11 +154,15 @@ def dist_print(msg, src=None, end="\n", error=False): stream = sys.stderr if error else sys.stdout if WORLD_RANK == (0 if src is None else src): stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") - dist.barrier() def _get_tolerances(dtype): - if FP8: + # loose tolerances for fp8_cs because of sequence parallel & amax reduction + # so that each rank has a different scale_inv for computing Y when we have + # row parallel & sequence parallel, because we do the all_gather in backward pass + if QUANTIZATION == "fp8_cs": + return {"rtol": 0.4, "atol": 0.25} + elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} if dtype == torch.float16: @@ -153,8 +187,7 @@ def _check_outputs(output_single_node, output_distributed): dist_print(output_info, src=WORLD_RANK, error=output_failed) numerics_failed[0] = int(output_failed) dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) - if bool(numerics_failed.item()): - sys.exit(1) + assert not bool(numerics_failed.item()) def _match_param_sizes(dist_param, single_param): @@ -213,13 +246,12 @@ def _check_gradients(model_distributed, model_single, main_grad_check=False): ) if grad_failed: - dist_print(i) - dist_print(name) + dist_print(i, src=WORLD_RANK) + dist_print(name, src=WORLD_RANK) dist_print(grad_info, src=WORLD_RANK, error=grad_failed) numerics_failed[0] = int(grad_failed) dist.all_reduce(numerics_failed, dist.ReduceOp.MAX, NCCL_WORLD) - if bool(numerics_failed.item()): - sys.exit(1) + assert not bool(numerics_failed.item()) def _copy_params(model_distributed, model_single): @@ -243,9 +275,18 @@ def _apply_models( model_single_node, model_distributed, input_single_node, input_distributed, **kwargs ): _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True - with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe): + input_single_node.requires_grad_() + input_distributed.requires_grad_() + with te.fp8_autocast( + enabled=QUANTIZATION is not None, + fp8_recipe=quantization_recipe(), + ): output_single_node = model_single_node(input_single_node, **kwargs) - with te.fp8_autocast(enabled=FP8, fp8_recipe=fp8_recipe, fp8_group=NCCL_WORLD): + with te.fp8_autocast( + enabled=QUANTIZATION is not None, + fp8_recipe=quantization_recipe(), + fp8_group=NCCL_WORLD, + ): output_distributed = model_distributed(input_distributed, **kwargs) return output_single_node, output_distributed @@ -262,6 +303,98 @@ def _alloc_main_grad(model_single_node, model_distributed): param.main_grad = torch.zeros_like(param, dtype=torch.float32) +############################################### +# Quantizer # +############################################### +def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): + """ + quantizer is the reference quantizer on a single GPU. + quantizer_dist is the distributed quantizer to be tested on multiple GPUs. + """ + if quantizer_class == Float8CurrentScalingQuantizer: + quantizer_dist = quantizer_class( + fp8_dtype=fp8_dtype, + device=device, + with_amax_reduction=True, + amax_reduction_group=tp_group, + amax_reduction_size=tp_size, + ) + quantizer = quantizer_class( + fp8_dtype=fp8_dtype, + device=device, + with_amax_reduction=False, + ) + return quantizer, quantizer_dist + else: + raise ValueError(f"Unsupported quantizer class: {quantizer_class}") + + +def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad).cuda()) + return out + + +@run_distributed_test() +def _test_quantizer(input_dtype, fp8_dtype): + """Test the quantizer under distributed settings. + + Args: + input_dtype (torch.dtype): The data type of the input. + fp8_dtype (tex.DType): The data type of the fp8. + """ + + M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE + + # high precision input + x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype) + # set one element of the input to a very large value, which doesn't live in rank 0 after the split + # to test the amax reduction on purpose + x_hp_cpu[M - 1, N - 1] = 1e4 + # rank 0 takes the full copy and quantize with GPU 0 for verification + if WORLD_RANK == 0: + x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda") + x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK] + + # Create quantizers + quantizer, quantizer_dist = _construct_quantizer( + Float8CurrentScalingQuantizer, fp8_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE + ) + + # quantize the input + if WORLD_RANK == 0: + x_fp8_single = quantizer(x_hp_rank0) + + # multi-GPU quantizer + x_fp8_dist = quantizer_dist(x_hp_local_rank) + + # check scale_inv with zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close( + x_fp8_single._scale_inv, x_fp8_dist._scale_inv, rtol=0.0, atol=0.0 + ) + + +def test_quantizer(): + """ + Run quantizer tests with various configurations. + Currently only check fp8_cs because it needs to do amax reduction in the quantizer. + """ + # skip this test for other quantization schemes + if QUANTIZATION != "fp8_cs": + return + + input_dtypes = [torch.float32, torch.bfloat16] + fp8_dtypes = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + + for input_dtype in input_dtypes: + for fp8_dtype in fp8_dtypes: + _test_quantizer(input_dtype, fp8_dtype) + + ############################################ # Linear # ############################################ @@ -308,6 +441,11 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): torch.empty((WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) ) input_distributed = torch.randn((BATCH_SIZE, HIDDEN_SIZE)).cuda().to(params_dtype) + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[BATCH_SIZE - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed, dim=0).detach() else: input_distributed = input_single_node.clone() @@ -470,6 +608,12 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs # Duplicate input for sequence parallelism input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + # make the last element of the input a large value to test the amax reduction on purpose + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed).detach() else: input_distributed = input_single_node.clone() @@ -544,9 +688,7 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg """ # Set parameter data type params_dtype = kwargs.get("params_dtype", torch.float32) - FFN_HIDDEN_SIZE = ( - 64 if FP8 else 32 - ) # larger tensors lead to numerical failures with thight atol and rtol + FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128 # Create models model_single_node = te.LayerNormMLP(HIDDEN_SIZE, FFN_HIDDEN_SIZE, **kwargs) @@ -570,6 +712,12 @@ def _test_layernorm_mlp(set_parallel_mode=None, sequence_parallel=False, **kwarg # Duplicate input for sequence parallelism input_single_node = torch.empty((WORLD_SIZE * SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) input_distributed = torch.randn((SEQ_LEN, HIDDEN_SIZE)).cuda().to(params_dtype) + # make the last element of the input a large value to test the amax reduction on purpose + # when quantization is fp8_cs, we need to trigger corner cases to see if amax reduction is working + if QUANTIZATION == "fp8_cs": + input_distributed = torch.clamp(input_distributed, min=-10, max=10) + if WORLD_RANK == WORLD_SIZE - 1: + input_distributed[SEQ_LEN - 1, HIDDEN_SIZE - 1] = 11 input_single_node = _gather(input_distributed).detach() else: input_distributed = input_single_node.clone() @@ -622,6 +770,7 @@ def test_layernorm_mlp(): {"return_bias": True}, {"return_layernorm_output": True}, ] + for kwargs in kwargs_list: for set_parallel_mode in [True]: for sequence_parallel in [False, True]: @@ -636,9 +785,7 @@ def test_layernorm_mlp(): @run_distributed_test() def _test_transformer_layer_parallel(sequence_parallel=False, **kwargs): params_dtype = kwargs.get("params_dtype", torch.float32) - FFN_HIDDEN_SIZE = ( - 64 if FP8 else 32 - ) # larger tensors lead to numerical failures with thight atol and rtol + FFN_HIDDEN_SIZE = 32 if QUANTIZATION is None else 128 model_single_node = te.TransformerLayer( HIDDEN_SIZE, FFN_HIDDEN_SIZE, NR_HEADS, attention_dropout=0, hidden_dropout=0, **kwargs @@ -718,6 +865,7 @@ def test_transformer_layer(): {"fuse_qkv_params": True}, {"activation": "relu"}, ] + for kwargs in kwargs_list: for sequence_parallel in [False, True]: _test_transformer_layer_parallel(sequence_parallel, **kwargs) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index ce46a72189..01400bba6b 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import os @@ -16,11 +16,11 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -RNG_SEED: int = 1234 -SEQ_LENGTH: int = 512 +RNG_SEED: int = 42 +SEQ_LENGTH: int = 1024 BATCH_SIZE: int = 2 -NUM_HEADS: int = 12 -HEAD_DIM: int = 64 +NUM_HEADS: int = 16 +HEAD_DIM: int = 48 TE_LAYERS = [ te.Linear, te.LayerNormLinear, @@ -28,12 +28,16 @@ te.MultiheadAttention, te.TransformerLayer, ] +MAX_LAYER_NAME_LENGTH = max([len(layer.__name__) for layer in TE_LAYERS]) + +# to avoid numerical tolerance issues of doing comm gemm overlap, limit the number of GPUs used +MAX_GPUS_TO_USE = 4 TEST_ROOT = Path(__file__).parent.resolve() -NUM_PROCS: int = min(torch.cuda.device_count(), 4) +NUM_PROCS: int = min(torch.cuda.device_count(), MAX_GPUS_TO_USE) LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] if tex.ubuf_built_with_mpi(): - LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python"] + LAUNCH_CMD = ["mpirun", "-np", str(NUM_PROCS), "--oversubscribe", "--quiet", "python3"] # Fall back on CUDA IPC if the platform does not support CUDA multicast if not tex.device_supports_multicast(): @@ -46,7 +50,7 @@ torch._dynamo.reset() -def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggregate): +def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8): test_path = TEST_ROOT / "run_gemm_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -62,19 +66,15 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg if bulk: test_cmd.append("--bulk-overlap") else: - if fp8_in: + if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_out: - test_cmd.append("--fp8-output") if p2p: test_cmd.append("--p2p") - if aggregate: - test_cmd.append("--aggregate") if atomic: - if torch.cuda.get_device_properties(0).major < 9: - pytest.skip("Device compute capability 9.0 or higher required for Atomic GEMM.") + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip("Atomic GEMM is requires device compute capability 9.x (Hopper).") test_cmd.append("--atomic") result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) @@ -86,7 +86,9 @@ def _run_gemm_with_overlap(comm_type, bulk, p2p, atomic, fp8_in, fp8_out, aggreg raise AssertionError(result.stderr.decode()) -def _run_layer_with_overlap(layer_type, fp8, fp8_init): +def _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers=1 +): test_path = TEST_ROOT / "run_layer_with_overlap.py" test_cmd = LAUNCH_CMD + [ str(test_path), @@ -96,14 +98,19 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): f"--num-heads={NUM_HEADS}", f"--head-dim={HEAD_DIM}", f"--layer-type={layer_type}", + f"--num-layers={num_layers}", ] + if layer_type in [te.Linear.__name__, te.LayerNormLinear.__name__]: + test_cmd.append(f"--linear-parallel-mode={linear_parallel_mode}") + + if overlap_rs_dgrad: + test_cmd.append("--overlap-rs-dgrad") if fp8: if not fp8_available: pytest.skip(reason_for_no_fp8) test_cmd.append("--fp8") - if fp8_init: - test_cmd.append("--fp8-init") + test_cmd.append(f"--quantization={quantization}") os.environ["PYTORCH_JIT"] = "0" os.environ["NVTE_TORCH_COMPILE"] = "0" @@ -124,7 +131,20 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): @pytest.mark.parametrize( - "fp8,aggregate", + "fp8", + (False, True), + ids=[" BF16 - RING-EXCHANGE ", " FP8 - RING-EXCHANGE "], +) +def test_split_all_gather_overlaps(fp8): + """ + Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or + te.cpp_extensions.fp8_gemm. + """ + _run_gemm_with_overlap("AG", False, True, False, fp8) + + +@pytest.mark.parametrize( + "fp8,p2p", [ (False, False), (False, True), @@ -132,118 +152,241 @@ def _run_layer_with_overlap(layer_type, fp8, fp8_init): (True, True), ], ids=[ - " BF16 IN - RING-EXCHANGE ", - " BF16 IN - RING-EXCHANGE - 2x AGGREGATED ", - " FP8 IN - RING-EXCHANGE ", - " FP8 IN - RING-EXCHANGE - 2x AGGREGATED ", + " BF16 - PIPELINE ", + " BF16 - RING-EXCHANGE ", + " FP8 - PIPELINE ", + " FP8 - RING-EXCHANGE ", ], ) -def test_split_all_gather_overlaps(fp8, aggregate): +def test_split_reduce_scatter_overlaps(fp8, p2p): """ - Test (split GEMM -> all-gather) overlaps with direct calls to te.cpp_extensions.gemm or + Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("AG", False, True, False, fp8, False, aggregate) + _run_gemm_with_overlap("RS", False, p2p, False, fp8) @pytest.mark.parametrize( - "fp8_in,fp8_out,p2p", + "comm_type, fp8, connections", [ - (False, False, False), - (False, False, True), - (True, False, False), - (True, False, True), - (True, True, False), - (True, True, True), + ("AG", False, 1), + ("RS", False, 1), + ("RS", True, 1), + ("AG", False, 8), + ("RS", False, 8), + ("RS", True, 8), ], ids=[ - " BF16 IN - BF16 OUT - PIPELINE ", - " BF16 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - BF16 OUT - PIPELINE ", - " FP8 IN - BF16 OUT - RING-EXCHANGE ", - " FP8 IN - FP8 OUT - PIPELINE ", - " FP8 IN - FP8 OUT - RING-EXCHANGE ", + "ALL-GATHER - BF16 - 1 connections", + "REDUCE-SCATTER - BF16 - 1 connections", + "REDUCE-SCATTER - FP8 - 1 connections", + "ALL-GATHER - BF16 - 8 connections", + "REDUCE-SCATTER - BF16 - 8 connections", + "REDUCE-SCATTER - FP8 - 8 connections", ], ) -def test_split_reduce_scatter_overlaps(fp8_in, fp8_out, p2p): +def test_bulk_overlaps(comm_type, fp8, connections): """ - Test (reduce-scatter -> split GEMM) overlaps with direct calls to te.cpp_extensions.gemm or - te.cpp_extensions.fp8_gemm. + Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. """ - _run_gemm_with_overlap("RS", False, p2p, False, fp8_in, fp8_out, False) + if connections == 8: + if torch.cuda.get_device_properties(0).major != 9: + pytest.skip( + "CUDA_DEVICE_MAX_CONNECTIONS=8 test only applies to devices with compute capability" + " 9.0 (HOPPER ARCH)." + ) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8" + _run_gemm_with_overlap(comm_type, True, False, False, fp8) + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + else: + _run_gemm_with_overlap(comm_type, True, False, False, fp8) @pytest.mark.parametrize( - "ag_type,rs_type,p2p,fp8_out", - [ - (0, 0, False, False), - (0, 1, False, False), - (0, 1, False, True), - (0, 2, False, False), - (0, 2, False, True), - (0, 0, True, False), - (0, 0, True, True), - (1, 0, True, False), - (1, 0, True, True), + "fp8", + (False,), + ids=[ + " BF16 ", ], +) +@pytest.mark.parametrize( + "layer_type,linear_parallel_mode,overlap_rs_dgrad", + [ + (te.Linear.__name__, "row", False), + (te.Linear.__name__, "column", False), + (te.Linear.__name__, "column", True), + (te.LayerNormLinear.__name__, "row", False), + (te.LayerNormLinear.__name__, "column", False), + (te.LayerNormLinear.__name__, "column", True), + ] + + list( + zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + [None] * len(TE_LAYERS[2:]) * 2, + [False, True] * len(TE_LAYERS[2:]), + ) + ), ids=[ - " NON-ATOMIC AG - NON-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - BF16 OUT ", - " NON-ATOMIC AG - MULTI-ATOMIC RS - PIPELINE - FP8 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " NON-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - BF16 OUT ", - " MULTI-ATOMIC AG - NON-ATOMIC RS - RING-EXCHANGE - FP8 OUT ", + f" {te.Linear.__name__} - ROW-PARALLEL ", + f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", + f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", + ] + + [ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), + ) ], ) -def test_atomic_gemm_overlaps(ag_type, rs_type, p2p, fp8_out): +def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8): """ - Test paired (all-gather -> atomic GEMM) and (atomic GEMM -> reduce-scatter) overlaps with - direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. + Test Transformer Engine layers with comm+GEMM overlap. """ - os.environ["NVTE_AG_P2P_MULTI_ATOMIC"] = str(ag_type) - os.environ["NVTE_RS_STRIDED_ATOMIC"] = str(rs_type) - _run_gemm_with_overlap("AG", False, p2p, True, True, fp8_out, False) + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None) @pytest.mark.parametrize( - "comm_type,fp8", + "quantization", + ["fp8_delayed_scaling", "fp8_current_scaling"], + ids=[" DELAYED SCALING ", " CURRENT SCALING "], +) +@pytest.mark.parametrize( + "fp8", + (True,), + ids=[ + " FP8 ", + ], +) +@pytest.mark.parametrize( + "layer_type,linear_parallel_mode,overlap_rs_dgrad", [ - ("AG", False), - ("RS", False), - ("RS", True), + (te.Linear.__name__, "row", False), + (te.Linear.__name__, "column", False), + (te.Linear.__name__, "column", True), + (te.LayerNormLinear.__name__, "row", False), + (te.LayerNormLinear.__name__, "column", False), + (te.LayerNormLinear.__name__, "column", True), + ] + + list( + zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + [None] * len(TE_LAYERS[2:]) * 2, + [False, True] * len(TE_LAYERS[2:]), + ) + ), + ids=[ + f" {te.Linear.__name__} - ROW-PARALLEL ", + f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", + f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", + f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", + ] + + [ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [layer.__name__ for layer in TE_LAYERS[2:] for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"] * len(TE_LAYERS[2:]), + ) ], - ids=[" ALL-GATHER - BF16 ", " REDUCE-SCATTER - BF16 ", " REDUCE-SCATTER - FP8 "], ) -def test_bulk_overlaps(comm_type, fp8): +def test_layers_with_overlap_fp8( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization +): """ - Test bulk overlaps with direct calls to te.cpp_extensions.gemm or te.cpp_extensions.fp8_gemm. + Test Transformer Engine layers with comm+GEMM overlap. """ - _run_gemm_with_overlap(comm_type, True, False, False, fp8, False, False) + _run_layer_with_overlap(layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization) @pytest.mark.parametrize( - "layer_type", - [layer.__name__ for layer in TE_LAYERS], - ids=[(" " + layer.__name__ + " ") for layer in TE_LAYERS], + "fp8", + (False,), + ids=[ + " BF16 ", + ], ) @pytest.mark.parametrize( - "fp8,fp8_init", - [ - (False, False), - (True, False), - (True, True), + "num_layers", + (2,), + ids=[ + " 2 layers ", + ], +) +@pytest.mark.parametrize( + "layer_type,linear_parallel_mode,overlap_rs_dgrad", + list( + zip( + [te.TransformerLayer.__name__ for _ in range(2)], + [None] * 2, + [False, True], + ) + ), + ids=[ + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [te.TransformerLayer.__name__ for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"], + ) ], +) +def test_multi_layer_with_overlap_bf16( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, num_layers +): + """ + Test Transformer Engine layers with comm+GEMM overlap. + """ + _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, None, num_layers + ) + + +@pytest.mark.parametrize( + "quantization", + ["fp8_delayed_scaling", "fp8_current_scaling"], + ids=[" DELAYED SCALING ", " CURRENT SCALING "], +) +@pytest.mark.parametrize( + "fp8", + (True,), + ids=[ + " FP8 ", + ], +) +@pytest.mark.parametrize( + "num_layers", + (2,), + ids=[ + " 2 layers ", + ], +) +@pytest.mark.parametrize( + "layer_type,linear_parallel_mode,overlap_rs_dgrad", + list( + zip( + [te.TransformerLayer.__name__ for _ in range(2)], + [None] * 2, + [False, True], + ) + ), ids=[ - " BF16 GEMM - BF16 PARAMS ", - " FP8 GEMM - BF16 PARAMS ", - " FP8 GEMM - FP8 PARAMS ", + " " + " - ".join(test_name_parts) + " " + for test_name_parts in zip( + [te.TransformerLayer.__name__ for _ in range(2)], + ["BULK DGRAD/WGRAD", "DGRAD+RS"], + ) ], ) -def test_layers_with_overlap(layer_type, fp8, fp8_init): +def test_multi_layer_with_overlap_fp8( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers +): """ Test Transformer Engine layers with comm+GEMM overlap. """ - _run_layer_with_overlap(layer_type, fp8, fp8_init) + _run_layer_with_overlap( + layer_type, linear_parallel_mode, overlap_rs_dgrad, fp8, quantization, num_layers + ) diff --git a/tests/pytorch/distributed/test_convergence.py b/tests/pytorch/distributed/test_convergence.py deleted file mode 100644 index 5a267cb25e..0000000000 --- a/tests/pytorch/distributed/test_convergence.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import functools -import os -import pytest -import subprocess -from dataclasses import asdict, dataclass -from typing import List, Tuple, Union - -import torch - - -@dataclass() -class ModelConfigGPT: - NUM_LAYERS: int = 12 - HIDDEN_SIZE: int = 768 - NHEADS: int = 12 - SEQLEN: int = 2048 - MAX_POSITION_EMBEDDINGS: int = 2048 - LR: float = 6.0e-4 - MIN_LR: float = 6.0e-5 - SPLIT: str = "98,2,0" - CLIP_GRAD: float = 1.0 - WEIGHT_DECAY: float = 0.1 - ADAM_BETA1: float = 0.9 - ADAM_BETA2: float = 0.95 - INIT_METHOD_STD: float = 0.023 - - -model_configs = { - "126m": ModelConfigGPT(), -} - -dtypes = ["bf16"] - - -fp8_recipes = [False, "hybrid"] - - -all_boolean = [True, False] - - -te_path = os.getenv("TE_PATH", "/opt/transformerengine") -mlm_log_dir = os.path.join(te_path, "ci_logs") - - -@functools.lru_cache(maxsize=None) -def get_parallel_configs() -> List[Tuple[int, int]]: - """Returns valid combinations of (tp, pp).""" - sizes = [1, 2, 4] - num_devices = torch.cuda.device_count() - parallel_configs = [] - if num_devices > 1: - for dp in sizes: - for tp in sizes: - for pp in sizes: - if dp * tp * pp == num_devices: - parallel_configs.append((dp, tp, pp)) - return parallel_configs - - -def get_filename( - model: str, dp: int, tp: int, pp: int, sp: bool, use_te: bool, fp8_recipe: Union[bool, str] -) -> str: - sp = tp if sp else 1 - config = f"gpt3_{model}_dp{dp}_tp{tp}_pp{pp}_sp{sp}" - config_dir = os.path.join(mlm_log_dir, config) - os.makedirs(config_dir, exist_ok=True) - fname = ( - f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt" - ) - return os.path.join(config_dir, fname) - - -def get_bash_arguments(filename: str, **kwargs) -> List[str]: - args = [] - script_path = os.path.join(te_path, "tests/pytorch/distributed/run_megatron_lm_gpt.sh") - args.append(script_path) - - for k, v in kwargs.items(): - args.append(f"{k}={str(v)}") - args.append(f"FILENAME={filename}") - return args - - -@pytest.mark.parametrize("sp", all_boolean) -@pytest.mark.parametrize("use_te", all_boolean) -@pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("dp, tp, pp", get_parallel_configs()) -@pytest.mark.parametrize("model", model_configs.keys()) -def test_distributed(dtype, fp8_recipe, dp, tp, pp, sp, use_te, model): - if sp and tp == 1: - pytest.skip("No tensor parallel.") - if fp8_recipe and not use_te: - pytest.skip("TransformerEngine needed for FP8.") - subprocess.run( - get_bash_arguments( - get_filename(model, dp, tp, pp, sp, use_te, fp8_recipe), - DTYPE=dtype, - FP8=fp8_recipe, - SP=sp, - DP_SIZE=dp, - TP_SIZE=tp, - PP_SIZE=pp, - TRANSFORMER_IMPL="transformer_engine" if use_te else "local", - **asdict(model_configs[model]), - ), - check=True, - ) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index d8a018761b..c8ef7687fa 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -1,31 +1,42 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from __future__ import annotations import argparse +from collections.abc import Iterable import functools import itertools import os import pathlib import subprocess import sys +from typing import Optional import pytest import torch import transformer_engine +import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex -# Check if FP8 is supported + +# Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +quantization_list: list[Optional[str]] = [None] +if fp8_available: + quantization_list.append("fp8") +if mxfp8_available: + quantization_list.append("mxfp8") @functools.cache @@ -66,22 +77,18 @@ def make_reference_and_test_tensors( in Transformer Engine operations. """ - - # Random data ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) - - # Make copy of tensor + test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(ref) - else: - test = ref.to(device=test_device, dtype=test_dtype) - if test.data_ptr() == ref.data_ptr(): - test = test.clone() - - # Make sure reference and test tensors represent exact same values + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + test = quantizer(test) + elif test.data_ptr() == ref.data_ptr(): + test = test.clone() ref.copy_(test) - - # Return reference and test tensors ref.requires_grad_(requires_grad) test.requires_grad_(requires_grad) return ref, test @@ -120,6 +127,21 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: raise ValueError(f"Unsupported dtype ({dtype})") +def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name == "fp8": + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + raise ValueError(f"Unsupported quantization scheme ({name})") + + def _test_all_reduce( *, local_size: int = 17, @@ -293,17 +315,16 @@ def _test_reduce_scatter( def _test_basic_linear( *, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + local_batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_grad_output: bool = False, + quantization: Optional[str] = None, + quantized_weight: bool = False, tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + quantized_compute = quantization is not None # Distributed process group process_group = world_group() @@ -313,10 +334,13 @@ def _test_basic_linear( # Tensor dimensions local_out_features, local_in_features = local_weight_shape out_features, in_features = local_out_features, local_in_features + batch_size = local_batch_size if tensor_parallel_mode == "column": out_features *= world_size elif tensor_parallel_mode == "row": in_features *= world_size + if sequence_parallel: + batch_size *= world_size in_shape = [batch_size, in_features] out_shape = [batch_size, out_features] @@ -326,21 +350,28 @@ def _test_basic_linear( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) + if isinstance(w_test, QuantizedTensor): + w_test = w_test.dequantize() dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=quantized_compute, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -391,7 +422,8 @@ def _test_basic_linear( x_test.requires_grad_() # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -404,7 +436,7 @@ def _test_basic_linear( with torch.no_grad(): op.weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -412,10 +444,8 @@ def _test_basic_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -429,17 +459,16 @@ def _test_basic_linear( def _test_linear( *, bias: bool = True, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + local_batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_grad_output: bool = False, + quantization: Optional[str] = None, + quantized_weight: bool = False, tensor_parallel_mode: str = "column", sequence_parallel: bool = False, ) -> None: + quantized_compute = quantization is not None # Distributed process group process_group = world_group() @@ -449,10 +478,13 @@ def _test_linear( # Tensor dimensions local_out_features, local_in_features = local_weight_shape out_features, in_features = local_out_features, local_in_features + batch_size = local_batch_size if tensor_parallel_mode == "column": out_features *= world_size elif tensor_parallel_mode == "row": in_features *= world_size + if sequence_parallel: + batch_size *= world_size in_shape = [batch_size, in_features] out_shape = [batch_size, out_features] @@ -462,14 +494,19 @@ def _test_linear( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) + if isinstance(w_test, QuantizedTensor): + w_test = w_test.dequantize() b_ref, b_test = None, None if bias: if tensor_parallel_mode == "row": @@ -485,9 +522,11 @@ def _test_linear( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=quantized_compute, requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) @@ -552,7 +591,8 @@ def _test_linear( x_test.requires_grad_() # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -571,7 +611,7 @@ def _test_linear( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -579,12 +619,8 @@ def _test_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -603,8 +639,8 @@ def _test_fp8_scale_update( amax_history_len: int = 31, amax_compute_algo: str = "max", margin: float = 2, - local_weight_shape: tuple[int, int] = (16, 16), - batch_size: int = 16, + local_weight_shape: tuple[int, int] = (32, 32), + batch_size: int = 32, dtype: torch.dtype = torch.float32, device: torch.device = "cuda", tensor_parallel_mode: str = "column", @@ -715,20 +751,12 @@ def ref_amax_and_scale( y_test.backward(dy_test) # Check results - forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - x_fp8_meta = op.get_fp8_meta("input")[forward_key] - w_fp8_meta = op.get_fp8_meta("param")[forward_key] - dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] - x_amax_test = x_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - w_amax_test = w_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - dy_amax_test = dy_fp8_meta.amax_history[-1, 0].to(dtype=torch.float32, device="cpu") - x_scale_test = x_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - w_scale_test = w_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - dy_scale_test = dy_fp8_meta.scale[0].to(dtype=torch.float32, device="cpu") - torch.testing.assert_close(x_amax_test, x_amax_ref) - torch.testing.assert_close(w_amax_test, w_amax_ref) - torch.testing.assert_close(dy_amax_test, dy_amax_ref) + x_quantizer = op.get_quantizer("forward", 0) + w_quantizer = op.get_quantizer("forward", 1) + dy_quantizer = op.get_quantizer("backward", 0) + x_scale_test = x_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) + w_scale_test = w_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) + dy_scale_test = dy_quantizer.scale.to(dtype=torch.float32, device="cpu").reshape([]) torch.testing.assert_close(x_scale_test, x_scale_ref) torch.testing.assert_close(w_scale_test, w_scale_ref) torch.testing.assert_close(dy_scale_test, dy_scale_ref) @@ -755,38 +783,32 @@ def run_parallel_tests() -> None: # Basic linear op for config in itertools.product( - (False, True) if fp8_available else (False,), + quantization_list, ("column", "row"), (False, True), ): if rank == 0: print(f"Running _test_basic_linear with {config=}") - fp8, tensor_parallel_mode, sequence_parallel = config + quantization, tensor_parallel_mode, sequence_parallel = config _test_basic_linear( - fp8_compute=fp8, - fp8_input=fp8, - fp8_weight=fp8, - fp8_grad_output=fp8, + quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, sequence_parallel=sequence_parallel, ) # Linear op for config in itertools.product( - (False, True) if fp8_available else (False,), + quantization_list, ("column", "row"), ): if rank == 0: print(f"Running _test_linear with {config=}") - fp8, tensor_parallel_mode = config + quantization, tensor_parallel_mode = config dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, - fp8_compute=fp8, - fp8_input=fp8, - fp8_weight=fp8, - fp8_grad_output=fp8, + quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, ) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index ead121f314..b61f519c99 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index d0b445a505..b4e2b680b3 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -27,29 +27,33 @@ pytest.skip("Distributed training needs at least 2 GPUs.") fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] -def _run_test(fp8): +def _run_test(quantization): test_path = TEST_ROOT / "run_numerics.py" test_cmd = LAUNCH_CMD + [str(test_path)] - if fp8: - test_cmd += ["--fp8"] + if quantization is not None: + test_cmd += ["--quantization", quantization] - result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) - if result.returncode != 0 or "NUMERICAL CHECK FAILED" in result.stderr.decode(): - raise AssertionError(result.stderr.decode()) + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 all_boolean = [True, False] -@pytest.mark.parametrize("fp8", all_boolean) -def test_distributed(fp8): - if fp8 and not fp8_available: +@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"]) +def test_distributed(quantization): + if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8) + if quantization == "fp8_cs" and not fp8_available: + pytest.skip(fp8_available) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + _run_test(quantization) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py new file mode 100644 index 0000000000..f5c186a3bc --- /dev/null +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import subprocess +from pathlib import Path +from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +import torch + + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +NUM_PROCS: int = torch.cuda.device_count() + + +def _run_test(fp_init, sharding_dims): + test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" + test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] + + if fp_init: + test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: + test_cmd += ["--sharding-dims", str(sharding_dims[0])] + elif len(sharding_dims) == 2: + test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] + else: + assert False + result = subprocess.run(test_cmd, env=os.environ, check=True) + + +@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") +@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") +@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") +@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) +@pytest.mark.parametrize("fp8_init", (False, True)) +def test_distributed(fp8_init, sharding_dims): + + # Skip invalid configurations + if torch.cuda.device_count() < 4: + pytest.skip("FSDP2 test requires at least 4 GPUs") + + if fp8_init and not fp8_available: + pytest.skip(reason_for_no_fp8) + + _run_test(fp8_init, sharding_dims) + + +def test_dummy() -> None: + """Dummy test + + pytest returns exit code 5 if all tests are skipped. + + """ + pass diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py index 2d863b3bba..4a1fd17be7 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,7 +11,7 @@ import transformer_engine_torch as tex from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -42,7 +42,7 @@ def run_dpa_with_cp( "causal", "no_mask", ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if kernel_backend == "FusedAttention" and qkv_format == "thd": + if qkv_format == "thd": if "causal" in config.attn_mask_type: config.attn_mask_type = "padding_causal" else: @@ -163,12 +163,10 @@ def run_dpa_with_cp( torch.tensor([q_input_shape[0]], dtype=torch.int32), ] ).cuda() - if kernel_backend == "FlashAttention": - cu_seqlens_q = cu_seqlens_q_padded[:-1] - else: - cu_seqlens_q = torch.cat( - [torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)] - ).cuda() + cu_seqlens_q = torch.clone(cu_seqlens_q_padded) + if kernel_backend == "FusedAttention": + cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() + cu_seqlens_q[-1] = cu_seqlens_q[-2] cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_q_padded else: @@ -178,6 +176,11 @@ def run_dpa_with_cp( k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda() dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) # create flash attention bias if config.attn_bias_type not in ["no_bias", "alibi"]: @@ -204,13 +207,11 @@ def run_dpa_with_cp( core_attention_bias=bias, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] - ), + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: - dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2) + dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: out.backward(dout) @@ -276,13 +277,11 @@ def run_dpa_with_cp( core_attention_bias=bias_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] - ), + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if fp8_mha: - dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2) + dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: out_.backward(dout_) @@ -311,7 +310,7 @@ def run_dpa_with_cp( dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size + cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) @@ -327,7 +326,7 @@ def run_dpa_with_cp( ).item() == 0 ) - cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True ) diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 4e995dabb1..bbdf8f22f2 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -18,14 +18,16 @@ from transformer_engine.pytorch.attention import ( DotProductAttention, MultiheadAttention, - RotaryPositionEmbedding, + _attention_backends, +) +from transformer_engine.pytorch.dot_product_attention.utils import ( + FlashAttentionUtils, get_attention_backend, - _flash_attn_2_3_plus, - _flash_attn_3_is_installed, check_set_window_size, AttentionParams, - _attention_backends, ) +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.constants import TE_DType import transformer_engine.pytorch.cpp_extensions as ext from transformer_engine.pytorch.cpp_extensions.fused_attn import ( @@ -48,6 +50,12 @@ from transformer_engine.pytorch.utils import get_cudnn_version import transformer_engine_torch as tex from transformer_engine_torch import NVTE_Fused_Attn_Backend +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) # Only run FP8 tests on H100 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() @@ -89,6 +97,8 @@ def __init__( num_layers: int = 1, bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), + total_requests: int = None, + max_ctx_len: int = None, ): self.batch_size = batch_size self.num_heads = num_heads @@ -107,6 +117,8 @@ def __init__( self.num_layers = num_layers self.bias_shape = bias_shape self.window_size = window_size + self.total_requests = total_requests + self.max_ctx_len = max_ctx_len @contextmanager @@ -129,6 +141,8 @@ def _get_attention_backends( deterministic: bool = False, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + is_training: bool = True, + inference_params: Optional[InferenceParams] = None, ) -> Tuple[List, List]: """Check if what attention backends support a model configuration""" @@ -158,6 +172,7 @@ def _get_attention_backends( fused_attn_backends = [] available_backends = None + flash_attention_backend = None fused_attention_backend = None def test(): @@ -183,21 +198,36 @@ def test(): deterministic=deterministic, fp8=fp8, fp8_meta=fp8_meta, + is_training=is_training, + inference_params=inference_params, ) - _, _, fused_attention_backend, _, available_backends = get_attention_backend( - attention_params - ) - return available_backends, fused_attention_backend + ( + use_flash_attention, + use_fused_attention, + flash_attention_backend, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) = get_attention_backend(attention_params) + # Set attention.py _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False + return available_backends, flash_attention_backend, fused_attention_backend backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} with logging_context(): for i in range(3): os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) _attention_backends["backend_selection_requires_update"] = True - available_backends, fused_attention_backend = test() + available_backends, flash_attention_backend, fused_attention_backend = test() if fused_attention_backend == FusedAttnBackend[backends[i]]: fused_attn_backends.append(fused_attention_backend) - return available_backends, fused_attn_backends + return available_backends, flash_attention_backend, fused_attn_backends model_configs_base = { @@ -237,20 +267,19 @@ def test_dot_product_attention( tols = dict(atol=1.5e-2, rtol=1.5e-2) config = model_configs[model] is_mla = config.head_dim_qk != config.head_dim_v + is_mqa_gqa = config.num_heads != config.num_gqa_groups if qkv_layout is None: if config.attn_type == "self": - qkv_layout = "sb3hd" if not is_mla else "sbhd_sbhd_sbhd" + qkv_layout = "sb3hd" if not is_mla and not is_mqa_gqa else "sbhd_sbhd_sbhd" else: - qkv_layout = "bshd_bs2hd" if not is_mla else "bshd_bshd_bshd" + qkv_layout = "bshd_bs2hd" if not is_mla and not is_mqa_gqa else "bshd_bshd_bshd" if "3" in qkv_layout and config.attn_type == "cross": pytest.skip("No need to test this layout for cross attention") - # Test backend availability - window_size = (-1, -1) - if swa: - window_size = [2, 2] - config.window_size = check_set_window_size(config.attn_mask_type, window_size) - available_backends, fused_attn_backends = _get_attention_backends( + if config.window_size == (-1, -1) and swa: + config.window_size = [2, 2] + config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -258,11 +287,17 @@ def test_dot_product_attention( pad_between_seqs=pad_between_seqs, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes - if pad_between_seqs and not ( - config.max_seqlen_q != config.max_seqlen_kv - and config.attn_mask_type in ["causal", "padding_causal"] + if ( + pad_between_seqs + and FlashAttentionUtils.is_installed + and not ( + config.max_seqlen_q != config.max_seqlen_kv + and config.attn_mask_type in ["causal", "padding_causal"] + ) + and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) ): flash_attn_supported = True @@ -334,16 +369,16 @@ def test_dot_product_attention( is_training, ) - if unfused_attn_supported and fused_attn_supported: - logging.info("[test_dot_product_attention]: unfused attn vs fused attn") - torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) - for i, _ in enumerate(unfused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols) + if unfused_attn_supported and fused_attn_supported: + logging.info("[test_dot_product_attention]: unfused attn vs fused attn") + torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + for i, _ in enumerate(unfused_attn_bwd): + torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: fused attn vs flash attn") torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) @@ -399,30 +434,41 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"), - "mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "mask_7_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_7_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "mask_8_0": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), + "mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "mask_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_5_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_8_1": ModelConfig( - 1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), + "mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), + "mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), + "mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"), + "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), + "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_10_0": ModelConfig( + 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "mask_10_1": ModelConfig( + 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" ), - "mask_9_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), - "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), - "mask_10_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_10_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), } @@ -531,22 +577,34 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), + "swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), + "swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), + "swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), + "swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "swa_6_1": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_2": ModelConfig( + 2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "swa_6_3": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), } -@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.") +@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_swa]) @pytest.mark.parametrize("model", model_configs_swa.keys()) @@ -568,7 +626,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): } -@pytest.mark.skipif(not _flash_attn_2_3_plus, reason="Flash-attn 2.3+ is required.") +@pytest.mark.skipif(not FlashAttentionUtils.v2_3_plus, reason="Flash-attn 2.3+ is required.") @pytest.mark.parametrize("dtype", param_types_lean) @pytest.mark.parametrize("model_configs", [model_configs_alibi_slopes]) @pytest.mark.parametrize("model", model_configs_alibi_slopes.keys()) @@ -619,18 +677,57 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_1": ModelConfig(3, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(8, 16, 4, 64, 128, 128, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_4": ModelConfig(8, 16, 1, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"), - "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), + "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), + "layout_2_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_2_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + ), + "layout_3_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_3_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) + ), + "layout_4_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_4_2": ModelConfig( + 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + ), + "layout_5_0": ModelConfig( + 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_1": ModelConfig( + 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + ), + "layout_5_2": ModelConfig( + 2, + 24, + 24, + 128, + 2048, + 4096, + 0.0, + "padding_causal_bottom_right", + "no_bias", + window_size=(4, 0), + ), } @@ -647,11 +744,13 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout): config = model_configs[model] if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = True") pad_between_seqs = True test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs ) if get_cudnn_version() >= (9, 3, 0): + logging.info("[test_dpa_qkv_layout_thd]: pad_between_seqs = False") # cuDNN 9.3.0+ is required to run pad_between_seqs = False/True in the same run pad_between_seqs = False test_dot_product_attention( @@ -691,9 +790,12 @@ def _run_dot_product_attention( ) seqlens_kv = seqlens_q if config.attn_type == "cross": - seqlens_q = torch.randint( - 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" - ) + if config.max_seqlen_q > 1: + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda") seqlens_kv = torch.randint( 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" ) @@ -1041,7 +1143,7 @@ def test_transformer_layer( workspace_opt = True # Test backend availability - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", @@ -1299,13 +1401,18 @@ def _run_transformer_layer( model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), + "fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), - "fp8_15": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"), + "fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"), + "fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), + "fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1354,8 +1461,18 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] + if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( + 9, + 7, + 0, + ): + pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if _flash_attn_3_is_installed and not is_training: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1381,7 +1498,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1433,7 +1554,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: fp8_mha=fp8_mha, ) - with fp8_model_init(enabled=fp8_mha): + with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: PE = RotaryPositionEmbedding(dim=config.head_dim_qk) @@ -1457,12 +1578,26 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if not is_training: mha = mha.eval() - seqlens_q = torch.full( - [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" - ) - seqlens_kv = torch.full( - [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1499,6 +1634,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: core_attention_bias_type=config.attn_bias_type, is_first_microbatch=None, rotary_pos_emb=rotary_pos_emb, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) if is_training: out.backward(out_grad) @@ -1528,13 +1665,33 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): config = model_configs_fp8_vs_f16[model] + # TODO(cyang): think of another way to verify dropout results + # test cuDNN FP8 dropout + # 1. we modify the config here to not affect mha_fp8_vs_f16 tests + # 2. there is no other backend that implements dropout the same way as cuDNN FP8, and as an + # indirect verification method, we create Q/K/V as all 1s and check if O is all 1s + # 3. we avoid running FP16/BF16 kernels as they do not have dropout support on Blackwell + # if "padding" not in config.attn_mask_type and "causal" not in config.attn_mask_type: + # if get_device_compute_capability() >= (10, 0): + # config.dropout_p = 0.1 + + if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( + 9, + 7, + 0, + ): + pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: pytest.skip("qkv_layout not applicable for MQA/GQA") os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if _flash_attn_3_is_installed and not is_training: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1551,17 +1708,23 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): dtype, config, True, qkv_layout, is_training ) - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") - fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training - ) + if config.dropout_p == 0.0: + # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( + dtype, config, False, qkv_layout, is_training + ) atol = 5e-1 rtol = 5e-2 - rmse_tol = 0.1 + rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if _flash_attn_3_is_installed and not is_training: + if ( + FlashAttentionUtils.v3_is_installed + and not is_training + and "padding" not in config.attn_mask_type + ): _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1571,27 +1734,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rtol, rmse_tol, ) - _error( - fused_attn_fwd_fp8, - fused_attn_fwd_f16, - "fused_attn_fwd_fp8", - "fused_attn_fwd_f16", - atol, - rtol, - rmse_tol, - ) - if is_training: - for i, _ in enumerate(fused_attn_bwd_f16): - logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - _error( - fused_attn_bwd_fp8[i], - fused_attn_bwd_f16[i], - f"fused_attn_bwd_fp8[{i}]", - f"fused_attn_bwd_f16[{i}]", - atol, - rtol, - rmse_tol, - ) + if config.dropout_p != 0.0: + # test cuDNN FP8 dropout + assert torch.all( + fused_attn_fwd_fp8 == 1 + ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." + else: + _error( + fused_attn_fwd_fp8, + fused_attn_fwd_f16, + "fused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + _error( + fused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"fused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + ) def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): @@ -1630,12 +1799,26 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if not is_training: dpa = dpa.eval() - seqlens_q = torch.full( - [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" - ) - seqlens_kv = torch.full( - [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.full( + [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" + ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) @@ -1664,7 +1847,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layout = layout.replace("h", "hg") layout = layout.replace("t", "tg") tensor_shape = [dim_to_num[j] for j in layout.split("_")] - tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") + if config.dropout_p == 0.0: + tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda") + else: + # test cuDNN FP8 dropout + tensor = torch.ones(tensor_shape, dtype=dtype, device="cuda") tensor_count = 1 split_dim = 0 for dim, l in enumerate(layout.split("_")): @@ -1700,7 +1887,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, - is_first_microbatch=True, ) if is_training: out.backward(out_grad) @@ -1753,7 +1939,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol = 5e-1 rtol = 5e-1 - rmse_tol = 0.1 + rmse_tol = 0.13 _error( fused_attn_fwd_fp8, unfused_attn_fwd_f16, @@ -1907,7 +2093,9 @@ def forward( workspace: torch.Tensor, is_training: bool, mask_type: str, + quantizers: list[Quantizer], ) -> torch.Tensor: + qkv_dtype = inp.dtype assert inp.dim() == 2 in_features = qkv_weight.shape[-1] @@ -1915,83 +2103,53 @@ def forward( d = in_features // h b = cu_seqlens.numel() - 1 - fp8_dtype_forward = fp8.get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2] + dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3] - inp_fp8, inp_t_fp8 = ext.fp8_cast_transpose_fused( - inp, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + inp_fp8 = input_quantizer(inp) - qkv_weight_fp8, qkv_weight_t_fp8 = ext.fp8_cast_transpose_fused( - qkv_weight, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, - ) + qkv_weight_fp8 = qkv_weight_quantizer(qkv_weight) - M = None - ZInv = None - philox_unpacked = None - - qkv, _ = ext.fp8_gemm( + qkv, *_ = ext.general_gemm( qkv_weight_fp8, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, inp_fp8, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - torch.uint8, workspace, bias=qkv_bias, - use_bias=True, - out_index=META_QKV, - fp8_meta_tensor=fp8_meta["scaling_fwd"], + out_dtype=qkv_weight_fp8.dtype, + quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, - D_dtype=fp8_dtype_forward, ) qkv = qkv.view(-1, 3, h, d) - qkv_fp16 = ( - ext.cast_from_fp8( - qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, tex.DType.kFloat16 - ) - .view(b, max_s, 3, h, d) - .contiguous() - ) + qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") if cudnn_frontend_version == 1: qkv = qkv.view(b, max_s, 3, h, d) # bs3hd # FMHA - out, aux_ctx_tensors, *rest = fused_attn_fwd( + q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :] + k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :] + v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :] + q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape) + k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape) + v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape) + + out, aux_ctx_tensors = fused_attn_fwd( is_training, max_s, max_s, cu_seqlens, cu_seqlens, - qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], - qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], - qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], - fp8_dtype_forward, + q, + k, + v, + qkv_dtype, FusedAttnBackend["FP8"], - None, - None, - None, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset attn_scale=None, dropout=p_dropout, fast_zero_fill=fast_zero_fill, @@ -1999,20 +2157,18 @@ def forward( attn_bias_type="no_bias", attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", rng_gen=None, + o_quantizer=o_quantizer, + s_quantizer=s_quantizer, ) - M, ZInv, philox_unpacked = aux_ctx_tensors - - ctx.save_for_backward( - inp_t_fp8, - qkv_weight_t_fp8, - workspace, - qkv, - out, - fp8_meta["scaling_fwd"].scale, - fp8_meta["scaling_fwd"].scale_inv, + tensors_to_save, tensor_objects = prepare_for_saving( + q, k, v, inp_fp8, qkv_weight_fp8, workspace, out ) + + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.aux_ctx_tensors = aux_ctx_tensors + ctx.qkv_dtype = qkv_dtype ctx.fp8_meta = fp8_meta ctx.cu_seqlens = cu_seqlens ctx.p_dropout = p_dropout @@ -2023,58 +2179,46 @@ def forward( ctx.mask_type = mask_type ctx.dtype = inp.dtype + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.S_quantizer = s_quantizer + out = out.view(-1, in_features) # (bs)(hd) - out_fp16 = ext.cast_from_fp8( - out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16 - ) + out_fp16 = out.dequantize() torch.save(out_fp16, "out.pt") # (bs)(hd) return out_fp16 @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: with torch.cuda.nvtx.range("_DPA"): - ( - inp_t_fp8, - qkv_weight_t_fp8, - workspace, - qkv, - out, - fwd_scales, - fwd_scale_inverses, - ) = ctx.saved_tensors - fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + saved_tensors = ctx.saved_tensors + (q, k, v, inp_fp8, qkv_weight_fp8, workspace, out) = restore_from_saved( + ctx.tensor_objects, saved_tensors + ) - proj_dgrad = ext.cast_to_fp8( - grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) # (bs)(hd) + proj_dgrad = ctx.dO_quantizer(grad_output) + fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, ctx.max_s, ctx.cu_seqlens, ctx.cu_seqlens, - qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :], - qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :], - qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :], + q, + k, + v, out, proj_dgrad.view_as(out), - fp8_dtype_forward, + ctx.qkv_dtype, fp8_dtype_backward, ctx.aux_ctx_tensors, FusedAttnBackend["FP8"], None, None, - fwd_scale_inverses[META_QKV], # d_scale_qkv, - fwd_scale_inverses[META_S], # d_scale_s, - fwd_scale_inverses[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.S_quantizer, + ctx.dP_quantizer, + ctx.dQKV_quantizer, attn_scale=None, dropout=ctx.p_dropout, fast_zero_fill=ctx.fast_zero_fill, @@ -2083,58 +2227,42 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", ) dim = 2 if cudnn_frontend_version == 1 else 1 - dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype) - dqkv_shape = list(dq.shape) + dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype) + dqkv_shape = list(dq._data.shape) dqkv_shape.insert(dim, 3) - dqkv_stride = list(dq.stride()) + dqkv_stride = list(dq._data.stride()) dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3)) - dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd + dqkv.set_( + dq._data.untyped_storage(), dq._data.storage_offset(), dqkv_shape, dqkv_stride + ) # bs3hd dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size) - dqkv_c_fp16 = ext.cast_from_fp8( - dqkv_c, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - tex.DType.kFloat16, - ) + dqkv_c = dq.make_like(tensor=dq, data=dqkv_c, shape=dqkv_c.shape) + dqkv_c_fp16 = dqkv_c.dequantize() torch.save(dqkv_c_fp16, "dqkv.pt") - qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused( - dqkv_c, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.dtype, - ) + qkv_bgrad, dqkv = ext.bgrad_quantize(dqkv_c_fp16, ctx.dQKV_quantizer) + dqkv_c._transpose = None + dqkv_c._create_transpose() # QKV DGRAD - qkv_dgrad, _ = ext.fp8_gemm( - qkv_weight_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_WEIGHT, - fp8_dtype_forward, + qkv_dgrad, *_ = ext.general_gemm( + qkv_weight_fp8, dqkv_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - META_DQKV, - fp8_dtype_backward, - ctx.dtype, workspace, + ctx.dtype, use_split_accumulator=_2X_ACC_DGRAD, + layout="NN", ) + # QKV WGRAD - qkv_wgrad, _ = ext.fp8_gemm( - inp_t_fp8, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - dqkv_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - META_DQKV, - fp8_dtype_backward, - ctx.dtype, + qkv_wgrad, *_ = ext.general_gemm( + inp_fp8, + dqkv, workspace, + ctx.dtype, use_split_accumulator=_2X_ACC_WGRAD, + layout="NT", ) return ( @@ -2192,7 +2320,7 @@ def forward( cu_seqlens, max_s, ) -> torch.Tensor: - with self.prepare_forward(inp, None, num_gemms=3) as inp: + with self.prepare_forward(inp, num_gemms=3) as inp: out = _custom_mha_fp8.apply( inp, self.qkv_weight, @@ -2206,5 +2334,6 @@ def forward( self.workspace, self.training, self.mask_type, + self.quantizers, ) return out diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 1007d6aa34..303c39e6c0 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -1,19 +1,18 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import os -import pytest import subprocess -from test_fused_attn import ModelConfig -from transformer_engine.pytorch.attention import ( - _flash_attn_2_plus, - _flash_attn_2_3_plus, -) + +import pytest +import torch from transformer_engine.pytorch.utils import ( get_device_compute_capability, get_cudnn_version, ) +from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils +from test_fused_attn import ModelConfig model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias @@ -38,7 +37,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): args = [ - "python", + "python3", "-m", "torch.distributed.launch", "--nproc-per-node=" + str(num_gpus_per_node), @@ -51,13 +50,17 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args -@pytest.mark.skipif(not _flash_attn_2_plus, reason="Flash-attn 2.0+ is required.") +@pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 + if num_gpus > torch.cuda.device_count(): + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + config = model_configs_flash_attn[model] if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") @@ -77,7 +80,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( - num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, + num_gpus_per_node=num_gpus, dtype=dtype, model=model, qkv_format=qkv_format, @@ -115,6 +118,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("fp8_mha", [False, True]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 + if num_gpus > torch.cuda.device_count(): + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): @@ -123,18 +130,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("FP8 attention is only supported on sm90+!") config = model_configs_fused_attn[model] - if qkv_format == "thd" and config.num_heads != config.num_gqa_groups: - pytest.skip("THD format does not support QGA/MQA yet!") if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": pytest.skip("CP implementation with KV all-gather does not support THD format yet!") if qkv_format == "thd" and "a2a" in cp_comm_type: pytest.skip("CP implementation with QKVO A2A does not support THD format yet!") - if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a": - pytest.skip( - "Sliding window attention only can be supported with the implementation of QKVO A2A!" - ) if dtype == "fp8" and cp_comm_type == "all_gather": pytest.skip( "CP implementation with KV all-gather does not support FP8 + context parallelism yet!" @@ -145,6 +146,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("FP8 attention cannot work with bias yet!") if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("FP8 attention cannot work with sliding window yet!") + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): + pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias": pytest.skip("CP implementation with KV all-gather does not support bias yet!") if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias": @@ -159,7 +162,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha subprocess.run( get_bash_arguments( - num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, + num_gpus_per_node=num_gpus, dtype=dtype, model=model, qkv_format=qkv_format, diff --git a/tests/pytorch/fused_attn/test_paged_attn.py b/tests/pytorch/fused_attn/test_paged_attn.py new file mode 100644 index 0000000000..f810f11195 --- /dev/null +++ b/tests/pytorch/fused_attn/test_paged_attn.py @@ -0,0 +1,699 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections import OrderedDict +from typing import List +import os +import logging +import math + +import pytest +import torch + +from torch.distributions import Exponential +from transformer_engine.pytorch import make_graphed_callables +from transformer_engine.common import recipe +from transformer_engine.pytorch import fp8_autocast, fp8_model_init +from transformer_engine.pytorch.transformer import ( + TransformerLayer, +) +from transformer_engine.pytorch.attention import DotProductAttention +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils +from transformer_engine.pytorch.utils import ( + get_device_compute_capability, + init_method_normal, + scaled_init_method_normal, + is_bf16_compatible, +) +from test_fused_attn import ( + ModelConfig, + reset_rng_states, + _get_attention_backends, +) + +# Initialize RNG state +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +_cpu_rng_state = torch.get_rng_state() +_cuda_rng_state = torch.cuda.get_rng_state() + + +param_types = [torch.float16] +if is_bf16_compatible(): + param_types.append(torch.bfloat16) + +model_configs_infer = { + # test: b, h, hg, d, sq, skv, p, mask, bias + "infer_0": ModelConfig( + 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 + ), + "infer_1": ModelConfig( + 2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 + ), +} + +qkv_formats = ["bshd", "sbhd", "thd"] + + +def to_pretty_string(x: torch.Tensor): + return "[" + ",".join(["{:>3s}".format(str(i)) for i in x.tolist()]) + "]" + + +def round_up(a: int, b: int): + return b * math.ceil(a / b) + + +class Simulation: + def __init__( + self, + total_requests: int = 10, + max_seq_len: int = 1024, + max_ctx_len: int = 128, + max_batch_size: int = 5, + poisson_rate: float = 1, + ): + self.total_requests = total_requests + self.max_seq_len = max_seq_len + self.max_batch_size = max_batch_size + self.poisson_rate = poisson_rate + + # calculate maximum context/generation length + self.max_ctx_len = max_ctx_len + self.max_gen_len = max_seq_len - self.max_ctx_len + + # simulate sequence ids in monotonically increasing fashion + self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu") + + # simulate context lengths in Uniform distribution + self.context_lens = torch.randint( + 1, self.max_ctx_len, [total_requests], dtype=torch.int32, device="cpu" + ) + + # simulate gen lengths in Exponential distribution + gen_dist = Exponential(1 / self.max_gen_len) + gen_lens = gen_dist.sample((total_requests,)) + gen_lens = torch.where(gen_lens > self.max_gen_len, self.max_gen_len, gen_lens).to( + dtype=torch.int32, device="cpu" + ) + self.gen_lens = torch.where(gen_lens == 0, 1, gen_lens).to(dtype=torch.int32, device="cpu") + + # simulate arrival times in Poisson distribution + if poisson_rate is None: + self.poisson_rate = torch.randint(1, max_batch_size, [1]).item() + interval_dist = Exponential(self.poisson_rate) + arrival_intervals = interval_dist.sample((total_requests,)) + self.arrival_times = torch.cumsum(arrival_intervals, dim=0).to( + dtype=torch.int32, device="cpu" + ) + self.last_arrival = self.arrival_times.max().item() + + # initialize tensors + self.reset() + + def reset(self): + self.t = 0 + self.request_delays = torch.zeros([self.total_requests], dtype=torch.int32, device="cpu") + self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32, device="cpu") + self.serving_times = self.arrival_times + self.complete_times = self.arrival_times + + # batch info at step t + self.t_seq_ids = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.t_ctx_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.t_gen_lens = torch.Tensor([]).to(dtype=torch.bool, device="cpu") + self.t_total_lens = self.t_ctx_lens + self.t_gen_lens + self.t_batch_size = 0 + + # step info from step t-1 to t + self.step_lens = torch.Tensor([]).to(dtype=torch.int32, device="cpu") + + def print_setup(self, logger): + logger.info("Simulation:") + logger.info(" {:<31s}: {}".format("total number of requests", self.total_requests)) + logger.info(" {:<31s}: {}".format("max sequence length per request", self.max_seq_len)) + logger.info(" {:<31s}: {}".format("max context length", self.max_ctx_len)) + logger.info(" {:<31s}: {}".format("max generation length", self.max_gen_len)) + logger.info(" {:<31s}: {}".format("max batch size per iteration", self.max_batch_size)) + logger.info(" {:<31s}: {}".format("Poisson rate", self.poisson_rate)) + logger.info(" {:<17s}: {}".format("sequence ids", to_pretty_string(self.seq_ids))) + logger.info(" {:<17s}: {}".format("arrival times", to_pretty_string(self.arrival_times))) + logger.info(" {:<17s}: {}".format("context lengths", to_pretty_string(self.context_lens))) + logger.info(" {:<17s}: {}".format("generation lengths", to_pretty_string(self.gen_lens))) + + def print_step(self, logger): + logger.info(f"Step t = {self.t}:") + logger.info(" {:<15s}: {}".format("t_batch_size", self.t_batch_size)) + logger.info(" {:<15s}: {}".format("t_seq_ids", self.t_seq_ids.tolist())) + logger.info(" {:<15s}: {}".format("t_ctx_lens", self.t_ctx_lens.tolist())) + logger.info(" {:<15s}: {}".format("t_gen_lens", self.t_gen_lens.tolist())) + logger.info(" {:<15s}: {}".format("t_total_lens", self.t_total_lens.tolist())) + logger.info(" {:<15s}: {}".format("step_lens", self.step_lens.tolist())) + + def print_summary(self, logger): + logger.info("Summary:") + logger.info(" {:<18s}: {}".format("total steps taken", self.t)) + logger.info(" {:<18s}: {}".format("arrival_times", to_pretty_string(self.arrival_times))) + logger.info(" {:<18s}: {}".format("serving_times", to_pretty_string(self.serving_times))) + logger.info(" {:<18s}: {}".format("total_gen_lens", to_pretty_string(self.gen_lens))) + logger.info(" {:<18s}: {}".format("complete_times", to_pretty_string(self.complete_times))) + + def add_new_seqs(self, new_seq_ids): + # get ctx_lens for new seqs + self.t_seq_ids = torch.cat([self.t_seq_ids, new_seq_ids], dim=0) + self.t_ctx_lens = torch.cat([self.t_ctx_lens, self.context_lens[new_seq_ids]], dim=0) + gen_lens = torch.Tensor([0] * len(new_seq_ids)).to(dtype=torch.int32, device="cpu") + self.t_gen_lens = torch.cat([self.t_gen_lens, gen_lens], dim=0) + + # append new seqs' ctx_lens to step_lens + self.step_lens = torch.cat([self.step_lens, self.context_lens[new_seq_ids]], dim=0) + + def remove_finished(self): + # figure out which seqs have finished + finished = torch.where(self.t_gen_lens - self.gen_lens[self.t_seq_ids] < 0, False, True).to( + dtype=torch.bool, device="cpu" + ) + self.t_seq_ids = self.t_seq_ids[~finished] + self.t_ctx_lens = self.t_ctx_lens[~finished] + self.t_gen_lens = self.t_gen_lens[~finished] + + # add ones for unfinished seqs to step_lens + self.step_lens = torch.ones([len(self.t_seq_ids)], dtype=torch.int32, device="cpu") + + def step(self, dynamic_fill: bool = True): + # remove finished seqs + if self.t != 0: + self.remove_finished() + + # get allowed new seqs + arrived_seq_ids = torch.where(self.arrival_times == self.t, True, False).nonzero().view(-1) + queuing_seq_ids = torch.cat([self.delayed_seq_ids, arrived_seq_ids], dim=0) + if dynamic_fill: + allowed_num_new_seqs = self.max_batch_size - len(self.t_seq_ids) + else: + allowed_num_new_seqs = 0 if len(self.t_seq_ids) else self.max_batch_size + if len(queuing_seq_ids) > allowed_num_new_seqs: + new_seq_ids = queuing_seq_ids[:allowed_num_new_seqs] + self.delayed_seq_ids = queuing_seq_ids[allowed_num_new_seqs:] + self.request_delays[self.delayed_seq_ids.tolist()] += 1 + else: + new_seq_ids = queuing_seq_ids + self.delayed_seq_ids = torch.Tensor().to(dtype=torch.int32) + + # add new seqs to batch + self.add_new_seqs(new_seq_ids) + + # update batch variables + self.t_batch_size = len(self.t_seq_ids) + self.t_total_lens = self.t_ctx_lens + self.t_gen_lens + + +def get_model( + module: torch.nn.Module, + config: ModelConfig, + dtype: torch.dtype, + backend: str = "FusedAttention", + qkv_format: str = "bshd", + num_layers: int = 1, + mode: str = "reference", + is_fp8: bool = False, +): + reset_rng_states() + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, num_layers) + + if mode == "reference": + attn_mask_type = "causal" + qkv_format = "bshd" + if mode == "inference": + attn_mask_type = "padding_causal" if backend != "FusedAttention" else "padding" + + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=is_fp8, + fp8_mha=False, + ) + + if module == "TransformerLayer": + hidden_size = config.head_dim_qk * config.num_heads + with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): + model = [ + TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=4 * hidden_size, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + hidden_dropout=0.0, + attention_dropout=config.dropout_p, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + layer_number=layer_number, + kv_channels=config.head_dim_qk, + self_attn_mask_type=attn_mask_type, + fuse_qkv_params=False, + params_dtype=dtype, + attn_input_format=qkv_format, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] + if module == "DotProductAttention": + with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): + model = [ + DotProductAttention( + kv_channels=config.head_dim_qk, + num_attention_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + layer_number=layer_number, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + ) + .cuda() + .eval() + for layer_number in range(1, num_layers + 1) + ] + return model + + +def generate_args( + module: torch.nn.Module, + config: ModelConfig, + dtype: torch.dtype, + qkv_format: str = "bshd", + mode: str = "full_inputs", +): + # full inputs used as reference + if mode == "full_inputs": + warmup = False + shapes = [] + if module == "TransformerLayer": + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_heads * config.head_dim_qk] + ) + if module == "DotProductAttention": + shapes.append( + [config.total_requests, config.max_seqlen_kv, config.num_heads, config.head_dim_qk] + ) + shapes.append( + [ + config.total_requests, + config.max_seqlen_kv, + config.num_gqa_groups, + config.head_dim_qk, + ] + ) + shapes.append( + [ + config.total_requests, + config.max_seqlen_kv, + config.num_gqa_groups, + config.head_dim_v, + ] + ) + # sample args used for cuda graph warmup + elif mode == "sample_args": + warmup = True + shapes = [] + if qkv_format == "bshd": + shape = [config.batch_size, config.max_ctx_len] + if qkv_format == "sbhd": + shape = [config.max_ctx_len, config.batch_size] + if qkv_format == "thd": + shape = [config.batch_size * config.max_ctx_len] + if module == "TransformerLayer": + shapes.append([*shape, config.num_heads * config.head_dim_qk]) + if module == "DotProductAttention": + shapes.append([*shape, config.num_heads, config.head_dim_qk]) + shapes.append([*shape, config.num_gqa_groups, config.head_dim_qk]) + shapes.append([*shape, config.num_gqa_groups, config.head_dim_v]) + + num_tensors = len(shapes) + if warmup: + return [ + torch.ones( + *shapes[i], + device="cuda", + dtype=dtype, + ) + for i in range(num_tensors) + ] + elif module == "TransformerLayer": + return [ + 0.01 + * torch.randint( + -100, + 100, + shapes[i], + device="cuda", + dtype=dtype, + ) + for i in range(num_tensors) + ] + elif module == "DotProductAttention": + return [ + 0.1 + * torch.randn( + *shapes[i], + device="cuda", + dtype=dtype, + ) + for i in range(num_tensors) + ] + + +def get_tols(module, backend, dtype): + if module == "TransformerLayer": + tols = { + torch.half: (5e-3, 5e-3), + torch.bfloat16: (3.5e-2, 3.5e-2), + } + if module == "DotProductAttention": + tols = { + torch.half: (1e-3, 1e-3), + torch.bfloat16: (1e-2, 1e-3), + torch.float8_e4m3fn: (2e-2, 3e-2), + } + return tols[dtype] + + +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model", model_configs_infer.keys()) +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("is_paged", [False, True]) +@pytest.mark.parametrize("backend", ["FusedAttention", "FlashAttention", "UnfusedAttention"]) +@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention"]) +@pytest.mark.parametrize("is_cuda_graph", [False, True]) +@pytest.mark.parametrize("is_fp8", [False, True]) +def test_paged_attn(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): + reset_rng_states() + logger = logging.getLogger("test_paged_attn") + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=is_fp8, + fp8_mha=False, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe + + config = model_configs_infer[model] + num_layers = 2 if module == "TransformerLayer" and backend != "FusedAttention" else 1 + # flash-attn v2 requires page_size >= 256 + if backend == "FlashAttention" and not fa_utils.v3_is_installed: + config_max_seqlen_q = config.max_seqlen_q + config_max_seqlen_kv = config.max_seqlen_kv + config.max_seqlen_q = 256 + config.max_seqlen_kv = 256 + + # create a real-life simulation + max_batch_size = config.batch_size + page_size = None + total_num_pages = None + if is_paged: + page_size = 256 if backend == "FlashAttention" and not fa_utils.v3_is_installed else 1 + config.max_seqlen_kv = round_up(config.max_seqlen_kv, page_size) + total_num_pages = int(max_batch_size * config.max_seqlen_kv / page_size) + else: + config.max_seqlen_kv = round_up(config.max_seqlen_kv, 64) + sim = Simulation( + total_requests=config.total_requests, + max_seq_len=config.max_seqlen_kv, + max_ctx_len=config.max_ctx_len, + max_batch_size=max_batch_size, + poisson_rate=2, + ) + sim.print_setup(logger) + + # initialize inference_params + inference_params = InferenceParams( + max_batch_size=max_batch_size, + max_seqlen_kv=config.max_seqlen_kv, + num_heads_kv=config.num_gqa_groups, + head_dim_k=config.head_dim_qk, + head_dim_v=config.head_dim_v, + dtype=dtype, + is_paged=is_paged, + page_size=page_size, + total_num_pages=total_num_pages, + max_ctx_len=config.max_ctx_len, + qkv_format=qkv_format, + ) + if module == "DotProductAttention": + for layer_number in range(1, num_layers + 1): + inference_params.allocate_memory(layer_number) + + # figure out supported backends + inference_params_qkv_format = "bshd" + qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) + if is_paged: + qkv_layout = "paged_kv_" + qkv_layout + available_backends, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=False, + is_training=False, + fp8=is_fp8, + fp8_meta=fp8_meta, + inference_params=inference_params, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if backend == "FlashAttention" and not flash_attn_supported: + pytest.skip("FlashAttention backend is not supported") + if backend == "FusedAttention" and not fused_attn_supported: + pytest.skip("FusedAttention backend is not supported") + if backend == "UnfusedAttention" and not unfused_attn_supported: + pytest.skip("UnfusedAttention backend is not supported") + os.environ["NVTE_FLASH_ATTN"] = str(int(backend == "FlashAttention")) + os.environ["NVTE_FUSED_ATTN"] = str(int(backend == "FusedAttention")) + os.environ["NVTE_UNFUSED_ATTN"] = str(int(backend == "UnfusedAttention")) + if backend == "UnfusedAttention" and is_cuda_graph: + pytest.skip("CUDA graph is not supported for UnfusedAttention backend") + # TransformerLayer FP8 TN Gemm currently requires %8=0 + if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"): + pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") + + # create full model + logger.info("=== Generating all tokens at once ===") + model = get_model(module, config, dtype, backend, qkv_format, num_layers, mode="reference") + + # generate data for all requests + full_inputs = generate_args(module, config, dtype, qkv_format="bshd", mode="full_inputs") + + # generate reference results + if module == "DotProductAttention": + full_output = full_inputs + for m in model: + full_output = m( + *full_output if isinstance(full_output, List) else full_output, + ) + if module == "TransformerLayer": + full_output = full_inputs + for m in model: + full_output = m( + full_output[0] if isinstance(full_output, List) else full_output, + ) + + # create inference model + logger.info("=== Generating one token at a time ===") + model = get_model( + module, + config, + dtype, + backend, + qkv_format, + num_layers, + mode="inference", + is_fp8=is_fp8, + ) + + # graph the model if necessary + if is_cuda_graph: + t_seq_ids = torch.range(0, max_batch_size, dtype=torch.int32, device="cpu") + step_lens = config.max_ctx_len * torch.ones(max_batch_size, dtype=torch.int32, device="cpu") + step_dict = OrderedDict(zip(t_seq_ids.tolist(), step_lens.tolist())) + inference_params.pre_step(step_dict) + + sample_args = generate_args( + module, config, dtype, qkv_format=qkv_format, mode="sample_args" + ) + sample_kwargs = {} + sample_kwargs["cu_seqlens_q"] = torch.linspace( + 0, + config.batch_size * config.max_ctx_len, + steps=config.batch_size + 1, + device="cuda", + dtype=torch.int32, + ) + sample_kwargs["cu_seqlens_kv"] = torch.linspace( + 0, + config.batch_size * config.max_ctx_len, + steps=config.batch_size + 1, + device="cuda", + dtype=torch.int32, + ) + sample_kwargs["inference_params"] = inference_params + sample_kwargs["max_seqlen_q"] = config.max_ctx_len + sample_kwargs["max_seqlen_kv"] = config.max_seqlen_kv + + model = [ + make_graphed_callables( + model[i], + sample_args, + num_warmup_iters=10, + fp8_enabled=is_fp8, + sample_kwargs=sample_kwargs, + fp8_recipe=fp8_recipe, + ) + for i in range(num_layers) + ] + + sim.reset() + inference_params.reset() + step_dict = OrderedDict() + + # simulate step by step + # t-1: ... + # compute for seq_ids = [0, 1, 2], ctx_lens = [5, 2, 3], gen_lens = [2, 9, 4], + # batch_size = 3, step_lens = [1, 1, 1] + # increase counter for gen_lens = [3, 10, 5] + # t: detect seq 1 is finished since expected_gen_lens = [12, 10, 15] + # add two new seqs 3 and 4, with ctx lens 10 and 11 + # compute for seq_ids = [0, 2, 3, 4], ctx_lens = [5, 3, 10, 11], gen_lens = [3, 5, 0, 0], + # batch_size = 4, step_lens = [1, 1, 10, 11] + # increase counter for gen_lens = [3, 5, 1, 1] + max_tokens = config.batch_size * config.max_ctx_len + while True: + # prepare batch for the current step + dynamic_fill = True # inference_params.is_paged + sim.step(dynamic_fill=dynamic_fill) + sim.print_step(logger) + + if sim.t_batch_size == 0: + # all sequences are finished + if sim.t > sim.last_arrival: + sim.serving_times = sim.arrival_times + sim.request_delays + sim.complete_times = sim.serving_times + sim.gen_lens + break + # not finished; run next iteration + else: + sim.t += 1 + continue + + # create incremental input + batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size + max_seqlen_q = sim.max_ctx_len if is_cuda_graph else max(sim.step_lens).item() + num_tensors = len(full_inputs) + if qkv_format == "thd": + incremental_inputs = [] + for i in range(num_tensors): + inp = full_inputs[i] + inc_inp = torch.Tensor().to(dtype=dtype, device="cuda") + for i, seq in enumerate(sim.t_seq_ids): + start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + end = sim.t_total_lens[i].item() + inc_inp = torch.cat([inc_inp, inp[seq, start:end]], dim=0) + if is_cuda_graph: + inc_inp = torch.cat( + [ + inc_inp, + torch.zeros( + max_tokens - sum(sim.step_lens), + *inp.shape[2:], + dtype=dtype, + device=inc_inp.device, + ), + ], + dim=0, + ) + incremental_inputs.append(inc_inp) + else: + incremental_inputs = [] + for i in range(num_tensors): + inp = full_inputs[i] + inc_inp = torch.zeros( + batch_size, + max_seqlen_q, + *inp.shape[2:], + dtype=dtype, + device="cuda", + ) + for i, seq in enumerate(sim.t_seq_ids): + start = (sim.t_total_lens[i] - sim.step_lens[i]).item() + end = sim.t_total_lens[i].item() + inc_inp[i, : sim.step_lens[i], :] = inp[seq, start:end] + if qkv_format == "sbhd": + inc_inp = inc_inp.transpose(0, 1).contiguous() + incremental_inputs.append(inc_inp) + + # run step + batch_size = max_batch_size if is_cuda_graph else sim.t_batch_size + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1 : sim.t_batch_size + 1] = torch.cumsum(sim.step_lens, dim=0) + cu_seqlens_kv = cu_seqlens_q.clone() + step_dict = OrderedDict(zip(sim.t_seq_ids.tolist(), sim.step_lens.tolist())) + inference_params.pre_step(step_dict) + if inference_params.is_paged: + inference_params.cache_manager.print_cache() + incremental_output = incremental_inputs + with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): + for m in model: + incremental_output = m( + *incremental_output, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + inference_params=inference_params, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + ) + incremental_output = [incremental_output] + incremental_output = incremental_output[0] + + # compare results + atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) + for i, seq in enumerate(sim.t_seq_ids): + token_index = sim.step_lens[i] - 1 + if qkv_format == "bshd": + torch.testing.assert_close( + full_output[seq, sim.t_total_lens[i] - 1, :], + incremental_output[i, sim.step_lens[i] - 1, :], + atol=atol, + rtol=rtol, + ) + if qkv_format == "sbhd": + torch.testing.assert_close( + full_output[seq, sim.t_total_lens[i] - 1, :], + incremental_output[sim.step_lens[i] - 1, i, :], + atol=atol, + rtol=rtol, + ) + if qkv_format == "thd": + torch.testing.assert_close( + full_output[seq, sim.t_total_lens[i] - 1, :], + incremental_output[cu_seqlens_q[i + 1] - 1, :], + atol=atol, + rtol=rtol, + ) + + sim.t += 1 + sim.t_gen_lens = sim.t_gen_lens + 1 + + # last value in complete_times should be equal to sim.t + sim.serving_times = sim.arrival_times + sim.request_delays + sim.complete_times = sim.serving_times + sim.gen_lens + sim.print_summary(logger) + + if backend == "FlashAttention" and not fa_utils.v3_is_installed: + config.max_seqlen_q = config_max_seqlen_q + config.max_seqlen_kv = config_max_seqlen_kv diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py new file mode 100644 index 0000000000..1895b31d78 --- /dev/null +++ b/tests/pytorch/references/ref_per_tensor_cs.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import torch +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType_To_Torch + + +# compute amax and scale +def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales): + x_fp32 = x.to(torch.float32) + amax = torch.amax(torch.abs(x_fp32)).view(1) + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + # option1: set scale to fp32 max when scale is inf + scale = torch.where(scale == torch.inf, torch.finfo(torch.float32).max, scale) + # option2: when scale is inf, set scale to 1 + scale = torch.where(scale == torch.inf, 1.0, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + # TODO: If/when adding a URM option an option is to cap to 126 + # rather than allowing the full range of FP32 (2 - 2^23) x 2^127 + # addresses cases where adding a mantissa overflows into inf scales. + # Not necessary currently without additional scale smudging options. + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax + + +def _multi_dim_transpose(tensor): + # Get the number of dimensions + dims = list(range(len(tensor.shape))) + + if len(dims) <= 1: + return tensor + + # circular shift of shapes + new_order = [] + new_order.append(dims[-1]) + for i in range(len(dims) - 1): + new_order.append(dims[i]) + + # Permute the tensor according to the new order + output_tensor = tensor.permute(new_order).contiguous() + + return output_tensor + + +# current scaling reference quantization +def ref_per_tensor_cs_cast( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + return_transpose: bool = False, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, +) -> torch.Tensor: + + quant_dtype_torch = TE_DType_To_Torch[fp8_dtype] + scale, scale_inv, _ = _ref_compute_amax_scale( + tensor, + quant_dtype_torch, + amax_epsilon, + force_pow_2_scales, + ) + + qx = (tensor.float() * scale).to(quant_dtype_torch) + sx = scale_inv + qx_t = None + sx_t = None + + if tensor.shape == torch.Size([]): + qx = qx.view([]) + + if return_transpose: + qx_t = _multi_dim_transpose(qx) + sx_t = sx + return qx, sx, qx_t, sx_t diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py new file mode 100644 index 0000000000..ed7cdda85b --- /dev/null +++ b/tests/pytorch/test_cpu_offloading.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +from contextlib import nullcontext + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +# Check if FP8 supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + +SIZE = 512 + +models = { + "linear": te.Linear, + "layernorm_mlp": te.LayerNormMLP, + "layernorm_linear": te.LayerNormLinear, +} + + +def _get_input(): + return torch.empty((128, SIZE, SIZE)).cuda() + + +def _measure_memory_between_forward_and_backward(model_cls, fp8, cpu_offload): + + input_layer = model_cls(SIZE, SIZE) + hidden_layer = model_cls(SIZE, SIZE) + output_layer = model_cls(SIZE, SIZE) + + input = _get_input() + if cpu_offload: + offload_context, sync_function = te.get_cpu_offload_context( + enabled=True, + num_layers=2, + model_layers=3, + offload_activations=True, + offload_weights=False, + ) + else: + offload_context = nullcontext() + sync_function = lambda x: x + + with te.fp8_autocast(enabled=fp8), offload_context: + out = input_layer(input) + out = sync_function(out) + with te.fp8_autocast(enabled=fp8), offload_context: + out = hidden_layer(out) + out = sync_function(out) + with te.fp8_autocast(enabled=fp8), offload_context: + out = output_layer(out) + out = sync_function(out) + + max_mem_used = torch.cuda.memory_allocated() / 1024**2 + + out.sum().backward() + + del input_layer + del hidden_layer + del output_layer + del input + del out + + torch.cuda.synchronize() + + return max_mem_used + + +@pytest.mark.parametrize("fp8", [True, False]) +@pytest.mark.parametrize("model_key", models.keys()) +def test_cpu_offload(fp8, model_key) -> None: + + if fp8 and not fp8_available: + pytest.skip(reason_for_no_fp8) + + model_cls = models[model_key] + + without_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, False) + + with_offloading = _measure_memory_between_forward_and_backward(model_cls, fp8, True) + + assert with_offloading < without_offloading diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 010050baea..dcdfa771c8 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -22,10 +22,12 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Record initial RNG state. @@ -49,6 +51,11 @@ class ModelConfig: model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} +fp8_recipes = [ + recipe.DelayedScaling(), + recipe.MXFP8BlockScaling(), +] + # Supported data types dtypes: List[torch.dtype] = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher @@ -152,6 +159,7 @@ def _test_cuda_graphs( fp8: bool, fp8_params: bool, fp8_weight_caching: bool, + fp8_recipe: recipe.Recipe, ) -> List[torch.Tensor]: """Helper function for CUDA graph test.""" reset_rng_states() @@ -162,7 +170,7 @@ def _test_cuda_graphs( fp8_weight_caching = False # Create modules. - with fp8_model_init(enabled=fp8_params): + with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe): if module == "transformer": modules = [ TransformerLayer( @@ -244,6 +252,7 @@ def _test_cuda_graphs( num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) elif graph_mode == "individual": # Graph individual modules. @@ -254,6 +263,7 @@ def _test_cuda_graphs( num_warmup_iters=10, fp8_enabled=fp8, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) for module in modules ] @@ -270,7 +280,7 @@ def _test_cuda_graphs( for grad_accumulation_step in range(2): input_ = generate_data(model_config, dtype) grad_output = generate_data(model_config, dtype, requires_grad=False) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 @@ -285,6 +295,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True)) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def test_make_graphed_callables( *, module: str, @@ -293,6 +304,7 @@ def test_make_graphed_callables( dtype: torch.dtype, fp8: bool, fp8_params: bool, + fp8_recipe: recipe.Recipe, fp8_weight_caching: bool = False, ) -> None: @@ -303,6 +315,8 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) # Run model with different CUDA graph settings. model_config = model_configs[model_config] @@ -314,6 +328,7 @@ def test_make_graphed_callables( fp8=fp8, fp8_params=fp8_params, fp8_weight_caching=fp8_weight_caching, + fp8_recipe=fp8_recipe, ) outputs = _test_cuda_graphs(graph_mode="none", **kwargs) graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) @@ -339,16 +354,19 @@ def test_make_graphed_callables( _test_make_graphed_callables_with_fp8_weight_caching_modules, ) @pytest.mark.parametrize("fp8_params", (False, True)) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, fp8_params: bool, + fp8_recipe: recipe.Recipe, ) -> None: test_make_graphed_callables( module=module, dtype=torch.float32, fp8=True, fp8_params=fp8_params, + fp8_recipe=fp8_recipe, fp8_weight_caching=True, ) diff --git a/tests/pytorch/test_deferred_init.py b/tests/pytorch/test_deferred_init.py index 0469a01c5f..7d6d523622 100644 --- a/tests/pytorch/test_deferred_init.py +++ b/tests/pytorch/test_deferred_init.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py new file mode 100644 index 0000000000..9741b1258c --- /dev/null +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -0,0 +1,802 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pathlib +import os +import torch +import pytest + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import Float8CurrentScaling +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype + + +# read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory +TENSOR_DUMP_DIR = pathlib.Path(__file__).resolve().parent.parent.parent / "tensor_dumps" +tensor_dump_dir_env = os.getenv("NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") +if tensor_dump_dir_env is not None: + TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) + + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +class GetRecipes: + + @staticmethod + def none(): + return None + + @staticmethod + def fp8_per_tensor_current_scaling_default(): + # return default configs + return Float8CurrentScaling() + + +# base class for validating current_scaling x linear layer +class TestFP8RecipeLinearBase: + @staticmethod + def _prepare_data( + batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32 + ): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda") + bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None + gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda") + + return x, w, bias, gradient + + @staticmethod + def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad)) + return out + + @staticmethod + def _gather_tensor(local, world_size, tp_group, concat_dim): + out_list = [torch.zeros_like(local) for _ in range(world_size)] + torch.distributed.all_gather(out_list, local, tp_group) + return torch.cat(out_list, dim=concat_dim) + + @staticmethod + def _all_reduce_tensor(local, world_size, tp_group): + if world_size == 1: + return local + handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False) + return local + + @staticmethod + def _get_sum_abs_error(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _get_mean_abs_relative_error(a, b): + return torch.mean(torch.abs((a - b) / b)) + + @staticmethod + def _load_golden_tensor_values(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _check_golden_tensor_dumps(dump_dir, get_recipe, dims, input_dtype, use_bias): + recipe = get_recipe() + batch_size, hidden_size, out_size = dims + fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) + + # Expected tensor names based on the naming template + scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example + "ScalingType.PER_TENSOR" + ) + current_seed = torch.initial_seed() # Get the current seed + + expected_tensor_names = { + "y": f"y_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "dgrad": f"dgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "wgrad": f"wgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "bgrad": f"bgrad_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + } + + if not use_bias: + expected_tensor_names.pop("bgrad") + + # Check if all expected tensors are in the tensor dumps directory + tensor_map = {} + for tensor_key, tensor_name in expected_tensor_names.items(): + tensor_path = dump_dir / tensor_name + if not os.path.exists(tensor_path): + print(f"Missing tensor: {tensor_name}") + return None + + # Load the tensor + tensor_map[tensor_key] = torch.load(tensor_path) + return tensor_map + + @classmethod + def run_linear_preprocess_parallel( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_size=1, + rank=0, + ): + if tp_size > 1: + if parallel_mode == "column": + # split w in N dim, which should be axis 0 + w = cls._shard_tensor(w, tp_size, 0)[rank] + bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None + # split gradient in N dim, which should be axis 1 + gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + if sequence_parallel: + # split x in M dim, which should be axis 0 + x = cls._shard_tensor(x, tp_size, 0)[rank] + # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1 + if parallel_mode == "row": + # split x in K dim, which should be axis 1 + x = cls._shard_tensor(x, tp_size, 1)[rank] + # split w in K dim, which should be axis 1 + w = cls._shard_tensor(w, tp_size, 1)[rank] + if sequence_parallel: + # split gradient in M dim, which should be axis 0 + gradient = cls._shard_tensor(gradient, tp_size, 0)[rank] + return x, w, bias, gradient + + @classmethod + def run_linear_postprocess_parallel( + cls, + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ): + if tp_size > 1: + if parallel_mode == "column": + # gather y_q in N dim, which should be axis 1 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1) + # gather wgrad in N dim, which should be axis 0 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0) + # gather bgrad in N dim, which should be axis 0 + bgrad = ( + cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None + ) + if sequence_parallel: + # gather dgrad in M dim, which should be axis 0 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0) + if parallel_mode == "row": + # gather dgrad in K dim, which should be axis 1 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1) + # gather wgrad in K dim, which should be axis 1 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1) + if sequence_parallel: + # gather y_q in M dim, which should be axis 0 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0) + # we need to sum bias gradient when using TP + SP + bgrad = ( + cls._all_reduce_tensor(bgrad, tp_size, tp_group) + if bgrad is not None + else None + ) + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_one_step( + cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False + ): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + if isinstance(layer, te.Linear): + # Kitchen Linear + y_q = layer.forward(x, is_first_microbatch=is_first_microbatch) + else: + # the default torch.nn.Linear + y_q = layer(x) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + bgrad = ( + layer._parameters["bias"].grad + if layer._parameters.get("bias", None) is not None + else None + ) + assert "weight" in layer._parameters + if fuse_wgrad_accumulation: + wgrad = layer._parameters["weight"].main_grad + assert layer._parameters["weight"].grad is None + else: + wgrad = layer._parameters["weight"].grad + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation=False, + ): + """ + Run multiple steps of linear layer and collect results. + """ + + y_q_list, dgrad_list, wgrad_list = [], [], [] + bgrad_list = [] if layer._parameters.get("bias", None) is not None else None + + for i in range(run_num_steps): + x_i = (x + i).clone().detach().requires_grad_(True) + # run_linear_one_step + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step( + layer, + x_i, + gradient, + is_first_microbatch=(i == 0) if enable_weight_cache else None, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + # Collect results + y_q_list.append(y_q.detach().clone()) + dgrad_list.append(dgrad.detach().clone()) + wgrad_list.append(wgrad.detach().clone()) + if bgrad_list is not None and bgrad is not None: + bgrad_list.append(bgrad.detach().clone()) + + @classmethod + def run_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + fuse_wgrad_accumulation=False, + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = te.Linear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + layer = layer.to("cuda") + + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + if fuse_wgrad_accumulation: + assert ( + run_num_steps > 1 + ), "Fused weight gradient accumulation requires run_num_steps > 1" + layer.weight.main_grad = torch.zeros_like(layer.weight) + + # Run one step or multiple steps + if run_num_steps == 1: + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + else: + y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps( + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation, + ) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, dgrad, wgrad, bgrad + + def compare_recipe( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed, + dtype, + y_error=0.0, + dgrad_error=0.0, + wgrad_error=0.0, + bgrad_error=0.0, + recipe1_golden_tensors=None, + recipe2_golden_tensors=None, + ): + x, w, bias, gradient = self._prepare_data( + batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype + ) + + # recipe1 + using_fp8_recipe = recipe1 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) + else: + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) + + # recipe2 + using_fp8_recipe = recipe2 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) + else: + y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) + + # Compare results (mean abs relative error) + assert ( + self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error + ), "y and y_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error + ), "dgrad and dgrad_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error + ), "wgrad and wgrad_ref has too large mean abs relative error" + if use_bias: + assert ( + self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error + ), "bgrad and bgrad_ref has too large mean abs relative error" + + # enforce zero tolerance check when we can find golden tensor value dump + if recipe2_golden_tensors is not None: + torch.testing.assert_close( + y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0 + ) + torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0) + torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0) + if use_bias: + torch.testing.assert_close( + bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0 + ) + + +class TestFP8RecipeLayerNormLinearBase(TestFP8RecipeLinearBase): + + @staticmethod + def _check_golden_tensor_dumps( + dump_dir, get_recipe, dims, input_dtype, use_bias, normalization + ): + recipe = get_recipe() + batch_size, hidden_size, out_size = dims + fp8_type_x = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_w = get_fp8_torch_dtype(recipe, fprop_tensor=True) + fp8_type_g = get_fp8_torch_dtype(recipe, fprop_tensor=False) + + # Expected tensor names based on the naming template + scaling_type = ( # Assuming the scaling type is PER_TENSOR for this example + "ScalingType.PER_TENSOR" + ) + current_seed = torch.initial_seed() # Get the current seed + + expected_tensor_names = { + "y": f"y_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "ln_out": f"ln_out_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "dgrad": f"dgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "wgrad": f"wgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + "bgrad": f"bgrad_{normalization}_{scaling_type}_{batch_size}_{hidden_size}_{out_size}_{current_seed}_{input_dtype}_{fp8_type_x}_{fp8_type_w}_{fp8_type_g}.pt", + } + + if not use_bias: + expected_tensor_names.pop("bgrad") + + # Check if all expected tensors are in the tensor dumps directory + tensor_map = {} + for tensor_key, tensor_name in expected_tensor_names.items(): + tensor_path = dump_dir / tensor_name + if not os.path.exists(tensor_path): + print(f"Missing tensor: {tensor_name}") + return None + + # Load the tensor + tensor_map[tensor_key] = torch.load(tensor_path) + return tensor_map + + @classmethod + def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + + parameters = layer._parameters + + # bias and weight gradients + bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None + assert "weight" in parameters + wgrad = parameters["weight"].grad + + return y_q, ln_out, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False + ): + # raise error, no test case for multiple steps for now + raise NotImplementedError("LayerNormLinear does not support test multiple steps for now") + + @classmethod + def run_layernorm_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + LayerNormLinearClass=te.LayerNormLinear, + normalization="LayerNorm", + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = LayerNormLinearClass( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + normalization=normalization, + return_layernorm_output=True, + ) + + layer = layer.to("cuda") + + # Copy weights + # kitchen_linear has different parameter names + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + # Run one step + y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, ln_out, dgrad, wgrad, bgrad + + def compare_recipe( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed, + dtype, + y_error=0.0, + ln_out_error=0.0, + dgrad_error=0.0, + wgrad_error=0.0, + bgrad_error=0.0, + normalization="LayerNorm", + LayerNormLinearClass1=te.LayerNormLinear, + LayerNormLinearClass2=te.LayerNormLinear, + recipe1_golden_tensors=None, + recipe2_golden_tensors=None, + ): + x, w, bias, gradient = self._prepare_data( + batch_size, hidden_size, out_size, use_bias, seed=seed, dtype=dtype + ) + + # recipe1 + using_fp8_recipe = recipe1 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass1, + ) + else: + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass1, + ) + + # recipe2 + using_fp8_recipe = recipe2 != GetRecipes.none + if using_fp8_recipe: + with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass2, + ) + else: + y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( + x, + w, + bias, + gradient, + normalization=normalization, + LayerNormLinearClass=LayerNormLinearClass2, + ) + + # Compare results (mean abs relative error) + assert ( + self._get_mean_abs_relative_error(y_q, y_q_ref).item() < y_error + ), "y and y_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(ln_out, ln_out_ref).item() < ln_out_error + ), "ln_out and ln_out_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(dgrad, dgrad_ref) < dgrad_error + ), "dgrad and dgrad_ref has too large mean abs relative error" + assert ( + self._get_mean_abs_relative_error(wgrad, wgrad_ref).item() < wgrad_error + ), "wgrad and wgrad_ref has too large mean abs relative error" + if use_bias: + assert ( + self._get_mean_abs_relative_error(bgrad, bgrad_ref).item() < bgrad_error + ), "bgrad and bgrad_ref has too large mean abs relative error" + + # enforce zero tolerance check when we can find golden tensor value dump + if recipe2_golden_tensors is not None: + torch.testing.assert_close( + y_q.float(), recipe2_golden_tensors["y"].float(), atol=0, rtol=0.0 + ) + torch.testing.assert_close(ln_out, recipe2_golden_tensors["ln_out"], atol=0.0, rtol=0.0) + torch.testing.assert_close(dgrad, recipe2_golden_tensors["dgrad"], atol=0.0, rtol=0.0) + torch.testing.assert_close(wgrad, recipe2_golden_tensors["wgrad"], atol=0.0, rtol=0.0) + if use_bias: + torch.testing.assert_close( + bgrad, recipe2_golden_tensors["bgrad"], atol=0.0, rtol=0.0 + ) + + +# FP8 per tesnor current scaling +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFP8CurrentScalingRecipeLinear(TestFP8RecipeLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default), + ], + ) + def test_fp8_current_scaling_with_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, recipe2, (batch_size, hidden_size, out_size), dtype, use_bias + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFP8CurrentScalingRecipeLayerNormLinear(TestFP8RecipeLayerNormLinearBase): + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (16, 256, 128), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) + @pytest.mark.parametrize( + "recipe1, recipe2", + [ + (GetRecipes.none, GetRecipes.fp8_per_tensor_current_scaling_default), + ], + ) + def test_fp8_current_scaling_with_layernorm_linear_module( + self, + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + dtype, + use_bias=True, + ): + fp8_zero_tolerance_tensor_dumps_recipe2 = None + # check tensor dumps dir, if the dir exists, then read files to get y, dgrad, wgrad, bgrad + # if we cannot get all four tensors, then still set the tensor dump to None + tensor_map = self._check_golden_tensor_dumps( + TENSOR_DUMP_DIR, + recipe2, + (batch_size, hidden_size, out_size), + dtype, + use_bias, + "LayerNorm", + ) + if tensor_map is not None: + fp8_zero_tolerance_tensor_dumps_recipe2 = tensor_map + + self.compare_recipe( + recipe1, + recipe2, + batch_size, + hidden_size, + out_size, + use_bias, + seed=torch.initial_seed(), + dtype=dtype, + y_error=0.5, + ln_out_error=0.5, + dgrad_error=1, + wgrad_error=1, + bgrad_error=0.5, + recipe1_golden_tensors=None, + recipe2_golden_tensors=fp8_zero_tolerance_tensor_dumps_recipe2, + ) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 51f4c695dc..42600e3099 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,10 +11,18 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8Tensor, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch +from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported import transformer_engine_torch as tex +from references.ref_per_tensor_cs import ref_per_tensor_cs_cast + # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] # TE FP8 dtypes @@ -42,6 +50,44 @@ def _to_list(x: Union[Iterable, Any]) -> List: fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# delayed scaling +def to_float8( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 1.0, +) -> Float8Tensor: + """Cast tensor to FP8""" + quantizer = Float8Quantizer( + scale=torch.full([1], scale, dtype=torch.float32, device="cuda"), + amax=torch.empty([1], dtype=torch.float32, device="cuda"), + fp8_dtype=fp8_dtype, + ) + return quantizer(tensor.cuda()) + + +# current scaling +def to_float8_CS( + tensor: torch.Tensor, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + return_transpose: bool = False, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, +) -> Float8Tensor: + """Cast tensor to FP8""" + tensor = tensor.cuda() + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=tensor.device, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + if return_transpose: + quantizer.set_usage(rowwise=True, columnwise=True) + else: + quantizer.set_usage(rowwise=True, columnwise=False) + return quantizer(tensor) + + @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFloat8Tensor: @@ -62,10 +108,11 @@ def test_constructor( """Call constructor and perform sanity checks""" dims = _to_list(dims) tensor = Float8Tensor( + shape=dims, + dtype=dtype, data=torch.zeros(dims, device="cuda", dtype=torch.uint8), fp8_dtype=fp8_dtype, fp8_scale_inv=torch.full([1], scale_inv), - dtype=dtype, ) assert list(tensor.size()) == dims, "Incorrect dims" assert tensor.dtype == dtype, "Incorrect nominal dtype" @@ -84,11 +131,7 @@ def _test_quantize_dequantize( x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1 # Cast to FP8 and back - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) x_fp8 = x_fp8.dequantize().cpu() # Check results @@ -115,62 +158,6 @@ def test_quantize_dequantize_scales(self, scale: float) -> None: def test_quantize_dequantize_dims(self, dims: DimsType) -> None: self._test_quantize_dequantize(dims=dims) - def test_fp8_meta( - self, - dtype: torch.dtype = torch.float32, - dims: DimsType = 23, - ) -> None: - """Construct Float8Tensor using FP8 metadata and perform basic checks""" - - # Get FP8 metadata from linear module - fp8_dtype = tex.DType.kFloat8E4M3 - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): - module = te.Linear(32, 32) - _ = module(torch.zeros([8, 32], device="cuda")) - fp8_meta = module.fp8_meta - fp8_meta_index = tex.FP8FwdTensors.GEMM1_WEIGHT - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - - # Initialize random data - dims = _to_list(dims) - x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - - # Make Float8Tensor - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_meta=fp8_meta, - fp8_meta_index=fp8_meta_index, - ) - x_ref = x_fp8.dequantize() - assert list(x_fp8.size()) == dims, "Incorrect dims" - assert x_fp8.dtype == dtype, "Incorrect nominal dtype" - assert x_fp8.is_cuda, "Incorrect device" - assert x_fp8._fp8_dtype == fp8_dtype, "Incorrect FP8 dtype" - - # Change FP8 metadata scale - fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 2 - fp8_meta[fp8_meta_key].scale_inv.fill_(123) - - # Check results - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - with pytest.raises(AssertionError): - # Make sure we are not trivially passing the test - torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype]) - - # Check if scaling factor is updated after in-place ops - x_fp8 += 0 - fp8_meta[fp8_meta_key].scale[fp8_meta_index] = 4 - fp8_meta[fp8_meta_key].scale_inv.fill_(321) - assert x_fp8._scale_inv.item() == 0.5, "Incorrect FP8 scale_inv" - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - y = x_fp8.detach() - y += 0 - assert x_fp8._scale_inv.item() == 0.25, "Incorrect FP8 scale_inv" - torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype]) - def test_basic_ops( self, dims: DimsType = 23, @@ -184,16 +171,8 @@ def test_basic_ops( dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - y_fp8 = Float8Tensor.to_float8( - y_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() y_ref = y_fp8.dequantize() @@ -214,6 +193,36 @@ def test_basic_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) + @pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]]) + def test_chunk_op( + self, + dims: DimsType, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test for ops for which shape of inputs and outputs differ.""" + + # Initialize random data + dims = _to_list(dims) + x_ref = torch.randn(dims, dtype=dtype, device="cpu") + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0) + + # Get chunks. + chunk1, chunk2 = x_fp8.chunk(2, dim=0) + + # Test chunks. + torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0) + torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0) + + # Check shapes. + assert ( + chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk1" + assert ( + chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk2" + def test_inplace_ops( self, dims: DimsType = 23, @@ -227,16 +236,8 @@ def test_inplace_ops( dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - y_fp8 = Float8Tensor.to_float8( - y_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) + y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() y_ref = y_fp8.dequantize() @@ -260,56 +261,6 @@ def test_inplace_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8, x_ref, **tols) - @pytest.mark.parametrize("dims", [[33, 41], [7, 11]]) - def test_transpose( - self, - dims: DimsType, - fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, - scale: float = 0.5, - dtype: torch.dtype = torch.float32, - ) -> None: - """Test transpose""" - - # Initialize random data - dims = _to_list(dims) - x = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) - x = x_fp8.dequantize() - - # Perform transpose - x_fp8_t = x_fp8.transpose_2d() - x_t = x.transpose(0, 1) - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8_t) - - # Check results - tols = dict(rtol=0, atol=0) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - - # Make sure we are not trivially passing the test - with pytest.raises(AssertionError): - torch.testing.assert_close(x_fp8_t, x, **tols) - - # Caching test - assert x_fp8._transpose_invalid, "Transpose cache must be invalid when not caching." - x_fp8 += 0.5 - x = x_fp8.dequantize() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8.transpose_2d(fill_cache=True)) - x_t = x.transpose(0, 1) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - - # Inplace update test - x_fp8 += 0.5 - assert not x_fp8._transpose_invalid, "Transpose cache reset incorrectly." - x = x_fp8.dequantize() - x_fp8_t = Float8Tensor.make_like(x_fp8, data=x_fp8._transpose) - x_t = x.transpose(0, 1) - torch.testing.assert_close(x_fp8_t, x_t, **tols) - def test_serialization( self, dims: DimsType = [2, 3, 5], @@ -321,11 +272,7 @@ def test_serialization( # Initialize random data dims = _to_list(dims) x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 - x_fp8 = Float8Tensor.to_float8( - x_ref, - fp8_dtype=fp8_dtype, - scale=torch.full([1], scale), - ) + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale) x_ref = x_fp8.dequantize() # Serialize tensor @@ -339,7 +286,7 @@ def test_serialization( del x_fp8, byte_stream # Deserialize tensor - x_fp8 = torch.load(io.BytesIO(x_bytes)) + x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False) del x_bytes # Check results @@ -357,7 +304,7 @@ def test_set_data(self): # Initialize Float8Tensor x0 = torch.zeros(4, dtype=torch.float32) - x = Float8Tensor.to_float8(x0) + x = to_float8(x0) assert isinstance(x, Float8Tensor) assert x0.size() == x.size() == x._data.size() assert x.dtype == torch.float32 @@ -382,7 +329,7 @@ def test_set_data(self): assert x.device == y.device # Set data to Float8Tensor - x0 = Float8Tensor.to_float8(torch.zeros((4, 3, 1), dtype=torch.float32)) + x0 = to_float8(torch.zeros((4, 3, 1), dtype=torch.float32)) x.data = x0 assert isinstance(x, Float8Tensor) assert x0.size() == x.size() == x._data.size() @@ -395,3 +342,89 @@ def test_set_data(self): assert x.size() == y.size() assert x.dtype == y.dtype assert x.device == y.device + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestCurrentScalingFloat8Tensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize( + "dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3], [128, 128], [611, 782]] + ) + @pytest.mark.parametrize("return_transpose", [True, False], ids=str) + @pytest.mark.parametrize("force_pow_2_scales", [True, False], ids=str) + @pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str) + def test_quantize( + self, + fp8_dtype: tex.DType, + dtype: torch.dtype, + dims: DimsType, + return_transpose: bool, + force_pow_2_scales: bool, + amax_epsilon: float, + ) -> None: + """Check numerical error when casting to FP8""" + + # Skip invalid configurations + if non_tn_fp8_gemm_supported() and return_transpose: + pytest.skip("FP8 transpose is neither needed nor supported on current system") + + # Initialize random high precision data + device = "cuda" + x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1 + + # Cast to FP8 and back + x_fp8 = to_float8_CS( + x_hp, + fp8_dtype=fp8_dtype, + return_transpose=return_transpose, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + + # get reference implementation of current scaling + x_fp8_ref, sx_ref, x_fp8_t_ref, _ = ref_per_tensor_cs_cast( + x_hp, + fp8_dtype=fp8_dtype, + return_transpose=return_transpose, + force_pow_2_scales=force_pow_2_scales, + amax_epsilon=amax_epsilon, + ) + + torch.testing.assert_close(x_fp8._data, x_fp8_ref.view(torch.uint8), atol=0.0, rtol=0.0) + torch.testing.assert_close(x_fp8._scale_inv, sx_ref, atol=0.0, rtol=0.0) + if return_transpose: + torch.testing.assert_close( + x_fp8._transpose, x_fp8_t_ref.view(torch.uint8), atol=0.0, rtol=0.0 + ) + + @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str) + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str) + @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]]) + def test_quantize_dequantize( + self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType + ) -> None: + """Check numerical error when casting to FP8 and back""" + + # Initialize random high precision data + device = "cuda" + x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1 + + # Cast to FP8 and back + x_fp8 = to_float8_CS(x_hp, fp8_dtype=fp8_dtype) + x_fp8_dequantized = x_fp8.dequantize() + + # Check results + torch.testing.assert_close(x_fp8_dequantized, x_hp, **_tols[fp8_dtype]) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype]) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 4d4eb38342..507fd3f350 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,6 +11,7 @@ from torch import nn from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te +from transformer_engine.common.recipe import DelayedScaling from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch import fp8_model_init from transformer_engine.pytorch.utils import is_bf16_compatible @@ -184,6 +185,7 @@ def gen_precision_aware_test( grad_dtype, exp_avg_dtype, exp_avg_sq_dtype, + store_param_remainders=False, model_rtol=None, model_atol=None, master_rtol=None, @@ -220,6 +222,7 @@ def gen_precision_aware_test( "weight_decay": 0, "amsgrad": False, } + ref_optim = torch.optim.Adam(ref_params, **options) tst_optim = te.optimizers.FusedAdam( model_params, @@ -228,6 +231,7 @@ def gen_precision_aware_test( exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype, use_decoupled_grad=True, + store_param_remainders=store_param_remainders, **options, ) @@ -237,7 +241,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer): p.decoupled_grad = p_ref.grad.clone().to(grad_dtype) ref_optimizer.step() tst_optimizer.step() - if use_master_weights: + if use_master_weights and not store_param_remainders: master_weights_to_fp32 = [ tst_optim.get_unscaled_state(p, "master_param") for p in model_params ] @@ -270,6 +274,7 @@ def test_one_iteration(ref_optimizer, tst_optimizer): exp_avg_dtype=exp_avg_dtype, exp_avg_sq_dtype=exp_avg_sq_dtype, use_decoupled_grad=True, + store_param_remainders=store_param_remainders, **options, ) tst_optim.load_state_dict(state_dict) @@ -300,6 +305,19 @@ def test_fp32_master(self): exp_avg_sq_dtype=torch.float32, ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + def test_fp32_master_store_param_remainders(self): + self.gen_precision_aware_test( + use_fp8_params=False, + param_dtype=torch.bfloat16, + use_master_weights=True, + master_weight_dtype=torch.float32, + grad_dtype=torch.float32, + exp_avg_dtype=torch.float32, + exp_avg_sq_dtype=torch.float32, + store_param_remainders=True, + ) + @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") def test_fp16_master(self): self.gen_precision_aware_test( @@ -429,7 +447,7 @@ def test_bf16_model_weight_cast(self): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_model_weight_cast(self): dtype = torch.bfloat16 - with fp8_model_init(enabled=True): + with fp8_model_init(enabled=True, recipe=DelayedScaling()): model = MultiheadAttention( hidden_size=1024, num_attention_heads=16, diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 81c4973756..e236a29a9d 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,11 +1,11 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import math import pytest import torch from typing import Callable, Tuple, Union -from transformer_engine.pytorch.attention import ( +from transformer_engine.pytorch.dot_product_attention.rope import ( RotaryPositionEmbedding, apply_rotary_pos_emb, ) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 29829ac4ac..97d48e2aa3 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1,10 +1,12 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. from __future__ import annotations +from collections.abc import Iterable import math +from typing import Optional import pytest import torch @@ -12,7 +14,6 @@ import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops._common import is_float8_tensor @@ -21,11 +22,14 @@ ForwardLinearBiasActivation, ForwardLinearBiasAdd, ) +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] @@ -36,6 +40,38 @@ _devices: list[torch.device] = [torch.device("cpu"), torch.device("cuda")] +def maybe_skip_quantization( + quantization: Optional[str], + *, + dims: Optional[Iterable[int] | int] = None, + device: Optional[torch.device | str] = None, +) -> None: + + # Don't skip if there is no quantization + if quantization is None: + return + + # Check if quantization scheme is supported + if quantization == "fp8" and not fp8_available: + pytest.skip(reason_for_no_fp8) + if quantization == "mxfp8" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + + if dims is not None: + if not isinstance(dims, Iterable): + dims = (dims,) + if quantization == "fp8": + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("FP8 GEMMs require dims that are divisible by 16") + elif quantization == "mxfp8": + if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: + pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + + # Check if device is supported + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") + + def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: """Estimated numerical error for a datatype @@ -89,7 +125,12 @@ def make_reference_and_test_tensors( ref = torch.rand(shape, dtype=ref_dtype, device=ref_device) test = ref.to(device=test_device, dtype=test_dtype) if test_is_fp8: - test = Float8Tensor.to_float8(test, with_transpose_cache=True) + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device=test_device).squeeze(), + amax=torch.zeros(1, dtype=torch.float32, device=test_device), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + test = quantizer(test) elif test.data_ptr() == ref.data_ptr(): test = test.clone() ref.copy_(test) @@ -98,6 +139,21 @@ def make_reference_and_test_tensors( return ref, test +def make_recipe(name: Optional[str] = None) -> Optional[Recipe]: + """Make recipe for quantization scheme""" + if name is None: + return None + if name == "fp8": + return transformer_engine.common.recipe.DelayedScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + if name == "mxfp8": + return transformer_engine.common.recipe.MXFP8BlockScaling( + fp8_format=transformer_engine.common.recipe.Format.E4M3, + ) + raise ValueError(f"Unsupported quantization scheme ({name})") + + class TestSequential: """Tests for sequential container""" @@ -239,7 +295,7 @@ def test_fp8_scale_update( ) # Construct model - with te.fp8_model_init(): + with te.fp8_model_init(recipe=recipe): model = te_ops.basic.BasicLinear( size, size, @@ -293,41 +349,36 @@ def test_fp8_scale_update( ) # Check that scaling factors match expected - w_amax_ref = max(w_vals[: step + 2]) + w_amax_ref = max(w_vals[: step + 1]) x_amax_ref = max(x_vals[: step + 1]) dy_amax_ref = max(dy_vals[: step + 1]) w_scale_ref = (fp8_format.value.max_fwd / w_amax_ref) / (2**margin) x_scale_ref = (fp8_format.value.max_fwd / x_amax_ref) / (2**margin) dy_scale_ref = (fp8_format.value.max_bwd / dy_amax_ref) / (2**margin) - forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - w_scale = model.get_fp8_meta("param")[forward_key].scale - x_scale = model.get_fp8_meta("input")[forward_key].scale - dy_scale = model.get_fp8_meta("grad_output")[backward_key].scale + w_scale = model.get_quantizer("forward", 1).scale + x_scale = model.get_quantizer("forward", 0).scale + dy_scale = model.get_quantizer("backward", 0).scale torch.testing.assert_close(w_scale, torch.full_like(w_scale, w_scale_ref)) torch.testing.assert_close(x_scale, torch.full_like(x_scale, x_scale_ref)) torch.testing.assert_close(dy_scale, torch.full_like(dy_scale, dy_scale_ref)) @pytest.mark.parametrize("init_dtype", _dtypes) @pytest.mark.parametrize("final_dtype", _dtypes) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_dtype_cast( self, *, - size: int = 16, + size: int = 32, init_dtype: torch.dtype, final_dtype: torch.dtype, device: torch.device = "cuda", - fp8_weight: bool, + quantization: Optional[str], ) -> None: """Check dtype cast functions""" # Skip invalid configurations - if fp8_weight: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, device=device) + with_quantization = quantization is not None # Random data dtype = torch.float32 @@ -339,11 +390,11 @@ def test_dtype_cast( (size, size), test_dtype=dtype, test_device=device, - test_is_fp8=fp8_weight, + test_is_fp8=with_quantization, ) # Construct operation - with te.fp8_model_init(enabled=fp8_weight): + with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) with torch.no_grad(): op.weight.copy_(w_test) @@ -358,7 +409,7 @@ def test_dtype_cast( op.bfloat16() # Check weights - assert isinstance(op.weight, Float8Tensor) == fp8_weight + assert isinstance(op.weight, QuantizedTensor) == with_quantization assert op.weight.dtype == final_dtype w_test = op.weight.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(w_test, w_ref, rtol=0, atol=0) @@ -378,29 +429,27 @@ def test_dtype_cast( @pytest.mark.parametrize("model_dtype", _dtypes) @pytest.mark.parametrize("autocast_dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_pyt_autocast( self, *, - size: int = 16, + size: int = 32, model_dtype: torch.dtype, autocast_dtype: torch.dtype, device: torch.device = "cuda", - fp8_weight: bool = False, - fp8_compute: bool, + quantization: Optional[str], + quantized_weights: bool = False, ) -> None: """Test with PyTorch autocast""" device = torch.device(device) # Skip invalid configurations - if fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization) # Construct operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weights, recipe=recipe): op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) # Check forward and backward pass @@ -410,7 +459,7 @@ def test_pyt_autocast( device=device, requires_grad=True, ) - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): with torch.autocast(device_type=device.type, dtype=autocast_dtype): y = op(x) y.backward(torch.zeros_like(y)) @@ -419,11 +468,11 @@ def test_pyt_autocast( assert op.weight.grad.dtype == model_dtype # Check forward and backward pass (swapped context order) - if fp8_compute: + if quantized_compute: x.grad = None op.weight.grad = None with torch.autocast(device_type=device.type, dtype=autocast_dtype): - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y = op(x) y.backward(torch.zeros_like(y)) assert y.dtype == autocast_dtype @@ -505,19 +554,14 @@ def test_identity( ), ) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("device", ("cuda", "cpu")) - @pytest.mark.parametrize( - "memory_format", - (torch.contiguous_format, torch.channels_last), - ) @pytest.mark.parametrize("fp8", (False, True)) def test_reshape( self, *, shapes: tuple[Iterable[int], Iterable[int]], dtype: torch.dtype, - device: torch.device, - memory_format: torch.memory_format, + device: torch.device = "cuda", + memory_format: torch.memory_format = torch.contiguous_format, fp8: bool, ) -> None: in_shape, out_shape = shapes @@ -634,19 +678,23 @@ def test_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) + @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) @pytest.mark.parametrize("cast_forward", (False, True)) @pytest.mark.parametrize("cast_backward", (False, True)) - def test_cast_float8( + def test_quantize( self, *, - in_shape: Iterable[int] = (1,), + in_shape: Iterable[int] = (32, 32), dtype: torch.dtype = torch.bfloat16, device: torch.device = "cuda", + quantization: str, cast_forward: bool, cast_backward: bool, ) -> None: - """FP8 cast""" + """Quantize""" + + # Skip invalid configurations + maybe_skip_quantization(quantization) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -656,7 +704,7 @@ def test_cast_float8( requires_grad=False, test_is_fp8=True, ) - x_test = x_test.from_float8().requires_grad_() + x_test = x_test.dequantize().requires_grad_() dy_ref, dy_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, @@ -664,7 +712,7 @@ def test_cast_float8( requires_grad=False, test_is_fp8=True, ) - dy_test = dy_test.from_float8() + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = x_ref @@ -672,16 +720,14 @@ def test_cast_float8( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) + recipe = make_recipe(quantization) with te.fp8_autocast(fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) # Check tensor types - assert is_float8_tensor(y_test) == cast_forward - assert is_float8_tensor(x_test.grad) == cast_backward + assert isinstance(y_test, QuantizedTensor) == cast_forward + assert isinstance(x_test.grad, QuantizedTensor) == cast_backward # Check values tols = dict(rtol=0, atol=0) @@ -697,12 +743,13 @@ def _test_basic_linear( in_shape: Iterable[int] = (32, -1), dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool = False, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, - fp8_grad_output: bool = False, - fp8_grad_input: bool = False, + quantization: Optional[str] = None, + quantized_compute: bool = False, + quantized_input: bool = False, + quantized_weight: bool = False, + quantized_output: bool = False, + quantized_grad_output: bool = False, + quantized_grad_input: bool = False, accumulate_into_main_grad: bool = False, ) -> None: """Helper function for tests with GEMM""" @@ -713,50 +760,50 @@ def _test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_compute or fp8_input or fp8_weight or fp8_output or fp8_grad_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantization == "fp8" and quantized_output and not quantized_compute: pytest.skip("FP8 output is only supported with FP8 GEMMs") - if fp8_grad_input and not fp8_compute: + if quantization == "fp8" and quantized_grad_input and not quantized_compute: pytest.skip("FP8 grad input is only supported with FP8 GEMMs") + if quantization == "mxfp8" and quantized_output: + pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") + if quantization == "mxfp8" and quantized_grad_input: + pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=(quantized_compute or quantized_input), ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_grad_output), + test_is_fp8=(quantized_compute or quantized_grad_output), requires_grad=False, ) + if isinstance(dy_test, QuantizedTensor): + dy_test = dy_test.dequantize() # Plain PyTorch implementation y_ref = torch.nn.functional.linear(x_ref, w_ref) y_ref.backward(dy_ref) # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -769,14 +816,11 @@ def _test_basic_linear( del w_test op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) forward = te_ops.Sequential( - te_ops.Quantize(forward=fp8_input, backward=fp8_grad_input), + te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input), op, - te_ops.Quantize(forward=fp8_output, backward=fp8_grad_output), + te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output), ) - recipe = transformer_engine.common.recipe.DelayedScaling( - fp8_format=transformer_engine.common.recipe.Format.E4M3, - ) - with te.fp8_autocast(enabled=fp8_compute, fp8_recipe=recipe): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -784,10 +828,8 @@ def _test_basic_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute or fp8_output or fp8_grad_input: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute or quantized_output or quantized_grad_input: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -813,10 +855,10 @@ def _test_basic_linear( ) torch.testing.assert_close(dw_test, w_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((48, 16), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (2, 2, 4, -1))) + @pytest.mark.parametrize("weight_shape", ((64, 32), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (5, 1, -1), (4, 2, 4, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) def test_basic_linear( self, @@ -824,7 +866,7 @@ def test_basic_linear( weight_shape: tuple[int, int], in_shape: Iterable[int], dtype: torch.dtype, - fp8_compute: bool, + quantization: Optional[str], accumulate_into_main_grad: bool, ) -> None: """GEMM""" @@ -832,52 +874,55 @@ def test_basic_linear( weight_shape=weight_shape, in_shape=in_shape, dtype=dtype, - fp8_compute=fp8_compute, + quantization=quantization, + quantized_compute=quantization is not None, accumulate_into_main_grad=accumulate_into_main_grad, ) @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_output", (False, True)) - @pytest.mark.parametrize("fp8_grad_input", (False, True)) - def test_basic_linear_fp8( + @pytest.mark.parametrize("quantization", ("fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_compute", (False, True)) + @pytest.mark.parametrize("quantized_input", (False, True)) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("quantized_output", (False, True)) + @pytest.mark.parametrize("quantized_grad_output", (False, True)) + @pytest.mark.parametrize("quantized_grad_input", (False, True)) + def test_basic_linear_quantized( self, *, - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, - fp8_output: bool, - fp8_grad_output: bool, - fp8_grad_input: bool, + quantization: str, + quantized_compute: bool, + quantized_input: bool, + quantized_weight: bool, + quantized_output: bool, + quantized_grad_output: bool, + quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" self._test_basic_linear( dtype=torch.bfloat16, - fp8_compute=fp8_compute, - fp8_input=fp8_input, - fp8_weight=fp8_weight, - fp8_output=fp8_output, - fp8_grad_output=fp8_grad_output, - fp8_grad_input=fp8_grad_input, + quantization=quantization, + quantized_compute=quantized_compute, + quantized_input=quantized_input, + quantized_weight=quantized_weight, + quantized_output=quantized_output, + quantized_grad_output=quantized_grad_output, + quantized_grad_input=quantized_grad_input, ) @pytest.mark.parametrize("bias", (False, True)) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) def test_linear( self, *, bias: bool, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype = torch.float32, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool, + quantization: Optional[str], + quantized_weight: bool, ) -> None: """GEMM + bias""" @@ -887,31 +932,25 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -932,7 +971,8 @@ def test_linear( y_ref.backward(dy_ref) # Implementation with fusible operation - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.Linear( in_features, out_features, @@ -946,7 +986,7 @@ def test_linear( op.bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -954,10 +994,8 @@ def test_linear( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - op.weight._fp8_dtype if is_float8_tensor(op.weight) else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -970,12 +1008,11 @@ def test_linear( db_test = op.bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) - @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) - @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("weight_shape", ((7, 2), (32,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_layer_norm( self, *, @@ -985,8 +1022,7 @@ def test_layer_norm( device: torch.device = "cuda", eps: float = 0.3, zero_centered_gamma: bool, - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Layer norm""" @@ -994,18 +1030,13 @@ def test_layer_norm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, ) w_ref, w_test = make_reference_and_test_tensors( weight_shape, @@ -1047,17 +1078,19 @@ def test_layer_norm( op.bias.copy_(b_test) del w_test del b_test + quantized_compute = quantization is not None + recipe = make_recipe(quantization) forward = te_ops.Sequential( op, - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1145,12 +1178,11 @@ def test_layer_norm_autocast( torch.testing.assert_close(dw_test, w_ref.grad, **dtype_tols(dtype)) torch.testing.assert_close(db_test, b_ref.grad, **dtype_tols(dtype)) - @pytest.mark.parametrize("weight_shape", ((19,), (16, 4))) - @pytest.mark.parametrize("in_shape", ((-1,), (6, 8, -1))) + @pytest.mark.parametrize("weight_shape", ((19,), (64,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("zero_centered_gamma", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_output", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_rmsnorm( self, *, @@ -1160,8 +1192,7 @@ def test_rmsnorm( device: torch.device = "cuda", eps: float = 0.3, zero_centered_gamma: bool, - fp8_input: bool, - fp8_output: bool, + quantization: Optional[str], ) -> None: """Layer norm""" @@ -1169,18 +1200,13 @@ def test_rmsnorm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - if fp8_input or fp8_output: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") + maybe_skip_quantization(quantization, dims=in_shape, device=device) # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_input, ) w_ref, w_test = make_reference_and_test_tensors( weight_shape, @@ -1214,17 +1240,19 @@ def test_rmsnorm( with torch.no_grad(): op.weight.copy_(w_test) del w_test + quantized_compute = quantization is not None + recipe = make_recipe(quantization) forward = te_ops.Sequential( op, - te_ops.Quantize(forward=fp8_output, backward=False), + te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=fp8_output): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if fp8_output: + if quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1362,6 +1390,161 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + def test_activation( + self, + *, + activation: str, + out_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + ) -> None: + """Activation functions""" + + # Tensor dimensions + in_shape = list(out_shape) + if activation in ("geglu", "reglu", "swiglu"): + in_shape[-1] *= 2 + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + test_is_fp8=quantized_compute, + ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref: torch.Tensor + if activation == "gelu": + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) + elif activation == "geglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "reglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "swiglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + else: + raise ValueError(f"Unexpected activation function ({activation})") + y_ref.backward(dy_ref) + + # Implementation with fusible operation + recipe = make_recipe(quantization) + make_op = dict( + gelu=te_ops.GELU, + relu=te_ops.ReLU, + geglu=te_ops.GEGLU, + reglu=te_ops.ReGLU, + swiglu=te_ops.SwiGLU, + )[activation] + forward = te_ops.Sequential( + make_op(), + te_ops.Quantize(forward=quantized_compute, backward=False), + ) + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + if activation == "relu": + tols = {"atol": 0, "rtol": 0} + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantize_forward", (False, True)) + @pytest.mark.parametrize("quantize_backward", (False, True)) + def test_swiglu( + self, + *, + out_shape: Iterable[int] = (32, 32), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantize_forward: bool, + quantize_backward: bool, + ): + + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and (quantize_forward or quantize_backward): + pytest.skip("Quantization scheme has not been provided") + maybe_skip_quantization(quantization, dims=in_shape, device=device) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.silu(x1) * x2 + y_ref.backward(dy_ref) + + # Implementation with fusible operation + recipe = make_recipe(quantization) + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantize_backward), + te_ops.SwiGLU(), + te_ops.Quantize(forward=quantize_forward, backward=False), + ) + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = forward(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + class TestFusedOps: """Tests for fused operations""" @@ -1373,12 +1556,11 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) - @pytest.mark.parametrize("weight_shape", ((32, 48), (3, 5))) - @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (4, 2, 10, -1))) + @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) + @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) - @pytest.mark.parametrize("fp8_input", (False, True)) - @pytest.mark.parametrize("fp8_weight", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) + @pytest.mark.parametrize("quantized_weight", (False, True)) def test_forward_linear_bias_activation( self, *, @@ -1387,9 +1569,8 @@ def test_forward_linear_bias_activation( in_shape: Iterable[int], dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool, - fp8_weight: bool, + quantization: Optional[str], + quantized_weight: bool, ) -> None: """Forward GEMM + bias + activation""" @@ -1399,18 +1580,9 @@ def test_forward_linear_bias_activation( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) if dtype not in (torch.float16, torch.bfloat16): pytest.skip( "FP8 fused linear-bias-activation is only supported with FP16 or BF16 output" @@ -1421,13 +1593,16 @@ def test_forward_linear_bias_activation( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if quantized_compute: + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1448,7 +1623,8 @@ def test_forward_linear_bias_activation( y_ref.backward(dy_ref) # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -1464,7 +1640,7 @@ def test_forward_linear_bias_activation( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -1477,12 +1653,8 @@ def test_forward_linear_bias_activation( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1497,19 +1669,17 @@ def test_forward_linear_bias_activation( @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_forward_linear_bias_add( self, *, bias: bool, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, + quantization: Optional[str], + quantized_weight: bool = False, ) -> None: """Forward GEMM + bias + add""" @@ -1519,21 +1689,10 @@ def test_forward_linear_bias_add( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_output or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: - pytest.skip("FP8 output requires FP8 compute") - if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") # Random data @@ -1541,13 +1700,16 @@ def test_forward_linear_bias_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x1_test, QuantizedTensor): + with torch.no_grad(): + x1_test = x1_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) b_ref, b_test = None, None if bias: @@ -1560,7 +1722,6 @@ def test_forward_linear_bias_add( out_shape, test_dtype=dtype, test_device=device, - test_is_fp8=fp8_output, ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1574,7 +1735,8 @@ def test_forward_linear_bias_add( y_ref.backward(dy_ref) # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -1591,7 +1753,7 @@ def test_forward_linear_bias_add( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y_test = model(x1_test, x2_test) y_test.backward(dy_test) @@ -1604,12 +1766,8 @@ def test_forward_linear_bias_add( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[0].weight._fp8_dtype - if is_float8_tensor(model[0].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1625,18 +1783,16 @@ def test_forward_linear_bias_add( torch.testing.assert_close(db_test, b_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) - @pytest.mark.parametrize("fp8_compute", (False, True)) + @pytest.mark.parametrize("quantization", (None, "fp8", "mxfp8")) def test_backward_linear_add( self, *, - weight_shape: tuple[int, int] = (16, 16), - in_shape: Iterable[int] = (16, -1), + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), dtype: torch.dtype, device: torch.device = "cuda", - fp8_compute: bool, - fp8_input: bool = False, - fp8_weight: bool = False, - fp8_output: bool = False, + quantization: Optional[str], + quantized_weight: bool = False, ) -> None: """Backward dgrad GEMM + add""" @@ -1646,21 +1802,10 @@ def test_backward_linear_add( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - if fp8_input or fp8_weight or fp8_output or fp8_compute: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - if fp8_compute: - if ( - math.prod(in_shape[:-1]) % 16 != 0 - or in_features % 16 != 0 - or out_features % 16 != 0 - ): - pytest.skip("FP8 GEMMs require dims that are divisible by 16") - if fp8_output and not fp8_compute: - pytest.skip("FP8 output requires FP8 compute") - if fp8_compute and dtype not in (torch.float16, torch.bfloat16): + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") # Random data @@ -1668,13 +1813,16 @@ def test_backward_linear_add( in_shape, test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_input), + test_is_fp8=quantized_compute, ) + if isinstance(x_test, QuantizedTensor): + with torch.no_grad(): + x_test = x_test.dequantize().requires_grad_() w_ref, w_test = make_reference_and_test_tensors( (out_features, in_features), test_dtype=dtype, test_device=device, - test_is_fp8=(fp8_compute or fp8_weight), + test_is_fp8=(quantized_compute or quantized_weight), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -1695,7 +1843,8 @@ def test_backward_linear_add( (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() # Implementation with fusible operations - with te.fp8_model_init(enabled=fp8_weight): + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight): model = te_ops.Sequential( te_ops.MakeExtraOutput(), te_ops.Linear( @@ -1709,7 +1858,7 @@ def test_backward_linear_add( with torch.no_grad(): model[1].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=fp8_compute): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): y1_test, y2_test = model(x_test) (y1_test * dy1_test + y2_test * dy2_test).sum().backward() @@ -1722,12 +1871,8 @@ def test_backward_linear_add( tols = dtype_tols(dtype) if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM - if fp8_compute: - tols = dtype_tols( - model[1].weight._fp8_dtype - if is_float8_tensor(model[1].weight) - else tex.DType.kFloat8E4M3 - ) + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results y1_test = y1_test.to(dtype=torch.float64, device="cpu") diff --git a/tests/pytorch/test_gqa.py b/tests/pytorch/test_gqa.py index 9f9098891f..3ef4806182 100644 --- a/tests/pytorch/test_gqa.py +++ b/tests/pytorch/test_gqa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_jit.py b/tests/pytorch/test_jit.py index 7d69e03712..ec62fba9d9 100644 --- a/tests/pytorch/test_jit.py +++ b/tests/pytorch/test_jit.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 216b200e09..ecc06c3ace 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index c237dbaeb6..b364b01140 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1,7 +1,8 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +from collections import OrderedDict import math import os from typing import Dict, List, Optional @@ -13,7 +14,11 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager, fp8_model_init +from transformer_engine.pytorch.fp8 import ( + FP8GlobalStateManager, + fp8_autocast, + fp8_model_init, +) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, @@ -30,18 +35,21 @@ RMSNorm, TransformerLayer, LayerNorm, - InferenceParams, Fp8Padding, Fp8Unpadding, ) +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, fp8_grouped_gemm, gemm, grouped_gemm +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.common import recipe import transformer_engine_torch as tex -# Only run FP8 tests on H100. +# Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -52,6 +60,8 @@ _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() +torch._dynamo.config.recompile_limit = 16 + class ModelConfig: def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): @@ -70,9 +80,9 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq model_configs_inference = { # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 16), + "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256), } -backends_inference = ["FlashAttention", "UnfusedAttention"] +backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] input_formats_inference = ["sbhd", "bshd"] @@ -90,6 +100,12 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq mask_types = ["causal", "no_mask"] +fp8_recipes = [ + recipe.MXFP8BlockScaling(), + recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), +] + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -450,7 +466,8 @@ def __init__( self.fc2 = nn.Linear(ffn_hidden_size, hidden_size) def forward(self, x): - return self.fc2(self.gelu(self.fc1(self.ln(x)))) + t = self.gelu(self.fc1(self.ln(x))) + return self.fc2(t) class TorchGPT(nn.Module): @@ -480,7 +497,9 @@ def forward( return x -def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False): +def _test_e2e_selective_recompute( + bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False +): reset_rng_states() FP8GlobalStateManager.reset() @@ -488,7 +507,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -515,7 +534,7 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, @@ -536,18 +555,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) -def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_params): +def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] outputs = _test_e2e_selective_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=False + bs, dtype, config, fp8, recipe, fp8_model_params, recompute=False ) outputs_recompute = _test_e2e_selective_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=True + bs, dtype, config, fp8, recipe, fp8_model_params, recompute=True ) # Check that results match @@ -556,6 +578,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par tols["atol"] = 1e-4 if fp8 or fp8_model_params: tols.update(dict(rtol=0.125, atol=0.0675)) + for i, (ref, test) in enumerate(zip(outputs, outputs_recompute)): torch.testing.assert_close( test, @@ -566,7 +589,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par def _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True + bs, dtype, config, fp8, recipe, fp8_model_params=False, recompute=False, use_reentrant=True ): reset_rng_states() FP8GlobalStateManager.reset() @@ -575,7 +598,7 @@ def _test_e2e_full_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -603,7 +626,7 @@ def _test_e2e_full_recompute( te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if recompute: te_out = te_checkpoint( block, @@ -641,11 +664,18 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) -def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, use_reentrant): +def test_gpt_full_activation_recompute( + dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant +): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for full recompute.") config = model_configs[model] @@ -654,10 +684,24 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params, os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0" outputs, names = _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant + bs, + dtype, + config, + fp8, + recipe, + fp8_model_params, + recompute=False, + use_reentrant=use_reentrant, ) outputs_recompute, _ = _test_e2e_full_recompute( - bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant + bs, + dtype, + config, + fp8, + recipe, + fp8_model_params, + recompute=True, + use_reentrant=use_reentrant, ) if not use_reentrant: @@ -741,7 +785,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= del block block = _test_e2e_checkpointing_get_model(config, dtype) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) reset_rng_states() for p in block.parameters(): @@ -1267,9 +1311,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere torch.half: 2e-3, torch.bfloat16: 2e-2, } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } # Check output. - assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype]) if model == "small": atol = { @@ -1335,8 +1384,14 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): torch.bfloat16: 5e-2, } + rtol = { + torch.float32: 1e-3, + torch.half: 4e-2, + torch.bfloat16: 4e-2, + } + # Check output. - assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype]) + assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype], rtol[dtype]) # Check gradients, only for small model rtol = { @@ -1351,7 +1406,9 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization): assert_allclose(te_output, torch_output, atol[dtype], rtol[dtype]) -def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): +def _test_grouped_linear_accuracy( + block, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation +): reset_rng_states() if fp8: FP8GlobalStateManager.reset() @@ -1365,16 +1422,22 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False inp_hidden_states.retain_grad() if num_gemms > 1: - m = config.seq_len // 16 + split_size = 1 + if fp8: + if recipe.delayed(): + split_size = 16 + if recipe.mxfp8(): + split_size = 128 + m = config.seq_len // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - m_splits = m_splits * 16 + m_splits = m_splits * split_size assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms else: m_splits = torch.tensor([config.seq_len]) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, GroupedLinear): m_splits = m_splits * bs out = block(inp_hidden_states, m_splits.tolist()) @@ -1392,7 +1455,11 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False outputs = [out, inp_hidden_states.grad] for p in block.parameters(): if p.requires_grad: - outputs.append(p.grad) + if getattr(p, "main_grad", None) is not None: + outputs.append(p.main_grad) + assert p.grad is None # grad should be None if fuse_wgrad_accumulation is True + else: + outputs.append(p.grad) return outputs @@ -1401,18 +1468,34 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None + dtype, + num_gemms, + bs, + model, + fp8, + recipe, + fp8_model_params, + fuse_wgrad_accumulation, + parallel_mode=None, ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches + pytest.skip("MXFP8 unsupported for grouped linear.") + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -1421,6 +1504,7 @@ def test_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() sequential_linear = torch.nn.ModuleList( [ @@ -1431,6 +1515,7 @@ def test_grouped_linear_accuracy( params_dtype=dtype, parallel_mode=parallel_mode, device="cuda", + fuse_wgrad_accumulation=fuse_wgrad_accumulation, ).eval() for _ in range(num_gemms) ] @@ -1441,10 +1526,16 @@ def test_grouped_linear_accuracy( for i in range(num_gemms): sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) + if fuse_wgrad_accumulation: + weight_i = getattr(grouped_linear, f"weight{i}") + weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) + sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() - outputs = _test_grouped_linear_accuracy(grouped_linear, num_gemms, bs, dtype, config, fp8) outputs_ref = _test_grouped_linear_accuracy( - sequential_linear, num_gemms, bs, dtype, config, fp8 + sequential_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation + ) + outputs = _test_grouped_linear_accuracy( + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8, fuse_wgrad_accumulation ) # Shoule be bit-wise match @@ -1453,7 +1544,8 @@ def test_grouped_linear_accuracy( @pytest.mark.parametrize("parallel_mode", ["column", "row"]) -def test_grouped_linear_accuracy_parallel_mode(parallel_mode): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_grouped_linear_accuracy_parallel_mode(parallel_mode, recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, @@ -1461,12 +1553,15 @@ def test_grouped_linear_accuracy_parallel_mode(parallel_mode): bs=2, model="126m", fp8=True, + recipe=recipe, fp8_model_params=True, parallel_mode=parallel_mode, + fuse_wgrad_accumulation=True, ) -def test_grouped_linear_accuracy_single_gemm(): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( dtype=torch.float32, @@ -1474,11 +1569,13 @@ def test_grouped_linear_accuracy_single_gemm(): bs=2, model="126m", fp8=True, + recipe=recipe, fp8_model_params=True, + fuse_wgrad_accumulation=True, ) -def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, fp8=False): +def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, fp8=False): def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): """Padding tensor shapes to multiples of 16.""" @@ -1546,7 +1643,7 @@ def _generate_random_numbers(n, total_sum): m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) - with fp8_autocast(enabled=fp8): + with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): out = block(inp_hidden_states, m_splits) else: @@ -1575,18 +1672,25 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) +@pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( - dtype, num_gemms, bs, model, fp8, fp8_model_params, parallel_mode=None + dtype, num_gemms, bs, model, fp8, recipe, fp8_model_params, parallel_mode=None ): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches + pytest.skip("MXFP8 unsupported for grouped linear.") + if fp8 and recipe.float8_current_scaling(): + pytest.skip("Float8 Current Scaling unsupported for grouped linear.") config = model_configs[model] if config.seq_len % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, @@ -1597,7 +1701,7 @@ def test_padding_grouped_linear_accuracy( fp8=fp8, ).eval() - with fp8_model_init(enabled=fp8 and fp8_model_params): + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): ref_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -1619,10 +1723,10 @@ def test_padding_grouped_linear_accuracy( ) outputs = _test_padding_grouped_linear_accuracy( - grouped_linear, num_gemms, bs, dtype, config, fp8 + grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) outputs_ref = _test_padding_grouped_linear_accuracy( - ref_grouped_linear, num_gemms, bs, dtype, config, fp8 + ref_grouped_linear, num_gemms, bs, dtype, config, recipe, fp8 ) # Shoule be bit-wise match @@ -1734,7 +1838,7 @@ def test_gpt_cuda_graph(dtype, bs, model): assert_allclose(grads, graphed_grads, 1e-3) -def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): +def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): reset_rng_states() FP8GlobalStateManager.reset() @@ -1742,7 +1846,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_model_params): + with fp8_model_init(enabled=fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -1769,7 +1873,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.seq_len) - with fp8_autocast(enabled=True): + with fp8_autocast(enabled=True, fp8_recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -1785,14 +1889,17 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -def test_gpt_fp8_parameters(dtype, bs, model): +@pytest.mark.parametrize("recipe", fp8_recipes) +def test_gpt_fp8_parameters(dtype, bs, model, recipe): if not fp8_available: pytest.skip(reason_for_no_fp8) + if recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) config = model_configs[model] - outputs = _test_gpt_fp8_parameters(bs, dtype, config, False) - outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True) + outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe) + outputs_fp8_params = _test_gpt_fp8_parameters(bs, dtype, config, True, recipe) # Check that results match tols = dict(rtol=0.125, atol=0.0675) @@ -1935,14 +2042,25 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): @pytest.mark.parametrize("input_format", input_formats_inference) @pytest.mark.parametrize("module", module_inference) @pytest.mark.parametrize("backend", backends_inference) -def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend): +@pytest.mark.parametrize("is_paged", [False, True]) +def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged): + reset_rng_states() + + if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: + pytest.skip("FusedAttention and FlashAttention do not support FP32") + if use_RoPE: + pytest.skip("KV cache does not support starting positions for RoPE") + os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" if backend == "FlashAttention": os.environ["NVTE_FLASH_ATTN"] = "1" elif backend == "FusedAttention": os.environ["NVTE_FUSED_ATTN"] = "1" + elif backend == "UnfusedAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" config = model_configs_inference[model_key] @@ -1955,7 +2073,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, # Limits the max size of KV-cache B_max = B - S_max = S + 2 + S_max = S if module == "TransformerLayer": model = TransformerLayer( @@ -1985,7 +2103,17 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, .eval() ) - inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max) + inference_params = InferenceParams( + max_batch_size=B_max, + max_seqlen_kv=S_max, + num_heads_kv=H, + head_dim_k=head_size, + dtype=dtype, + is_paged=is_paged, + total_num_pages=int(B_max * S_max / 256), + page_size=256, + ) + rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") input = torch.randn((S, B, D), dtype=dtype, device="cuda") @@ -1998,22 +2126,39 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) # Incrementaly generate outputs using KV-cache + step_dict = OrderedDict(zip(list(range(B)), [1] * B)) for i in range(S): + inference_params.pre_step(step_dict) + if input_format == "sbhd": incremental_input = input[i].view(1, B, D) else: incremental_input = input[:, i, :].view(B, 1, D) + seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") + cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_kv = cu_seqlens_q.clone() + + mask_type = "padding" + kwargs = {} + if module == "TransformerLayer": + kwargs["self_attn_mask_type"] = mask_type + else: + kwargs["attn_mask_type"] = mask_type line_output = model( hidden_states=incremental_input, inference_params=inference_params, rotary_pos_emb=rotary_freqs if use_RoPE else None, + **kwargs, + max_seqlen_q=1, + max_seqlen_kv=S, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) - inference_params.sequence_len_offset += 1 - if input_format == "sbhd": - incremental_output[i] = line_output.view(B, D) + incremental_output[i, :, :] = line_output.view(B, D) else: incremental_output[:, i, :] = line_output.view(B, D) @@ -2057,42 +2202,55 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): if layout == "TN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - out = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # output + B = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + out = [torch.randn(m, n, dtype=dtype, device="cuda")] # output + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = False + single_output = True elif layout == "NN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight - B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output - out = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # dgrad + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output + out = [torch.randn(m, k, dtype=dtype, device="cuda")] # dgrad + out_ref = [o.clone() for o in torch.split(out[0], m_splits)] grad = True + single_output = True else: # layout == "NT" - A = torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits) # input - B = torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) # grad_output + A = list(torch.split(torch.randn(m, k, dtype=dtype, device="cuda"), m_splits)) # input + B = list( + torch.split(torch.randn(m, n, dtype=dtype, device="cuda"), m_splits) + ) # grad_output out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + out_ref = [o.clone() for o in out] grad = True + single_output = False - out_ref = [o.clone() for o in out] for i in range(z): - gemm( + general_gemm( A[i], B[i], - dtype, get_workspace(), + dtype, grad=grad, accumulate=accumulate, layout=layout, out=out_ref[i], ) + if single_output: + out_ref = [torch.cat(out_ref)] - grouped_gemm( + general_grouped_gemm( A, B, out, dtype, get_multi_stream_cublas_workspace(), + m_splits=m_splits, grad=grad, accumulate=accumulate, layout=layout, + single_output=single_output, ) # should be bit-wise match @@ -2115,7 +2273,7 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): pytest.skip(reason_for_no_fp8) z, m, k, n = shape - m_splits = m // z + m_splits = [m // z] * z dtype = torch.bfloat16 A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight @@ -2124,64 +2282,52 @@ def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): out_ref = [o.clone() for o in out] # fp8 should be robust enough to this fake scale - scale = 1 + torch.rand(z * 3, dtype=torch.float32, device="cuda") - scale_inv = 1 / scale - amax = torch.zeros(1024, z * 3, dtype=torch.float32, device="cuda") + scale = 1 + torch.rand(1, dtype=torch.float32, device="cuda").squeeze() + amax = torch.zeros(1, 1, dtype=torch.float32, device="cuda") - A_fp8 = [ - torch.ops.tex_ts.cast_to_fp8_ts( - A[i], - scale, - amax, - scale_inv, - i, # fp8 meta tensor index + a_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), tex.DType.kFloat8E4M3, ) - for i in range(z) + for _ in range(z) ] - B_fp8 = [ - torch.ops.tex_ts.cast_to_fp8_ts( - B[i], - scale, - amax, - scale_inv, - z + i, # fp8 meta tensor index - fp8_dtype, + b_quantizers = [ + Float8Quantizer( + scale.clone(), + amax.clone(), + tex.DType.kFloat8E4M3, ) - for i in range(z) + for _ in range(z) ] - fp8_grouped_gemm( - A_fp8, - [scale_inv], - 0, # A_offset - tex.DType.kFloat8E4M3, - B_fp8, - scale_inv, - z, # B_offset - fp8_dtype, - out, - dtype, - get_multi_stream_cublas_workspace(), - accumulate=accumulate, - ) + A_fp8 = [] + B_fp8 = [] + + for i in range(z): + A_fp8.append(a_quantizers[i](A[i])) + B_fp8.append(b_quantizers[i](B[i])) # baseline for i in range(z): - fp8_gemm( + general_gemm( A_fp8[i], - scale_inv, - i, - tex.DType.kFloat8E4M3, B_fp8[i], - scale_inv, - z + i, - fp8_dtype, - dtype, get_workspace(), + dtype, out=out_ref[i], accumulate=accumulate, ) + general_grouped_gemm( + A_fp8, + B_fp8, + out, + dtype, + get_multi_stream_cublas_workspace(), + m_splits=m_splits, + accumulate=accumulate, + ) # should be bit-wise match for o, o_ref in zip(out, out_ref): diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py deleted file mode 100644 index 6a463b556a..0000000000 --- a/tests/pytorch/test_onnx_export.py +++ /dev/null @@ -1,1562 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -This file contains tests for exporting TransformerEngine models to ONNX. - -The purpose of these tests is validation that TE models are converted to their correct ONNX -representation. Toward this end, each test captures the output of a TE module forward pass, -converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and -validate the output against TE's output. - -Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented -using custom ORT operations. - -To run many repetitive tests use pytest-loop: - $ python3 -m pip install pytest-loop - $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm - -For reproducability use: torch.manual_seed(0) -""" - -import os -import tempfile -import pytest -import warnings -import numpy as np -import onnxruntime as ort -import torch -from torch import nn as nn -from typing import Optional, Union, Tuple, List -import transformer_engine.pytorch as te -from transformer_engine.common import recipe -import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import ( - gemm, - fp8_gemm, - gelu, - cast_to_fp8, - cast_from_fp8, -) -from transformer_engine.pytorch.module.base import get_workspace -import transformer_engine.pytorch.cpp_extensions as texcpp -import transformer_engine.pytorch.softmax as softmax_defs -from transformer_engine.pytorch.utils import get_default_init_method -from transformer_engine.pytorch.export import is_in_onnx_export_mode -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager - -# Global test configuration knobs. - -# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance). -SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0"))) - -if SAVE_TEST_IO: - from polygraphy.json import save_json - from polygraphy.comparator import RunResults - -# The directory where generated ONNX test models are stored. -NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR") -NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join( - tempfile.gettempdir(), "./gen_onnx_models" -) - - -# The directory where this file is stored. -TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) - -# ScaledUpperTriangMaskedSoftmax is exported via ONNX::Trilu which was introduced in opset 14. -TRILU_OPSET = 14 -# Opset used in the ONNX files generated by the tests. -OPSET = 17 -assert OPSET >= TRILU_OPSET - -# Shared library implementing custom FP8 Q/DQ operators for ONNX Runtime (ORT). -ORT_CUSTOM_OPS_LIB = os.path.join(TESTS_DIR, "custom_ort_ops", "libcustom_ort_ops.so") - -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) - -supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] - -all_normalizations = ["LayerNorm", "RMSNorm"] - - -@pytest.fixture() -def seed_default_rng(): - """Reseed the PRNG for test reproducibility""" - torch.manual_seed(1234) - - -@pytest.fixture() -def set_max_seq_len(max_seq_len=128): - """Set the maximum sequence length that can be used for attention masking""" - os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}" - - -@pytest.fixture(autouse=True) -def reset_global_fp8_state(): - yield - FP8GlobalStateManager.reset() - - -def create_fp8_recipe(): - return recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3) - - -def do_export( - model: torch.nn.Module, - inp: torch.Tensor, - fname: str, - use_fp8: bool = True, - opset: int = OPSET, - input_names: List[str] = None, - output_names: List[str] = None, - dynamic_axes: List[str] = None, -): - """Export to ONNX""" - fp8_recipe = create_fp8_recipe() - input_names = input_names or ["input"] - output_names = output_names or ["output"] - - with torch.inference_mode(), te.fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe - ), warnings.catch_warnings(): - warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") - - model.cuda().eval() - os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True) - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - - inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,) - assert len(inps) == len(input_names) - inds_to_del = [i for i in range(len(inps)) if inps[i] is None] - input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del] - - with te.onnx_export(True): - torch.onnx.export( - model, - inps, - fname, - verbose=True, - dynamic_axes=dynamic_axes, - opset_version=opset, - input_names=input_names, - output_names=output_names, - do_constant_folding=True, - operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH, - ) - - -def to_numpy(tensor): - if isinstance(tensor, torch.Tensor): - if tensor.dtype == torch.bfloat16: - tensor = tensor.type(torch.float32) - tensor = tensor.detach().cpu().numpy() - return tensor - - -def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int): - """Initialize the FP8 quantization scales in module""" - NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors. - nb_total_scales = num_gemms * NB_SCALES_PER_GEMM - module.init_fp8_metadata(num_gemms) - module.fp8_meta["scaling_fwd"].scale = ( - torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") / scale - ) - module.fp8_meta["scaling_fwd"].scale_inv = ( - torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") * scale - ) - - -def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool): - """Transformer Engine forward propagation.""" - fp8_recipe = create_fp8_recipe() - with torch.inference_mode(), te.fp8_autocast( - enabled=is_fp8, fp8_recipe=fp8_recipe - ), warnings.catch_warnings(): - te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) - if not isinstance(te_outputs, tuple): - te_outputs = (te_outputs,) - return te_outputs - - -def compare_outputs( - onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname -): - """Compare ORT and TE outputs.""" - assert len(onnx_outputs) == len(te_outputs) - # Compare ORT and PyTorch outputs. - for onnx_output, te_output in zip(onnx_outputs, te_outputs): - # np.isclose: abs(a - b) <= (atol + rtol * abs(b)) - te_output = to_numpy(te_output) - onnx_output = to_numpy(onnx_output) - ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol) - mismatches = ac.nonzero() - mismatched_ids = [loc for loc in zip(*mismatches)] - if mismatched_ids: - # Log some information in case of error. - print("*" * 100) - nb_errors = len(mismatched_ids) - nb_vals = min(nb_errors, max_errors_printed) - print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})") - print(f"Showing first {nb_vals} errors (ONNX -- TE):") - abs_err = np.abs(onnx_output - te_output) - errors = abs_err[mismatches] - for loc in mismatched_ids[:nb_vals]: - ref = te_output[loc] - print( - f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >" - f" {atol + rtol * abs(ref)}" - ) - print(f"Max error: {np.max(errors)}") - if nb_errors > allow_cnt_errors: - raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors") - - -def serialize_inputs_outputs( - fname: str, - inputs: Union[Tuple[torch.Tensor], torch.Tensor], - te_outputs: List[torch.Tensor], - input_names: Optional[List[str]] = None, - output_names: Optional[List[str]] = None, -): - if not SAVE_TEST_IO: - return - - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - - input_names = input_names or ["input"] - output_names = output_names or ["output"] - inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) - named_inputs = zip(input_names, inputs) - input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}] - json_fname = fname[: -len(".onnx")] + "_inputs.json" - save_json(input_data, json_fname, description="custom input data") - - json_fname = fname[: -len(".onnx")] + "_output.json" - named_outputs = zip(output_names, te_outputs) - output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None} - custom_outputs = RunResults() - custom_outputs.add([output_data], runner_name="custom_runner") - custom_outputs.save(json_fname) - - -def validate_result( - fname: str, - inps: Union[Tuple[torch.Tensor], torch.Tensor], - model: torch.nn.Module, - atol: float = 1.0e-8, # np.isclose default atol - rtol: float = 1.0e-5, # np.isclose default rtol - max_errors_printed: int = 10, - is_fp8: bool = False, - allow_cnt_errors: int = 0, - input_names: List[str] = None, - output_names: List[str] = None, - te_outputs: List[torch.Tensor] = None, -): - """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX - representation using ONNX Runtime (ORT) and ensure they are close. - - The purpose of the output comparison is to validate that TE models are converted to - their correct ONNX representation by testing that TE and ORT outputs match within some - small threshold (allowing for finite precision errors). - - Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring, - a very small number (0-3) of outliers. This is fine to do because these outliers are due to - small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX - representation (the tests assume both ORT or TE kernels are correct). - - Argument `te_outputs` can be used to provide pre-computed TE outputs. - """ - - def create_ort_session(fname: str, is_fp8: bool): - def load_custom_ops(session_opts: ort.SessionOptions): - """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension.""" - if not os.path.exists(ORT_CUSTOM_OPS_LIB): - raise FileNotFoundError(f"Unable to find {ORT_CUSTOM_OPS_LIB}") - session_opts.register_custom_ops_library(ORT_CUSTOM_OPS_LIB) - print("registered custom FP8 Q/DQ ops!") - - """Create an ONNX Runtime session for validation.""" - kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]} - if is_fp8: - sess_options = ort.SessionOptions() - load_custom_ops(sess_options) - kwargs["sess_options"] = sess_options - - s = ort.InferenceSession(fname, **kwargs) - return s - - def create_ort_input_dict(session, inputs): - inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,) - input_names = [x.name for x in session.get_inputs()] - inps = [to_numpy(x) for x in inputs if x is not None] - inp_dict = dict(zip(input_names, inps)) - return inp_dict - - input_names = input_names or ["input"] - output_names = output_names or ["output"] - - # Run ORT session and TE model. - fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname) - if not te_outputs: - te_outputs = te_infer(model, inps, is_fp8) - ort_s = create_ort_session(fname, is_fp8) - input_feed = create_ort_input_dict(ort_s, inps) - onnx_outputs = ort_s.run(None, input_feed=input_feed) - compare_outputs( - onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname - ) - - -def create_meta(scale_factor: float, size: int = 1): - meta = tex.FP8TensorMeta() - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor - return meta - - -def dtype2str(dtype: torch.dtype, fake_bf16_io=False): - if fake_bf16_io: - assert dtype == torch.bfloat16 - return "_fake_bf16" - return { - torch.float32: "_fp32", - torch.float16: "_fp16", - torch.bfloat16: "_bf16", - }[dtype] - - -def as_te_type(dtype: torch.dtype): - return { - torch.float32: tex.DType.kFloat32, - torch.float16: tex.DType.kFloat16, - torch.bfloat16: tex.DType.kBFloat16, - }[dtype] - - -def get_attn_mask_str(use_mask, attn_mask_type): - # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names. - if attn_mask_type is None: - return "_mask" if use_mask else "_no-mask" - attn_mask_str = "_arbitrary-no-mask" - attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str - attn_mask_str = ( - "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str - ) - return attn_mask_str - - -class FP8GemmModule(nn.Module): - def __init__(self, precision, use_bias, gelu, scale_factors, hidden_size, out_features): - super().__init__() - self.use_bias = use_bias - self.gelu = gelu - self.precision = precision - - self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT - nb_inp_scales, nb_weight_scales = 1, out_features - act_scale_factor, weight_scale_factor = scale_factors - self.meta_inp = create_meta(act_scale_factor, nb_inp_scales) - self.meta_weight = create_meta(weight_scale_factor, nb_weight_scales) - - bias_size = nb_weight_scales - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") - - self.inp_type = tex.DType.kFloat8E4M3 - self.weights_type = tex.DType.kFloat8E4M3 - self.outp_type = precision - - def forward(self, inp, weight): - inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) - - weight_fp8 = cast_to_fp8( - weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type - ) - - ret, _ = fp8_gemm( - weight_fp8, - self.meta_weight.scale_inv, - self.fp8_tensor_weight, - self.inp_type, - inp_fp8, - self.meta_inp.scale_inv, - self.fp8_tensor_inp, - self.weights_type, - self.outp_type, - get_workspace(), - bias=self.bias, - use_bias=self.use_bias, - use_split_accumulator=False, - ) - return ret - - -""" -Tests cases begin here. -""" - - -@skip_FP8 -@pytest.mark.parametrize("scale_factor", [1, 224]) -@pytest.mark.parametrize( - "precision, atol", - [ - [torch.float32, 1e-7], - [torch.float16, 1e-7], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3], - ], -) -def test_export_cast_ops( - seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - class TestFP8_QDQ(nn.Module): - def __init__(self, fake_bf16_io): - super().__init__() - self.fp8_tensor = 0 - self.meta = create_meta(scale_factor) - self.highprec_type = as_te_type(precision) - self.fp8_type = tex.DType.kFloat8E4M3 - self.fake_bf16_io = fake_bf16_io - - def forward(self, inp): - ret = cast_to_fp8(inp, self.meta, self.fp8_tensor, self.fp8_type) - - ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - # Set dimensions (these are arbitrary). - in_features = 64 - hidden_size = 256 - inp = torch.randn( - hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision - ) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx" - model = TestFP8_QDQ(fake_bf16_io) - - do_export(model, inp, fname) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs) - - -@skip_FP8 -@pytest.mark.parametrize("scale_factor", [448]) -@pytest.mark.parametrize( - "precision, atol", - [ - [torch.float32, 1e-5], - [torch.float16, 1e-5], - [torch.bfloat16, 5e-3], - ["fake-torch.bfloat16", 5e-3], - ], -) -def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - class TestFP8_Gelu(nn.Module): - def __init__(self, fake_bf16_io): - super().__init__() - self.fp8_tensor = 0 - self.meta = create_meta(scale_factor) - self.highprec_type = as_te_type(precision) - self.fp8_type = tex.DType.kFloat8E4M3 - self.fake_bf16_io = fake_bf16_io - - def forward(self, inp): - ret = gelu(inp, self.meta, self.fp8_tensor, self.fp8_type) - ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type) - if self.fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - # Set dimensions (these are arbitrary). - in_features = 64 - hidden_size = 256 - inp = torch.randn( - hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision - ) - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx" - model = TestFP8_Gelu(fake_bf16_io) - do_export(model, inp, fname) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, - inp, - model, - rtol=0, - atol=atol, - is_fp8=True, - allow_cnt_errors=2, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize( - "scale_factors", - [ - ( - 224, - 224, - ), - ], -) -@pytest.mark.parametrize( - "precision, use_fp8, use_bias, use_gelu", - [ - (torch.float32, False, False, False), - (torch.float16, False, False, False), - (torch.bfloat16, False, False, False), - (torch.float32, False, True, False), - (torch.float16, False, True, False), - (torch.bfloat16, False, True, False), - (torch.float32, False, True, True), - (torch.float16, False, True, True), - (torch.bfloat16, False, True, True), - # For FP8 GEMM GeLU is not used. - (torch.float32, True, False, False), - (torch.float16, True, False, False), - (torch.bfloat16, True, False, False), - # When enabling bias we must use float16 or bfloat16 (because of kernel limitations) - (torch.float16, True, True, False), - (torch.bfloat16, True, True, False), - ], -) -def test_export_gemm( - seed_default_rng, - precision, # Precision of inputs, weights, output and bias - use_fp8, - use_bias, - use_gelu, - scale_factors, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - class Test_GEMM(nn.Module): - def __init__(self, precision, use_bias=False, gelu=False): - super().__init__() - self.use_bias = use_bias - self.gelu = gelu - self.precision = precision - bias_size = out_features - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - self.gelu_input = torch.randn(hidden_size, out_features, dtype=precision, device="cuda") - - def forward(self, inp, weight): - outp_type = self.precision - - # note: due to logic in lines 104:116 and L129 in cpp_extensions.py - # it appears either bias OR gelu can be activated, not both - ret, _, _ = gemm( - weight, - inp, - outp_type, - get_workspace(), - # test bias - bias=self.bias, - use_bias=self.use_bias, - # test gelu - gelu=self.gelu, - gelu_input=self.gelu_input, - grad=False, # only True for backward pass - accumulate=False, - ) - return ret - - # If gelu is applied then bias must be added, as defined by TE kernel. - if use_gelu: - assert use_bias - # Set dimensions (these are arbitrary). - out_features = 128 - hidden_size = 256 - in_features = 64 - inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) - weight = torch.randn(out_features, in_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - gelu_str = "_gelu" if use_gelu else "" - high_prec_str = dtype2str(precision) - fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx" - input_names = ["input", "weight"] - if use_fp8: - model = FP8GemmModule( - precision, use_bias, use_gelu, scale_factors, hidden_size, out_features - ) - do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision != torch.bfloat16: - validate_result( - fname, - (inp, weight), - model, - rtol=1e-2, - atol=2e-2, - is_fp8=True, - input_names=input_names, - te_outputs=te_outputs, - ) - else: - model = Test_GEMM(precision, use_bias, use_gelu) - do_export(model, (inp, weight), fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision != torch.bfloat16: - validate_result( - fname, - (inp, weight), - model, - rtol=1e-2, - atol=2e-2, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize("scale_factor", [448, 112]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize( - "use_fp8, precision, atol", - [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2], - ], -) -def test_export_layernorm( - seed_default_rng, - use_fp8: bool, - scale_factor: float, - precision: torch.dtype, - zero_centered_gamma: bool, - atol: float, -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - inp_shape = [64, 32] - - class Test_Layernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - eps = 1e-6 # An arbitrary small value - dtype = torch.float if fake_bf16_io else precision - self.ln = ( - te.LayerNorm( - inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma - ) - .eval() - .cuda() - ) - - def forward(self, inp): - ret = self.ln(inp) - return ret - - class TestFP8_Layernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.bias = torch.zeros( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - - def forward(self, inp): - ret = texcpp.layernorm_fwd_fp8_inf( - inp, - self.weight, - self.bias, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - ret = cast_from_fp8( - ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) - ) - if fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - model = TestFP8_Layernorm() if use_fp8 else Test_Layernorm() - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" - fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" - do_export(model, inp, fname, use_fp8=use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("scale_factor", [448, 112]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize( - "use_fp8, precision, atol", - [ - [False, torch.float32, 1e-7], - [False, torch.float16, 1e-7], - [False, torch.bfloat16, 1e-7], - [False, "fake-torch.bfloat16", 1e-7], - [True, torch.float32, 1e-7], - [True, torch.float16, 1e-7], - [True, torch.bfloat16, 1e-2], - [True, "fake-torch.bfloat16", 1e-2], - ], -) -def test_export_rmsnorm( - seed_default_rng, - use_fp8: bool, - scale_factor: float, - precision: torch.dtype, - zero_centered_gamma: bool, - atol: float, -): - fake_bf16_io = precision == "fake-torch.bfloat16" - # reset precision to torch.bfloat16 after capturing fake BF16 mode - precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - inp_shape = [64, 32] - - class Test_RMSnorm(nn.Module): - def __init__(self) -> None: - super().__init__() - eps = 1e-6 # An arbitrary small value - dtype = torch.float if fake_bf16_io else precision - self.ln = ( - te.RMSNorm( - inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma - ) - .eval() - .cuda() - ) - - def forward(self, inp): - ret = self.ln(inp) - return ret - - class TestFP8_RMSnorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn( - *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision - ) - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - - def forward(self, inp): - ret = texcpp.rmsnorm_fwd_fp8_inf( - inp, - self.weight, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - ret = cast_from_fp8( - ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision) - ) - if fake_bf16_io: - ret = ret.type(torch.float32) - return ret - - inp = torch.randn(*inp_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision) - model = TestFP8_RMSnorm() if use_fp8 else Test_RMSnorm() - high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io) - fp8_str = f"_fp8-{scale_factor}" if use_fp8 else "" - fname = f"te.layernorm{fp8_str}{high_prec_str}.onnx" - do_export(model, inp, fname, use_fp8=use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if fake_bf16_io or precision != torch.bfloat16: - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("scale_factor", [1]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, False), - (torch.float16, True), - # Todo: cannot configure BF16 when bias is disabled (ORT issue?) - (torch.bfloat16, False), - # Todo: cannot configure BF16 when bias is enabled (ORT issue?) - (torch.bfloat16, True), - ], -) -def test_export_linear( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - precision: torch.dtype, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - - class Test_Linear(nn.Module): - def __init__(self, in_features, out_features, use_bias, return_bias, precision): - super().__init__() - self.linear = te.Linear( - in_features, - out_features, - bias=use_bias, - return_bias=return_bias, - params_dtype=precision, - ) - - def forward(self, inp): - ret = self.linear(inp) - return ret - - inp = torch.randn(hidden_size, in_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" - with te.fp8_autocast(enabled=use_fp8): - model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( - device="cuda" - ) - if use_fp8: - set_layer_scale(model.linear, scale_factor, num_gemms=1) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - - if precision in (torch.bfloat16,): - return - if not use_fp8: - validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) - else: - validate_result(fname, inp, model, atol=1e-3, is_fp8=use_fp8, te_outputs=te_outputs) - - -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_linear( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - - inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" - - with te.fp8_autocast(enabled=use_fp8): - model = te.LayerNormLinear( - hidden_size, - 3 * hidden_size, - bias=use_bias, - return_bias=return_bias, - return_layernorm_output=return_layernorm_output, - params_dtype=precision, - zero_centered_gamma=zero_centered_gamma, - normalization=normalization, - ).to(device="cuda") - if use_fp8: - set_layer_scale(model, scale_factor, num_gemms=1) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16,): - return - if not use_fp8: - validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs) - elif precision != torch.bfloat16: - validate_result(fname, inp, model, atol=1e-6, is_fp8=use_fp8, te_outputs=te_outputs) - - -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("use_fp8", [False, True]) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [False]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_mlp( - seed_default_rng, - scale_factor: float, - use_fp8: bool, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - activation: str, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Set dimensions (these are arbitrary). - in_features = 64 - out_features = 256 - hidden_size = 256 - ffn_hidden_size = 256 - - inp = torch.randn(in_features, out_features, device="cuda", dtype=precision) - fp8_str = "_fp8" if use_fp8 else "" - bias_str = "_bias" if use_bias else "" - high_prec_str = dtype2str(precision) - fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" - with te.fp8_autocast(enabled=use_fp8): - model = te.LayerNormMLP( - hidden_size, - ffn_hidden_size, - bias=use_bias, - return_bias=return_bias, - return_layernorm_output=return_layernorm_output, - params_dtype=precision, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - normalization=normalization, - ).to(device="cuda") - if use_fp8: - set_layer_scale(model, scale_factor, num_gemms=2) - do_export(model, inp, fname, use_fp8) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs) - if precision in (torch.bfloat16,): - return - atol = 1e-6 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) - validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs) - - -@skip_FP8 -@pytest.mark.parametrize( - "precision, use_mask, attn_mask_type", - [ - (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask) - (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask) - (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask) - ], -) -def test_export_core_attention( - seed_default_rng, - set_max_seq_len, - precision: torch.dtype, - use_mask: bool, - attn_mask_type: str, -): - # Set dimensions (these are arbitrary). - seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64) - qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels) - qkv_format = "sbhd" - - query_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - key_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - value_layer = torch.randn(qkv_size, dtype=precision, device="cuda") - input_names = ["query", "key", "value", "attention_mask"] - attention_mask = None - if use_mask: - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (query_layer, key_layer, value_layer, attention_mask) - - mask_str = get_attn_mask_str(use_mask, attn_mask_type) - high_prec_str = dtype2str(precision) - fname = f"te.core_attention{mask_str}{high_prec_str}.onnx" - - model = te.attention.DotProductAttention( - num_attention_heads=num_attention_heads, - kv_channels=kv_channels, - attention_dropout=0.5, - qkv_format=qkv_format, - attn_mask_type=attn_mask_type, - ).to(device="cuda") - do_export(model, inp, fname, input_names=input_names, use_fp8=True) - te_outputs = te_infer(model, inp, is_fp8=True) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16,): - return - validate_result( - fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs - ) - - -test_configs_multihead_attention = [ - # "use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledSoftmax - (True, "arbitrary"), # calls ScaledMaskedSoftmax -] -test_configs_attention_type = [ - # "input_layernorm, attention_type, fuse_qkv_params" - (True, "self", True), - (False, "self", True), - (True, "self", False), - (False, "self", False), - (True, "cross", True), - (False, "cross", True), - (True, "cross", False), - (False, "cross", False), -] - - -@pytest.mark.parametrize("use_fp8", [False, True]) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type -) -def test_export_multihead_attention( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - use_mask: bool, - attn_mask_type: str, - precision: torch.dtype, - return_layernorm_output: bool, - input_layernorm: bool, - attention_type: str, - fuse_qkv_params: bool, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - hidden_size = 256 - sequence_length = 128 - batch_size = 4 - num_attention_heads = 32 - kv_channels = 8 - attention_dropout = 0.1 - layernorm_epsilon = 1e-5 - init_method = output_layer_init_method = get_default_init_method() - attention_args = ( - hidden_size, - num_attention_heads, - kv_channels, - attention_dropout, - layernorm_epsilon, - init_method, - output_layer_init_method, - ) - - hidden_states_context = torch.randn( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - attention_mask = None - if use_mask and attn_mask_type != "causal": - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones( - batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision - ) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - - encoder_output = None - - if attention_type == "cross": - encoder_output = torch.randn( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - - fp8_str = "_fp8" if use_fp8 else "" - dtype_str = dtype2str(precision) - attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention" - fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else "" - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - input_ln_str = "_input-ln" if input_layernorm else "" - fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx" - - model = te.MultiheadAttention( - *attention_args, - attn_mask_type=attn_mask_type, - params_dtype=precision, - return_layernorm_output=return_layernorm_output, - input_layernorm=input_layernorm, - attention_type=attention_type, - fuse_qkv_params=fuse_qkv_params, - return_bias=True, - ).to(device="cuda") - - inp_context = (hidden_states_context, attention_mask, encoder_output) - input_names = ["hidden_states", "attention_mask", "encoder_output"] - output_names = ["attention_output", "attention_bias"] - do_export( - model, - inp_context, - fname, - use_fp8, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "hidden_states": {0: "seq", 1: "bs"}, - "attention_output": {0: "seq", 1: "bs"}, - }, - ) - te_outputs = te_infer(model, inp_context, is_fp8=use_fp8) - serialize_inputs_outputs( - fname, inp_context, te_outputs, input_names=input_names, output_names=output_names - ) - if precision in (torch.bfloat16,): - return - - if not use_fp8: - validate_result( - fname, - inp_context, - model, - atol=1e-3, - input_names=input_names, - output_names=output_names, - te_outputs=te_outputs, - ) - else: - validate_result( - fname, - inp_context, - model, - atol=1e-2, - is_fp8=use_fp8, - input_names=input_names, - output_names=output_names, - allow_cnt_errors=3, - te_outputs=te_outputs, - ) - - # In GPT generative phase (inference) the input sequence is smaller than the maximum - # allowed sequence length and we want to test this condition. - # Pretend that we're in generative phase when it makes sense (causal mask and self-attention). - is_generative_phase = attn_mask_type == "causal" and attention_type == "self" - if is_generative_phase: - seq_len_offset = 8 - hidden_states_generative = torch.randn( - sequence_length - seq_len_offset, - batch_size, - hidden_size, - dtype=precision, - device="cuda", - ) - inp_generative = (hidden_states_generative, attention_mask, encoder_output) - if not use_fp8: - validate_result( - fname, - inp_generative, - model, - atol=1e-3, - input_names=input_names, - output_names=output_names, - ) - else: - validate_result( - fname, - inp_generative, - model, - atol=1e-2, - is_fp8=use_fp8, - input_names=input_names, - output_names=output_names, - allow_cnt_errors=3, - ) - - -@pytest.mark.parametrize("use_fp8", [False, True]) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize( - "output_layernorm", - [ - # True, # TO DO: handle this - False - ], -) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("fuse_qkv_params", [False, True]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -def test_export_transformer_layer( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - use_mask: bool, - attn_mask_type: str, - output_layernorm: bool, - precision: torch.dtype, - fuse_qkv_params: bool, - zero_centered_gamma: bool, - activation: str, -): - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Layer configuration - hidden_size = 64 - sequence_length = 128 - batch_size = 1 - ffn_hidden_size = 256 - num_attention_heads = 4 - - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - input_names = ["input", "attention_mask"] - attention_mask = None - if use_mask and attn_mask_type != "causal": - # Generate a random mask with 50% probability for 0 or 1. - probs = 0.5 * torch.ones( - batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision - ) - attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool) - inp = (input_tensor, attention_mask) - - fp8_str = "_fp8" if use_fp8 else "" - fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" - high_prec_str = dtype2str(precision) - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx" - - model = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_attention_heads, - self_attn_mask_type=attn_mask_type, - output_layernorm=output_layernorm, - params_dtype=precision, - fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, - activation=activation, - ).to(device="cuda") - do_export(model, inp, fname, use_fp8, input_names=input_names) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision in (torch.bfloat16,): - return - atol = 5e-1 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3) - validate_result( - fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs - ) - - -@pytest.mark.parametrize("use_fp8", [True]) -@pytest.mark.parametrize("ln_scale_factor", [448 * 2]) -@pytest.mark.parametrize( - "gemm_scale_factors", - [ - ( - 224, - 224, - ), - ], -) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -def test_export_gemm_layernorm( - seed_default_rng, - use_fp8: bool, - ln_scale_factor: float, - gemm_scale_factors: Tuple[float, float], - precision: torch.dtype, - zero_centered_gamma: bool, -): - """This is a regression test for testing that all LN inputs have the same type. - - The test sets up GEMM with FP32 output which feeds into an LN that is configured - with FP16 or BF16 weights and bias. - """ - out_features = 128 - hidden_size = 128 - in_features = 128 - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - class TestFP8_GemmLayernorm(nn.Module): - def __init__(self) -> None: - super().__init__() - normalized_shape = torch.Size(inp.shape[1:]) - self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda") - self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda") - self.eps = 1e-6 # An arbitrary small value - - self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT - self.meta = create_meta(ln_scale_factor) - self.fp8_type = tex.DType.kFloat8E4M3 - self.gemm = FP8GemmModule( - precision, - use_bias=False, - gelu=False, - scale_factors=gemm_scale_factors, - hidden_size=hidden_size, - out_features=out_features, - ) - - def forward(self, inp, weight): - x = self.gemm(inp, weight) - x = texcpp.layernorm_fwd_fp8_inf( - x, - self.weight, - self.bias, - self.eps, - self.meta, - self.fp8_tensor, - self.fp8_type, - 0, - zero_centered_gamma, - ) - - x = cast_from_fp8( - x, - self.meta, - self.fp8_tensor, - self.fp8_type, - tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16, - ) - return x - - inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda") - weight = torch.randn(out_features, in_features, dtype=precision, device="cuda") - model = TestFP8_GemmLayernorm() - high_prec_str = dtype2str(precision) - fp8_str = f"_fp8" if use_fp8 else "" - fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx" - input_names = ["input", "weight"] - do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names) - te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8) - serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names) - if precision not in (torch.bfloat16,): - validate_result( - fname, - (inp, weight), - model, - atol=5e-2, - is_fp8=use_fp8, - allow_cnt_errors=2, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@skip_FP8 -@pytest.mark.parametrize("use_fp8", [True, False]) -@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [True]) -def test_export_gpt_generation( - seed_default_rng, - set_max_seq_len, - use_fp8: bool, - precision: torch.dtype, - zero_centered_gamma: bool, -): - """Test that the ONNX model can correctly handle inputs with different shapes and that - the attention mask it adjusted on-the-fly to different sequence lengths. - """ - - # Skip FP8 tests on non-hopper devices - if use_fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - - # Layer configuration - hidden_size = 64 - sequence_length = 128 - batch_size = 1 - ffn_hidden_size = 256 - num_attention_heads = 4 - attention_mask = None - use_mask = True - attn_mask_type = "causal" - fuse_qkv_params = True - output_layernorm = False - - fp8_str = "_fp8" if use_fp8 else "" - fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else "" - high_prec_str = dtype2str(precision) - attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type) - fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx" - - model = te.TransformerLayer( - hidden_size, - ffn_hidden_size, - num_attention_heads, - self_attn_mask_type=attn_mask_type, - output_layernorm=output_layernorm, - params_dtype=precision, - fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, - ).to(device="cuda") - - # "Context phase": use full input sequence length - input_names = ["input"] - output_names = ["output"] - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - inp = (input_tensor,) - do_export( - model, - inp, - fname, - use_fp8, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "input": {0: "seq", 1: "bs"}, - "output": {0: "seq", 1: "bs"}, - }, - ) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs( - fname, inp, te_outputs, input_names=input_names, output_names=output_names - ) - if precision not in (torch.bfloat16,): - validate_result( - fname, - inp, - model, - atol=6e-3, - is_fp8=use_fp8, - input_names=input_names, - te_outputs=te_outputs, - ) - - # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8. - sequence_length = 1 if not use_fp8 else 8 - input_tensor = torch.rand( - sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" - ) - inp = (input_tensor, attention_mask) - te_outputs = te_infer(model, inp, is_fp8=use_fp8) - serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names) - if precision not in (torch.bfloat16,): - validate_result( - fname, - inp, - model, - atol=6e-3, - is_fp8=use_fp8, - input_names=input_names, - te_outputs=te_outputs, - ) - - -@pytest.mark.parametrize("enabled", [True, False]) -def test_export_ctx_manager(enabled): - assert is_in_onnx_export_mode() == False - with te.onnx_export(enabled): - assert is_in_onnx_export_mode() == enabled - assert is_in_onnx_export_mode() == False diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py new file mode 100644 index 0000000000..5e355dc989 --- /dev/null +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -0,0 +1,108 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import random +import pytest +import torch +from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy + + +class TestParallelCrossEntropy: + + def generate_iters(self, iters: int): + self.iters = iters + + def generate_infra(self, reduce_loss: bool, label_smoothing: float): + self.test_loss_func = parallel_cross_entropy + self.ref_loss_func = torch.nn.CrossEntropyLoss( + label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" + ) + + def generate_input(self, dtype: torch.dtype, swap_dim: bool): + + SQ = random.choice([64, 128]) + batch = random.choice([1, 2]) + vocab = random.choice([64000, 128000]) + + if swap_dim: + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() + self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + else: + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() + self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + + self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) + self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + + def one_iteration_test( + self, dtype: torch.dtype, swap_dim: bool, label_smoothing: float, reduce_loss: bool + ): + + self.generate_input(dtype, swap_dim) + + self.input_test.requires_grad_(True) + self.input_ref.requires_grad_(True) + + test_loss = self.test_loss_func( + self.input_test, self.tar_test, label_smoothing, reduce_loss, None + ) + if reduce_loss: + test_loss.backward() + + ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) + if reduce_loss: + ref_loss.backward() + + test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss + + torch.testing.assert_close(test_loss, ref_loss, check_dtype=False) + if reduce_loss: + torch.testing.assert_close( + torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad + ) + + self.input_test = None + self.input_ref = None + self.tar_test = None + self.tar_ref = None + + def test_float32_input(self): + self.generate_iters(5) + self.generate_infra(True, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=True + ) + + def test_bfloat16_input(self): + self.generate_iters(5) + self.generate_infra(True, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.bfloat16, swap_dim=False, label_smoothing=0, reduce_loss=True + ) + + def test_swapped_input(self): + self.generate_iters(5) + self.generate_infra(True, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=True, label_smoothing=0, reduce_loss=True + ) + + def test_label_smoothing(self): + self.generate_iters(3) + self.generate_infra(True, 0.1) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=False, label_smoothing=0.1, reduce_loss=True + ) + + def test_non_reduced_loss(self): + self.generate_iters(1) + self.generate_infra(False, 0) + for i in range(self.iters): + self.one_iteration_test( + dtype=torch.float32, swap_dim=False, label_smoothing=0, reduce_loss=False + ) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index ed25b96955..0dc183e298 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -1,15 +1,23 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import random + import torch import pytest from typing import Dict, List -from transformer_engine.pytorch import moe_permute as te_permute, moe_unpermute as te_unpermute +from transformer_engine.pytorch import ( + moe_permute as te_permute, + moe_permute_with_probs as te_permute_with_probs, + moe_unpermute as te_unpermute, + moe_sort_chunks_by_index as te_sort_chunks_by_index, + moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, +) from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine_torch as tex @@ -18,7 +26,7 @@ torch.cuda.manual_seed(seed) -def pytorch_permute(tokens, indices, num_out_tokens: int = None): +def pytorch_permute_index_map(tokens, indices, num_out_tokens: int = None): """ Permute the tokens based on the indices. Token with the same index will be grouped together. The input indices shape is [tokens, top_k], it indicates which experts were selected by each token separately. @@ -50,7 +58,7 @@ def pytorch_permute(tokens, indices, num_out_tokens: int = None): return permuted_tokens, sorted_indices -def pytorch_unpermute( +def pytorch_unpermute_index_map( permuted_tokens: torch.Tensor, sorted_indices: torch.Tensor, probs: torch.Tensor = None, @@ -95,6 +103,86 @@ def pytorch_unpermute( return unpermuted_tokens +def pytorch_permute_mask_map(tokens, routing_map): + """Permute the tokens and probs based on the mask. + Tokens with the same designated expert will be grouped together. + The shape of mask is [tokens, num_experts], it indicates which experts were selected + by each token. + + Args: + tokens (torch.Tensor): The input token tensor, [num_tokens, hidden]. + routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. + """ + num_tokens, _ = tokens.shape + num_experts = routing_map.shape[1] + + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.bool().T.contiguous() + + # Create a dense expert-to-token mapping from the sparse token-to-expert mapping + token_indices = ( + torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1) + ) + sorted_indices = token_indices.masked_select(routing_map) + + # use the mapping to permute the tokens + permuted_input = tokens.index_select(0, sorted_indices) + + return permuted_input, sorted_indices + + +def pytorch_unpermute_mask_map( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + restore_shape: torch.Size, + probs: torch.Tensor = None, + routing_map: torch.Tensor = None, +): + """ + Restore the original order of tokens after permutation. If probs are provided, it + will also apply them to the tokens before restoring the order. + + Args: + permuted_tokens (torch.Tensor): The permuted token tensor. + sorted_indices (torch.Tensor): The indices used to sort the tokens. + restore_shape (torch.Size): The shape of the unpermuted tensor. + probs (torch.Tensor, optional): The unpermuted probs tensor, + routing_map (torch.Tensor, optional): Token to expert mapping, shape + [num_tokens, num_experts]. + + Returns: + torch.Tensor: The tokens restored to their original order. + """ + _, hidden = restore_shape + + if probs is not None: + assert routing_map is not None, "Mask must be provided to permute the probs." + permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous()) + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + + # Create an output tensor filled with zeros + output_tokens = torch.zeros( + restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype + ) + # Scatter add the permuted_input back to the original positions + output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens) + return output_tokens + + +def pytorch_sort_chunks_by_index( + input: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, +): + """ + Split and sort the input tensor based on the split_sizes and sorted indices. + return a tuple of (output, row_id_map). row_id_map is only used when fused=True. + """ + input = torch.split(input, split_sizes.tolist(), dim=0) + output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0) + return output + + def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: """Estimated tolerances for a datatype @@ -112,7 +200,17 @@ def dtype_tols(te_dtype: tex.DType) -> Dict[str, float]: raise ValueError(f"Unsuppored dtype ({te_dtype})") -def _test_permutation( +def backward_wrapper( + act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False +): + # Set forward_input.grad to None to avoid grad accumulation. + if accumulate_grad == False: + for i in forward_input: + i.grad = None + return act.backward(backward_input, retain_graph=retain_graph) + + +def _test_permutation_index_map( te_dtype, num_tokens, num_expert, @@ -132,7 +230,8 @@ def _test_permutation( num_out_tokens = num_tokens * topK print( - f"token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + "index map:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) fp8 = False @@ -159,20 +258,28 @@ def _test_permutation( unpermute_bwd_input = torch.rand( size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" ) - - permute_fwd_input = Float8Tensor.to_float8( - permute_fwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - permute_bwd_input = Float8Tensor.to_float8( - permute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _permute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) - unpermute_bwd_input = Float8Tensor.to_float8( - unpermute_bwd_input, fp8_dtype=te_dtype, scale=torch.full([1], 1.0) + _unpermute_bwd_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, ) + permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) + permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) + unpermute_bwd_input = _unpermute_bwd_quantizer(unpermute_bwd_input) - pytorch_permute_fwd_input = permute_fwd_input.from_float8(torch.float16) - pytorch_permute_bwd_input = permute_bwd_input.from_float8(torch.float16) - pytorch_unpermute_bwd_input = unpermute_bwd_input.from_float8(torch.float16) + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) else: pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() @@ -198,7 +305,7 @@ def _test_permutation( # PyTorch Permutation # ################################################################################################################################### - pytorch_permute_output, sorted_indices = pytorch_permute( + pytorch_permute_output, sorted_indices = pytorch_permute_index_map( pytorch_permute_fwd_input, indices, num_out_tokens ) pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) @@ -206,7 +313,7 @@ def _test_permutation( pytorch_unpermute_fwd_input = pytorch_permute_output.detach() pytorch_unpermute_fwd_input.requires_grad_(True) - pytorch_unpermute_output = pytorch_unpermute( + pytorch_unpermute_output = pytorch_unpermute_index_map( pytorch_unpermute_fwd_input, sorted_indices, probs=probs ) pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) @@ -220,7 +327,9 @@ def _test_permutation( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute(te_permute_fwd_input, indices, num_out_tokens) + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, indices, num_out_tokens, map_type="index" + ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -231,7 +340,9 @@ def _test_permutation( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) + te_unpermute_output = te_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" + ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -242,10 +353,10 @@ def _test_permutation( tols = dtype_tols(te_dtype) if fp8: - te_permute_output_ = te_permute_output.from_float8(torch.float32) - te_permute_fwd_input_grad = te_permute_fwd_input.grad.from_float8(torch.float32) - te_unpermute_output_ = te_unpermute_output.from_float8(torch.float32) - te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.from_float8(torch.float32) + te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) else: te_permute_output_ = te_permute_output.float() te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() @@ -289,21 +400,12 @@ def _test_permutation( # Benchmark # ################################################################################################################################### - def backward_wrapper( - act, backward_input, forward_input=[], retain_graph=True, accumulate_grad=False - ): - # Set forward_input.grad to None to avoid grad accumulation. - if accumulate_grad == False: - for i in forward_input: - i.grad = None - return act.backward(backward_input, retain_graph=retain_graph) - if BENCHMARK: t1 = perf_test_cuda_kernel( - lambda: pytorch_permute(pytorch_permute_fwd_input, indices, num_out_tokens) + lambda: pytorch_permute_index_map(pytorch_permute_fwd_input, indices, num_out_tokens) ) t2 = perf_test_cuda_kernel( - lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens) + lambda: te_permute(te_permute_fwd_input, indices, num_out_tokens, map_type="index") ) print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -328,10 +430,12 @@ def backward_wrapper( print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") t1 = perf_test_cuda_kernel( - lambda: pytorch_unpermute(pytorch_unpermute_fwd_input, sorted_indices, probs=probs) + lambda: pytorch_unpermute_index_map( + pytorch_unpermute_fwd_input, sorted_indices, probs=probs + ) ) t2 = perf_test_cuda_kernel( - lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs) + lambda: te_unpermute(te_unpermute_fwd_input, row_id_map, te_probs, map_type="index") ) print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") @@ -362,150 +466,1006 @@ def backward_wrapper( print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") -def perf_test_cuda_kernel(cuda_kernel_fn): - if torch.cuda.is_available(): - # create CUDA event - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - # warmup - for _ in range(50): - cuda_kernel_fn() - - start_event.record() - for _ in range(100): - cuda_kernel_fn() - end_event.record() - torch.cuda.synchronize() - - elapsed_time_ms = start_event.elapsed_time(end_event) - return elapsed_time_ms / 100 - else: - pytest.skip("CUDA is not available.") - - -# TE tensor dtypes -_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] -if is_bf16_compatible(): - _te_dtypes.append(tex.DType.kBFloat16) - - -@pytest.mark.parametrize("te_dtype", _te_dtypes) -@pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) -@pytest.mark.parametrize("num_out_tokens", [None, 2039]) -def test_permutation( +def _test_permutation_mask_map( te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, + with_probs, + BENCHMARK=False, ): - with_probs = True - BENCHMARK = False + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") - _test_permutation( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=BENCHMARK, + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "mask map:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" ) + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") -# Only run FP8 tests on H100. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + permute_bwd_input = torch.rand( + size=(num_out_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + _permute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + _unpermute_bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + permute_fwd_input = _permute_fwd_input_quantizer(permute_fwd_input) + permute_bwd_input = _permute_bwd_input_quantizer(permute_bwd_input) + unpermute_bwd_input = _unpermute_bwd_input_quantizer(unpermute_bwd_input) -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) -@pytest.mark.parametrize("num_tokens", [2048]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -@pytest.mark.parametrize("topK", [1, 2, 5]) -@pytest.mark.parametrize("num_out_tokens", [None, 2039]) -def test_permutation_fp8( - te_dtype, - num_tokens, - num_expert, - hidden_size, - topK, - num_out_tokens, -): - with_probs = True - BENCHMARK = False + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_permute_bwd_input = permute_bwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() - _test_permutation( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=BENCHMARK, + pytorch_permute_fwd_input.requires_grad_(True) + + restore_shape = pytorch_permute_fwd_input.shape + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = None + if with_probs: + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + if fp8: + probs = probs.to(torch.float16) + else: + probs = probs.to(dtype) + probs.requires_grad_(True) + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute_mask_map( + pytorch_permute_fwd_input, routing_map ) + pytorch_permute_output.backward(pytorch_permute_bwd_input, retain_graph=True) + pytorch_unpermute_fwd_input = pytorch_permute_output.detach() + pytorch_unpermute_fwd_input.requires_grad_(True) -@pytest.mark.parametrize("te_dtype", _te_dtypes) -@pytest.mark.parametrize("num_tokens", [4096]) -@pytest.mark.parametrize("num_expert", [8, 16]) -@pytest.mark.parametrize("hidden_size", [4096]) -def test_permutation_topk1_no_probs( - te_dtype, - num_tokens, - num_expert, - hidden_size, -): - topK = 1 - num_out_tokens = None - with_probs = False - BENCHMARK = False + pytorch_unpermute_output = pytorch_unpermute_mask_map( + pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) - _test_permutation( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=BENCHMARK, + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + te_permute_bwd_input = permute_bwd_input if fp8 else pytorch_permute_bwd_input.detach() + + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" ) + te_permute_output.backward(te_permute_bwd_input, retain_graph=True) + te_probs = None + if with_probs: + te_probs = probs.detach() + te_probs.requires_grad_(True) + te_unpermute_fwd_input = te_permute_output.detach() + te_unpermute_fwd_input.requires_grad_(True) + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() -def test_permutation_single_case(): - print("GPU:", torch.cuda.get_device_name(0)) + te_unpermute_output = te_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" + ) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) - # te_dtype = tex.DType.kFloat32 - # te_dtype = tex.DType.kFloat16 - # te_dtype = tex.DType.kBFloat16 - te_dtype = tex.DType.kFloat8E5M2 - # te_dtype = tex.DType.kFloat8E4M3 + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) - num_tokens = 10 - num_expert = 4 - hidden_size = 16 - topK = 2 - num_out_tokens = num_tokens * topK - 1 - with_probs = True - Benchmark = True + if fp8: + te_permute_output_ = te_permute_output.dequantize(dtype=torch.float32) + te_permute_fwd_input_grad = te_permute_fwd_input.grad.dequantize(dtype=torch.float32) + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.dequantize(dtype=torch.float32) + else: + te_permute_output_ = te_permute_output.float() + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float() - _test_permutation( - te_dtype=te_dtype, - num_tokens=num_tokens, - num_expert=num_expert, - hidden_size=hidden_size, - topK=topK, - num_out_tokens=num_out_tokens, - with_probs=with_probs, - BENCHMARK=Benchmark, + torch.testing.assert_close( + pytorch_permute_output.float(), + te_permute_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in te_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_unpermute_fwd_input.grad.float(), + te_unpermute_fwd_input_grad, + msg=f"Mismatch in te_unpermute bwd", + **tols, + ) + if with_probs: + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols + ) + + if not pytorch_permute_fwd_input.numel(): + print("Empty pytorch_permute_fwd_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: pytorch_permute_mask_map(pytorch_permute_fwd_input, routing_map) + ) + t2 = perf_test_cuda_kernel( + lambda: te_permute( + te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" + ) + ) + print(f"permute\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_permute_output, + pytorch_permute_bwd_input, + forward_input=[pytorch_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_permute_output, + te_permute_bwd_input, + forward_input=[te_permute_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"permute\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: pytorch_unpermute_mask_map( + pytorch_unpermute_fwd_input, sorted_indices, restore_shape, probs, routing_map + ) + ) + t2 = perf_test_cuda_kernel( + lambda: te_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" + ) + ) + print(f"unpermute\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_unpermute_output, + pytorch_unpermute_bwd_input, + forward_input=( + [pytorch_unpermute_fwd_input, probs] + if with_probs + else [pytorch_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_unpermute_output, + te_unpermute_bwd_input, + forward_input=( + [te_unpermute_fwd_input, te_probs] if with_probs else [te_unpermute_fwd_input] + ), + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"unpermute\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + +def _test_moe_chunk_sort( + te_dtype, + num_tokens, + num_expert, + tp_size, + hidden_size, + BENCHMARK=False, +): + print( + "chunk permute:" + f" token:{num_tokens} hidden_size:{hidden_size} num_expert:{num_expert} tp_size:{tp_size} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + fwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") + bwd_input = torch.rand(size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda") + + _fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + _bwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + fwd_input = _fwd_input_quantizer.quantize(fwd_input) + bwd_input = _bwd_input_quantizer.quantize(bwd_input) + + pytorch_fwd_input = fwd_input.dequantize(dtype=torch.float16) + pytorch_bwd_input = bwd_input.dequantize(dtype=torch.float16) + else: + pytorch_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_fwd_input.requires_grad_(True) + + _split_sizes = [0] * (num_expert * tp_size) + for _ in range(num_tokens): + idx = random.randint(0, num_expert * tp_size - 1) + _split_sizes[idx] += 1 + split_sizes = torch.tensor(_split_sizes, dtype=torch.int32).ravel() + split_sizes_cuda = split_sizes.to(device="cuda") + + _sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32) + sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel() + sorted_idxs_cuda = sorted_idxs.to(device="cuda") + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_output = pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs) + pytorch_output.backward(pytorch_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_fwd_input = fwd_input if fp8 else pytorch_fwd_input.detach() + te_fwd_input.requires_grad_(True) + te_bwd_input = bwd_input if fp8 else pytorch_bwd_input.detach() + + te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) + te_output.backward(te_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # Results Check + # + ################################################################################################################################### + tols = dtype_tols(te_dtype) + + if fp8: + te_output_ = te_output.dequantize(dtype=torch.float32) + te_fwd_input_grad = te_fwd_input.grad.dequantize(dtype=torch.float32) + else: + te_output_ = te_output.float() + te_fwd_input_grad = te_fwd_input.grad.float() + + torch.testing.assert_close( + pytorch_output.float(), + te_output_, + msg=f"Mismatch in te_permute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_fwd_input.grad.float(), + te_fwd_input_grad, + msg=f"Mismatch in te_permute bwd", + **tols, + ) + + if not pytorch_fwd_input.numel(): + print("Empty pytorch_fwd_input activation test passed.") + return + + ################################################################################################################################### + # + # Benchmark + # + ################################################################################################################################### + if BENCHMARK: + t1 = perf_test_cuda_kernel( + lambda: pytorch_sort_chunks_by_index(pytorch_fwd_input, split_sizes, sorted_idxs) + ) + t2 = perf_test_cuda_kernel( + lambda: te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) + ) + print(f"chunk sort\t\tfwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + t1 = perf_test_cuda_kernel( + lambda: backward_wrapper( + pytorch_output, + pytorch_bwd_input, + forward_input=[pytorch_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + t2 = perf_test_cuda_kernel( + lambda: backward_wrapper( + te_output, + te_bwd_input, + forward_input=[te_fwd_input], + retain_graph=True, + accumulate_grad=False, + ) + ) + print(f"chunk sort\t\tbwd: pytorch: {t1:.3f} ms, TE: {t2:.3f} ms") + + +def _test_permutation_mask_map_alongside_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + if topK > num_expert: + pytest.skip("topK should be smaller than the number of experts.") + + if num_out_tokens == None: + num_out_tokens = num_tokens * topK + + print( + "mask map alongside probs:" + f" token:{num_tokens} hidden_size:{hidden_size} expert:{num_expert} topK:{topK} {te_dtype}" + ) + + fp8 = False + # Convert TE dtypes to PyTorch dtypes + if te_dtype == tex.DType.kFloat32: + dtype = torch.float32 + elif te_dtype == tex.DType.kFloat16: + dtype = torch.float16 + elif te_dtype == tex.DType.kBFloat16: + dtype = torch.bfloat16 + elif fp8_available and (te_dtype == tex.DType.kFloat8E5M2 or te_dtype == tex.DType.kFloat8E4M3): + dtype = torch.uint8 + fp8 = True + else: + pytest.skip("Invalid dtype.") + + if fp8: + permute_fwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + unpermute_bwd_input = torch.rand( + size=(num_tokens, hidden_size), dtype=torch.float32, device="cuda" + ) + + _permute_fwd_input_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + _unpermute_bwd_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + permute_fwd_input = _permute_fwd_input_quantizer.quantize(permute_fwd_input) + unpermute_bwd_input = _unpermute_bwd_quantizer.quantize(unpermute_bwd_input) + + pytorch_permute_fwd_input = permute_fwd_input.dequantize(dtype=torch.float16) + pytorch_unpermute_bwd_input = unpermute_bwd_input.dequantize(dtype=torch.float16) + else: + pytorch_permute_fwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + pytorch_unpermute_bwd_input = torch.rand((num_tokens, hidden_size), dtype=dtype).cuda() + + pytorch_permute_fwd_input.requires_grad_(True) + + restore_shape = pytorch_permute_fwd_input.shape + + _tmp_tensor = torch.zeros((num_tokens * num_expert,)) + _tmp_tensor[: int(num_out_tokens)] = 1.0 + _tmp_idx = torch.randperm(num_tokens * num_expert) + routing_map = torch.reshape(_tmp_tensor[_tmp_idx], (num_tokens, num_expert)).bool().cuda() + + probs = torch.rand(num_tokens, num_expert).cuda() * routing_map + row_sums = probs.sum(dim=1, keepdim=True) + probs = probs / row_sums + if fp8: + probs = probs.to(torch.float16) + else: + probs = probs.to(dtype) + probs.requires_grad_(True) + + split_sizes = [0] * (num_expert * tp_size) + for i in range(num_out_tokens): + idx = random.randint(0, num_expert * tp_size - 1) + split_sizes[idx] += 1 + split_sizes = torch.tensor(split_sizes, dtype=torch.int32) + split_sizes_cuda = split_sizes.to(device="cuda") + + _sorted_idxs = torch.arange(num_expert * tp_size, dtype=torch.int32) + sorted_idxs = _sorted_idxs.reshape(tp_size, num_expert).T.ravel() + sorted_idxs_cuda = sorted_idxs.to(device="cuda") + + split_sizes_2 = [split_sizes[i] for i in sorted_idxs.tolist()] + split_sizes_2 = torch.tensor(split_sizes_2, dtype=torch.int32) + split_sizes_2_cuda = split_sizes_2.to(device="cuda") + + sorted_idxs_2 = [0] * (num_expert * tp_size) + for i in range(num_expert * tp_size): + sorted_idxs_2[sorted_idxs[i]] = i + sorted_idxs_2 = torch.tensor(sorted_idxs_2, dtype=torch.int32) + sorted_idxs_2_cuda = sorted_idxs_2.to(device="cuda") + + ################################################################################################################################### + # + # PyTorch Permutation + # + ################################################################################################################################### + pytorch_permute_output, sorted_indices = pytorch_permute_mask_map( + pytorch_permute_fwd_input, routing_map + ) + + pytorch_permute_output = pytorch_sort_chunks_by_index( + pytorch_permute_output, split_sizes, sorted_idxs + ) + + pytorch_permute_output = pytorch_sort_chunks_by_index( + pytorch_permute_output, split_sizes_2, sorted_idxs_2 + ) + + pytorch_unpermute_output = pytorch_unpermute_mask_map( + pytorch_permute_output, sorted_indices, restore_shape, probs, routing_map + ) + pytorch_unpermute_output.backward(pytorch_unpermute_bwd_input, retain_graph=True) + + ################################################################################################################################### + # + # TE Permutation + # + ################################################################################################################################### + te_permute_fwd_input = permute_fwd_input if fp8 else pytorch_permute_fwd_input.detach() + te_permute_fwd_input.requires_grad_(True) + + te_unpermute_bwd_input = unpermute_bwd_input if fp8 else pytorch_unpermute_bwd_input.detach() + te_probs = probs.detach() + te_probs.requires_grad_(True) + print(te_probs.shape) + + te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( + te_permute_fwd_input, + te_probs, + routing_map, + num_out_tokens=num_out_tokens, + ) + print(te_permuted_probs.shape) + + te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( + te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda + ) + + if fp8: + _permute_output_quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda().squeeze(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=te_dtype, + ) + te_permute_output = te_permute_output.dequantize(dtype=torch.float32) + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = _permute_output_quantizer.quantize(te_permute_output) + else: + te_permute_output_dtype = te_permute_output.dtype + print(te_permute_output.shape) + print(te_permuted_probs.shape) + te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) + te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) + + te_permute_output = te_sort_chunks_by_index( + te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda + ) + + te_unpermute_output = te_unpermute( + te_permute_output, + row_id_map, + restore_shape=restore_shape, + map_type="mask", + ) + te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + + ############################################################################################### + + tols = dtype_tols(te_dtype) + + if fp8: + # backward of dequantize is in high precision + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.dequantize(dtype=torch.float32) + else: + te_permute_fwd_input_grad = te_permute_fwd_input.grad.float() + te_unpermute_output_ = te_unpermute_output.float() + + torch.testing.assert_close( + pytorch_unpermute_output.float(), + te_unpermute_output_, + msg=f"Mismatch in fused_unpermute fwd", + **tols, + ) + torch.testing.assert_close( + pytorch_permute_fwd_input.grad.float(), + te_permute_fwd_input_grad, + msg=f"Mismatch in fused_permute bwd", + **tols, + ) + torch.testing.assert_close( + probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols + ) + + +def perf_test_cuda_kernel(cuda_kernel_fn): + if torch.cuda.is_available(): + # create CUDA event + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + # warmup + for _ in range(50): + cuda_kernel_fn() + + start_event.record() + for _ in range(100): + cuda_kernel_fn() + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) + return elapsed_time_ms / 100 + else: + pytest.skip("CUDA is not available.") + + +# TE tensor dtypes +_te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] +if is_bf16_compatible(): + _te_dtypes.append(tex.DType.kBFloat16) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_index_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_mask_map( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_permutation_mask_map_empty_input(te_dtype): + with_probs = True + BENCHMARK = False + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + hidden_size=4096, + topK=2, + num_out_tokens=0, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +def test_permutation_mask_map_alongside_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + hidden_size=4096, + topK=2, + num_out_tokens=0, + tp_size=2, + ) + + +# Only run FP8 tests on H100. +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_index_map_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +def test_permutation_mask_map_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, +): + with_probs = True + BENCHMARK = False + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) +@pytest.mark.parametrize("num_tokens", [2048]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("topK", [1, 2, 5]) +@pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +def test_permutation_mask_map_alongside_probs_fp8( + te_dtype, + num_tokens, + num_expert, + hidden_size, + topK, + num_out_tokens, + tp_size, +): + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=tp_size, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_permutation_index_map_topk1_no_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, +): + topK = 1 + num_out_tokens = None + with_probs = False + BENCHMARK = False + + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_permutation_mask_map_topk1_no_probs( + te_dtype, + num_tokens, + num_expert, + hidden_size, +): + topK = 1 + num_out_tokens = None + with_probs = False + BENCHMARK = False + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +@pytest.mark.parametrize("num_tokens", [4096]) +@pytest.mark.parametrize("num_expert", [8, 16]) +@pytest.mark.parametrize("tp_size", [1, 2, 8]) +@pytest.mark.parametrize("hidden_size", [4096]) +def test_chunk_permutation( + te_dtype, + num_tokens, + num_expert, + tp_size, + hidden_size, +): + BENCHMARK = False + + _test_moe_chunk_sort( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + tp_size=tp_size, + hidden_size=hidden_size, + BENCHMARK=BENCHMARK, + ) + + +@pytest.mark.parametrize("te_dtype", _te_dtypes) +def test_chunk_permutation_empty_input(te_dtype): + BENCHMARK = False + + _test_moe_chunk_sort( + te_dtype=te_dtype, + num_tokens=0, + num_expert=8, + tp_size=2, + hidden_size=4096, + BENCHMARK=BENCHMARK, + ) + + +def test_permutation_single_case(): + print("GPU:", torch.cuda.get_device_name(0)) + + # te_dtype = tex.DType.kFloat32 + # te_dtype = tex.DType.kFloat16 + # te_dtype = tex.DType.kBFloat16 + te_dtype = tex.DType.kFloat8E5M2 + # te_dtype = tex.DType.kFloat8E4M3 + + num_tokens = 10 + num_expert = 4 + hidden_size = 16 + topK = 2 + num_out_tokens = num_tokens * topK - 1 + with_probs = True + Benchmark = True + + _test_permutation_index_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=Benchmark, + ) + + _test_permutation_mask_map( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + with_probs=with_probs, + BENCHMARK=Benchmark, + ) + + _test_moe_chunk_sort( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + tp_size=4, + hidden_size=hidden_size, + BENCHMARK=Benchmark, + ) + + _test_permutation_mask_map_alongside_probs( + te_dtype=te_dtype, + num_tokens=num_tokens, + num_expert=num_expert, + hidden_size=hidden_size, + topK=topK, + num_out_tokens=num_out_tokens, + tp_size=4, ) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 0c2118718c..30989bec61 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -15,6 +15,7 @@ _amax_and_scale_update, get_default_fp8_recipe, ) +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops import transformer_engine_torch as tex @@ -22,6 +23,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +# FP8 per tensor delayed scaling @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) class TestFP8Recipe: @@ -64,17 +66,17 @@ def test_fp8_scale_update_with_linear_module( forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) amax_history_forward = fp8_meta[forward_key].amax_history scale_forward = fp8_meta[forward_key].scale - scale_inv_forward = fp8_meta[forward_key].scale_inv + # scale_inv_forward = fp8_meta[forward_key].scale_inv backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) amax_history_backward = fp8_meta[backward_key].amax_history scale_backward = fp8_meta[backward_key].scale - scale_inv_backward = fp8_meta[backward_key].scale_inv + # scale_inv_backward = fp8_meta[backward_key].scale_inv # Tweak amax history and scaling factors amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5) amax_history_forward[0, :].zero_() scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5) - scale_inv_forward.copy_(torch.reciprocal(scale_forward)) + # scale_inv_forward.copy_(torch.reciprocal(scale_forward)) amax_history_backward[0, :].zero_() # Expected amax history after update @@ -100,11 +102,11 @@ def test_fp8_scale_update_with_linear_module( raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) - ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) update_weight_amax = is_first_microbatch is None or is_first_microbatch - if not update_weight_amax: - ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) - ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + # if not update_weight_amax: + # ref_scale_inv_forward[1].copy_(scale_inv_forward[1]) + # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Perform forward, backward, and optimizer steps to update fp8_meta with te.fp8_autocast(enabled=True, fp8_recipe=recipe): @@ -133,8 +135,8 @@ def test_fp8_scale_update_with_linear_module( raise ValueError(f"{amax_compute_algo=} is not supported") ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin) ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin) - ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) - ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) + # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward) + # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Check that scale and scale inverse match expected values # Note: scale and scale inverse are only updated when amax is updated @@ -142,27 +144,15 @@ def test_fp8_scale_update_with_linear_module( scale_forward[0], ref_scale_forward[0], ) - torch.testing.assert_close( - scale_inv_forward[0], - ref_scale_inv_forward[0], - ) if update_weight_amax: torch.testing.assert_close( scale_forward[1], ref_scale_forward[1], ) - torch.testing.assert_close( - scale_inv_forward[1], - ref_scale_inv_forward[1], - ) torch.testing.assert_close( scale_backward[0], ref_scale_backward[0], ) - torch.testing.assert_close( - scale_inv_backward[0], - ref_scale_inv_backward[0], - ) @pytest.mark.parametrize("amax_history_len", [31, 1024]) @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"]) @@ -180,12 +170,23 @@ def test_fp8_scale_update_with_linear_fuser_op( # Construct linear op op = te_ops.BasicLinear(in_shape[-1], in_shape[-1]) - # Get FP8 meta tensors + # FP8 recipe forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False) - x_fp8_meta = op.get_fp8_meta("input")[forward_key] - w_fp8_meta = op.get_fp8_meta("param")[forward_key] - dy_fp8_meta = op.get_fp8_meta("grad_output")[backward_key] + fp8_format = transformer_engine.common.recipe.Format.HYBRID + recipe = transformer_engine.common.recipe.DelayedScaling( + margin=margin, + interval=1, + fp8_format=fp8_format, + amax_history_len=amax_history_len, + amax_compute_algo=amax_compute_algo, + ) + + # Get FP8 meta tensors + with te.fp8_autocast(fp8_recipe=recipe): + x_fp8_meta = op.get_quantizer("forward", 0) + w_fp8_meta = op.get_quantizer("forward", 1) + dy_fp8_meta = op.get_quantizer("backward", 0) # Perform training steps x_history = [] @@ -214,14 +215,6 @@ def test_fp8_scale_update_with_linear_fuser_op( op.weight.fill_(w_history[-1]) # Forward and backward pass - fp8_format = transformer_engine.common.recipe.Format.HYBRID - recipe = transformer_engine.common.recipe.DelayedScaling( - margin=margin, - interval=1, - fp8_format=fp8_format, - amax_history_len=amax_history_len, - amax_compute_algo=amax_compute_algo, - ) with te.fp8_autocast(fp8_recipe=recipe): y = op(x) y.backward(dy) @@ -247,7 +240,7 @@ def check_amax_history( ) def check_scale( - fp8_meta: dict, + quantizer: Float8Quantizer, ref_amax_history: Iterable[float], stage: str, ): @@ -272,18 +265,11 @@ def check_scale( # Check values in FP8 meta tensors torch.testing.assert_close( - fp8_meta.scale.item(), + quantizer.scale.item(), ref_scale, ) - torch.testing.assert_close( - fp8_meta.scale_inv.item(), - 1 / ref_scale, - ) # Check that results match expected values - check_amax_history(x_fp8_meta, x_history) - check_amax_history(w_fp8_meta, w_history) - check_amax_history(dy_fp8_meta, dy_history) check_scale(x_fp8_meta, x_history, "forward") check_scale(w_fp8_meta, w_history, "forward") check_scale(dy_fp8_meta, dy_history, "backward") @@ -369,7 +355,6 @@ def setup_fp8_meta(): fp8_meta[forward_key].amax_history.clone().view(-1), [fp8_meta[forward_key].amax_history], [fp8_meta[forward_key].scale], - [fp8_meta[forward_key].scale_inv], recipe.amax_compute_algo, fp8_dtype, recipe.margin, @@ -378,12 +363,8 @@ def setup_fp8_meta(): _amax_and_scale_update( fp8_meta[forward_key].amax_history, fp8_meta[forward_key].scale, - fp8_meta[forward_key].scale_inv, fp8_max, recipe, ) torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale) - torch.testing.assert_close( - fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale) - ) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 4f057c12fe..1e6250f26f 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,7 +8,6 @@ import torch import pytest -import io import os from transformer_engine.pytorch.fp8 import ( @@ -26,6 +25,7 @@ from transformer_engine.pytorch import ( LayerNormLinear, Linear, + GroupedLinear, LayerNormMLP, TransformerLayer, RMSNorm, @@ -34,19 +34,22 @@ ) from transformer_engine.common import recipe import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import ( - gemm, - fp8_gemm, - gelu, - cast_to_fp8, - cast_from_fp8, -) +from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace -from test_onnx_export import create_meta +from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from test_numerics import reset_rng_states, dtype_tols -# Only run FP8 tests on H100. +# Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() + + +def create_meta(scale_factor: float, size: int = 1): + meta = tex.FP8TensorMeta() + meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") + meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor + meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor + return meta def custom_amax_to_scale( @@ -96,13 +99,9 @@ def is_fp8_supported(self): fp8_recipes = [ None, # Handles non-FP8 case + recipe.MXFP8BlockScaling(), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3), recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID), - recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.E4M3, - override_linear_precision=(False, False, True), - ), recipe.DelayedScaling( margin=0, fp8_format=recipe.Format.E4M3, @@ -136,7 +135,7 @@ def is_fp8_supported(self): all_boolean = [True, False] batch_sizes_with_zero = [0, 1, 2] -all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"] +all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -236,6 +235,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): torch.cuda.synchronize() assert te_out.dtype == dtype, "AMP wrong output type." + assert te_inp_hidden_states.grad is not None, "Gradient should not be empty" assert te_inp_hidden_states.grad.dtype == torch.float32, "AMP wrong dgrad type." for name, p in block.named_parameters(): if p.requires_grad: @@ -272,11 +272,14 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci loss.backward() torch.cuda.synchronize() + failed_grads = [] for name, p in block.named_parameters(): if "layer_norm_weight" in name: continue elif "weight" in name and p.requires_grad: - assert torch.count_nonzero(p.main_grad) > 0, "Gradient not accumulated." + if not torch.count_nonzero(p.main_grad) > 0: + failed_grads.append(name) + assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): @@ -411,6 +414,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) torch.cuda.synchronize() assert te_out.dtype == dtype, "AMP wrong output type." + assert te_inp.grad is not None, "Gradient should not be empty" assert te_inp.grad.dtype == torch.float32, "AMP wrong dgrad type." for name, p in block.named_parameters(): if p.requires_grad: @@ -445,6 +449,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -474,6 +480,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -504,11 +512,13 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None - with fp8_model_init(enabled=use_fp8 and fp8_model_params): + with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_linear = Linear( config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype ).cuda() @@ -523,6 +533,55 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ assert out.shape == (num_tokens, ffn_hidden_size) +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("bs", batch_sizes_with_zero) +@pytest.mark.parametrize("model", ["small", "weird"]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_model_params", all_boolean) +@pytest.mark.parametrize("use_bias", all_boolean) +@pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) +@pytest.mark.parametrize("num_gemms", [4]) +def test_sanity_grouped_linear( + dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split +): + config = model_configs[model] + ffn_hidden_size = 4 * config.hidden_size + # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. + bs = bs * 16 + num_tokens = bs * config.seq_len * (num_gemms - 1) + + if fp8_recipe is not None: + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8(): + pytest.skip("Grouped linear does not support MXFP8") + if not config.is_fp8_supported(): + pytest.skip("Model config does not support FP8") + + use_fp8 = fp8_recipe is not None + with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): + te_grouped_linear = GroupedLinear( + num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype + ).cuda() + + inp_hidden_states = torch.randn( + num_tokens, config.hidden_size, dtype=dtype, requires_grad=True + ).cuda() + m_splits = [bs * config.seq_len] * num_gemms + if empty_split == "first": + m_splits[0] = 0 + elif empty_split == "last": + m_splits[-1] = 0 + elif empty_split == "middle": + m_splits[num_gemms // 2] = 0 + + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + out = te_grouped_linear(inp_hidden_states, m_splits) + loss = out.sum() + loss.backward() + assert out.shape == (num_tokens, ffn_hidden_size) + + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small", "weird"]) @@ -539,6 +598,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -587,6 +648,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -652,6 +715,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -709,6 +774,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -764,6 +831,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -797,6 +866,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -833,6 +904,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -872,6 +945,8 @@ def test_sanity_gradient_accumulation_fusion( if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -912,6 +987,8 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) + if fp8_recipe.mxfp8() and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) if not config.is_fp8_supported(): pytest.skip("Model config does not support FP8") @@ -962,7 +1039,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): inp = torch.reshape(scratchpad[offset:-offset], (N, N)) weight = torch.reshape(scratchpad[offset * 2 :], (N, N)) - _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace()) + _ = general_gemm(A=weight, B=inp, workspace=get_workspace()) torch.cuda.synchronize() @@ -971,35 +1048,24 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) def test_sanity_fp8_gemm_with_unalignment(N, datatype): offset = 16 - scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype) + scratchpad = torch.randn(N, N * N + offset, device="cuda", dtype=datatype) - fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT + scales = torch.ones(1).cuda().squeeze() + amaxes = torch.ones(1).cuda().squeeze() + dtype = tex.DType.kFloat8E4M3 + fp8_quantizer = Float8Quantizer(scales, amaxes, dtype) - nb_inp_scales, nb_weight_scales = 1, N - scale_factor = 1.0 - meta_inp = create_meta(scale_factor, nb_inp_scales) - meta_weight = create_meta(scale_factor, nb_weight_scales) - inp_type = tex.DType.kFloat8E4M3 - weights_type = tex.DType.kFloat8E4M3 outp_type = datatype - scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type) - inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N)) - weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N)) - _, _ = fp8_gemm( + scratchpad_fp8 = fp8_quantizer(scratchpad) + inp_fp8 = torch.reshape(scratchpad_fp8[0][:-offset], (N, N)) + weight_fp8 = torch.reshape(scratchpad_fp8[0][offset:], (N, N)) + general_gemm( weight_fp8, - meta_weight.scale_inv, - fp8_tensor_weight, - inp_type, inp_fp8, - meta_inp.scale_inv, - fp8_tensor_inp, - weights_type, - outp_type, get_workspace(), + outp_type, bias=None, - use_bias=False, use_split_accumulator=False, ) torch.cuda.synchronize() @@ -1062,13 +1128,15 @@ def get_model(dtype, config): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_enabled): + with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, config.num_attention_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, + hidden_dropout=0.0, + attention_dropout=0.0, fuse_qkv_params=True, params_dtype=dtype, device="cuda", @@ -1101,7 +1169,7 @@ def get_model(dtype, config): del block block = get_model(dtype, config) - block.load_state_dict(torch.load(path)) + block.load_state_dict(torch.load(path, weights_only=False)) torch.set_rng_state(_cpu_rng_state_new) torch.cuda.set_rng_state(_cuda_rng_state_new) diff --git a/tests/pytorch/test_sanity_import.py b/tests/pytorch/test_sanity_import.py index 954d807b7d..5657cf0d85 100644 --- a/tests/pytorch/test_sanity_import.py +++ b/tests/pytorch/test_sanity_import.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py deleted file mode 100644 index 7bf8fb99d5..0000000000 --- a/tests/pytorch/test_torch_save_load.py +++ /dev/null @@ -1,474 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -This file contains tests for saving and loading TransformerEngine torch checkpoints. - -The purpose of this test is to validate the TransformerEngine hooks for saving FP8 metadata -in torch checkpoints, which are called as part of torch.save() and torch.load(). -The test verifies the values of FP8 metadata object after saving and loading a checkpoint -are identical to the original values. -""" - -import io -import tempfile -from typing import Iterable, Union - -import pytest -import torch -import transformer_engine.common -import transformer_engine.pytorch as te -import transformer_engine.pytorch.ops as te_ops -import transformer_engine_torch as tex -from transformer_engine.pytorch.cpp_extensions import fp8_gemm, cast_to_fp8 -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - -# Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() - - -def init_meta(size: int = 1): - meta = tex.FP8TensorMeta() - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - return meta - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("scale_fwd", [224, 112, 66]) -@pytest.mark.parametrize("scale_bwd", [448, 33]) -@pytest.mark.parametrize("history_fwd", [1.23, 4.56]) -@pytest.mark.parametrize("history_bwd", [2.34, 5.67]) -def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd): - - tmp_filename = tempfile.NamedTemporaryFile().name - - precision = torch.float32 - - class Test_TE_Export(TransformerEngineBaseModule): - def __init__(self, precision, use_bias): - super().__init__() - self.use_bias = use_bias - self.precision = precision - - self.fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT - self.fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT - nb_inp_scales = nb_weight_scales = 1 - self.meta_inp = init_meta(nb_inp_scales) - self.meta_weight = init_meta(nb_weight_scales) - - bias_size = nb_weight_scales - self.bias = torch.randn(bias_size, dtype=precision, device="cuda") - - self.inp_type = tex.DType.kFloat8E4M3 - self.weights_type = tex.DType.kFloat8E4M3 - self.outp_type = precision - - def get_fp8_weights_scratchpad(self, is_first_microbatch): - raise RuntimeError( - "Method get_fp8_weights_scratchpad is dummy and should not be invoked." - ) - - def forward(self, inp, weight): - inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type) - - weight_fp8 = cast_to_fp8( - weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type - ) - - ret = fp8_gemm( - weight_fp8, - self.meta_weight.scale_inv, - self.fp8_tensor_weight, - self.inp_type, - inp_fp8, - self.meta_inp.scale_inv, - self.fp8_tensor_inp, - self.weights_type, - self.outp_type, - get_workspace(), - bias=self.bias, - use_bias=self.use_bias, - use_split_accumulator=False, - ) - return ret - - model_in = Test_TE_Export(precision, True) - with te.fp8_autocast(enabled=True): - model_in.init_fp8_metadata() - # scaling fwd - model_in.fp8_meta["scaling_fwd"].scale = ( - torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd - ) - model_in.fp8_meta["scaling_fwd"].scale_inv = ( - torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd - ) - model_in.fp8_meta["scaling_fwd"].amax_history = ( - torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd - ) - # scaling bwd - model_in.fp8_meta["scaling_bwd"].scale = ( - torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd - ) - model_in.fp8_meta["scaling_bwd"].scale_inv = ( - torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd - ) - model_in.fp8_meta["scaling_bwd"].amax_history = ( - torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd - ) - - torch.save(model_in.state_dict(), tmp_filename) - - model_out = Test_TE_Export(precision, True) - model_out.load_state_dict(torch.load(tmp_filename)) - model_out.eval() - - # scaling fwd - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale - ) - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv - ) - assert torch.allclose( - model_in.fp8_meta["scaling_fwd"].amax_history, - model_out.fp8_meta["scaling_fwd"].amax_history, - ) - # scaling bwd - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale - ) - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv - ) - assert torch.allclose( - model_in.fp8_meta["scaling_bwd"].amax_history, - model_out.fp8_meta["scaling_bwd"].amax_history, - ) - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("save_fp8_model", [True, False]) -@pytest.mark.parametrize("load_fp8_model", [True, False]) -def test_fp8_model_checkpoint( - save_fp8_model: bool, - load_fp8_model: bool, - dims: Iterable[int] = [32, 32], - dtype: torch.dtype = torch.float32, - device: Union[torch.device, str] = "cuda", -): - - # Construct model - dims = list(dims) - hidden_dim = dims[-1] - with te.fp8_model_init(enabled=save_fp8_model): - model = te.Linear( - hidden_dim, - hidden_dim, - bias=False, - params_dtype=dtype, - device=device, - ) - # Keep track of model output - x = torch.randn(dims, dtype=dtype, device=device) - with te.fp8_autocast(): - y_ref = model(x.detach().clone()).detach().clone() - - fp8_meta_ref = {"scaling_fwd": {}, "scaling_bwd": {}} - with te.fp8_autocast(), torch.no_grad(): - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - fp8_meta_fwd_ref["scale"] = torch.rand_like(fp8_meta_fwd.scale) + 0.5 - fp8_meta_fwd_ref["scale_inv"] = fp8_meta_fwd_ref["scale"].reciprocal() - fp8_meta_bwd_ref["scale"] = torch.rand_like(fp8_meta_bwd.scale) + 0.5 - fp8_meta_bwd_ref["scale_inv"] = fp8_meta_bwd_ref["scale"].reciprocal() - fp8_meta_fwd.scale.copy_(fp8_meta_fwd_ref["scale"]) - fp8_meta_fwd.scale_inv.copy_(fp8_meta_fwd_ref["scale_inv"]) - fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"]) - fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"]) - del fp8_meta_fwd, fp8_meta_bwd - - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor. - # The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method. - # It is essential for these values to be equal, so setting scale_inv only in the model metadata is insufficient. - model.weight.data.copy_(model.weight.float().cuda()) - # After copying, the tensor computes the meta scale_inv based on the amax history; we then reset these values. - model.fp8_meta["scaling_fwd"].scale = fp8_meta_fwd_ref["scale"] - model.fp8_meta["scaling_fwd"].scale_inv = fp8_meta_fwd_ref["scale_inv"] - - # Keep track of weights and FP8 scaling factors - weight_ref = model.weight.float().detach().clone() - - # Save checkpoint - byte_stream = io.BytesIO() - torch.save(model.state_dict(), byte_stream) - model_bytes = byte_stream.getvalue() - del byte_stream - - # Disturb and destroy model - with torch.no_grad(): - model.weight.zero_() - model.fp8_meta = {"This": "is", "filled": "with", "nonsense": 1234} - del model - - # Construct new model - with te.fp8_model_init(enabled=load_fp8_model): - model = te.Linear( - hidden_dim, - hidden_dim, - bias=False, - params_dtype=dtype, - device=device, - ) - - # Make sure new model does not match saved model - tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 - with pytest.raises(AssertionError): - torch.testing.assert_close(model.weight, weight_ref, **tols) - with te.fp8_autocast(): - model.init_fp8_metadata() - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) - with pytest.raises(AssertionError): - torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) - with te.fp8_autocast(): - y = model(x.detach().clone()) - with pytest.raises(AssertionError): - torch.testing.assert_close(y, y_ref, **tols) - - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # When save_fp8_model=True, we load a model with weights in high precision, - # which does not include _scale_inv, - # but has the fp8 scaling factor in the meta data. This scenario can occur - # when using te.fp8_autocast(enabled=False, calibrating=True). - # - # In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first, - # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior - # is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule, - # to load the fp8 metadata before loading tensors. - # - # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) - del model_bytes - - # Check that loaded model matches saved model - torch.testing.assert_close(model.weight, weight_ref, **tols) - with te.fp8_autocast(): - fp8_meta_fwd = model.fp8_meta["scaling_fwd"] - fp8_meta_bwd = model.fp8_meta["scaling_bwd"] - fp8_meta_fwd_ref = fp8_meta_ref["scaling_fwd"] - fp8_meta_bwd_ref = fp8_meta_ref["scaling_bwd"] - torch.testing.assert_close(fp8_meta_fwd.scale, fp8_meta_fwd_ref["scale"]) - torch.testing.assert_close(fp8_meta_fwd.scale_inv, fp8_meta_fwd_ref["scale_inv"]) - torch.testing.assert_close(fp8_meta_bwd.scale, fp8_meta_bwd_ref["scale"]) - torch.testing.assert_close(fp8_meta_bwd.scale_inv, fp8_meta_bwd_ref["scale_inv"]) - with te.fp8_autocast(): - y = model(x.detach().clone()) - torch.testing.assert_close(y, y_ref, **tols) - - if load_fp8_model: - # [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ] - # We need to ensure that the tensor's scale_inv parameter matches its meta data. - # This is crucial to avoid confusion about which value is correct. - meta_index = model.weight._fp8_meta_index - torch.testing.assert_close( - model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item() - ) - - -@pytest.mark.parametrize("fp8", (False, True)) -@pytest.mark.parametrize("save_fp8_model", (False, True)) -@pytest.mark.parametrize("load_fp8_model", (False, True)) -def test_sequential_model( - *, - in_shape: Iterable[int] = (16, 16), - dtype: torch.dtype = torch.float32, - device: torch.device = "cuda", - save_steps: int = 2, - load_steps: int = 2, - fp8: bool, - save_fp8_model: bool, - load_fp8_model: bool, -) -> None: - - # Skip invalid configurations - if fp8 or save_fp8_model or load_fp8_model: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if torch.device(device).type != "cuda": - pytest.skip("FP8 is only supported on CUDA devices") - - # FP8 recipe - margin = 2 - fp8_format = transformer_engine.common.recipe.Format.E4M3 - recipe = transformer_engine.common.recipe.DelayedScaling( - margin=margin, - fp8_format=fp8_format, - amax_history_len=8, - amax_compute_algo="max", - ) - - # Construct model to save to checkpoint - with te.fp8_model_init(enabled=save_fp8_model): - model = te_ops.Sequential( - te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), - ) - with torch.no_grad(): - torch.rand(model[0].weight.size(), out=model[0].weight) - torch.rand(model[0].bias.size(), out=model[0].bias) - - # Synthetic data - xs_ref = [ - torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) - ] - dys_ref = [ - torch.rand(in_shape, dtype=dtype, device=device) for _ in range(save_steps + load_steps) - ] - - def train_step( - model: te_ops.Sequential, - x: torch.Tensor, - dy: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Helper function to perform training step""" - x = x.detach().clone().requires_grad_() - dy = dy.detach().clone() - with te.fp8_autocast(enabled=fp8, fp8_recipe=recipe): - y = model(x) - y.backward(dy) - with torch.no_grad(): - for param in model.parameters(): - param += 0.125 - return ( - y.detach().clone(), - x.grad.detach().clone(), - model[0].weight.detach().float().clone(), - ) - - # Initial training steps with saved model - ys_ref = [] - dxs_ref = [] - ws_ref = [] - for step in range(save_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - ys_ref.append(y) - dxs_ref.append(dx) - ws_ref.append(w) - - # Keep track of FP8 metadata if needed - fp8_meta_ref = dict(input={}, param={}, grad_output={}) - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - m_ref["amax"] = m_model.amax_history.detach().clone() - m_ref["scale"] = m_model.scale.detach().clone() - m_ref["scale_inv"] = m_model.scale_inv.detach().clone() - del m_model, m_ref - - # Save checkpoint - byte_stream = io.BytesIO() - torch.save(model.state_dict(), byte_stream) - model_bytes = byte_stream.getvalue() - del byte_stream - - # More training steps with saved model - for step in range(save_steps, save_steps + load_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - ys_ref.append(y) - dxs_ref.append(dx) - ws_ref.append(w) - - # Disturb and destroy model - with torch.no_grad(): - for param in model.parameters(): - param.zero_() - model[0].basic_ops[0]._fp8_metas = None - del model - - # Construct new model to load from checkpoint - with te.fp8_model_init(enabled=load_fp8_model): - model = te_ops.Sequential( - te_ops.Linear(in_shape[-1], in_shape[-1], device=device, dtype=dtype), - ) - - # Tolerances for numerical checks - tols = {} - if fp8 or save_fp8_model or load_fp8_model: - tols = dict(rtol=0.125, atol=0.0675) # fp8e4me3 epsilon = 0.0625 - exact_tols = dict(rtol=0, atol=0) - - # Training steps with dummy data - for step in range(save_steps): - y, dx, w = train_step( - model, - torch.zeros_like(xs_ref[step]), - torch.zeros_like(dys_ref[step]), - ) - - # Make sure results don't match saved model - with pytest.raises(AssertionError): - torch.testing.assert_close(y, ys_ref[step], **tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(dx, dxs_ref[step], **tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(w, ws_ref[step], **tols) - - # Make sure new model's FP8 metadata doesn't match saved model - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) - with pytest.raises(AssertionError): - torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) - - # Load checkpoint - model.load_state_dict(torch.load(io.BytesIO(model_bytes))) - del model_bytes - - # Check that new model's FP8 metadata matches saved model - if fp8: - for fp8_meta_type, fp8_meta_key in ( - ("input", "scaling_fwd"), - ("param", "scaling_fwd"), - ("grad_output", "scaling_bwd"), - ): - m_model = model[0].basic_ops[0].get_fp8_meta(fp8_meta_type)[fp8_meta_key] - m_ref = fp8_meta_ref[fp8_meta_type] - torch.testing.assert_close(m_model.amax_history, m_ref["amax"], **exact_tols) - torch.testing.assert_close(m_model.scale, m_ref["scale"], **exact_tols) - torch.testing.assert_close(m_model.scale_inv, m_ref["scale_inv"], **exact_tols) - - # More training steps with loaded model - for step in range(save_steps, save_steps + load_steps): - y, dx, w = train_step(model, xs_ref[step], dys_ref[step]) - torch.testing.assert_close(y, ys_ref[step], **tols) - torch.testing.assert_close(dx, dxs_ref[step], **tols) - torch.testing.assert_close(w, ws_ref[step], **tols) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index a8b181a187..450c24da33 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index b18ded9775..8b80364a3d 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -19,19 +19,9 @@ except (ImportError, StopIteration) as e: pass -try: - from . import paddle -except (ImportError, StopIteration) as e: - pass - try: import transformer_engine_jax except ImportError: pass -try: - import transformer_engine_paddle -except ImportError: - pass - __version__ = str(metadata.version("transformer_engine")) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3784689f9a..007618ad57 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,13 +6,17 @@ cmake_minimum_required(VERSION 3.21) # Language options if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + endif() endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) if (CMAKE_BUILD_TYPE STREQUAL "Debug") - set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") + set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") endif() # Hide non-necessary symbols in shared object. @@ -30,14 +34,14 @@ endif() # cuDNN frontend API set(CUDNN_FRONTEND_INCLUDE_DIR - "${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include") if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") message(FATAL_ERROR - "Could not find cuDNN frontend API. " + "Could not find cuDNN frontend API at ${CUDNN_FRONTEND_INCLUDE_DIR}. " "Try running 'git submodule update --init --recursive' " "within the Transformer Engine source.") endif() -include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -63,23 +67,26 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn.cpp fused_attn/utils.cu gemm/cublaslt_gemm.cu - layer_norm/ln_api.cpp - layer_norm/ln_bwd_semi_cuda_kernel.cu - layer_norm/ln_fwd_cuda_kernel.cu + normalization/common.cpp + normalization/layernorm/ln_api.cpp + normalization/layernorm/ln_bwd_semi_cuda_kernel.cu + normalization/layernorm/ln_fwd_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_api.cpp + normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu + normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu - rmsnorm/rmsnorm_api.cpp - rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu - rmsnorm/rmsnorm_fwd_cuda_kernel.cu util/cast.cu util/padding.cu util/cuda_driver.cpp + util/cuda_nvml.cpp util/cuda_runtime.cpp util/rtc.cpp - util/system.cpp + swizzle/swizzle.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu fused_rope/fused_rope.cu + recipe/current_scaling.cu recipe/delayed_scaling.cu comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -145,6 +152,14 @@ set_source_files_properties(fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") +option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) +if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) + set_source_files_properties(activation/gelu.cu + activation/relu.cu + activation/swiglu.cu + PROPERTIES + COMPILE_OPTIONS "--use_fast_math") +endif() set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 4bcd1f8e27..a8c845efd8 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -83,6 +83,13 @@ def _load_library(): """Load shared library with Transformer Engine C extensions""" so_path = get_te_path() / "transformer_engine" / f"libtransformer_engine.{_get_sys_extension()}" + if not so_path.exists(): + so_path = ( + get_te_path() + / "transformer_engine" + / "wheel_lib" + / f"libtransformer_engine.{_get_sys_extension()}" + ) if not so_path.exists(): so_path = get_te_path() / f"libtransformer_engine.{_get_sys_extension()}" assert so_path.exists(), f"Could not find libtransformer_engine.{_get_sys_extension()}" diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 6184e235bd..708403f911 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -1,114 +1,74 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +/*! \file activation_template.h + * \brief Activation functions template. + */ + +#ifndef TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ +#define TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ + #include #include #include "../common.h" +#include "../util/cast_gated_kernels.cuh" +#include "../util/cast_kernels.cuh" +#include "../util/math.h" #include "../util/vectorized_pointwise.h" namespace transformer_engine { template -void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "act_lu_input"); - CheckOutputTensor(*output, "act_lu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - const size_t tot_elts = product(input.data.shape); +void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = true; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); } template -void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "dact_lu_input"); - CheckInputTensor(grad, "dact_lu_input_grad"); - CheckOutputTensor(*output, "dact_lu_output"); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); - NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match."); - const size_t tot_elts = product(input.data.shape); +void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), tot_elts, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); } -template -void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(input.data.shape[0] == output->data.shape[0], - "Input shape[0] must be equal to output shape[0]."); - NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, - "Input shape[1] must be 2x larger than output shape[1]."); +template +void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = false; + constexpr NVTETensor grad = nullptr; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], - output->data.shape[1], {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_gated_helper(grad, input, output, stream); } -template -void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(grad.data.shape.size() == 2, "Grad must have 2 dimensions."); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); - NVTE_CHECK(output->data.shape[0] == grad.data.shape[0], - "Output shape[0] must be equal to grad shape[0]."); - NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2, - "Output shape[1] must be 2x larger than grad shape[1]."); - NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match."); +template +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + using namespace detail; + constexpr bool IS_DGATED = true; - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), grad.data.shape[0], grad.data.shape[1], - {}, - stream);); // NOLINT(*) - ); // NOLINT(*) + quantize_gated_helper(grad, input, output, stream); } } // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_ACTIVATION_TEMPLATE_H_ diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index f9cd7b845a..0cf43007a7 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -1,71 +1,60 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ + #include "../util/math.h" #include "./activation_template.h" void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_gelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dgelu>(grad, input, output, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dqgelu>(grad, input, output, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index c18d018a8e..a794b7315f 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,63 +10,51 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_relu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_drelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, drelu>(grad, input, output, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_srelu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsrelu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dsrelu>(grad, input, output, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index c745ffeeb4..8194964745 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,31 +10,25 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_silu); using namespace transformer_engine; - act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + act_fn>(input, output, stream); } void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsilu); using namespace transformer_engine; - dact_fn>(*reinterpret_cast(grad), - *reinterpret_cast(input), - reinterpret_cast(output), stream); + dact_fn>(grad, input, output, stream); } void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(*reinterpret_cast(input), - reinterpret_cast(output), stream); + gated_act_fn>(input, output, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>( - *reinterpret_cast(grad), *reinterpret_cast(input), - reinterpret_cast(output), stream); + dgated_act_fn, dsilu>(grad, input, output, stream); } diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index a663385b68..3dd5f7228b 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -21,6 +21,8 @@ #define HALF_BYTES 2 #define UB_MAX_SM 32 +#define AS_VECTOR(shape) std::vector(shape.data, shape.data + shape.ndim) + using namespace std::placeholders; namespace transformer_engine { @@ -40,8 +42,9 @@ bool ubuf_built_with_mpi() { CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm) { + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { // Initialize userbuf communicator if (!_comm_created) { if (myrank == 0) { @@ -59,9 +62,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; + if (gemm_priority == 0 && comm_priority == 0) { + transformer_engine::cuda::stream_priority_range(&_gemm_priority, &_comm_priority); + } else { + _gemm_priority = gemm_priority; + _comm_priority = comm_priority; + } for (int i = 0; i < std::min(num_max_streams, num_splits); i++) { cudaStream_t stream; - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _gemm_priority)); _stream_compute.push_back(std::move(stream)); } @@ -90,6 +99,23 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl cudaEventCreateWithFlags(&_stop_compute, 0); cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_comm, 0); + + /* + Defining the launcher order between the communication and GEMM kernels + using Fast Dependent Launch when CUDA_DEVICE_MAX_CONNECTIONS>1. + The event is used to schedule the communication kernel before the GEMM. + This is needed only for Hopper, which uses persistent CTA execution. + */ + int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); + int runtime_version = 0; + cudaRuntimeGetVersion(&runtime_version); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { + cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + } else { + _comm_launch_event = 0; + } } CommOverlapCore::~CommOverlapCore() { @@ -97,6 +123,7 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); + if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); if (_atomic_gemm) cudaFree(_counter.dptr()); @@ -112,6 +139,73 @@ CommOverlapCore::~CommOverlapCore() { } } +TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, size_t chunk_offset, + const std::vector &chunk_shape) { + TensorWrapper chunk; + for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { + auto param_type = static_cast(param_id); + auto param = source.get_parameter(param_type); + auto param_dptr = reinterpret_cast(param.data_ptr); + auto param_dtype = static_cast(param.dtype); + auto param_shape = AS_VECTOR(param.shape); + + if (param_dptr != nullptr) { + if (param_type == NVTETensorParam::kNVTERowwiseData || + param_type == NVTETensorParam::kNVTEColumnwiseData) { + // Offset data pointer + param_dptr += chunk_offset * typeToSize(param_dtype); + param_shape = chunk_shape; + + if (param_type == NVTETensorParam::kNVTEColumnwiseData && + source.scaling_mode() != NVTEScalingMode::NVTE_MXFP8_1D_SCALING) { + // Columnwise shape for non-block scaled tensors shifts the last dimension to the front + auto last_dim = param_shape.back(); + param_shape.pop_back(); + param_shape.insert(param_shape.begin(), last_dim); + } + } else if (source.scaling_mode() == NVTEScalingMode::NVTE_MXFP8_1D_SCALING && + (param_type == NVTETensorParam::kNVTERowwiseScaleInv || + param_type == NVTETensorParam::kNVTEColumnwiseScaleInv)) { + // Calculate block scaling offset and size + auto scaled_tensor_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? source.shape().data[0] + : source.columnwise_shape().data[0]; + auto scaled_chunk_dim_size = (param_type == NVTETensorParam::kNVTERowwiseScaleInv) + ? chunk_shape.front() + : chunk_shape.back(); + auto chunk_scale_start = chunk_offset / 32; + auto chunk_scale_end = (chunk_offset + scaled_chunk_dim_size) / 32; + auto chunk_scale_size = chunk_scale_end - chunk_scale_start; + param_dptr += chunk_scale_start * typeToSize(param_dtype); + param_shape = std::vector{chunk_scale_size}; + } + + // Set chunked source parameters into the chunked tensor output + chunk.set_parameter(param_type, reinterpret_cast(param_dptr), param_dtype, + param_shape); + } + } + return chunk; +} + +TensorWrapper CommOverlapCore::get_buffer_chunk_like(const TensorWrapper &source, + size_t chunk_offset, + const std::vector &chunk_shape) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, chunk_offset, chunk_shape); + + // Update chunk with offset data pointers from the communication buffer + auto ubuf_ptr = reinterpret_cast(_ubuf.dptr()) + (chunk_offset * _ubuf.element_size()); + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(reinterpret_cast(ubuf_ptr), chunk.dtype(), + chunk.columnwise_shape()); + } + return chunk; +} + /*************************************************************************************************** * Comm+GEMM Overlap Base (Pipelined / Collective) **************************************************************************************************/ @@ -120,11 +214,14 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, - num_comm_sm, set_sm_margin, false, atomic_gemm) { + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, + atomic_gemm) { + _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, "Invalid choice for NVTE_RS_STRIDED_ATOMIC: Must be 0 (non-atomic), 1 (atomic) ", @@ -137,7 +234,8 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, -1)); + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_comm, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_d2dcopy, 0)); } @@ -150,8 +248,8 @@ CommOverlapBase::~CommOverlapBase() { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ -void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, @@ -164,11 +262,13 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _start_comm, 0)); // Communication: AG and RS int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size if (comm_type == CommOverlapType::AG) { - allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { if (_ubuf.element_size() == 1) { assert(_ubuf_scale_inv_initialized); @@ -177,40 +277,48 @@ void CommOverlapBase::bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper assert(rs_output.size(0) == _ubuf.size(0) / _tp_size); assert(rs_output.element_size() == 2); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf_scale_inv, _ub_reg, 0, - comm_elements, _ub_comm, _stream_comm); + reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(rs_output_ptr, _ubuf.scale_inv(), _ub_reg, 0, + comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } else { - reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm); + reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, _stream_comm, + (cudaEvent_t)_comm_launch_event); } } assert(pre_gelu_out.numel() == 0); + // When the kernel launch order is defined, enforce the GEMM kernel launch to wait for the communication kernel launch + if (_comm_launch_event) + NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _comm_launch_event, 0)); nvte_cublas_gemm(A.data(), B.data(), D.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, _math_sms, - stream_main); + _stream_compute[0]); _ub_comm->sms = ori_sms; NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_compute[0])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); + } // CommOverlapBase::bulk_overlap /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, +void CommOverlapBase::atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; // Get GEMM dimensions - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); @@ -231,9 +339,8 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens assert(pre_gelu_out.numel() == 0); - auto output_d = TensorWrapper(_ubuf.dptr(), {n, m}, D.dtype(), D.amax(), D.scale(), nullptr); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, {n, m}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _num_splits, 0, true, _counter.data(), @@ -245,11 +352,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_atomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, m, _num_splits, &counter_ptr[i], _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_atomic(rs_output_ptr, _ub_reg, i * m_chunk, m_chunk, n, m, @@ -258,11 +364,10 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens } } else if (_rs_kernel_type == 2) { if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_strided_multiatomic_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, + rs_output_ptr, D.scale_inv(), _ub_reg, m_chunk, m_chunk, n, m, m, _num_splits, counter_ptr, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_strided_multiatomic(rs_output_ptr, _ub_reg, m_chunk, m_chunk, n, m, @@ -275,7 +380,7 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens if (_ubuf.element_size() == 1) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, _ubuf_scale_inv, + reducescatter2_userbuff_stridedoutput_fp8(rs_output_ptr, D.scale_inv(), _ub_reg, i * m_chunk, m_chunk, n, m, _ub_comm, _stream_comm);); } else { @@ -297,34 +402,24 @@ void CommOverlapBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, Tens /* ** Split FPROP GEMM + ReduceScatter */ -void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, +void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main) { + TensorWrapper &rs_output, cudaStream_t stream_main) { // Get GEMM dimensions int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t m = A.size(0); - size_t k = A.size(1); - size_t n = B.size(0); + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n = _ubuf.size(0); size_t m_chunk = m / _num_splits; size_t input_a_chunk_size = m_chunk * k; size_t output_chunk_size = n * m_chunk; - size_t bias_chunk_size = m_chunk; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - // Get input, output, and workspace data pointers - char *input_a_chunk_ptr = reinterpret_cast(A.dptr()); - char *output_buf_chunk_ptr = reinterpret_cast(_ubuf.dptr()); - char *bias_chunk_ptr = reinterpret_cast(bias.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); - - char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); - // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); for (size_t i = 0; i < _stream_compute.size(); i++) { @@ -334,39 +429,23 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap assert(pre_gelu_out.numel() == 0); - if (gemm_overlap) { - auto input_a_chunk = - TensorWrapper(A.dptr(), {m_chunk, k}, A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = - TensorWrapper(_ubuf.dptr(), {m, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto bias_chunk = - TensorWrapper(bias.dptr(), {m_chunk}, bias.dtype(), nullptr, nullptr, nullptr); - auto workspace_chunk = TensorWrapper( - workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); + if (_rs_overlap_first_gemm) { + auto input_a_chunk = get_tensor_chunk(A, 0, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, 0, {m, m_chunk}); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[0]); for (int i = 1; i < _num_splits; i++) { - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * D.element_size(); - if (bias_chunk_ptr != nullptr) { - bias_chunk_ptr += bias_chunk_size * bias.element_size(); - } - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), {n, m_chunk}, - D.dtype(), D.amax(), D.scale(), nullptr); - bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, bias.dtype(), - nullptr, nullptr, nullptr); - workspace_chunk = TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -377,11 +456,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Communication chunk if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, (i - 1) * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size, @@ -398,12 +476,11 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap // Last communication chunk with max SM _ub_comm->sms = UB_MAX_SM; if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, (_num_splits - 1) * output_chunk_size, - m_chunk, n, m, _ub_comm, _stream_comm);); + rs_output_ptr, D.scale_inv(), _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, + n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (_num_splits - 1) * output_chunk_size, m_chunk, n, m, @@ -411,20 +488,12 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } } else { for (int i = 0; i < _num_splits; i++) { - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - - auto input_a_chunk = TensorWrapper(reinterpret_cast(input_a_chunk_ptr), {m_chunk, k}, - A.dtype(), nullptr, nullptr, A.scale_inv()); - auto output_chunk = TensorWrapper(reinterpret_cast(output_buf_chunk_ptr), - {n, m_chunk}, D.dtype(), D.amax(), D.scale(), nullptr); - auto bias_chunk = TensorWrapper(reinterpret_cast(bias_chunk_ptr), {m_chunk}, - bias.dtype(), nullptr, nullptr, nullptr); - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); - - nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias_chunk.data(), + auto input_a_chunk = get_tensor_chunk(A, i * input_a_chunk_size, {m_chunk, k}); + auto output_chunk = get_buffer_chunk_like(D, i * output_chunk_size, {n, m_chunk}); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); + + nvte_cublas_gemm(input_a_chunk.data(), B.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); @@ -437,11 +506,10 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap _ub_comm->sms = UB_MAX_SM; } if (_ubuf.element_size() == 1) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, reducescatter2_userbuff_stridedoutput_fp8( - rs_output_ptr, _ubuf_scale_inv, _ub_reg, i * output_chunk_size, m_chunk, n, m, + rs_output_ptr, D.scale_inv(), _ub_reg, i * output_chunk_size, m_chunk, n, m, _ub_comm, _stream_comm);); } else { reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, @@ -449,11 +517,6 @@ void CommOverlapBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrap } rs_output_ptr += m_chunk * rs_output.element_size(); - input_a_chunk_ptr += input_a_chunk_size * B.element_size(); - output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); - if (bias_chunk_ptr != nullptr) { - bias_chunk_ptr += bias_chunk_size * bias.element_size(); - } } } @@ -475,11 +538,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool use_ce, bool atomic_gemm, bool aggregate) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm, bool aggregate) : CommOverlapCore(myrank, numranks, mylocal, numlocal, mynode, numnodes, tp_size, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, - num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -528,8 +593,13 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_send, cudaStreamNonBlocking, -1)); - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, -1)); + for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + cudaStream_t stream; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); + _stream_send.push_back(std::move(stream)); + } + NVTE_CHECK_CUDA( + cudaStreamCreateWithPriority(&_stream_recv, cudaStreamNonBlocking, _comm_priority)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_send, 0)); NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_recv, 0)); } @@ -538,7 +608,22 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); cudaStreamDestroy(_stream_recv); - cudaStreamDestroy(_stream_send); + for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); +} + +TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, + size_t chunk_id) { + // Start with a chunk of the source tensor + auto chunk = get_tensor_chunk(source, 0, AS_VECTOR(_ubufs[chunk_id].shape())); + + // Update chunk with offset data pointers from the communication buffer + if (chunk.dptr() != nullptr) { + chunk.set_rowwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.shape()); + } + if (chunk.columnwise_dptr() != nullptr) { + chunk.set_columnwise_data(_ubufs[chunk_id].dptr(), chunk.dtype(), chunk.columnwise_shape()); + } + return chunk; } /* @@ -546,12 +631,10 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_ag( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -559,8 +642,7 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T // Get GEMM dimensions between TN and NN input layouts const size_t m = (transa) ? A.size(0) : A.size(1); - const size_t n = _ubuf.size(0); - const size_t n_chunk = n / _tp_size; + const size_t n_chunk = _ubufs[0].size(0); assert(pre_gelu_out.numel() == 0); // Get communication and GEMM output chunk sizes @@ -570,7 +652,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T void *D_buffer_ptr; int D_chunk_bytes = n_chunk * m * D.element_size(); NVTE_CHECK_CUDA(cudaMallocAsync(&D_buffer_ptr, (_tp_size + 1) * D_chunk_bytes, stream_main)); - auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); + auto D_buffer = TensorWrapper(D_buffer_ptr, D.shape(), D.dtype(), D.amax(), D.scale(), + D.scale_inv(), D.scale_inv_shape(), D.scaling_mode()); // Reset atomic counters int *counter_ptr = reinterpret_cast(_counter.dptr()); @@ -578,13 +661,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T // Catch up the default torch stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); - auto input_b = TensorWrapper(_ubuf.dptr(), B.shape(), B.dtype(), nullptr, nullptr, B.scale_inv()); + auto input_b = get_buffer_chunk_like(B, 0, AS_VECTOR(B.shape())); size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.dptr(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto workspace_chunk = get_tensor_chunk(workspace, 0, {workspace_size_chunk}); for (int i = 0; i < _tp_size - 1; i++) { // Set the userbuffer id. Buffer under send is the input for the current @@ -625,8 +707,8 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T NVTE_CHECK_CUDA( cudaMemcpyAsync(B_copy.dptr(), _ubufs[_self_chunk_id].dptr(), _ubufs[_self_chunk_id].numel() * _ubufs[_self_chunk_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); } @@ -650,11 +732,12 @@ void CommOverlapP2PBase::atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, T ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ -void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &B_copy, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -667,24 +750,20 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW // Get communication and GEMM output chunk sizes const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); const bool do_gelu = pre_gelu_out.numel() > 0; - const int output_chunk_bytes = (n_chunk * m) * D.element_size(); - const int aux_chunk_bytes = do_gelu ? (n_chunk * m) * pre_gelu_out.element_size() : 0; - - // Get output and workspace data pointers - char *output_ptr = reinterpret_cast(D.dptr()); - char *pre_gelu_out_ptr = reinterpret_cast(pre_gelu_out.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); } if (_aggregate) { const int num_steps = _tp_size / 2; - char *input_b_ptr = reinterpret_cast(_ubuf.dptr()); + input_chunk_size *= 2; + output_chunk_size *= 2; // Initial 1X input chunk exchange between neighboring peers int send_chunk_id = _tp_id; @@ -693,11 +772,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW int recv_offset = comm_bytes * recv_chunk_id; int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank; userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank, - _stream_send); + _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[0], _stop_recv, 0)); int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1; @@ -712,27 +791,15 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW recv_offset = comm_bytes * recv_chunk_id; // GEMM - char *input_b_chunk_ptr = input_b_ptr + send_offset; auto input_b_chunk = - TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk * 2, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), - {n_chunk * 2, m}, D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = - (do_gelu) ? std::vector{n_chunk * 2, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk * 2, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk * 2, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk * 2, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -742,11 +809,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW if (i < num_steps - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm, - next_rank, _stream_send); + next_rank, _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm, prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { @@ -754,7 +821,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); } } } else { @@ -769,24 +836,14 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW int recv_offset = comm_bytes * recv_chunk_id; // GEMM - auto input_b_chunk = TensorWrapper(_ubufs[send_chunk_id].dptr(), {n_chunk, k}, B.dtype(), - nullptr, nullptr, B.scale_inv()); - - char *output_chunk_ptr = output_ptr + (send_chunk_id * output_chunk_bytes); - auto output_chunk = TensorWrapper(reinterpret_cast(output_chunk_ptr), {n_chunk, m}, - D.dtype(), D.amax(), D.scale(), nullptr); - - char *aux_chunk_ptr = - (do_gelu) ? pre_gelu_out_ptr + (send_chunk_id * aux_chunk_bytes) : nullptr; - auto aux_chunk_shape = (do_gelu) ? std::vector{n_chunk, m} : std::vector{0}; - auto aux_chunk = TensorWrapper(reinterpret_cast(aux_chunk_ptr), aux_chunk_shape, - pre_gelu_out.dtype()); - - char *workspace_chunk_ptr = - workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; - auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + auto input_b_chunk = get_buffer_chunk_like(B, input_chunk_size * send_chunk_id, {n_chunk, k}); + auto output_chunk = get_tensor_chunk(D, output_chunk_size * send_chunk_id, {n_chunk, m}); + auto aux_chunk = + (do_gelu) + ? get_tensor_chunk(pre_gelu_out, output_chunk_size * send_chunk_id, {n_chunk, k}) + : TensorWrapper(nullptr, std::vector{0}, pre_gelu_out.dtype()); + auto workspace_chunk = get_tensor_chunk( + workspace, (i % _stream_compute.size()) * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), aux_chunk.data(), transa, transb, grad, workspace_chunk.data(), accumulate, @@ -796,11 +853,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW if (i < _tp_size - 1) { // P2P communication userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, - _next_rank, _stream_send); + _next_rank, _stream_send[0]); userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, _prev_rank, _stream_recv); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _stop_recv, 0)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); } else if (B_copy.numel() > 0) { @@ -808,7 +865,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), _ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(), - cudaMemcpyDeviceToDevice, _stream_send)); + cudaMemcpyDeviceToDevice, _stream_send[0])); } } } @@ -818,7 +875,7 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[0])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); @@ -827,13 +884,11 @@ void CommOverlapP2PBase::split_overlap_ag(TensorWrapper &A, bool transa, TensorW /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, - bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, - cudaStream_t stream_main) { +void CommOverlapP2PBase::atomic_gemm_overlap_rs( + const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; @@ -852,14 +907,10 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T // Atomic GEMM // Process GEMM chunks in the order that AG+GEMM places the output chunks. - auto output_d = TensorWrapper(_ubuf.dptr(), D.shape(), D.dtype(), D.amax(), D.scale(), nullptr); - size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); - auto workspace_chunk = - TensorWrapper(workspace.data(), std::vector{workspace_size_chunk}, workspace.dtype()); + auto output_d = get_buffer_chunk_like(D, 0, AS_VECTOR(D.shape())); nvte_cublas_atomic_gemm(A.data(), B.data(), output_d.data(), bias.data(), pre_gelu_out.data(), - transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, 0, _tp_size, true, _counter.data(), - stream_main); + transa, transb, grad, workspace.data(), accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, true, _counter.data(), stream_main); // P2P communication chunk for (int i = 1; i < _tp_size; i++) { @@ -883,10 +934,9 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); @@ -897,31 +947,33 @@ void CommOverlapP2PBase::atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, T /* ** Split ReduceScatter + GEMM using P2P communication */ -void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, - bool transb, TensorWrapper &D, TensorWrapper &bias, - TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - TensorWrapper &rs_output, cudaStream_t stream_main) { +void CommOverlapP2PBase::split_overlap_rs(const TensorWrapper &A, bool transa, + const TensorWrapper &B, bool transb, TensorWrapper &D, + TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) { int ori_sms = _ub_comm->sms; _ub_comm->use_ce = _use_ce; _ub_comm->sms = _num_comm_sm; _ub_comm->cga_size = _cga_size; - size_t k = A.size(1); - size_t n = B.size(0); // Get communication and GEMM input chunk sizes - size_t n_chunk = n / _tp_size; + size_t m = transa ? A.size(0) : A.size(1); + size_t k = transa ? A.size(1) : A.size(0); + size_t n_chunk = _ubufs[0].size(0); const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); - const int input_b_chunk_bytes = n_chunk * k * B.element_size(); // Get input and workspace data pointers - char *input_b_ptr = reinterpret_cast(B.dptr()); - char *workspace_ptr = reinterpret_cast(workspace.dptr()); + size_t input_chunk_size = n_chunk * k; + size_t output_chunk_size = n_chunk * m; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); // Catch up the main stream NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_compute, 0)); + for (size_t i = 0; i < _stream_send.size(); i++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[i], _start_compute, 0)); + } NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_compute[i], _start_compute, 0)); @@ -930,36 +982,30 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW // GEMM and send/recv chunks for (int i = 0; i < _tp_size; i++) { // GEMM chunk + int stream_id = i % _stream_compute.size(); int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; - char *input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); - - auto input_b_chunk = TensorWrapper(reinterpret_cast(input_b_chunk_ptr), {n_chunk, k}, - B.dtype(), nullptr, nullptr, B.scale_inv()); - auto output_chunk = - TensorWrapper(_ubufs[i].dptr(), _ubufs[i].shape(), D.dtype(), D.amax(), D.scale(), nullptr); - - char *workspace_chunk_ptr = workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk; + auto input_b_chunk = get_tensor_chunk(B, input_b_chunk_id * input_chunk_size, {n_chunk, k}); + auto output_chunk = get_buffer_chunk_by_id(D, i); auto workspace_chunk = - TensorWrapper(reinterpret_cast(workspace_chunk_ptr), - std::vector{workspace_size_chunk}, workspace.dtype()); + get_tensor_chunk(workspace, stream_id * workspace_size_chunk, {workspace_size_chunk}); nvte_cublas_gemm(A.data(), input_b_chunk.data(), output_chunk.data(), bias.data(), pre_gelu_out.data(), transa, transb, grad, workspace_chunk.data(), accumulate, - use_split_accumulator, _math_sms, _stream_compute[i % _stream_compute.size()]); + use_split_accumulator, _math_sms, _stream_compute[stream_id]); if (i > 0) { // P2P communication chunk + int prev_stream_id = (i - 1) % _stream_compute.size(); int send_offset = comm_bytes * (i - 1); int recv_offset = comm_bytes * (i - 1 + _tp_size); int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; - NVTE_CHECK_CUDA( - cudaEventRecord(_start_comm, _stream_compute[(i - 1) % _stream_compute.size()])); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send, _start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventRecord(_start_comm, _stream_compute[prev_stream_id])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[prev_stream_id], _start_comm, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_comm, 0)); userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, send_rank, - _stream_send); + _stream_send[prev_stream_id]); userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, recv_rank, _stream_recv); } @@ -969,8 +1015,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_compute, 0)); } - NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + for (size_t i = 0; i < _stream_compute.size(); i++) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_send, _stream_send[i])); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_send, 0)); + } NVTE_CHECK_CUDA(cudaEventRecord(_stop_recv, _stream_recv)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_recv, 0)); @@ -978,11 +1026,10 @@ void CommOverlapP2PBase::split_overlap_rs(TensorWrapper &A, bool transa, TensorW char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].dptr()); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { - assert(_ubuf_scale_inv_initialized); char *rs_output_ptr = reinterpret_cast(rs_output.dptr()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( D.dtype(), fp8_type, - reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, _ubuf_scale_inv, _tp_size, + reduce_fp8_in_bf16_out(reduce_buf_ptr, rs_output_ptr, D.scale_inv(), _tp_size, _ubufs[0].numel(), stream_main);); } else { reduce_bf16(reduce_buf_ptr, rs_output_ptr, _tp_size, _ubufs[0].numel(), stream_main); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc index 2fc6ffbdf9..71ea00de3a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.cc @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h index 979df384a8..aa6021a190 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/ipcsocket.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 6f3eef3d28..e52cdd8a1f 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -20,6 +20,7 @@ #include #include "common/util/cuda_driver.h" +#include "common/util/cuda_nvml.h" #include "common/util/cuda_runtime.h" #include "common/util/logging.h" #include "common/util/system.h" @@ -29,7 +30,6 @@ #ifdef NVTE_UB_WITH_MPI static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD; static MPI_Comm EXT_COMM_INTRA; -static MPI_Comm EXT_COMM_INTER; #define UB_MPI_CHECK(expr) \ do { \ @@ -58,11 +58,20 @@ void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } #else #define EXT_COMM_WORLD "world" #define EXT_COMM_INTRA "intra" -#define EXT_COMM_INTER "inter" #endif #define MULTICAST_GB_TOTAL 512 +#if CUDART_VERSION < 12030 +// MNNVL: FABRIC handle support lifted from CUDA 12.3 +#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL) +#define CU_IPC_HANDLE_SIZE 64 +typedef struct CUmemFabricHandle_st { + unsigned char data[CU_IPC_HANDLE_SIZE]; +} CUmemFabricHandle_v1; +typedef CUmemFabricHandle_v1 CUmemFabricHandle; +#endif + int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } #define IPCCHECK(cmd) \ @@ -82,18 +91,43 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co } \ } while (0); -int pipe_rank(communicator *comm, int step) { - int mynode = comm->myrank / comm->nvsize; - int mylocal = comm->nvrank; - int numlocal = comm->nvsize; - - int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize; - int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal; - int newnode = mynode; - newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes; - int allnodes = comm->nranks / comm->nvsize; - newnode = (allnodes + (newnode % allnodes)) % allnodes; - return newnode * numlocal + newlocal; +bool has_mnnvl_fabric(int device_id) { +#if CUDA_VERSION < 12040 + if (getenv("NVTE_UBDEBUG")) { + printf( + "TransformerEngine does not support multi-node NVLINK " + "since it was not built with CUDA version >= 12.4.\n"); + } + return false; +#else + bool mnnvl_fabric_support = false; + CUdevice dev; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id); + int fabric_handle_supported = 0; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &fabric_handle_supported, + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev); + if (fabric_handle_supported) { + NVTE_CALL_CHECK_CUDA_NVML(nvmlInit_v2); + nvmlDevice_t local_device; + NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetHandleByIndex_v2, device_id, &local_device); + nvmlGpuFabricInfoV_t fabricInfo = {}; + fabricInfo.version = nvmlGpuFabricInfo_v2; + fabricInfo.clusterUuid[0] = '\0'; + NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetGpuFabricInfoV, local_device, &fabricInfo); + NVTE_CALL_CHECK_CUDA_NVML(nvmlShutdown); + if (fabricInfo.state >= NVML_GPU_FABRIC_STATE_COMPLETED && fabricInfo.clusterUuid[0] != '\0') { + mnnvl_fabric_support = true; + } + } + if (getenv("NVTE_UBDEBUG")) { + if (mnnvl_fabric_support) { + printf("MNNVL NVLINK is supported on this platform.\n"); + } else { + printf("MNNVL NVLINK is not supported on this platform.\n"); + } + } + return mnnvl_fabric_support; +#endif } int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, @@ -122,10 +156,6 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, (*comm)->use_ce = 0; (*comm)->cga_size = 2; for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; - (*comm)->head = 0; - (*comm)->tail = 0; - (*comm)->active_nreqs = 0; - for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1; int device_clock = 0; // 110 sec wait time by default @@ -182,29 +212,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, // ar2 has step equal to ar_nvsize int allnodes = numranks / numlocal; int nodeid = myrank / numlocal; - int datanodes = allnodes / pipenodes / tensornodes; - int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes); - (*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus); - - (*comm)->comm_inter = EXT_COMM_INTER; - (*comm)->first_node = nodeid - mynode; (*comm)->num_nodes = numnodes; (*comm)->my_node = mynode; - (*comm)->num2_nodes = tensornodes; - (*comm)->my2_node = (mynode / datanodes) % tensornodes; - (*comm)->first2_node = mynode - (*comm)->my2_node * datanodes; - - (*comm)->fifo = reinterpret_cast(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS)); - (*comm)->nblocks = 8; - (*comm)->alignblock = 1024 * 512; - (*comm)->minblock = 1024 * 2 * 1024; - (*comm)->asyncblocks = 16; - #define NBUF 2 #if CUDART_VERSION >= 12010 + bool mnnvl_fabric = has_mnnvl_fabric(cur_dev); if (!transformer_engine::getenv("UB_SKIPMC") && transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { // multicast init only for TP ops (____2 operations) @@ -215,7 +230,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, CUmulticastObjectProp mcProp = {}; mcProp.numDevices = (*comm)->ar2_nvsize; mcProp.size = (*comm)->mc_maxsize; - mcProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + mcProp.handleTypes = + mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; NVTE_CALL_CHECK_CUDA_DRIVER( cuMulticastGetGranularity, &gran, &mcProp, @@ -223,46 +239,78 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran; mcProp.size = mc_maxsize; (*comm)->mc_maxsize = mc_maxsize; - - // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. - // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the - // file descriptor and prevent cuMemImportFromShareableHandle() from correctly - // interpreting the file. Instead, we use Unix domain sockets for the kernel to - // recreate the correct file descriptor on every receiving rank. - int fd; - volatile uint32_t abortFlag = 0; - IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu; - ipcSocketResult_t ret = ipcSocketSuccess; - IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); - (*comm)->_barrier((*comm)->comm_world); - - if ((*comm)->ar2_nvrank == 0) { + if ((*comm)->ar2_nvrank == 0) NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastCreate, &(*comm)->mc_handle, &mcProp); - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemExportToShareableHandle, reinterpret_cast(&fd), (*comm)->mc_handle, - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), - (uint64_t)0); - for (int p = 1; p < (*comm)->ar2_nvsize; p++) { - (*comm)->_barrier((*comm)->comm_intra); - IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + if (mnnvl_fabric) { + CUmemFabricHandle *exphndl = + reinterpret_cast(malloc(sizeof(CUmemFabricHandle))); + CUmemFabricHandle *tmphndl = + reinterpret_cast(malloc(sizeof(CUmemFabricHandle))); + CUmemFabricHandle *exphndls; + NVTE_CHECK_CUDA(cudaMallocHost(&exphndls, (*comm)->nvsize * sizeof(CUmemFabricHandle))); + if ((*comm)->ar2_nvrank == 0) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, static_cast(tmphndl), + (*comm)->mc_handle, CU_MEM_HANDLE_TYPE_FABRIC, 0); + for (int grp = 0; grp < (*comm)->ar_nvsize; + grp++) { // we do N broadcasts for N TP groups in NVL domain + int root = grp * (*comm)->ar2_nvsize; + + // It just needs to be a bcast but reuse existing allgather comm + (*comm)->_allgather( + reinterpret_cast(exphndls), (*comm)->nvsize * sizeof(CUmemFabricHandle), + reinterpret_cast(tmphndl), sizeof(CUmemFabricHandle), (*comm)->comm_intra); + + //save data if brodcast was from rank 0 in our group + if ((*comm)->ar2_firstgpu == root) + memcpy(exphndl, exphndls + root, sizeof(CUmemFabricHandle)); } + if ((*comm)->ar2_nvrank != 0) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &(*comm)->mc_handle, + reinterpret_cast(exphndl), CU_MEM_HANDLE_TYPE_FABRIC); + free(exphndl); + free(tmphndl); + NVTE_CHECK_CUDA(cudaFreeHost(exphndls)); } else { - for (int p = 1; p < (*comm)->ar2_nvsize; p++) { - (*comm)->_barrier((*comm)->comm_intra); - if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + // Broadcast the a POSIX file descriptor from the local root rank to other local ranks. + // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the + // file descriptor and prevent cuMemImportFromShareableHandle() from correctly + // interpreting the file. Instead, we use Unix domain sockets for the kernel to + // recreate the correct file descriptor on every receiving rank. + int fd; + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafe0000 + (*comm)->my_node; + ipcSocketResult_t ret = ipcSocketSuccess; + IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag)); + (*comm)->_barrier((*comm)->comm_world); + + if ((*comm)->ar2_nvrank == 0) { + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemExportToShareableHandle, reinterpret_cast(&fd), (*comm)->mc_handle, + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); + + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); + } + } else { + for (int p = 1; p < (*comm)->ar2_nvsize; p++) { + (*comm)->_barrier((*comm)->comm_intra); + if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); + } } - } - error: - if ((*comm)->ar2_nvrank != 0) { - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + error: + if ((*comm)->ar2_nvrank != 0) { + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast(fd), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + } + IPCCHECK(ipcSocketClose(&ipcSock)); + close(fd); } - IPCCHECK(ipcSocketClose(&ipcSock)); - close(fd); NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle, (CUdeviceptr)(*comm)->mydev); @@ -327,12 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, if (getenv("NVTE_UBDEBUG")) printf( - "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " - "%dx%d PIPE_ID %d/%d\n", + "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP x%d TPGROUP " + "%dx%d\n", myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, - (*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, - (*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, - pipegpus * pipenodes); + (*comm)->ar_nvrank, (*comm)->my_node, (*comm)->ar2_nvrank, (*comm)->ar_nvsize, + (*comm)->num_nodes, (*comm)->ar2_nvsize); fflush(NULL); return 0; @@ -361,43 +408,16 @@ int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipe UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks)); - // find intranode numbers and make internode communicator - char hostname[MPI_MAX_PROCESSOR_NAME]; - int namelen; - UB_MPI_CHECK(MPI_Get_processor_name(hostname, &namelen)); - - char(*hostnames)[MPI_MAX_PROCESSOR_NAME] = - static_cast(malloc(numranks * MPI_MAX_PROCESSOR_NAME)); - strcpy(hostnames[myrank], hostname); // NOLINT(*) - for (int n = 0; n < numranks; n++) - UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD)); - qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp); - - int color = 0; - for (int n = 0; n < numranks; n++) { - if (n > 0 && strcmp(hostnames[n - 1], hostnames[n])) color++; - if (strcmp(hostname, hostnames[n]) == 0) break; - } - free(hostnames); - int mylocal, numlocal; - UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, color, myrank, &EXT_COMM_INTRA)); + UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, myrank / tensorgpus, myrank, &EXT_COMM_INTRA)); UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal)); // find internode numbers and make internode communicator NVTE_CHECK_CUDA(cudaFree(0)); - int allnodes = numranks / numlocal; - int datanodes = allnodes / pipenodes / tensornodes; - // data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1 - int datanodegroup_id = myrank / numlocal / datanodes; - // mpi communicator only needed for SHARP which is always allreduce1/data-parallel - UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, myrank, - &EXT_COMM_INTER)); - // different rails from same group are in different subcommunicators int mynode, numnodes; - UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTER, &numnodes)); - UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTER, &mynode)); + mynode = myrank / numlocal; + numnodes = numranks / numlocal; // finally call the abstracted constructor with MPI info return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, @@ -447,13 +467,11 @@ void destroy_communicator(communicator *comm) { if (comm->use_mc) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle); } - free(comm->fifo); delete comm; } void destroy_communicator_mpi(communicator *comm) { #ifdef NVTE_UB_WITH_MPI - MPI_Comm_free(static_cast(&(comm->comm_inter))); MPI_Comm_free(static_cast(&(comm->comm_intra))); destroy_communicator(comm); #else @@ -472,6 +490,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * #if CUDART_VERSION >= 12010 if (comm->use_mc && alloc) { + bool mnnvl_fabric = has_mnnvl_fabric(comm->mydev); int nranks = comm->nvsize; // total GPUs in NVLINK domain int myrank = comm->nvrank; void **remptrs = reinterpret_cast(malloc(nranks * sizeof(void *))); @@ -481,7 +500,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.id = comm->mydev; prop.requestedHandleTypes = - CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; // CU_MEM_HANDLE_TYPE_FABRIC; + mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; size_t granularity = 0; NVTE_CALL_CHECK_CUDA_DRIVER( @@ -507,41 +526,58 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * NVTE_CALL_CHECK_CUDA_DRIVER(cuMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop, (uint64_t)0); - int *peerfd = reinterpret_cast(malloc(nranks * sizeof(int))); - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemExportToShareableHandle, reinterpret_cast(&peerfd[myrank]), - comm->uchandles[hndl][myrank], - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), - (uint64_t)0); - - volatile uint32_t abortFlag = 0; - IpcSocketHandle ipcSock = {0}; - uint64_t opId = 0xdeadcafebeef; - ipcSocketResult_t ret = ipcSocketSuccess; - - // All-gather POSIX file descriptors across local ranks - IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); - for (int p = 1; p < nranks; p++) { - int send_to = (myrank + p) % nranks; - int recv_from = (myrank + nranks - p) % nranks; - comm->_barrier(comm->comm_intra); - IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, error); - IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); - } + if (mnnvl_fabric) { + CUmemFabricHandle *exphndl; + CUmemFabricHandle myhndl; + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, &myhndl, + comm->uchandles[hndl][myrank], CU_MEM_HANDLE_TYPE_FABRIC, 0); + NVTE_CHECK_CUDA(cudaMallocHost(&exphndl, comm->nvsize * sizeof(CUmemFabricHandle))); + comm->_allgather(reinterpret_cast(exphndl), comm->nvsize * sizeof(CUmemFabricHandle), + reinterpret_cast(&myhndl), sizeof(CUmemFabricHandle), + comm->comm_intra); + for (int p = 0; p < nranks; p++) + if (p != myrank) + NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], + reinterpret_cast(&exphndl[p]), + CU_MEM_HANDLE_TYPE_FABRIC); + NVTE_CHECK_CUDA(cudaFreeHost(exphndl)); + } else { + int *peerfd = reinterpret_cast(malloc(nranks * sizeof(int))); + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemExportToShareableHandle, reinterpret_cast(&peerfd[myrank]), + comm->uchandles[hndl][myrank], + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), + (uint64_t)0); - error: - IPCCHECK(ipcSocketClose(&ipcSock)); + volatile uint32_t abortFlag = 0; + IpcSocketHandle ipcSock = {0}; + uint64_t opId = 0xdeadcafe0000 + comm->my_node; + ipcSocketResult_t ret = ipcSocketSuccess; + + // All-gather POSIX file descriptors across local ranks + IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); + for (int p = 1; p < nranks; p++) { + int send_to = (myrank + p) % nranks; + int recv_from = (myrank + nranks - p) % nranks; + comm->_barrier(comm->comm_intra); + IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, + error); + IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); + } - for (int p = 0; p < nranks; p++) { - if (p != myrank) - NVTE_CALL_CHECK_CUDA_DRIVER( - cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], - reinterpret_cast(peerfd[p]), - static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); - close(peerfd[p]); - } - free(peerfd); + error: + IPCCHECK(ipcSocketClose(&ipcSock)); + for (int p = 0; p < nranks; p++) { + if (p != myrank) + NVTE_CALL_CHECK_CUDA_DRIVER( + cuMemImportFromShareableHandle, &comm->uchandles[hndl][p], + reinterpret_cast(peerfd[p]), + static_cast(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); + close(peerfd[p]); + } + free(peerfd); + } CUdeviceptr ptr; NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks), (size_t)0, (CUdeviceptr)0, (uint64_t)0); @@ -571,13 +607,13 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * cudaMemcpy((reinterpret_cast(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)), remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice)); free(remptrs); - comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED; + comm->memflags[hndl] = NVTE_UB_MEM_UC_CONTIG | NVTE_UB_MEM_ALLOCATED; if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) { NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastBindMem, comm->mc_handle, comm->mc_offset, comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/, aligned_size, (uint64_t)0); - comm->memflags[hndl] |= UB_MEM_MC_CREATED; + comm->memflags[hndl] |= NVTE_UB_MEM_MC_CREATED; comm->mc_ptr[hndl] = reinterpret_cast(comm->mc_baseptr) + comm->mc_offset; comm->mc_offset += aligned_size; } else if (!comm->myrank) { diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 26843d8107..1211392e40 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -9,16 +9,17 @@ #include #if __CUDA_ARCH__ >= 800 -#include -#define half nv_bfloat16 +#define half_dtype nv_bfloat16 #else -#include +#define half_dtype half #endif #include #include #include +#include "common/util/system.h" +#include "common/util/vectorized_pointwise.h" #include "userbuffers.h" #define MAX_THREADS 1024 @@ -115,11 +116,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -199,11 +200,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -310,11 +311,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -377,11 +378,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -779,7 +780,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; int lastSM = 0; - half hscale = (half)*scale; + half_dtype hscale = (half_dtype)*scale; if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; @@ -822,13 +823,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 0; i < RANKS; i++) { fp8type *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half_dtype)(x[j]); } int hline = 2 * line; (reinterpret_cast(outbuf))[(hline / rowlines) * skiplines + (hline % rowlines)] = @@ -854,7 +855,7 @@ __global__ void __launch_bounds__(MAX_THREADS) int physgpu, targetgpu, *myptr; int *reduceidptr, reduce_id; int lastSM = 0; - half hscale = (half)*scale; + half_dtype hscale = (half_dtype)*scale; if (threadIdx.x < RANKS) { physgpu = myrank * gpustep + firstrank; @@ -918,13 +919,14 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum[2] = {{0, 0, 0, 0}, {0, 0, 0, 0}}; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 0; i < RANKS; i++) { fp8type *x = reinterpret_cast(&val[i]); #pragma unroll - for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) s[j] += hscale * (half)(x[j]); + for (int j = 0; j < sizeof(int4) / sizeof(fp8type); j++) + s[j] += hscale * (half_dtype)(x[j]); } (reinterpret_cast(outbuf))[index1_out] = sum[0]; (reinterpret_cast(outbuf))[index2_out] = sum[1]; @@ -987,11 +989,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_ } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -1077,11 +1079,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -1168,11 +1170,11 @@ __global__ void __launch_bounds__(MAX_THREADS) } int4 sum = val[0]; - half *s = reinterpret_cast(&sum); + half_dtype *s = reinterpret_cast(&sum); #pragma unroll for (int i = 1; i < RANKS; i++) { - half *x = reinterpret_cast(&val[i]); + half_dtype *x = reinterpret_cast(&val[i]); #pragma unroll for (int j = 0; j < 8; j++) s[j] += x[j]; } @@ -1366,6 +1368,28 @@ __global__ void __launch_bounds__(MAX_THREADS) cfg.attrs = attribute_ub; \ cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1; +#if (CUDART_VERSION >= 12030) +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[2].id = cudaLaunchAttributeLaunchCompletionEvent; \ + attribute_ub[2].val.launchCompletionEvent.event = comm_launch_event; +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 3 +#else +#define ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) +#define NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH 2 +#endif + +#define SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, threads, stream, comm_launch_event) \ + cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \ + cudaLaunchAttribute attribute_ub[NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH] = {}; \ + ADD_LAUNCH_COMPLETION_EVENT(attribute_ub, comm_launch_event) \ + attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \ + attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \ + attribute_ub[1].val.clusterDim.y = 1; \ + attribute_ub[1].val.clusterDim.z = 1; \ + attribute_ub[0].id = cudaLaunchAttributeCooperative; \ + cfg.attrs = attribute_ub; \ + cfg.numAttrs = NUM_LAUNCH_ATTRIBUTE_FOR_FDL_LAUNCH; + #define callranks_ag(x) \ if (ar_nvsize == x) { \ int arg1 = op - NVTE_MAX_OPS, \ @@ -1658,6 +1682,7 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) + callranks_rs_oop_stride(16) callranks_rs_oop_stride(32) } void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, const int rowelements, const int colelements, @@ -1679,7 +1704,8 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) - callranks_rs_oop_stride_atomic(8) + callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16) + callranks_rs_oop_stride_atomic(32) } template @@ -1705,6 +1731,7 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) + callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32) } template @@ -1749,11 +1776,13 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) - callranks_rs_oop_stride_multiatomic(8) + callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16) + callranks_rs_oop_stride_multiatomic(32) } void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1766,11 +1795,20 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) + } } else { - callranks_ag(2) callranks_ag(4) callranks_ag(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32) + } else { + callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32) + } } } @@ -1790,7 +1828,8 @@ void allgather2_userbuff_inplace_sliced(const int handler, const int offset, con } void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu; @@ -1803,17 +1842,26 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) + } } else { - callranks_rs(2) callranks_rs(4) callranks_rs(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32) + } else { + callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32) + } } } void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, communicator *comm, - cudaStream_t stream) { + cudaStream_t stream, cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1827,23 +1875,39 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { - callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) + callranks_rs_oopMC(32) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16) + callranks_rs_oop(32) + } } else { - callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) { + callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16) + callranks_rs_oopMC(32) + } else { + callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16) + callranks_rs_oop(32) + } } } void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream) { - reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream); + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { + reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream, + comm_launch_event); } template void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream) { + communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { const int elements = rowelements * colelements; const int op = userbuffers_allreduceop_nonsharp2; const int ar_firstgpu = @@ -1857,33 +1921,45 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int warps = comm->threads / 32; if (warps < ar_nvsize) warps = ar_nvsize; - SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); - callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) + if (comm_launch_event) { + SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) + callranks_rs_oop_fp8(32) + } else { + SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); + callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16) + callranks_rs_oop_fp8(32) + } } template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e5m2>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_stridedoutput_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, - const int colelements, const int strideelements, communicator *comm, cudaStream_t stream); + const int colelements, const int strideelements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream) { + const int elements, communicator *comm, cudaStream_t stream, + cudaEvent_t comm_launch_event) { reducescatter2_userbuff_stridedoutput_fp8(output, scale, handler, offset, elements, 1, 0, - comm, stream); + comm, stream, comm_launch_event); } template void reducescatter2_userbuff_fp8<__nv_fp8_e5m2>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_fp8<__nv_fp8_e4m3>(void *output, float *scale, const int handler, const int offset, const int elements, communicator *comm, - cudaStream_t stream); + cudaStream_t stream, + cudaEvent_t comm_launch_event); template void reducescatter2_userbuff_strided_atomic_fp8<__nv_fp8_e4m3>( void *output, float *scale, const int handler, const int offset, const int rowelements, @@ -2532,30 +2608,57 @@ void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); } -template +template __global__ void __launch_bounds__(MAX_THREADS / 4) reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, - const int num_inputs, const int input_size) { - const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + const int num_inputs, const int input_size, + const int num_aligned_elements_per_input, + const int tot_input_size) { fp8type *inputs_fp8 = reinterpret_cast(inputs); - float accum_buf = static_cast(inputs_fp8[tid]) * (*scale); + half_dtype *output_half = reinterpret_cast(output); + + transformer_engine::VectorizedLoader loader(inputs_fp8, tot_input_size); + transformer_engine::VectorizedStorer storer(output_half, input_size); + + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + if (tid >= num_aligned_elements_per_input) { + return; + } + float accum_buf[nvec]; // NOLINT(*) + + loader.load(tid, tot_input_size); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + accum_buf[i] = static_cast(loader.separate()[i]) * (*scale); + } + for (int input_id = 1; input_id < num_inputs; ++input_id) { + loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + accum_buf[i] += static_cast(loader.separate()[i]) * (*scale); + } + } #pragma unroll - for (int i = 1; i < num_inputs; i++) { - accum_buf += static_cast(inputs_fp8[tid + input_size * i]) * (*scale); + for (int i = 0; i < nvec; ++i) { + storer.separate()[i] = static_cast(accum_buf[i]); } - half *output_half = reinterpret_cast(output); - output_half[tid] = (half)accum_buf; + storer.store(tid, input_size); } template void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream) { + constexpr int nvec = 32; + assert(input_size % nvec == 0); + const int num_aligned_elements_per_input = input_size / nvec; + const int tot_input_size = input_size * num_inputs; size_t num_threads = MAX_THREADS / 4; - size_t num_blocks = (input_size + num_threads - 1) / num_threads; + size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads; dim3 block(num_threads); dim3 grid(num_blocks); - reduce_fp8_in_bf16_out_cuda - <<>>(inputs, output, scale, num_inputs, input_size); + reduce_fp8_in_bf16_out_cuda + <<>>(inputs, output, scale, num_inputs, input_size, + num_aligned_elements_per_input, tot_input_size); } template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, @@ -2565,23 +2668,50 @@ template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream); +template __global__ void __launch_bounds__(MAX_THREADS / 4) - reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size) { + reduce_bf16_cuda(void *inputs, void *output, const int num_inputs, const int input_size, + const int num_aligned_elements_per_input, const int tot_input_size) { + half_dtype *inputs_half = reinterpret_cast(inputs); + half_dtype *output_half = reinterpret_cast(output); + + transformer_engine::VectorizedLoader loader(inputs_half, tot_input_size); + transformer_engine::VectorizedStorer storer(output_half, input_size); + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; - half *inputs_half = reinterpret_cast(inputs); - float accum_buf = static_cast(inputs_half[tid]); + if (tid >= num_aligned_elements_per_input) { + return; + } + float accum_buf[nvec]; // NOLINT(*) + + loader.load(tid, tot_input_size); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + accum_buf[i] = static_cast(loader.separate()[i]); + } + for (int input_id = 1; input_id < num_inputs; ++input_id) { + loader.load(tid + num_aligned_elements_per_input * input_id, tot_input_size); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + accum_buf[i] += static_cast(loader.separate()[i]); + } + } #pragma unroll - for (int i = 1; i < num_inputs; i++) { - accum_buf += static_cast(inputs_half[tid + input_size * i]); + for (int i = 0; i < nvec; ++i) { + storer.separate()[i] = static_cast(accum_buf[i]); } - half *output_half = reinterpret_cast(output); - output_half[tid] = (half)accum_buf; + storer.store(tid, input_size); } void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cudaStream_t stream) { + constexpr int nvec = 32; + assert(input_size % nvec == 0); + const int num_aligned_elements_per_input = input_size / nvec; + const int tot_input_size = input_size * num_inputs; size_t num_threads = MAX_THREADS / 4; - size_t num_blocks = (input_size + num_threads - 1) / num_threads; + size_t num_blocks = (num_aligned_elements_per_input + num_threads - 1) / num_threads; dim3 block(num_threads); dim3 grid(num_blocks); - reduce_bf16_cuda<<>>(inputs, output, num_inputs, input_size); + reduce_bf16_cuda<<>>( + inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 57e68afce0..84defcdb23 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -34,11 +34,7 @@ using ExtBarrierOp = std::function; #define NVTE_MAX_REQUESTS 1024 #define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_CPU 2 -#define NVTE_MAX_NVLINK 8 - -#define UB_MEM_UC_CONTIG 1 -#define UB_MEM_MC_CREATED 2 -#define UB_MEM_ALLOCATED 4 +#define NVTE_MAX_NVLINK 32 #define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_MC_CREATED 2 @@ -124,11 +120,8 @@ struct communicator { ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup // (_splitar init used) would be equal to (nvsize,0) for regular comm_create int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step - int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size) int sm_arch; - int num_nodes, my_node, - first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes) - int num2_nodes, my2_node, first2_node; // with num_nodes as a stride + int num_nodes, my_node; // max value for running block counters in hostflags int basecounter[userbuffers_op_types]; // NOLINT(*) @@ -136,20 +129,11 @@ struct communicator { void *mem_mr[NVTE_MAX_REGIONS]; - ub_request *fifo; - int nblocks, alignblock, minblock, asyncblocks, active_nreqs; - ub_request active_req[userbuffers_op_types]; // NOLINT(*) - int padding[7]; - volatile int head; - int padding2[15]; - volatile int tail; - // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) ExtAllgatherOp _allgather; ExtBarrierOp _barrier; ExtComm comm_world; - ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail ExtComm comm_intra; // full intranode (all ndev GPUS) #ifdef NVTE_UB_WITH_MPI MPI_Request mpihndl[NVTE_MAX_SHARP]; @@ -199,11 +183,6 @@ void destroy_communicator_mpi(communicator *comm); returned offset is offset of gpubuff relative to buffer registered */ -int pipe_rank(communicator *comm, - int step); // helper function to help walk across allreduce1 x allreduce2 groups - // data-parallel and tensor-parallel position within data and tensor - // groups would be preserved - int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc); /* returns handler and registers buffers. assumed to be collective i.e. you use same groups and dont mix buffers for different operations returns -1 if cant register (too many preregistered @@ -213,7 +192,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * // for TP-parallelism, only single node is implemented void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); /* each Rank input is allgather2_userbuff_inplace: offset+myrank*elements @@ -228,21 +208,26 @@ for(int slice=0;slice void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, const int colelements, const int strideelements, - communicator *comm, cudaStream_t stream = 0); + communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_fp8(void *output, float *scale, const int handler, const int offset, - const int elements, communicator *comm, cudaStream_t stream = 0); + const int elements, communicator *comm, cudaStream_t stream = 0, + cudaEvent_t comm_launch_event = 0); template void reducescatter2_userbuff_strided_atomic_fp8(void *output, float *scale, const int handler, const int offset, const int rowelements, diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 4e95fc24de..c3a556edba 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -1,32 +1,137 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include +#include + #include "./common.h" #include "./utils.cuh" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" namespace transformer_engine { namespace { __global__ void __launch_bounds__(1) - update_tensor_scale_inv_kernel(const float* __restrict__ scale_ptr, - float* __restrict__ scale_inv_ptr) { + update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr, + float *__restrict__ scale_inv_ptr) { const float scale = scale_ptr == nullptr ? 1 : *scale_ptr; reciprocal(scale_inv_ptr, scale); } } // namespace -void update_tensor_scale_inv(Tensor* t, cudaStream_t stream) { - if (t->scale_inv.dptr != nullptr) { +void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { + if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { + NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( - reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + reinterpret_cast(t->scale.dptr), + reinterpret_cast(t->scale_inv.dptr)); + } +} + +void checkCuDriverContext(CUstream stream) { + CUcontext ctx; + const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx); + switch (driver_status) { + case CUDA_SUCCESS: + break; + + case CUDA_ERROR_INVALID_CONTEXT: + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx); + break; + + default: + const char *desc_NVTE_CHECK_CUDA_DRIVER; + cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER); + NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); } } +CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { + static const std::unordered_map dtypeMapping = { + {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32}, + {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16}, + {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16}, + {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}, + {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}}; + return dtypeMapping.at(dtype); +} + +// Set up parameters to create TMA descriptor. +void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, + const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, + const uint32_t shmemX, const uint32_t stride_elems, + const uint32_t offset_elems, const size_t type_size) { + // Get a function pointer to the cuTensorMapEncodeTiled driver API + static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() { + void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled"); + return reinterpret_cast(driver_ptr); + }(); + // rank is the number of dimensions of the array + constexpr uint32_t rank = 2; + uint64_t size[rank] = {globalX, globalY}; + + // The stride is the number of bytes to traverse from the first element of one row to the next + uint64_t stride[rank - 1] = {stride_elems * type_size}; + + // The boxSize is the size of the shared memory buffer that is used as the + // source/destination of a TMA transfer + uint32_t boxSize[rank] = {shmemX, shmemY}; + + // The distance between elements in units of sizeof(element) + uint32_t elemStride[rank] = {1, 1}; + + const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype); + void *dataPtr = + reinterpret_cast(reinterpret_cast(tensor.dptr) + offset_elems * type_size); + + NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), + "Tensor data pointer must be 16B aligned"); + + const int TMA_needed_size = TMA_gmem_alignment / type_size; + NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size, + "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX); + + // Create the tensor descriptor. + NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled( + &tensorMap, // CUtensorMap *tensorMap, + tensorDataType, + rank, // cuuint32_t tensorRank, + dataPtr, // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + boxSize, // const cuuint32_t *boxDim, + elemStride, // const cuuint32_t *elementStrides, + // Interleave patterns can be used to accelerate loading of values that + // are less than 4 bytes long. + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + + // Swizzling can be used to avoid shared memory bank conflicts. + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + + // L2 Promotion can be used to widen the effect of a cache-policy to a wider + // set of L2 cache lines. + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + // CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B, + + // Any element that is outside of bounds will be set to zero by the TMA transfer. + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} + +bool is_supported_by_CC_100() { + int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); + + return deviceComputeCapability >= 100; +} + } // namespace transformer_engine diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 8830c8875d..ac58398551 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -7,12 +7,14 @@ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ +#include #include #include #include #include #include +#include #include #include #include @@ -22,10 +24,44 @@ #include #include "./nvtx.h" +#include "./util/cuda_driver.h" #include "./util/logging.h" namespace transformer_engine { +std::string to_string(const DType type); +std::string to_string(const NVTEScalingMode &mode); + +inline bool is_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_block_scaling(const NVTEScalingMode &mode) { return !is_tensor_scaling(mode); } + +inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { + return mode == NVTE_DELAYED_TENSOR_SCALING; +} + +inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + +inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { + NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", + end, " in a vector with ", shape.size(), " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape[i]; + } + return ret; +} + +inline size_t product(const std::vector &shape) { + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; +} + struct SimpleTensor { void *dptr; std::vector shape; @@ -33,20 +69,142 @@ struct SimpleTensor { SimpleTensor(void *dptr, const std::vector &shape, DType dtype) : dptr(dptr), shape(shape), dtype(dtype) {} + + SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT + : dptr(tensor.data_ptr), + shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), + dtype(static_cast(tensor.dtype)) {} + SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} + + operator NVTEBasicTensor() const { + const NVTEShape shape = {this->shape.data(), this->shape.size()}; + return {dptr, static_cast(dtype), shape}; + } + + int numel() const { + size_t acc = 1; + for (const auto &dim : shape) { + acc *= dim; + } + return acc; + } }; struct Tensor { SimpleTensor data; + SimpleTensor columnwise_data; SimpleTensor amax; SimpleTensor scale; SimpleTensor scale_inv; + SimpleTensor columnwise_scale_inv; + + NVTEScalingMode scaling_mode; Tensor() : data(), + columnwise_data(), amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), - scale_inv(nullptr, {1}, DType::kFloat32) {} + scale_inv(nullptr, {1}, DType::kFloat32), + columnwise_scale_inv(nullptr, {1}, DType::kFloat32), + scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + + int numel() const { + size_t acc = 1; + for (const auto dim : shape()) { + acc *= dim; + } + return acc; + } + + bool has_data() const noexcept { return data.dptr != nullptr; } + + // Check for size (not just pointer) for 0-dim or no token cases. + bool has_columnwise_data() const noexcept { + return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; + } + + DType dtype() const { + if (has_data()) return data.dtype; + if (has_columnwise_data()) return columnwise_data.dtype; + // Fallback, used e.g. in workspace + return data.dtype; + } + + std::vector shape() const { + /* Note: We sometimes experience spurious compiler errors + * (-Wstringop-overflow) from this function. It appears that GCC + * has some bugs with std::vector (see + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). + */ + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: + if (!has_data() && has_columnwise_data()) { + std::vector ret; + if (!columnwise_data.shape.empty()) { + for (size_t i = 1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); + } + return ret; + } else { + return data.shape; + } + break; + case NVTE_MXFP8_1D_SCALING: + if (!has_data() && has_columnwise_data()) { + return columnwise_data.shape; + } else { + return data.shape; + } + break; + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); + return {}; + } + } + + /*! Matrix height after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_first_dim() const { + const auto &full_shape = shape(); + size_t ret = 1; + if (!full_shape.empty()) { + for (size_t i = 0; i < full_shape.size() - 1; i++) { + ret *= full_shape[i]; + } + } + return ret; + } + + /*! Matrix width after tensor is flattened to 2D + * + * If a tensor has dimensions (D1, D2, ..., Dn), it is reinterpreted + * as a (D1*D2*...*D(n-1), Dn) matrix. + */ + size_t flat_last_dim() const { + const auto &full_shape = shape(); + if (full_shape.empty()) { + return 1; + } else { + return full_shape.back(); + } + } +}; + +struct QuantizationConfig { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0f; + + static constexpr size_t attr_sizes[] = { + sizeof(bool), // force_pow_2_scales + sizeof(float) // amax_epsilon + }; }; template @@ -62,6 +220,10 @@ using fp16 = half; using bf16 = nv_bfloat16; using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; +#if CUDA_VERSION >= 12080 +using fp8e8m0 = __nv_fp8_e8m0; +#endif +using e8m0_t = uint8_t; namespace detail { @@ -80,6 +242,9 @@ TRANSFORMER_ENGINE_TYPE_NAME(half) TRANSFORMER_ENGINE_TYPE_NAME(nv_bfloat16) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3) TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2) +#if CUDA_VERSION >= 12080 +TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0) +#endif #undef TRANSFORMER_ENGINE_TYPE_NAME } // namespace detail @@ -150,6 +315,10 @@ struct TypeInfo { using type = fp8e5m2; \ { __VA_ARGS__ } \ } break; \ + case DType::kFloat8E8M0: { \ + using type = byte; \ + { __VA_ARGS__ } \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } @@ -181,6 +350,25 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +#define TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(dtype, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat16: { \ + using type = fp16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -236,15 +424,31 @@ struct TypeInfo { NVTE_ERROR("Invalid type for 16 bit."); \ } -//////////////////////////////////////////////////////////////////////////////////////////////////// +#define TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(SCALE_DIM, DIM, ...) \ + switch (SCALE_DIM) { \ + case 1: { \ + constexpr size_t DIM = 1; \ + { __VA_ARGS__ } \ + } break; \ + case 32: { \ + constexpr size_t DIM = 32; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Invalid size of the MX scaling factor."); \ + } \ + } -inline size_t product(const std::vector &shape) { - size_t ret = 1; - for (const auto &elem : shape) { - ret *= elem; +#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ + if (CONDITION) { \ + constexpr bool FLAG = true; \ + { __VA_ARGS__ } \ + } else { \ + constexpr bool FLAG = false; \ + { __VA_ARGS__ } \ } - return ret; -} + +//////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { int log2_value = 0; @@ -269,8 +473,26 @@ struct is_fp8 : std::true_type {}; template <> struct is_fp8 : std::true_type {}; +// [128,4] rowwise and [4,128] colwise alignment requirements for the tensor with scaling factors +constexpr size_t scale_tensor_alignment_X_rowwise = 4; +constexpr size_t scale_tensor_alignment_Y_rowwise = 128; +constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_Y_colwise = 4; + +// Alignment requirements for the Tensor Memory Accelerator (TMA) +constexpr int TMA_gmem_alignment = 16; // global memory address alignment + +inline bool is_aligned_ptr(const void *ptr, size_t alignment) { + return reinterpret_cast(ptr) % alignment == 0; +} + +inline bool is_aligned_tensor_data(const Tensor &t, size_t alignment) { + return is_aligned_ptr(static_cast(t.data.dptr), alignment); +} + size_t typeToSize(const DType type); +void CheckNoopTensor(const Tensor &t, const std::string &name); void CheckInputTensor(const Tensor &t, const std::string &name); void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false); @@ -286,6 +508,18 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); +void checkCuDriverContext(CUstream stream); + +CUtensorMapDataType get_CUtensorMapDataType(DType dtype); + +// Set up parameters to create TMA descriptor. +void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, + const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, + const uint32_t shmemX, const uint32_t stride_elems, + const uint32_t offset_elems, const size_t type_size); + +bool is_supported_by_CC_100(); + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_COMMON_H_ diff --git a/transformer_engine/common/cudnn_utils.cpp b/transformer_engine/common/cudnn_utils.cpp index 35e2d11799..eaf6de680a 100644 --- a/transformer_engine/common/cudnn_utils.cpp +++ b/transformer_engine/common/cudnn_utils.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -57,9 +57,13 @@ cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t) } } -void nvte_cudnn_handle_init() { - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); -} +void nvte_cudnn_handle_init() { auto _ = cudnnExecutionPlanManager::Instance().GetHandle(); } + +namespace detail { + +void CreateCuDNNHandle(cudnnHandle_t* handle) { NVTE_CHECK_CUDNN(cudnnCreate(handle)); } + +} // namespace detail } // namespace transformer_engine @@ -68,6 +72,6 @@ namespace cudnn_frontend { // This is needed to define the symbol `cudnn_dlhandle` // When using the flag NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING // to enable dynamic loading. -void *cudnn_dlhandle = nullptr; +void* cudnn_dlhandle = nullptr; } // namespace cudnn_frontend diff --git a/transformer_engine/common/cudnn_utils.h b/transformer_engine/common/cudnn_utils.h index d2827b637a..0016ad7f55 100644 --- a/transformer_engine/common/cudnn_utils.h +++ b/transformer_engine/common/cudnn_utils.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,37 +10,25 @@ #include #include #include - -#include -#include +#include #include "transformer_engine/transformer_engine.h" +#include "util/handle_manager.h" namespace transformer_engine { -cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); +namespace detail { -cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); +void CreateCuDNNHandle(cudnnHandle_t* handle); -class cudnnExecutionPlanManager { - public: - static cudnnExecutionPlanManager &Instance() { - static thread_local cudnnExecutionPlanManager instance; - return instance; - } +} // namespace detail - cudnnHandle_t GetCudnnHandle() { - static thread_local std::once_flag flag; - std::call_once(flag, [&] { cudnnCreate(&handle_); }); - return handle_; - } +cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t); - ~cudnnExecutionPlanManager() {} +cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t); - private: - cudnnHandle_t handle_ = nullptr; -}; +using cudnnExecutionPlanManager = detail::HandleManager; } // namespace transformer_engine -#endif +#endif // TRANSFORMER_ENGINE_CUDNN_UTILS_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9cde765401..a4aec22edb 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -37,7 +37,18 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD; + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + return NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -51,12 +62,14 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_SBHD_SB2HD: case NVTE_QKV_Layout::NVTE_SBHD_SBH2D: case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: return NVTE_QKV_Format::NVTE_SBHD; case NVTE_QKV_Layout::NVTE_BS3HD: case NVTE_QKV_Layout::NVTE_BSH3D: case NVTE_QKV_Layout::NVTE_BSHD_BS2HD: case NVTE_QKV_Layout::NVTE_BSHD_BSH2D: case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: return NVTE_QKV_Format::NVTE_BSHD; case NVTE_QKV_Layout::NVTE_T3HD: case NVTE_QKV_Layout::NVTE_TH3D: @@ -64,6 +77,56 @@ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) { case NVTE_QKV_Layout::NVTE_THD_TH2D: case NVTE_QKV_Layout::NVTE_THD_THD_THD: return NVTE_QKV_Format::NVTE_THD; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + return NVTE_QKV_Format::NVTE_SBHD_2BSHD; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + return NVTE_QKV_Format::NVTE_BSHD_2SBHD; + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: + return NVTE_QKV_Format::NVTE_THD_2BSHD; + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + return NVTE_QKV_Format::NVTE_THD_2SBHD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + +// map NVTE_QKV_Layout to NVTE_QKV_Format for Q +NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout) { + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + switch (qkv_format) { + case NVTE_QKV_Format::NVTE_SBHD: + case NVTE_QKV_Format::NVTE_SBHD_2BSHD: + return NVTE_QKV_Format::NVTE_SBHD; + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_BSHD_2SBHD: + return NVTE_QKV_Format::NVTE_BSHD; + case NVTE_QKV_Format::NVTE_THD: + case NVTE_QKV_Format::NVTE_THD_2BSHD: + case NVTE_QKV_Format::NVTE_THD_2SBHD: + return NVTE_QKV_Format::NVTE_THD; + default: + NVTE_ERROR("qkv_layout not supported!"); + } +} + +// map NVTE_QKV_Layout to NVTE_QKV_Format for KV +NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { + NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + switch (qkv_format) { + case NVTE_QKV_Format::NVTE_SBHD: + case NVTE_QKV_Format::NVTE_BSHD_2SBHD: + case NVTE_QKV_Format::NVTE_THD_2SBHD: + return NVTE_QKV_Format::NVTE_SBHD; + case NVTE_QKV_Format::NVTE_BSHD: + case NVTE_QKV_Format::NVTE_SBHD_2BSHD: + case NVTE_QKV_Format::NVTE_THD_2BSHD: + return NVTE_QKV_Format::NVTE_BSHD; + case NVTE_QKV_Format::NVTE_THD: + return NVTE_QKV_Format::NVTE_THD; default: NVTE_ERROR("qkv_layout not supported!"); } @@ -81,6 +144,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const int sm_arch_ = cuda::sm_arch(device_id); NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); auto cudnn_runtime_version = cudnnGetVersion(); @@ -93,17 +158,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( const bool supported_ragged_offset_size = (!requires_64bit_ragged_offset || cudnn_runtime_version >= 90500); - if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) && - (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) && - (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) && - (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim_qk == 64) && - (head_dim_v == 64) && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) || - ((cudnn_runtime_version >= 90201) && (max_seqlen_q % 128 == 0) && - (max_seqlen_kv % 128 == 0) && (head_dim_qk == 128) && (head_dim_v == 128) && - ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || - (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))) && + if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && + sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 8.9: t3hd, max_s=512, d=64, padding + ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && + qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && + max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; @@ -135,7 +214,12 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset) { flag_m512 = true; } - if ( // architecture + if ( + // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging + // special conditions for blackwell + // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 + !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && + // architecture ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && // sequence length @@ -152,7 +236,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - ((cudnn_runtime_version >= 8906) && + (cudnn_runtime_version >= 8906 && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || (bias_type == NVTE_Bias_Type::NVTE_ALIBI && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK && @@ -161,43 +245,99 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && sm_arch_ >= 90) || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) || - ((cudnn_runtime_version >= 90000) && + (cudnn_runtime_version >= 90000 && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) && // mask type + // pre-8.9.6: causal ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - ((cudnn_runtime_version >= 8906) && + // 8.9.6: {bshd, sbhd} + {no_mask, causal, padding, padding_causal} + (cudnn_runtime_version >= 8906 && + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - ((cudnn_runtime_version >= 90300) && - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && - max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && - bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + // 9.1: adds thd + {padding, padding_causal} + (cudnn_runtime_version >= 90100 && qkv_format == NVTE_QKV_Format::NVTE_THD && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.3: adds {bshd, sbhd} + causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90300 && (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) && - max_seqlen_q <= max_seqlen_kv && dropout == 0.0)) && + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + // 9.5: adds {paged_kv_bshd, paged_kv_sbhd} + {padding, padding_causal, padding_causal_bottom_right} + (cudnn_runtime_version >= 90500 && + layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD && + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv)) && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + // 9.6: adds {bshd, sbhd, thd} + padding_causal_bottom_right + self/cross-attn (sq <= skv) + (cudnn_runtime_version >= 90600 && + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0 && max_seqlen_q <= max_seqlen_kv && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + // 9.7: removes s_q/s_kv % 64 = 0 for {causal_bottom_right, padding_causal_bottom_right} + // for any q_format/kv_format, and paged/non-paged + (cudnn_runtime_version >= 90700 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + ((attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0) || + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) && + max_seqlen_q <= max_seqlen_kv)))) && // bias + mask combination (!(cudnn_runtime_version >= 8906 && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && // qkv format - ((qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD) || + (qkv_format == NVTE_QKV_Format::NVTE_SBHD || qkv_format == NVTE_QKV_Format::NVTE_BSHD || (qkv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90 && ((cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups) || - (cudnn_runtime_version >= 90600)))) && + cudnn_runtime_version >= 90600)) || + ((q_format == NVTE_QKV_Format::NVTE_SBHD || q_format == NVTE_QKV_Format::NVTE_BSHD || + (q_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90) || + kv_format == NVTE_QKV_Format::NVTE_SBHD || kv_format == NVTE_QKV_Format::NVTE_BSHD || + (kv_format == NVTE_QKV_Format::NVTE_THD && sm_arch_ >= 90)) && + cudnn_runtime_version >= 90700)) && // sliding window + // pre-9.2: full attn, causal ((cudnn_runtime_version < 90200 && window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + // 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd} (cudnn_runtime_version >= 90200 && ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && max_seqlen_q == max_seqlen_kv)) && - dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + max_seqlen_q <= max_seqlen_kv && dropout == 0.0 && + bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || - qkv_format == NVTE_QKV_Format::NVTE_SBHD))))) && + qkv_format == NVTE_QKV_Format::NVTE_SBHD)))) || + // 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd} + (cudnn_runtime_version >= 90600 && + ((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) || + ((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 && + ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK && + // TODO(cyang): fix bug for BRCM + cross-attention on sm100 + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700)))) || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK || + (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK && + (sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv && + cudnn_runtime_version <= 90700) || + cudnn_runtime_version > 90700))))) && + max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && + dropout == 0.0)))) && // check 64-bit ragged offset support (supported_ragged_offset_size)) { flag_arb = true; @@ -279,7 +419,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -361,7 +501,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con t = input_QKV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( @@ -415,22 +555,23 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = reinterpret_cast(cu_seqlens_q); const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = reinterpret_cast(page_table_k); + const Tensor *input_page_table_v = reinterpret_cast(page_table_v); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_KV = reinterpret_cast(KV); @@ -455,13 +596,42 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_KV->data.shape[0]; } + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_KV->data.shape[0]; + page_size_k = input_KV->data.shape[1]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_KV->data.shape[1]; + page_size_k = input_KV->data.shape[0]; + num_pages_v = num_pages_v; + page_size_v = page_size_v; + } + } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -481,11 +651,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, - input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, + input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -546,13 +717,16 @@ void nvte_fused_attn_bwd_kvpacked( } size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_KV->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); @@ -614,7 +788,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, @@ -626,6 +801,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_cu_seqlens_kv = reinterpret_cast(cu_seqlens_kv); const Tensor *input_cu_seqlens_q_padded = reinterpret_cast(cu_seqlens_q_padded); const Tensor *input_cu_seqlens_kv_padded = reinterpret_cast(cu_seqlens_kv_padded); + const Tensor *input_page_table_k = reinterpret_cast(page_table_k); + const Tensor *input_page_table_v = reinterpret_cast(page_table_v); const Tensor *input_rng_state = reinterpret_cast(rng_state); const Tensor *input_Q = reinterpret_cast(Q); const Tensor *input_K = reinterpret_cast(K); @@ -636,20 +813,51 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } + int64_t num_pages_k = 0; + int64_t num_pages_v = 0; + int64_t page_size_k = 0; + int64_t page_size_v = 0; + int64_t max_pages_per_seq_k = 0; + int64_t max_pages_per_seq_v = 0; + if (input_page_table_k->data.dptr != nullptr) { + max_pages_per_seq_k = input_page_table_k->data.shape[1]; + } + if (input_page_table_v->data.dptr != nullptr) { + max_pages_per_seq_v = input_page_table_v->data.shape[1]; + } + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) { + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (kv_format == NVTE_QKV_Format::NVTE_BSHD) { + num_pages_k = input_K->data.shape[0]; + page_size_k = input_K->data.shape[1]; + num_pages_v = input_V->data.shape[0]; + page_size_v = input_V->data.shape[1]; + } else if (kv_format == NVTE_QKV_Format::NVTE_SBHD) { + num_pages_k = input_K->data.shape[1]; + page_size_k = input_K->data.shape[0]; + num_pages_v = input_V->data.shape[1]; + page_size_v = input_V->data.shape[0]; + } + } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); @@ -669,11 +877,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, is_training, attn_scale, + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, - wkspace, stream, handle); + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, + input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -723,20 +932,24 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *wkspace = reinterpret_cast(workspace); auto ndim = input_Q->data.shape.size(); + auto ndim_kv = input_K->data.shape.size(); size_t b = input_cu_seqlens_q->data.shape[0] - 1; size_t h_q = input_Q->data.shape[ndim - 2]; - size_t h_kv = input_K->data.shape[ndim - 2]; + size_t h_kv = input_K->data.shape[ndim_kv - 2]; size_t d_qk = input_Q->data.shape[ndim - 1]; - size_t d_v = input_V->data.shape[ndim - 1]; + size_t d_v = input_V->data.shape[ndim_kv - 1]; size_t t_q = 0; size_t t_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD) { t_q = input_Q->data.shape[0]; + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { t_kv = input_K->data.shape[0]; } - auto handle = cudnnExecutionPlanManager::Instance().GetCudnnHandle(); + auto handle = cudnnExecutionPlanManager::Instance().GetHandle(); const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_K->data.dtype); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index f242502261..2ce93f196a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -50,14 +50,16 @@ namespace transformer_engine { namespace fused_attn { void fused_attn_arbitrary_seqlen_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, - int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, - bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, + int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, + int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, + float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, - void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, - size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -66,25 +68,35 @@ void fused_attn_arbitrary_seqlen_fwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (is_training && dropout_probability != 0.0f); - bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + NVTE_QKV_Format q_format = nvte_get_q_format(layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); + if (is_paged_kv) { + NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); + } + // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; - if (is_ragged && cudnn_runtime_version >= 90600) { + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); // replace batch size and maximum sequence lengths with maximum token counts // for query and key/value so the graph is static within each quantization bucket b = max_b; - s_q = max_t_q; - s_kv = max_t_kv; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; @@ -96,6 +108,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_kv, d_qk, d_v, + num_pages_k, + num_pages_v, + page_size_k, + page_size_v, + max_pages_per_seq_k, + max_pages_per_seq_v, bias_b, bias_h, scaling_factor, @@ -122,6 +140,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // bias std::shared_ptr, // seq_q std::shared_ptr, // seq_kv + std::shared_ptr, // page_table_k + std::shared_ptr, // page_table_v std::shared_ptr, // offset_q std::shared_ptr, // offset_k std::shared_ptr, // offset_v @@ -150,6 +170,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr Q, K, V, attn_scale; std::shared_ptr bias, seq_q, seq_kv; + std::shared_ptr page_table_k, page_table_v; std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; @@ -159,17 +180,36 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector v_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_qk, q_stride.data(), layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); - generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, - NVTE_QKV_Matrix::NVTE_V_Matrix); + if (is_paged_kv) { + generateMatrixStrides(num_pages_k, hg, page_size_k, page_size_v, d_qk, k_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(num_pages_v, hg, page_size_k, page_size_v, d_v, v_stride.data(), + layout, NVTE_QKV_Matrix::NVTE_V_Matrix); + } else { + generateMatrixStrides(b, hg, s_q, s_kv, d_qk, k_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_K_Matrix); + generateMatrixStrides(b, hg, s_q, s_kv, d_v, v_stride.data(), layout, + NVTE_QKV_Matrix::NVTE_V_Matrix); + } - if (is_ragged) { + Q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride)); + if (is_ragged_q) { offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_q") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + Q->set_ragged_offset(offset_q); + } + K = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("K").set_stride(k_stride)); + V = mha_graph->tensor(fe::graph::Tensor_attributes().set_name("V").set_stride(v_stride)); + if (is_paged_kv) { + K->set_dim({num_pages_k, hg, page_size_k, d_qk}); + V->set_dim({num_pages_v, hg, page_size_v, d_v}); + } else if (is_ragged_kv) { offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_k") .set_dim({b + 1, 1, 1, 1}) @@ -180,34 +220,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_ragged_offset(offset_q)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_ragged_offset(offset_k)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_ragged_offset(offset_v)); + K->set_dim({b, hg, s_kv, d_qk}).set_ragged_offset(offset_k); + V->set_dim({b, hg, s_kv, d_v}).set_ragged_offset(offset_v); } else { - Q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride)); - K = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride)); - V = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride)); + K->set_dim({b, hg, s_kv, d_qk}); + V->set_dim({b, hg, s_kv, d_v}); } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -226,7 +243,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_attn_scale(attn_scale); if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_options.set_sliding_window_length(window_size_left + 1); + sdpa_options.set_diagonal_band_left_bound(window_size_left + 1); } sdpa_options.set_alibi_mask(is_alibi); @@ -253,6 +270,24 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); } + if (is_paged_kv) { + page_table_k = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_k") + .set_dim({b, 1, max_pages_per_seq_k, 1}) + .set_stride({{max_pages_per_seq_k, max_pages_per_seq_v, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); + page_table_v = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("page_table_v") + .set_dim({b, 1, max_pages_per_seq_v, 1}) + .set_stride({{max_pages_per_seq_v, max_pages_per_seq_v, 1, 1}}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_paged_attention_k_table(page_table_k); + sdpa_options.set_paged_attention_v_table(page_table_v); + sdpa_options.set_paged_attention_max_seq_len_kv(static_cast(s_kv)); + } + if (is_dropout) { dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("Seed") @@ -272,37 +307,27 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - if (is_ragged) { + O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); + if (is_ragged_q) { offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_o") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - O->set_output(true) - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_ragged_offset(offset_o); - } else { - O->set_output(true).set_dim({b, h, s_q, d_v}).set_stride(o_stride); + O->set_ragged_offset(offset_o); } - if (is_ragged && cudnn_runtime_version >= 90600) { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, 1, h, 1}) - .set_ragged_offset(offset_stats); + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { - Stats->set_output(true) - .set_data_type(fe::DataType_t::FLOAT) - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}); + Stats->set_stride({h * s_q, s_q, 1, 1}); } std::tuple, // Q @@ -315,9 +340,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); - auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) + : std::make_tuple(nullptr, nullptr); + auto offset_qo_tuple = + is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = + is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); + auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -329,16 +358,16 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, - padding_tuple, offset_qkvo_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat( + std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, + page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, offset_q, offset_k, - offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = + auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, + offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -350,11 +379,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); size_t seqlen_offsets_workspace_size = 0; - if (is_ragged) { - if (cudnn_runtime_version >= 90600) { - seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + if (is_ragged_q || is_ragged_kv) { + size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { - seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; } } if (workspace == nullptr) { @@ -390,28 +420,49 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } - if (is_ragged) { + if (is_paged_kv) { + variant_pack[page_table_k] = devPtrPageTableK; + variant_pack[page_table_v] = devPtrPageTableV; + } + + if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; - void *devOffsetsQ = + void *devOffsets = static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; - void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; - void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsQ = nullptr; + void *devOffsetsO = nullptr; + if (is_ragged_q) { + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + } + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; + if (is_ragged_kv) { + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + } void *devOffsetsS = nullptr; - if (cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + if (is_ragged_q && cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsets) + + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); - variant_pack[offset_q] = devOffsetsQ; - variant_pack[offset_k] = devOffsetsK; - variant_pack[offset_v] = devOffsetsV; - variant_pack[offset_o] = devOffsetsO; - if (cudnn_runtime_version >= 90600) { + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (is_ragged_q && cudnn_runtime_version >= 90600) { variant_pack[offset_stats] = devOffsetsS; } } @@ -446,26 +497,37 @@ void fused_attn_arbitrary_seqlen_bwd_impl( (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_bottom_right = ((mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) || (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); - if (is_bottom_right && s_q == s_kv) { + bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) || + (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)); + if (is_bottom_right && s_q == s_kv && !is_padding) { is_causal = true; is_bottom_right = false; } - bool is_padding = ((mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)); bool is_dropout = (dropout_probability != 0.0f); - bool is_ragged = (nvte_get_qkv_format(layout) == NVTE_QKV_Format::NVTE_THD); + NVTE_QKV_Format q_format = nvte_get_q_format(layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); + bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); + bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); const auto cudnn_runtime_version = cudnnGetVersion(); const int device_id = cuda::current_device(); const int sm_arch_ = cuda::sm_arch(device_id); + + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); + bool is_paged_kv = (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD); + if (is_paged_kv) { + NVTE_CHECK(is_padding, "Paged attention requires padding mask!"); + } + // keep original batch size because cu_seqlens are created with [b+1] shape int64_t actual_b = b; - if (is_ragged && cudnn_runtime_version >= 90600) { + if ((is_ragged_q || is_ragged_kv) && cudnn_runtime_version >= 90600) { NVTE_CHECK(is_padding, "Ragged QKV input requires padding or padding_causal mask!"); // replace batch size and maximum sequence lengths with maximum token counts // for query and key/value so the graph is static within each quantization bucket b = max_b; - s_q = max_t_q; - s_kv = max_t_kv; + s_q = is_ragged_q ? max_t_q : s_q; + s_kv = is_ragged_kv ? max_t_kv : s_kv; } // We choose between 32-bit and 64-bit offsets depending on need. @@ -480,6 +542,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( s_kv, d_qk, d_v, + 0, + 0, + 0, + 0, + 0, + 0, bias_b, bias_h, scaling_factor, @@ -556,12 +624,42 @@ void fused_attn_arbitrary_seqlen_bwd_impl( generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - if (is_ragged) { + q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Q") + .set_dim({b, h, s_q, d_qk}) + .set_stride(q_stride)); + k = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("K") + .set_dim({b, hg, s_kv, d_qk}) + .set_stride(k_stride)); + v = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("V") + .set_dim({b, hg, s_kv, d_v}) + .set_stride(v_stride)); + o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("O") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride)); + dO = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("dO") + .set_dim({b, h, s_q, d_v}) + .set_stride(o_stride)); + if (is_ragged_q) { offset_q = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_q") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_o") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + q->set_ragged_offset(offset_q); + o->set_ragged_offset(offset_o); + dO->set_ragged_offset(offset_o); + } + if (is_ragged_kv) { offset_k = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_k") .set_dim({b + 1, 1, 1, 1}) @@ -572,77 +670,24 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - offset_o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_o") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_ragged_offset(offset_q)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_ragged_offset(offset_k)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_ragged_offset(offset_v)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_ragged_offset(offset_o)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride) - .set_ragged_offset(offset_o)); - } else { - q = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Q") - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride)); - k = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("K") - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride)); - v = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("V") - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride)); - o = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("O") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride)); - dO = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("dO") - .set_dim({b, h, s_q, d_v}) - .set_stride(o_stride)); + k->set_ragged_offset(offset_k); + v->set_ragged_offset(offset_v); } - if (is_ragged && cudnn_runtime_version >= 90600) { + + stats = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("stats") + .set_dim({b, h, s_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + if (is_ragged_q && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("offset_stats") .set_dim({b + 1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, 1, h, 1}) - .set_data_type(fe::DataType_t::FLOAT) - .set_ragged_offset(offset_stats)); + stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { - stats = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("stats") - .set_dim({b, h, s_q, 1}) - .set_stride({h * s_q, s_q, 1, 1}) - .set_data_type(fe::DataType_t::FLOAT)); + stats->set_stride({h * s_q, s_q, 1, 1}); } attn_scale = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -659,12 +704,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); - if (is_ragged && cudnn_runtime_version >= 90600) { + if (is_ragged_q && cudnn_runtime_version >= 90600) { sdpa_backward_options.set_max_total_seq_len_q(s_q); } + if (is_ragged_kv && cudnn_runtime_version >= 90600) { + sdpa_backward_options.set_max_total_seq_len_kv(s_kv); + } if (cudnn_runtime_version >= 90200 && window_size_left != -1) { - sdpa_backward_options.set_sliding_window_length(window_size_left + 1); + sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1); } if (cudnn_runtime_version >= 90000) { @@ -723,23 +771,15 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); - if (is_ragged) { - dQ->set_output(true) - .set_dim({b, h, s_q, d_qk}) - .set_stride(q_stride) - .set_ragged_offset(offset_q); - dK->set_output(true) - .set_dim({b, hg, s_kv, d_qk}) - .set_stride(k_stride) - .set_ragged_offset(offset_k); - dV->set_output(true) - .set_dim({b, hg, s_kv, d_v}) - .set_stride(v_stride) - .set_ragged_offset(offset_v); - } else { - dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); - dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); - dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); + dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); + dK->set_output(true).set_dim({b, hg, s_kv, d_qk}).set_stride(k_stride); + dV->set_output(true).set_dim({b, hg, s_kv, d_v}).set_stride(v_stride); + if (is_ragged_q) { + dQ->set_ragged_offset(offset_q); + } + if (is_ragged_kv) { + dK->set_ragged_offset(offset_k); + dV->set_ragged_offset(offset_v); } std::tuple, // q @@ -756,9 +796,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); - auto offset_qkvo_tuple = is_ragged ? std::make_tuple(offset_q, offset_k, offset_v, offset_o) - : std::make_tuple(nullptr, nullptr, nullptr, nullptr); - auto offset_s_tuple = (is_ragged && cudnn_runtime_version >= 90600) + auto offset_qo_tuple = + is_ragged_q ? std::make_tuple(offset_q, offset_o) : std::make_tuple(nullptr, nullptr); + auto offset_kv_tuple = + is_ragged_kv ? std::make_tuple(offset_k, offset_v) : std::make_tuple(nullptr, nullptr); + auto offset_s_tuple = (is_ragged_q && cudnn_runtime_version >= 90600) ? std::make_tuple(offset_stats) : std::make_tuple(nullptr); auto dropout_tuple = is_dropout ? std::make_tuple(dropout_seed, dropout_offset) @@ -772,14 +814,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, - offset_qkvo_tuple, offset_s_tuple, dropout_tuple); + offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_k, offset_v, offset_o, offset_stats, dropout_seed, dropout_offset] = + offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed @@ -791,11 +833,12 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const size_t num_bytes_per_ragged_offset = alignTo<16>((b + 1) * typeToSize(ragged_offset_type)); size_t seqlen_offsets_workspace_size = 0; - if (is_ragged) { - if (cudnn_runtime_version >= 90600) { - seqlen_offsets_workspace_size = 5 * num_bytes_per_ragged_offset; + if (is_ragged_q || is_ragged_kv) { + size_t count = 2 * (static_cast(is_ragged_q) + static_cast(is_ragged_kv)); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + seqlen_offsets_workspace_size = (count + 1) * num_bytes_per_ragged_offset; } else { - seqlen_offsets_workspace_size = 4 * num_bytes_per_ragged_offset; + seqlen_offsets_workspace_size = count * num_bytes_per_ragged_offset; } } if (workspace == nullptr) { @@ -844,28 +887,44 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[seq_kv] = devActualSeqlenKV; } - if (is_ragged) { + if (is_ragged_q || is_ragged_kv) { constexpr size_t nthreads_per_block = 128; const size_t grid = (b + nthreads_per_block) / nthreads_per_block; - void *devOffsetsQ = + void *devOffsets = static_cast(workspace) + plan_workspace_size + actual_seqlen_workspace_size; - void *devOffsetsK = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; - void *devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; - void *devOffsetsO = static_cast(devOffsetsV) + num_bytes_per_ragged_offset; + void *devOffsetsQ = nullptr; + void *devOffsetsO = nullptr; + if (is_ragged_q) { + devOffsetsQ = devOffsets; + devOffsetsO = static_cast(devOffsetsQ) + num_bytes_per_ragged_offset; + } + void *devOffsetsK = nullptr; + void *devOffsetsV = nullptr; + if (is_ragged_kv) { + devOffsetsK = static_cast(devOffsets) + + static_cast(is_ragged_q) * 2 * num_bytes_per_ragged_offset; + devOffsetsV = static_cast(devOffsetsK) + num_bytes_per_ragged_offset; + } void *devOffsetsS = nullptr; - if (cudnn_runtime_version >= 90600) { - devOffsetsS = static_cast(devOffsetsO) + num_bytes_per_ragged_offset; + if (is_ragged_q && cudnn_runtime_version >= 90600) { + devOffsetsS = static_cast(devOffsets) + + (static_cast(is_ragged_q) + static_cast(is_ragged_kv)) * 2 * + num_bytes_per_ragged_offset; } const NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout); cu_seqlens_padded_to_offsets<<>>( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); - variant_pack[offset_q] = devOffsetsQ; - variant_pack[offset_k] = devOffsetsK; - variant_pack[offset_v] = devOffsetsV; - variant_pack[offset_o] = devOffsetsO; - if (cudnn_runtime_version >= 90600) { + if (is_ragged_q) { + variant_pack[offset_q] = devOffsetsQ; + variant_pack[offset_o] = devOffsetsO; + } + if (is_ragged_kv) { + variant_pack[offset_k] = devOffsetsK; + variant_pack[offset_v] = devOffsetsV; + } + if (is_ragged_q && cudnn_runtime_version >= 90600) { variant_pack[offset_stats] = devOffsetsS; } } @@ -986,11 +1045,12 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, bias_b, bias_h, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, devPtrK, - devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, + devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, + handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1094,20 +1154,23 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; void *devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); size_t stride = 0; if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { stride = typeToSize(QKV_type) * num_gqa_groups * head_dim; @@ -1133,13 +1196,19 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + void *devPtrPageTableK = page_table_k->data.dptr; + void *devPtrPageTableV = page_table_v->data.dptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } @@ -1149,7 +1218,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1167,7 +1236,7 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1202,11 +1271,13 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, - devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1265,10 +1336,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } @@ -1318,17 +1394,20 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, - size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); void *devPtrQ = input_Q->data.dptr; void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; @@ -1347,13 +1426,19 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; + void *devPtrPageTableK = page_table_k->data.dptr; + void *devPtrPageTableV = page_table_v->data.dptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } @@ -1363,7 +1448,7 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = 3; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1381,7 +1466,7 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = 2; Tensor *output_S = reinterpret_cast(Aux_CTX_Tensors->tensors[0]); output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; @@ -1416,11 +1501,13 @@ void fused_attn_arbitrary_seqlen_fwd( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, is_training, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, devPtrQ, - devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, + devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, + devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, + &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1469,10 +1556,15 @@ void fused_attn_arbitrary_seqlen_bwd( size_t max_batch_size = 0; size_t max_tokens_q = 0; size_t max_tokens_kv = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { max_batch_size = get_max_batch_size(batch); + } + if (q_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_q = get_max_tokens(num_tokens_q); + } + if (kv_format == NVTE_QKV_Format::NVTE_THD) { max_tokens_kv = get_max_tokens(num_tokens_kv); } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index 3a1216f891..e1a20274f4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -38,13 +38,15 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, + float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, @@ -61,13 +63,15 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, - size_t num_tokens_kv, bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, + size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, + const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 9341ebf5f9..08e0642b29 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index a5b25f3279..171fe846ce 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f8fe458219..eacd8b53b4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -1670,8 +1670,6 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); try { FADescriptor_v1 descriptor{b, @@ -1681,6 +1679,12 @@ void fused_attn_fp8_fwd_impl_v1( s_kv, d, d, + 0, + 0, + 0, + 0, + 0, + 0, bias_b, bias_h, scaling_factor, @@ -1798,36 +1802,33 @@ void fused_attn_fp8_fwd_impl_v1( // sdpa_options.set_bias(bias); // } - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_options.set_padding_mask(is_padding).set_seq_len_q(seq_q).set_seq_len_kv(seq_kv); + } - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } auto [O, Stats, amax_s, amax_o] = mha_graph->sdpa_fp8( Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_options); @@ -1919,29 +1920,28 @@ void fused_attn_fp8_fwd_impl_v1( {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } + /* if (is_bias) { + variant_pack[bias] = devPtrBias; + } */ + + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; + } + + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { NVTE_ERROR(e.what()); @@ -1974,8 +1974,6 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); - NVTE_CHECK(~is_padding, "FP8 fused attention does not support padding/padding_causal mask yet!"); - NVTE_CHECK(~is_dropout, "FP8 fused attention does not support dropout yet!"); try { FADescriptor_v1 descriptor{b, @@ -1985,6 +1983,12 @@ void fused_attn_fp8_bwd_impl_v1( s_kv, d, d, + 0, + 0, + 0, + 0, + 0, + 0, bias_b, bias_h, scaling_factor, @@ -2151,36 +2155,35 @@ void fused_attn_fp8_bwd_impl_v1( // } // } - // if (is_padding) { - // seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_q") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("seq_kv") - // .set_dim({b, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT32)); - // sdpa_backward_options.set_padding_mask(is_padding) - // .set_seq_len_q(seq_q) - // .set_seq_len_kv(seq_kv); - // } + if (is_padding) { + seq_q = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_q") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + seq_kv = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("seq_kv") + .set_dim({b, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT32)); + sdpa_backward_options.set_padding_mask(is_padding) + .set_seq_len_q(seq_q) + .set_seq_len_kv(seq_kv); + } - // if (is_dropout) { - // dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Seed") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() - // .set_name("Offset") - // .set_dim({1, 1, 1, 1}) - // .set_stride({1, 1, 1, 1}) - // .set_data_type(fe::DataType_t::INT64)); - // sdpa_backward_options.set_dropout( - // dropout_probability, dropout_seed, dropout_offset); - // } + if (is_dropout) { + dropout_seed = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Seed") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + dropout_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Offset") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::INT64)); + sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); + } auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph->sdpa_fp8_backward( q, k, v, o, dO, stats, descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, @@ -2308,34 +2311,32 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dP, devPtrAmaxdP}, }; - // if (is_bias) { - // variant_pack[bias] = devPtrBias; - // if ((bias_b == 1) && (bias_h == h)) { - // variant_pack[dBias] = devPtrdBias; - // } else { - // variant_pack[dBias] = nullptr; - // } - // } - - // if (is_padding) { - // constexpr size_t nthreads_per_block = 128; - // const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - // void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - // void *devActualSeqlenKV = static_cast(devActualSeqlenQ) - // + b * sizeof(int32_t); - // cu_seqlens_to_actual_seqlens<<>>( - // b, static_cast(devPtrCuSeqlensQ), - // static_cast(devPtrCuSeqlensKV), - // static_cast(devActualSeqlenQ), - // static_cast(devActualSeqlenKV)); - // variant_pack[seq_q] = devActualSeqlenQ; - // variant_pack[seq_kv] = devActualSeqlenKV; - // } - - // if (is_dropout) { - // variant_pack[dropout_seed] = devPtrDropoutSeed; - // variant_pack[dropout_offset] = devPtrDropoutOffset; - // } + /* if (is_bias) { + variant_pack[bias] = devPtrBias; + if ((bias_b == 1) && (bias_h == h)) { + variant_pack[dBias] = devPtrdBias; + } else { + variant_pack[dBias] = nullptr; + } + } */ + + if (is_padding) { + constexpr size_t nthreads_per_block = 128; + const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; + void* devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; + void* devActualSeqlenKV = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); + cu_seqlens_to_actual_seqlens<<>>( + b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) + static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), + static_cast(devActualSeqlenKV)); + variant_pack[seq_q] = devActualSeqlenQ; + variant_pack[seq_kv] = devActualSeqlenKV; + } + + if (is_dropout) { + variant_pack[dropout_seed] = devPtrDropoutSeed; + variant_pack[dropout_offset] = devPtrDropoutOffset; + } NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException& e) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 55830d3cda..3daf45d162 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index a053c55fb6..516f7b84c5 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -117,6 +117,7 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 } break; case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix) || @@ -223,6 +224,9 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 break; case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD: case NVTE_QKV_Layout::NVTE_THD_THD_THD: + case NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD: if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { strideA[batch_dim_idx] = s_q * h * d; @@ -243,6 +247,52 @@ void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int6 strideA[hidden_transpose_dim_idx] = 1; } break; + case NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = s_kv * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } + break; + case NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD: + case NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD: + if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = b * h * d; + strideA[hidden_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose) || + (matrix == NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose)) { + strideA[batch_dim_idx] = h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_transpose_dim_idx] = b * h * d; + strideA[hidden_transpose_dim_idx] = 1; + } else if ((matrix == NVTE_QKV_Matrix::NVTE_Q_Matrix) || + (matrix == NVTE_QKV_Matrix::NVTE_O_Matrix)) { + strideA[batch_dim_idx] = s_q * h * d; + strideA[head_dim_idx] = d; + strideA[seqlen_dim_idx] = h * d; + strideA[hidden_dim_idx] = 1; + } + break; } if (matrix == NVTE_QKV_Matrix::NVTE_S_Matrix) { @@ -379,28 +429,44 @@ __device__ void cu_seqlens_padded_to_offsets_impl( size_t tid = blockIdx.x * blockDim.x + threadIdx.x; auto cu_seqlens_id = min(tid, actual_b); if (tid <= max_b) { - offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; if (offsets_s != nullptr) { offsets_s[tid] = h * cu_seqlens_q_padded[cu_seqlens_id]; } - switch (layout_group) { - case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; - offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; - offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; - break; - case NVTE_QKV_Layout_Group::NVTE_3HD: - case NVTE_QKV_Layout_Group::NVTE_H3D: - offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; - offsets_k[tid] = offsets_q[cu_seqlens_id]; - offsets_v[tid] = offsets_q[cu_seqlens_id]; - break; - case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; - offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; - offsets_v[tid] = offsets_k[cu_seqlens_id]; - break; + if (offsets_q != nullptr && offsets_o != nullptr) { + offsets_o[tid] = h * d_v * cu_seqlens_q_padded[cu_seqlens_id]; + switch (layout_group) { + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_3HD: + case NVTE_QKV_Layout_Group::NVTE_H3D: + offsets_q[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + offsets_q[tid] = h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + break; + } + } + if (offsets_k != nullptr && offsets_v != nullptr) { + switch (layout_group) { + case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: + offsets_k[tid] = hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = hg * d_v * cu_seqlens_kv_padded[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_3HD: + case NVTE_QKV_Layout_Group::NVTE_H3D: + offsets_k[tid] = 3 * h * d_qk * cu_seqlens_q_padded[cu_seqlens_id]; + offsets_v[tid] = offsets_k[cu_seqlens_id]; + break; + case NVTE_QKV_Layout_Group::NVTE_HD_2HD: + case NVTE_QKV_Layout_Group::NVTE_HD_H2D: + offsets_k[tid] = 2 * hg * d_qk * cu_seqlens_kv_padded[cu_seqlens_id]; + offsets_v[tid] = offsets_k[cu_seqlens_id]; + break; + } } } } @@ -433,6 +499,7 @@ DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_at std::array offsets_qkvo{}; switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: + case NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD: offsets_qkvo[0] = num_attn_heads * head_dim_qk * max_seqlen_q; offsets_qkvo[1] = num_gqa_groups * head_dim_qk * max_seqlen_kv; offsets_qkvo[2] = num_gqa_groups * head_dim_v * max_seqlen_kv; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index f790d3b567..30702a875d 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -93,6 +93,12 @@ struct FADescriptor_v1 { std::int64_t s_kv; std::int64_t d_qk; std::int64_t d_v; + std::int64_t num_pages_k; + std::int64_t num_pages_v; + std::int64_t page_size_k; + std::int64_t page_size_v; + std::int64_t max_pages_per_seq_k; + std::int64_t max_pages_per_seq_v; std::int64_t bias_b; std::int64_t bias_h; float attnScale; @@ -108,13 +114,16 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t bwd_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { - return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, bias_b, bias_h, attnScale, isTraining, - dropoutProbability, layout, mask_type, window_size_left, window_size_right, - deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < - std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.bias_b, - rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, - rhs.mask_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, - rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type); + return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, + page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, + attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, + window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, + rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, + rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, + rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, + rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index 26f104d3ed..7f35ddd70b 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu index 841edcf043..4628b37949 100644 --- a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 08fd32af9c..2d31f82bab 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 8571887ee6..7d97680ec3 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 593ec086d7..d24d114c29 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -14,7 +14,9 @@ #include #include "../common.h" +#include "../util/handle_manager.h" #include "../util/logging.h" +#include "common/util/cuda_runtime.h" namespace { @@ -46,20 +48,118 @@ uint32_t _getAlignment(uintptr_t address) { } } +inline void CreateCublasHandle(cublasLtHandle_t *handle) { + NVTE_CHECK_CUBLAS(cublasLtCreate(handle)); +} + +struct GemmParam { + void *A; + void *B; + cublasOperation_t transA; + cublasOperation_t transB; + transformer_engine::DType Atype; + transformer_engine::DType Btype; + void *A_scale_inv; + void *B_scale_inv; + int lda; + int ldb; + + GemmParam(cublasOperation_t transA, cublasOperation_t transB) + : A(nullptr), + B(nullptr), + transA(transA), + transB(transB), + Atype(transformer_engine::DType::kNumTypes), + Btype(transformer_engine::DType::kNumTypes), + A_scale_inv(nullptr), + B_scale_inv(nullptr), + lda(0), + ldb(0) {} +}; + +GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cublasOperation_t transA, + const transformer_engine::Tensor &B, const cublasOperation_t transB, + const int k, const int lda, const int ldb) { + using namespace transformer_engine; + NVTE_CHECK(A.scaling_mode == B.scaling_mode, + "Inputs A and B to GEMM need to have the same scaling mode!"); + NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); + NVTE_CHECK(B.has_data() || B.has_columnwise_data(), "Input B does not hold any data!"); + GemmParam ret(transA, transB); + + ret.lda = lda; + ret.ldb = ldb; + + if (is_tensor_scaling(A.scaling_mode)) { + ret.A = A.data.dptr; + ret.A_scale_inv = A.scale_inv.dptr; + if (transA == CUBLAS_OP_T) { + ret.Atype = A.data.dtype; + } else { + ret.Atype = A.has_columnwise_data() ? A.columnwise_data.dtype : A.data.dtype; + if (is_fp8_dtype(ret.Atype)) { + int arch = cuda::sm_arch(cuda::current_device()); + if (arch < 100) { + // Hopper and Ada - we need to use columnwise_data and change transA + NVTE_CHECK(A.has_columnwise_data(), "Input A is not suitable for columnwise usage!"); + ret.A = A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = k; + } + } + } + ret.B = B.data.dptr; + ret.B_scale_inv = B.scale_inv.dptr; + if (transB == CUBLAS_OP_T) { + ret.Btype = B.has_columnwise_data() ? B.columnwise_data.dtype : B.data.dtype; + if (is_fp8_dtype(ret.Btype)) { + int arch = cuda::sm_arch(cuda::current_device()); + if (arch < 100) { + // Hopper and Ada - we need to use columnwise_data and change transA + NVTE_CHECK(B.has_columnwise_data(), "Input B is not suitable for columnwise usage!"); + ret.B = B.columnwise_data.dptr; + ret.transB = CUBLAS_OP_N; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = k; + } + } + } else { + ret.Btype = B.data.dtype; + } + } else { + // If not tensor scaling (which includes also high precision types), we need to + // use the proper version of data + // We leave the transA/B values as is, since Blackwell supports transposes + ret.A = transA ? A.data.dptr : A.columnwise_data.dptr; + ret.Atype = transA ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = transA ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.B = transB ? B.columnwise_data.dptr : B.data.dptr; + ret.Btype = transB ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = transB ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + } + return ret; +} + } // namespace namespace transformer_engine { +using cublasHandleManager = detail::HandleManager; + void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, int ldb, int ldd, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { - void *A = inputA->data.dptr; - void *A_scale_inverse = inputA->scale_inv.dptr; - void *B = inputB->data.dptr; - void *B_scale_inverse = inputB->scale_inv.dptr; + // Return immediately if GEMM is trivial + if (m <= 0 || n <= 0) { + return; + } + NVTE_CHECK(k > 0); + + const GemmParam ¶m = CanonicalizeGemmInput(*inputA, transa, *inputB, transb, k, lda, ldb); void *C = outputD->data.dptr; void *D = outputD->data.dptr; void *D_scale = outputD->scale.dptr; @@ -72,15 +172,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, counter = inputCounter->data.dptr; } const bool gelu = pre_gelu_out != nullptr; - const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); - const cudaDataType_t A_type = get_cuda_dtype(inputA->data.dtype); - const cudaDataType_t B_type = get_cuda_dtype(inputB->data.dtype); + const bool use_fp8 = is_fp8_dtype(param.Atype) || is_fp8_dtype(param.Btype); + + const cudaDataType_t A_type = get_cuda_dtype(param.Atype); + const cudaDataType_t B_type = get_cuda_dtype(param.Btype); const cudaDataType_t D_type = get_cuda_dtype(outputD->data.dtype); const cudaDataType_t bias_type = get_cuda_dtype(inputBias->data.dtype); - NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, + NVTE_CHECK(!is_fp8_dtype(param.Atype) || param.A_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); - NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, + NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); // check consistency of arguments: @@ -98,8 +199,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, float zero = 0.0; float beta = (accumulate) ? one : zero; - cublasLtHandle_t handle; - NVTE_CHECK_CUBLAS(cublasLtCreate(&handle)); + cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtMatmulDesc_t operationDesc = nullptr; cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; @@ -117,17 +217,17 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // Create matrix descriptors. Not setting any extra attributes. - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, transa == CUBLAS_OP_N ? m : k, - transa == CUBLAS_OP_N ? k : m, lda)); - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, transb == CUBLAS_OP_N ? k : n, - transb == CUBLAS_OP_N ? n : k, ldb)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Adesc, A_type, param.transA == CUBLAS_OP_N ? m : k, + param.transA == CUBLAS_OP_N ? k : m, param.lda)); + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Bdesc, B_type, param.transB == CUBLAS_OP_N ? k : n, + param.transB == CUBLAS_OP_N ? n : k, param.ldb)); NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescCreate(&operationDesc, gemm_compute_type, CUDA_R_32F)); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, - &transa, sizeof(transa))); + ¶m.transA, sizeof(param.transA))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, - &transb, sizeof(transb))); + ¶m.transB, sizeof(param.transB))); // Set math SM count if (math_sm_count != 0) { NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -143,12 +243,52 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, - &A_scale_inverse, sizeof(A_scale_inverse))); - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, - CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, - &B_scale_inverse, sizeof(B_scale_inverse))); + + // Scaling factors. +#if CUDA_VERSION >= 12080 + cublasLtMatmulMatrixScale_t scaling_mode; +#endif + if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { + void *A_scale_inverse = param.A_scale_inv; + void *B_scale_inverse = param.B_scale_inv; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); +#if CUDA_VERSION >= 12080 + scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; + } else if ((is_block_scaling(inputA->scaling_mode) && is_block_scaling(inputB->scaling_mode))) { + fp8e8m0 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e8m0 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0; + // Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling. + // CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set. + if (cublasLtGetVersion() <= 120803) { + const int64_t dummy_a_vec_stride = 1; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride, + sizeof(dummy_a_vec_stride))); + } +#endif + } else { + NVTE_ERROR("Not implemented scaling modes: " + to_string(inputA->scaling_mode) + " and " + + to_string(inputB->scaling_mode) + "."); + } + +#if CUDA_VERSION >= 12080 + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); +#endif if (is_fp8_dtype(outputD->data.dtype)) { // Accumulation mode not supported for FP8 output C = nullptr; @@ -156,8 +296,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &D_scale, sizeof(D_scale))); NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax))); - // For FP8 output, cuBLAS requires C_type to be same as bias_type - NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, bias_type, m, n, ldd)); +#if CUDA_VERSION >= 12080 + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_MODE, &scaling_mode, sizeof(scaling_mode))); +#endif + // For FP8 output, cuBLAS requires C_type to match bias_type and + // be FP16/BF16 + const cudaDataType_t C_type = bias ? bias_type : CUDA_R_16BF; + NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, C_type, m, n, ldd)); } else { NVTE_CHECK_CUBLAS(cublasLtMatrixLayoutCreate(&Cdesc, D_type, m, n, ldd)); } @@ -235,8 +381,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); - const auto A_alignment = _getAlignment(reinterpret_cast(A)); - const auto B_alignment = _getAlignment(reinterpret_cast(B)); + const auto A_alignment = _getAlignment(reinterpret_cast(param.A)); + const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -260,8 +406,8 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, static_cast(&one), /* alpha */ - A, /* A */ - Adesc, B, /* B */ + param.A, /* A */ + Adesc, param.B, /* B */ Bdesc, static_cast(&beta), /* beta */ C, /* C */ Cdesc, D, /* D */ @@ -270,7 +416,10 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor - if (is_fp8_dtype(outputD->data.dtype)) { + // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. + // TODO: Changing gemm interface so that D->scale_inv is allocated and the scale_inv can be + // calculated here. + if (is_fp8_dtype(outputD->data.dtype) && outputD->scale_inv.dptr) { update_tensor_scale_inv(outputD, stream); } @@ -309,9 +458,14 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = reinterpret_cast(pre_gelu_out); Tensor *wspace = reinterpret_cast(workspace); - const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; - const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; - const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; + const size_t A0 = inputA->flat_first_dim(); + const size_t A1 = inputA->flat_last_dim(); + const size_t B0 = inputB->flat_first_dim(); + const size_t B1 = inputB->flat_last_dim(); + + const int m = transa ? A0 : A1; + const int k = transa ? A1 : A0; + const int n = transb ? B1 : B0; int lda, ldb, ldd; if (transa && !transb) { // TN lda = k; @@ -357,6 +511,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor const Tensor *inputCounter = reinterpret_cast(counter); Tensor *wspace = reinterpret_cast(workspace); + NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && + is_delayed_tensor_scaling(inputB->scaling_mode), + "Atomic GEMM only supports delayed scaling."); + const int m = transa ? inputA->data.shape[0] : inputA->data.shape[1]; const int k = transa ? inputA->data.shape[1] : inputA->data.shape[0]; const int n = transb ? inputB->data.shape[1] : inputB->data.shape[0]; diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 656c647fd4..49029ed588 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -19,7 +19,9 @@ extern "C" { /* Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ -/*! \brief Compute activation of the input. +/*! \brief Computes activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] input Input tensor for activation. * \param[in,out] output Output tensor. @@ -39,17 +41,59 @@ enum class NVTE_Activation_Type { SREGLU, }; +/*! \brief Computes the GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute activation gradient. +/*! \brief Computes the GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] grad Incoming gradient. * \param[in] input Input tensor for activation. @@ -59,19 +103,57 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the SiLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient. + * \param[in] input Input tensor for activation. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute gated activation of the input. +/*! \brief Computes the gated GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. * * \param[in] input Input tensor of shape [N, H * 2]. * \param[in,out] output Output tensor of shape [N, H]. @@ -80,15 +162,54 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu */ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation of the input. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); -/*! \brief Compute gated activation gradient. +/*! \brief Computes the gated GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * * \param[in] grad Incoming gradient of shape [N, H]. * \param[in] input Forward input tensor of shape [N, H * 2]. * \param[in,out] output Outgoing gradient of shape [N, H * 2]. @@ -97,15 +218,51 @@ void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 32f16922b9..d57975b2f4 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -1,11 +1,11 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ /*! \file cast.h - * \brief Functions to cast to/from FP8. + * \brief Functions to cast to/from FP8/MXFP8. */ #ifndef TRANSFORMER_ENGINE_CAST_H_ @@ -17,21 +17,200 @@ extern "C" { #endif -/*! \brief Cast tensor to FP8. +/* Cast the tensor to FP8 (or microscaling FP8 if the compute capability of the device is 10.0 or newer) + * The implementation is per the microscaling format MXFP8 defined by the OCP specification: + * https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf * - * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8 tensor. - * \param[in] stream CUDA stream used for the operation. + * Supported modes of scaling (live scaling): + * 1) Rowwise scaling (along the dim=0) computes one set of the output data, which includes: + * - the scaled output tensor + * - the corresponding scaling factors + * The scaling factors are computed for blocks of the shape [1,32] + * (i.e., each scaling factor spans 32 contiguous elements along rows). + * + * 2) Columwise scaling (along the dim=1) computes one set of the output data. + * The scaling factors are computed for blocks of the shape [32,1] + * (i.e., each scaling factor spans 32 contiguous elements along columns). + * + * 3) Both rowwise AND columnwise scaling (along the dim=0 and the dim=1) + * computes two sets of the output data: both 1) and 2). + * + * The shape of the MX block must be specified in the 'output' argument, + * and can be either [1,32] or [32,1] as no other shapes are currently supported. + * + * To cast the input tensor to the MXFP8, the scaling_mode.delayed_scaling parameter + * of the output tensor should be set to 0. + */ + +/*! \brief Casts input tensor to FP8/MXFP8. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] noop Noop tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream); + +/*! \brief Casts input tensor to MXFP8. Additionally, reduces the input along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workplace, cudaStream_t stream); + +/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the SiLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Quick GeLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); + +/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. + * Additionally, reduces the result of the Squared ReLU backward along columns. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * This function produces 2 results: + * - `output` is equal to `cast(dact(input))` + * - `dbias` is equal to `reduce(dact(input), dim=1)` + * + * Calling this function with the workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor to be cast. + * \param[in] act_input Activation input tensor. + * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[out] dbias Result of the reduction of the input along columns. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); -/*! \brief Cast tensor from FP8. +/*! \brief Casts input tensor from reduced to higher precision. + * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, + * the block dequantization (MXFP8) of the specified shape of the block will be used. + * In case of the MXFP8 dequantization, the dequantized values are stored to the rowwise + * data of the output tensor, regardless of whether the row- or columnwise scaling is used. * - * \param[in] input Input tensor to be cast. - * \param[out] output Output tensor. + * \param[in] input Input FP8/MXFP8 tensor to be cast. + * \param[in,out] output Output tensor. * \param[in] stream CUDA stream used for the operation. */ -void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h index 9043162bcb..678ffe9191 100644 --- a/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h +++ b/transformer_engine/common/include/transformer_engine/cast_transpose_noop.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -17,11 +17,26 @@ extern "C" { #endif +/*! \brief Transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); -void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, - NVTETensor cast_output, NVTETensor transposed_output, +/*! \brief Casts and transposes the input, providing the option to immediately exit the kernel + * based on the value of the 'noop' tensor. + * + * \param[in] input Input tensor. + * \param[in] noop Noop tensor. + * \param[in,out] output Output tensor. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream); #ifdef __cplusplus diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 17ecca5ff0..293c57526d 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -53,6 +53,8 @@ class CommOverlapCore { int _cga_size; int _use_ce; int _ub_reg; + int _gemm_priority; + int _comm_priority; bool _atomic_gemm{false}; bool _is_p2p{false}; @@ -62,13 +64,16 @@ class CommOverlapCore { bool _ubuf_scale_inv_initialized{false}; std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; public: + CommOverlapCore() {} // dummy constructor for exposing type to Python + CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, - int num_splits, int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool use_ce, bool atomic_gemm); + int num_splits, int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm); virtual ~CommOverlapCore(); @@ -77,25 +82,76 @@ class CommOverlapCore { _ubuf_scale_inv_initialized = true; } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, + const std::vector &shape); + + TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, + const std::vector &shape); + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } bool is_fp8_ubuf() { return _ubuf.element_size() == 1; } + + virtual void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, CommOverlapType comm_type, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, + TensorWrapper &rs_output, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, + bool grad, bool accumulate, bool use_split_accumulator, + TensorWrapper &B_copy, cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } + + virtual void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { protected: int _rs_kernel_type; + bool _rs_overlap_first_gemm; cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; public: + CommOverlapBase() {} // dummy constructor for exposing type to Python + CommOverlapBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); virtual ~CommOverlapBase(); @@ -103,97 +159,124 @@ class CommOverlapBase : public CommOverlapCore { ** Bulk GEMM + COMM ** This function assumes the communication input is pre-copied to _ubuf */ - void bulk_overlap(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, TensorWrapper &D, - TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, - bool grad, bool accumulate, bool use_split_accumulator, - CommOverlapType comm_type, TensorWrapper &rs_output, cudaStream_t stream_main); + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override; + + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } /* ** Split FPROP GEMM + ReduceScatter */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - TensorWrapper &rs_output, cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split FPROP GEMM + ReduceScatter */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, TensorWrapper &rs_output, - cudaStream_t stream_main); + bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { protected: bool _is_reduce_scatter{false}; bool _use_multiatomic_ag{false}; - + bool _aggregate; int _next_rank; int _prev_rank; int _rank_round_tp; - int _aggregate; int _num_ubuf_chunks; int _self_chunk_id; - std::vector _ubufs; - - cudaStream_t _stream_send; + std::vector _stream_send; cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; public: + CommOverlapP2PBase() {} // dummy constructor for exposing type to Python + CommOverlapP2PBase(const std::vector &buffer_shape, DType buffer_dtype, int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes, int tp_size, ExtAllgatherOp allgather_handle, ExtBarrierOp barrier_handle, CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, - int comm_cga_size = 1, int num_comm_sm = 1, bool set_sm_margin = false, - bool use_ce = true, bool atomic_gemm = false, bool aggregate = false); + int comm_cga_size = 1, int gemm_priority = 0, int comm_priority = 0, + int num_comm_sm = 1, bool set_sm_margin = false, bool use_ce = true, + bool atomic_gemm = false, bool aggregate = false); virtual ~CommOverlapP2PBase(); + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); + + void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, + TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, + TensorWrapper &workspace, bool grad, bool accumulate, + bool use_split_accumulator, CommOverlapType comm_type, TensorWrapper &rs_output, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } + /* ** Split AllGather + AtomicGEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void atomic_gemm_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + void atomic_gemm_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, + cudaStream_t stream_main) override; /* ** Split AllGather + GEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG ** outputs in each rank to be in the contiguous memory space after all ring exchange phases. */ - void split_overlap_ag(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_ag(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &B_copy, - cudaStream_t stream_main); + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void atomic_gemm_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, - TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, - TensorWrapper &workspace, bool grad, bool accumulate, - bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + void atomic_gemm_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, + bool transb, TensorWrapper &D, TensorWrapper &bias, + TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, + bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, + cudaStream_t stream_main) override; /* ** Split ReduceScatter + GEMM using P2P communication */ - void split_overlap_rs(TensorWrapper &A, bool transa, TensorWrapper &B, bool transb, + void split_overlap_rs(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, TensorWrapper &D, TensorWrapper &bias, TensorWrapper &pre_gelu_out, TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, - cudaStream_t stream_main); + cudaStream_t stream_main) override; }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/include/transformer_engine/cudnn.h b/transformer_engine/common/include/transformer_engine/cudnn.h index c5e4bc23a9..70acead631 100644 --- a/transformer_engine/common/include/transformer_engine/cudnn.h +++ b/transformer_engine/common/include/transformer_engine/cudnn.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ae08f2a4aa..3c7b3f5817 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -25,24 +25,34 @@ extern "C" { * head size, and the total number of tokens in a batch, i.e. `t = sum(s_i) for i = 0...b-1`. * `SBHD` and `BSHD`-based layouts are used when sequences in a batch are of equal length * or padded to the same length, and `THD`-based layouts are used when sequences have - * different lengths in a batch. + * different lengths in a batch. `Paged_KV`-based layouts are used for paged attention. */ enum NVTE_QKV_Layout { - NVTE_SB3HD = 0, /*!< SB3HD layout */ - NVTE_SBH3D = 1, /*!< SBH3D layout */ - NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ - NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ - NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ - NVTE_BS3HD = 5, /*!< BS3HD layout */ - NVTE_BSH3D = 6, /*!< BSH3D layout */ - NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ - NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ - NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ - NVTE_T3HD = 10, /*!< T3HD layout */ - NVTE_TH3D = 11, /*!< TH3D layout */ - NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ - NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ - NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ + NVTE_SB3HD = 0, /*!< SB3HD layout */ + NVTE_SBH3D = 1, /*!< SBH3D layout */ + NVTE_SBHD_SB2HD = 2, /*!< SBHD_SB2HD layout */ + NVTE_SBHD_SBH2D = 3, /*!< SBHD_SBH2D layout */ + NVTE_SBHD_SBHD_SBHD = 4, /*!< SBHD_SBHD_SBHD layout */ + NVTE_BS3HD = 5, /*!< BS3HD layout */ + NVTE_BSH3D = 6, /*!< BSH3D layout */ + NVTE_BSHD_BS2HD = 7, /*!< BSHD_BS2HD layout */ + NVTE_BSHD_BSH2D = 8, /*!< BSHD_BSH2D layout */ + NVTE_BSHD_BSHD_BSHD = 9, /*!< BSHD_BSHD_BSHD layout */ + NVTE_T3HD = 10, /*!< T3HD layout */ + NVTE_TH3D = 11, /*!< TH3D layout */ + NVTE_THD_T2HD = 12, /*!< THD_T2HD layout */ + NVTE_THD_TH2D = 13, /*!< THD_TH2D layout */ + NVTE_THD_THD_THD = 14, /*!< THD_THD_THD layout */ + NVTE_SBHD_BSHD_BSHD = 15, /*!< SBHD_BSHD_BSHD layout */ + NVTE_BSHD_SBHD_SBHD = 16, /*!< BSHD_SBHD_SBHD layout */ + NVTE_THD_BSHD_BSHD = 17, /*!< THD_BSHD_BSHD layout */ + NVTE_THD_SBHD_SBHD = 18, /*!< THD_SBHD_SBHD layout */ + NVTE_Paged_KV_BSHD_BSHD_BSHD = 19, /*!< Paged_KV_BSHD_BSHD_BSHD layout */ + NVTE_Paged_KV_BSHD_SBHD_SBHD = 20, /*!< Paged_KV_BSHD_SBHD_SBHD layout */ + NVTE_Paged_KV_SBHD_BSHD_BSHD = 21, /*!< Paged_KV_SBHD_BSHD_BSHD layout */ + NVTE_Paged_KV_SBHD_SBHD_SBHD = 22, /*!< Paged_KV_SBHD_SBHD_SBHD layout */ + NVTE_Paged_KV_THD_BSHD_BSHD = 23, /*!< Paged_KV_THD_BSHD_BSHD layout */ + NVTE_Paged_KV_THD_SBHD_SBHD = 24, /*!< Paged_KV_THD_SBHD_SBHD layout */ }; /*! \enum NVTE_QKV_Layout_Group @@ -59,18 +69,28 @@ enum NVTE_QKV_Layout_Group { NVTE_HD_H2D = 3, /*! HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD */ NVTE_HD_HD_HD = 4, + /*! Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD */ + NVTE_Paged_KV_HD_HD_HD = 5, }; /*! \enum NVTE_QKV_Format * \brief QKV formats */ enum NVTE_QKV_Format { - /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD */ + /*! SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_SBHD_SBHD */ NVTE_SBHD = 0, - /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD */ + /*! BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_BSHD_BSHD */ NVTE_BSHD = 1, /*! THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD */ NVTE_THD = 2, + /*! BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD */ + NVTE_BSHD_2SBHD = 3, + /*! SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD */ + NVTE_SBHD_2BSHD = 4, + /*! THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD */ + NVTE_THD_2BSHD = 5, + /*! THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD */ + NVTE_THD_2SBHD = 6, }; /*! \enum NVTE_Bias_Type @@ -135,6 +155,22 @@ NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout); */ NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout); +/*! \brief Get Q format for a given QKV layout. + * + * \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd. + * + * \return q format, e.g. sbhd. + */ +NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout); + +/*! \brief Get KV format for a given QKV layout. + * + * \param[in] qkv_layout QKV layout, e.g. sbhd_bshd_bshd. + * + * \return kv format, e.g. bshd. + */ +NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); + /*! \brief Get fused attention backend based on input parameters. * * \param[in] q_dtype The data type of Tensor Q. @@ -312,6 +348,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k]. + * \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -329,16 +367,14 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, - NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, + NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -445,6 +481,8 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv_padded Cumulative sequence offsets for KV, [batch_size + 1]. + * \param[in] page_table_k Page table for K cache, [batch_size, max_pages_per_seq_k]. + * \param[in] page_table_v Page table for V cache, [batch_size, max_pages_per_seq_v]. * \param[in] rng_state Seed and offset of CUDA random number generator. * \param[in] max_seqlen_q Max sequence length used for computing for Q. * it may be >= max(seqlen_q_i) for i=0,...batch_size-1. @@ -465,7 +503,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack* Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index b7b9b93881..41a0e3bc76 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 1cdbfd2eb5..2cb99f3d28 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/layer_norm.h b/transformer_engine/common/include/transformer_engine/layer_norm.h deleted file mode 100644 index 3bb4d47f29..0000000000 --- a/transformer_engine/common/include/transformer_engine/layer_norm.h +++ /dev/null @@ -1,159 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file layer_norm.h - * \brief LayerNorm functions. - */ - -#ifndef TRANSFORMER_ENGINE_LAYER_NORM_H_ -#define TRANSFORMER_ENGINE_LAYER_NORM_H_ - -#include "transformer_engine.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*! \brief Compute LayerNorm on the input. - * - * The formula used: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta - * @f] - * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. - * - * \param[in] x Input tensor of shape [N, H]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[in] beta Beta tensor of shape [H]. - * \param[in] epsilon Value added to denominator for numerical stability. - * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[out] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier); - -/*! \brief Compute LayerNorm with zero-centered gamma on the input. - * - * The formula used: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta - * @f] - * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. - * - * \param[in] x Input tensor of shape [N, H]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[in] beta Beta tensor of shape [H]. - * \param[in] epsilon Value added to denominator for numerical stability. - * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[out] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier); - -/*! \brief Compute backward of LayerNorm. - * - * This function computes the gradient of function: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta - * @f] - * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. - * - * Calling this function with workspace, barrier, dgamma_part and dbeta_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[in] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dbeta Gradient for beta tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[out] dbeta_part Storage for partial bias gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, - NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); - -/*! \brief Compute backward of LayerNorm with zero-centered gamma. - * - * This function computes the gradient of function: - * @f[ - * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}(1 + \gamma) + \beta - * @f] - * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. - * - * Calling this function with workspace, barrier, dgamma_part and dbeta_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] mu Mean of the input calculated over the last dimension. - * Shape: [N]. - * \param[in] rsigma Inverse of the variance of the input calculated over - * the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dbeta Gradient for beta tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[out] dbeta_part Storage for partial bias gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. - */ -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, - NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); -#ifdef __cplusplus -} // extern "C" -#endif - -#endif // TRANSFORMER_ENGINE_LAYER_NORM_H_ diff --git a/transformer_engine/common/include/transformer_engine/rmsnorm.h b/transformer_engine/common/include/transformer_engine/normalization.h similarity index 54% rename from transformer_engine/common/include/transformer_engine/rmsnorm.h rename to transformer_engine/common/include/transformer_engine/normalization.h index dc995e3c24..8c34540e34 100644 --- a/transformer_engine/common/include/transformer_engine/rmsnorm.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -1,15 +1,15 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ -/*! \file rmsnorm.h - * \brief RMSNorm functions. +/*! \file normalization.h + * \brief LayerNorm and RMSNorm functions. */ -#ifndef TRANSFORMER_ENGINE_RMSNORM_H_ -#define TRANSFORMER_ENGINE_RMSNORM_H_ +#ifndef TRANSFORMER_ENGINE_NORMALIZATION_H_ +#define TRANSFORMER_ENGINE_NORMALIZATION_H_ #include "transformer_engine.h" @@ -17,41 +17,73 @@ extern "C" { #endif -/*! \brief Compute RMSNorm on the input. +/*! \brief Compute LayerNorm on the input. * * The formula used: * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}\gamma - * @f] - * where - * @f[ - * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta * @f] * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] x Input tensor of shape [N, H]. * \param[in] gamma Gamma tensor of shape [H]. + * \param[in] beta Beta tensor of shape [H]. * \param[in] epsilon Value added to denominator for numerical stability. * \param[in,out] z Output tensor of shape [N, H]. - * \param[out] rsigma Reciprocal of the root mean square of the input - * calculated over the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. + * \param[out] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[out] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[out] workspace Workspace tensor. * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta, + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); + +/*! \brief Compute backward of LayerNorm. + * + * This function computes the gradient of function: + * @f[ + * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}}\gamma + \beta + * @f] + * else + * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. + * + * Calling this function with workspace set to empty tensor will not perform the operation, + * but instead set the shape and type of these tensors to the required values. + * + * \param[in] dz Incoming gradient tensor of shape [N, H]. + * \param[in] x Forward input tensor of shape [N, H]. + * \param[in] mu Mean of the input calculated over the last dimension. + * Shape: [N]. + * \param[in] rsigma Inverse of the variance of the input calculated over + * the last dimension. Shape: [N]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[out] dx Output gradient of shape [N, H]. + * \param[out] dgamma Gradient for gamma tensor of shape [H]. + * \param[out] dbeta Gradient for beta tensor of shape [H]. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ -void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, - NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); +void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor mu, + const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, + NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream); -/*! \brief Compute RMSNorm with zero-centered gamma on the input. +/*! \brief Compute RMSNorm. * * The formula used: * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma) + * y = \frac{x}{RMS_\varepsilon(x)}\gamma * @f] * where * @f[ @@ -68,14 +100,14 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep * \param[in,out] z Output tensor of shape [N, H]. * \param[out] rsigma Reciprocal of the root mean square of the input * calculated over the last dimension. Shape: [N]. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ -void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, - NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier); +void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, + NVTETensor rsigma, NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); /*! \brief Compute backward of RMSNorm. * @@ -100,53 +132,25 @@ void nvte_rmsnorm1p_fwd(const NVTETensor x, const NVTETensor gamma, const float * \param[in] gamma Gamma tensor of shape [H]. * \param[out] dx Output gradient of shape [N, H]. * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. */ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); -/*! \brief Compute backward of RMSNorm with zero-centered gamma. +/*! \brief Helper to enable cuDNN backend for normalization * - * This function computes the gradient of function: - * @f[ - * y = \frac{x}{RMS_\varepsilon(x)}(1 + \gamma) - * @f] - * where - * @f[ - * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} - * @f] - * with respect to \f$x\f$ and \f$gamma\f$. - * - * Calling this function with workspace, barrier, dgamma_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. - * - * \param[in] dz Incoming gradient tensor of shape [N, H]. - * \param[in] x Forward input tensor of shape [N, H]. - * \param[in] rsigma Reciprocal of the root mean square of the input - * calculated over the last dimension. Shape: [N]. - * \param[in] gamma Gamma tensor of shape [H]. - * \param[out] dx Output gradient of shape [N, H]. - * \param[out] dgamma Gradient for gamma tensor of shape [H]. - * \param[out] dgamma_part Storage for partial gamma gradient. - * \param[in] stream CUDA stream used for the operation. - * \param[in] multiprocessorCount Number of SMs in the device. - * \param[out] workspace Workspace tensor. - * \param[out] barrier Barrier tensor. + * \param[in] bool Enable if True */ -void nvte_rmsnorm1p_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, - const NVTETensor gamma, NVTETensor dx, NVTETensor dgamma, - NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier); +void nvte_enable_cudnn_norm_fwd(bool enable); +void nvte_enable_cudnn_norm_bwd(bool enable); #ifdef __cplusplus } // extern "C" #endif -#endif // TRANSFORMER_ENGINE_RMSNORM_H_ +#endif // TRANSFORMER_ENGINE_NORMALIZATION_H_ diff --git a/transformer_engine/common/include/transformer_engine/padding.h b/transformer_engine/common/include/transformer_engine/padding.h index a419b38234..4258463b1b 100644 --- a/transformer_engine/common/include/transformer_engine/padding.h +++ b/transformer_engine/common/include/transformer_engine/padding.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/permutation.h b/transformer_engine/common/include/transformer_engine/permutation.h index c6263bf87e..195075c975 100644 --- a/transformer_engine/common/include/transformer_engine/permutation.h +++ b/transformer_engine/common/include/transformer_engine/permutation.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 61b1f231b8..44614bbe6b 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -28,16 +28,10 @@ extern "C" { * \param[in] amax_history History of maximum absolute values. * Shape: [history_length, num_scales] * \param[in] scale Scaling factor for casting to FP8. Shape: [num_scales] - * \param[in] scale_inv Scaling factor for casting from FP8. Shape: [num_scales] - * \param[in] scale_inv_mask Boolean mask indicating scale_inv entries to update. May be - * empty, in which case all scale_inv entries are updated. - * Shape: [num_scales] * \param[out] updated_amax_history Updated history of maximum absolute values. * Shape: [history_length, num_scales] * \param[out] updated_scale Updated scaling factor for casting to FP8. * Shape: [num_scales] - * \param[out] updated_scale_inv Updated scaling factor for casting from FP8. - * Shape: [num_scales] * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * "most_recent". * \param[in] fp8_dtype FP8 datatype. @@ -45,9 +39,8 @@ extern "C" { * \param[in] stream CUDA stream. */ void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, - NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); /*! \brief Bulk-update FP8 scaling factors with delayed scaling recipe after amax reduction. @@ -55,7 +48,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( * Operations performed include, updating the most recent amax history * with the relevant segment of global reduction buffer if it's not 0, * rotating the amax history based on the rule below, and updating the - * scales and scale_invs. + * scales. * * The amax history is rotated by -1 (e.g. the first entry shifts to * the last, the last entry shifts to the second to last) and the @@ -69,8 +62,6 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( * Shape: num_tensors x [history_length, num_scales] * \param[in,out] scales List of scaling factors for casting to FP8. * Shape: num_tensors x [num_scales] - * \param[in,out] scale_invs List of scaling factors for casting from FP8. - * Shape: num_tensors x [num_scales] * \param[in] amax_compute_algo Method to reduce amax history. Options are "max" and * "most_recent". * \param[in] fp8_dtype FP8 datatype. @@ -79,8 +70,31 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update( */ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream); + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, cudaStream_t stream); + +/*! \brief Compute an FP8 tensor's amax. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Update an FP8 tensor's scale based on its amax. + * + * This is only supported for FP8 tensors with per-tensor scaling. + * Options are primarily intended for FP8 current-scaling recipes. + * + * \param[in,out] output FP8 tensor with per-tensor scaling. + * \param[in] config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/include/transformer_engine/softmax.h b/transformer_engine/common/include/transformer_engine/softmax.h index 6a6fc15fa6..9f1c423172 100644 --- a/transformer_engine/common/include/transformer_engine/softmax.h +++ b/transformer_engine/common/include/transformer_engine/softmax.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h new file mode 100644 index 0000000000..de5a11eb73 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast.h + * \brief Functions to cast to/from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_ +#define TRANSFORMER_ENGINE_SWIZZLE_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] input Input tensor with non-swizzled scale_inv. + * \param[in,out] output Output tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_SWIZZLE_H_ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index d302518235..dd1cfb8ddb 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -30,6 +30,7 @@ enum NVTEDType { kNVTEBFloat16 = 5, /*!< 16-bit bfloat (E8M7) */ kNVTEFloat8E4M3 = 6, /*!< 8-bit float (E4M3) */ kNVTEFloat8E5M2 = 7, /*!< 8-bit float (E5M2) */ + kNVTEFloat8E8M0 = 8, /*!< 8-bit float (E8M0) */ kNVTENumTypes /*!< Number of supported types */ }; @@ -43,6 +44,45 @@ struct NVTEShape { size_t ndim; }; +/*! \struct NVTEBasicTensor + * \brief A basic tensor type used to populate parameters of NVTETensor. + * It does not own the memory it points to. + */ +struct NVTEBasicTensor { + void *data_ptr; + NVTEDType dtype; + NVTEShape shape; +}; + +/*! \enum NVTETensorParam + * \brief Indicates the kind of the tensor parameter to set/get. + */ +enum NVTETensorParam { + kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */ + kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */ + kNVTEScale = 2, /*!< Scale tensor */ + kNVTEAmax = 3, /*!< Amax tensor */ + kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ + kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTENumTensorParams +}; + +/*! \enum NVTEScalingMode + * \brief Tensor data format. + */ +enum NVTEScalingMode { + /*! Either an unquantized tensor or an FP8 tensor with per-tensor scaling + * + * Not necessary used for delayed tensor scaling. The unintuitive + * name reflects legacy usage. + */ + NVTE_DELAYED_TENSOR_SCALING = 0, + /*! Single scale per block of 32 elements consecutive in either + rowwise or columnwise direction */ + NVTE_MXFP8_1D_SCALING = 1, + NVTE_INVALID_SCALING +}; + /*! \brief TE Tensor type * * NVTETensor is a contiguous tensor type storing a pointer @@ -53,21 +93,15 @@ typedef void *NVTETensor; /*! \brief Create a new TE tensor. * - * Create a new TE tensor with a given shape, datatype and data. + * Create a new TE tensor. Before use its parameters need to be set. * TE tensors are just wrappers on top of raw data and do not * own memory. * - * \param[in] dptr Pointer to the tensor data. - * \param[in] shape Shape of the tensor. - * \param[in] dtype Data type of the tensor. - * \param[in] amax_dptr Pointer to the AMAX value. - * \param[in] scale_dptr Pointer to the scale value. - * \param[in] scale_inv_dptr Pointer to the inverse of scale value. + * \param[in] scaling_mode Scaling mode of the tensor. * * \return A new TE tensor. */ -NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, - float *amax_dptr, float *scale_dptr, float *scale_inv_dptr); +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode); /*! \brief Destroy a TE tensor. * @@ -78,14 +112,22 @@ NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType */ void nvte_destroy_tensor(NVTETensor tensor); -/*! \brief Get a raw pointer to the tensor's data. +/*! \brief Get a raw pointer to the tensor's rowwise data. * * \param[in] tensor Tensor. * - * \return A raw pointer to tensor's data. + * \return A raw pointer to tensor's rowwise data. */ void *nvte_tensor_data(const NVTETensor tensor); +/*! \brief Get a raw pointer to the tensor's columnwise data. + * + * \param[in] tensor Tensor. + * + * \return A raw pointer to tensor's columnwise data. + */ +void *nvte_tensor_columnwise_data(const NVTETensor tensor); + /*! \brief Get a tensor's data shape. * * \param[in] tensor Tensor. @@ -94,6 +136,14 @@ void *nvte_tensor_data(const NVTETensor tensor); */ NVTEShape nvte_tensor_shape(const NVTETensor tensor); +/*! \brief Get a tensor's data shape. + * + * \param[in] tensor Tensor. + * + * \return A shape of the input tensor. + */ +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor); + /*! \brief Get a tensor's number of dimensions. * * \param[in] tensor Tensor. @@ -159,6 +209,46 @@ float *nvte_tensor_scale(const NVTETensor tensor); */ float *nvte_tensor_scale_inv(const NVTETensor tensor); +/*! \brief Get a tensor's scale_inv shape. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor); + +/*! \brief Reset tensor value to zero. + * + * \param[in] tensor Tensor. + * + * \return A scale_inv shape of the input tensor. + */ +void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); + +/*! \brief Set a parameter of the tensor. + * + * \param[in/out] tensor Tensor. + * \param[in] param_name The parameter to be set. + * \param[in] param The value to be set. + */ +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param); + +/*! \brief Get a value of the parameter of the tensor. + * + * \param[in] tensor Tensor. + * \param[in] param_name The parameter to be set. + */ +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name); + +/*! \brief Get the granularity of scaling of this tensor. + * + * \param[in] tensor Tensor. + * + * \return A struct containing the granularity of tensor's scaling. + */ +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor); + /*! \struct NVTETensorPack \brief Pack of tensors, generally used for auxiliary outputs. */ @@ -179,6 +269,57 @@ void nvte_tensor_pack_create(NVTETensorPack *pack); */ void nvte_tensor_pack_destroy(NVTETensorPack *pack); +/*! \brief Configuration for tensor quantization. */ +typedef void *NVTEQuantizationConfig; + +/*! \enum NVTEQuantizationConfigAttribute + * \brief Type of option for tensor quantization. + */ +enum NVTEQuantizationConfigAttribute { + /*! Whether to force power of 2 scales */ + kNVTEQuantizationConfigForcePow2Scales = 0, + /*! Small value to add to amax for numerical stability */ + kNVTEQuantizationConfigAmaxEpsilon = 1, + kNVTEQuantizationConfigNumAttributes +}; + +/*! \brief Create a new quantization config. + * \return A new quantization config. + */ +NVTEQuantizationConfig nvte_create_quantization_config(); + +/*! \brief Query an option in quantization config. + * + * \param[in] config Quantization config. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in quantization config. + * + * \param[in] config Quantization config. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, const void *buf, + size_t size_in_bytes); + +/*! \brief Destroy a quantization config. + * + * \param[in] config Config to be destroyed. + */ +void nvte_destroy_quantization_config(NVTEQuantizationConfig config); + #ifdef __cplusplus } // extern "C" @@ -201,6 +342,7 @@ enum class DType { kBFloat16 = 5, kFloat8E4M3 = 6, kFloat8E5M2 = 7, + kFloat8E8M0 = 8, kNumTypes }; @@ -220,12 +362,23 @@ class TensorWrapper { * \param[in] dtype Data type of the tensor. * \param[in] amax_dptr Pointer to the AMAX value. * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv * \param[in] scale_inv_dptr Pointer to the inverse of scale value. */ TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, - float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr) - : tensor_(nvte_create_tensor(dptr, shape, static_cast(dtype), amax_dptr, - scale_dptr, scale_inv_dptr)) {} + float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, + const NVTEShape scale_inv_shape = defaultShape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { + tensor_ = nvte_create_tensor(scaling_mode); + NVTEBasicTensor data = {dptr, static_cast(dtype), shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data); + NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax); + NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape}; + nvte_set_tensor_param(&tensor_, kNVTEScale, &scale); + NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape}; + nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv); + } /*! \brief Constructs new TensorWrapper. * @@ -238,19 +391,23 @@ class TensorWrapper { * \param[in] dtype Data type of the tensor. * \param[in] amax_dptr Pointer to the AMAX value. * \param[in] scale_dptr Pointer to the scale value. + * \param[in] scale_inv_shape Shape of scale_inv * \param[in] scale_inv_dptr Pointer to the inverse of scale value. */ TensorWrapper(void *dptr, const std::vector &shape, const DType dtype, float *amax_dptr = nullptr, float *scale_dptr = nullptr, - float *scale_inv_dptr = nullptr) + float *scale_inv_dptr = nullptr, const std::vector &scale_inv_shape = {1}, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) : TensorWrapper(dptr, NVTEShape{shape.data(), shape.size()}, dtype, amax_dptr, scale_dptr, - scale_inv_dptr) {} + scale_inv_dptr, NVTEShape{scale_inv_shape.data(), scale_inv_shape.size()}, + scaling_mode) {} /*! \brief Constructs new empty TensorWrapper. * * Create a new empty TE tensor which holds nothing. */ - TensorWrapper() : TensorWrapper(nullptr, std::vector(), DType::kFloat32) {} + explicit TensorWrapper(const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_tensor(scaling_mode)) {} /*! \brief TensorWrapper destructor. */ ~TensorWrapper() { nvte_destroy_tensor(tensor_); } @@ -283,6 +440,70 @@ class TensorWrapper { return *this; } + // Parameter setters + template + TensorWrapper &set_parameter(const NVTETensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_tensor_param(&tensor_, param, &data); + return *this; + } + + template + TensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseData, dptr, type, shape); + } + + template + TensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEScale, dptr, type, shape); + } + + template + TensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEAmax, dptr, type, shape); + } + + template + TensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTERowwiseScaleInv, dptr, type, shape); + } + + template + TensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); + } + + // Parameter getters + + NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { + return nvte_get_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTERowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEColumnwiseScaleInv); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -298,6 +519,15 @@ class TensorWrapper { return nvte_tensor_shape(tensor_); } + /*! \brief Get the shape of this TensorWrapper. + * + * \return Shape of this TensorWrapper. + */ + const NVTEShape columnwise_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_columnwise_shape(tensor_); + } + /*! \brief Get the size of this TensorWrapper in the given dimension. * * \param[in] size_t Dimension index. @@ -325,7 +555,7 @@ class TensorWrapper { * \return Number of elements in the tensor. */ size_t numel() const noexcept { - if (tensor_ == nullptr || this->dptr() == nullptr) return 0; + if (tensor_ == nullptr) return 0; return nvte_tensor_numel(tensor_); } @@ -366,6 +596,15 @@ class TensorWrapper { return nvte_tensor_data(tensor_); } + /*! \brief Get a raw pointer to the tensor's data. + * + * \return A raw pointer to tensor's data. + */ + void *columnwise_dptr() const noexcept { + if (tensor_ == nullptr) return nullptr; + return nvte_tensor_columnwise_data(tensor_); + } + /*! \brief Get a pointer to the tensor's amax data. * * \return A pointer to tensor's amax data. @@ -393,11 +632,90 @@ class TensorWrapper { return nvte_tensor_scale_inv(tensor_); } + /*! \brief Get the scale_inv_shape of this TensorWrapper. + * + * \return scale_inv_shape of this TensorWrapper. + */ + const NVTEShape scale_inv_shape() const noexcept { + if (tensor_ == nullptr) return NVTEShape{nullptr, 0}; + return nvte_tensor_scale_inv_shape(tensor_); + } + + /*! \brief Get a scaling mode of the tensor. + * + * \return Scaling mode of the tensor. + */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_tensor_scaling_mode(tensor_); + } + + void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = {&defaultData, 1}; + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { return {s.data(), s.size()}; } + /*! \brief Wrapped NVTETensor. */ NVTETensor tensor_ = nullptr; }; +/*! \struct QuantizationConfigWrapper + * \brief C++ wrapper for NVTEQuantizationConfigWrapper. + */ +class QuantizationConfigWrapper { + public: + QuantizationConfigWrapper() : config_{nvte_create_quantization_config()} {} + + QuantizationConfigWrapper(const QuantizationConfigWrapper &) = delete; + QuantizationConfigWrapper &operator=(const QuantizationConfigWrapper &) = delete; + + QuantizationConfigWrapper(QuantizationConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + QuantizationConfigWrapper &operator=(QuantizationConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_quantization_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~QuantizationConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_quantization_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEQuantizationConfig. + * + * \return NVTEQuantizationConfig held by this QuantizationConfigWrapper. + */ + operator NVTEQuantizationConfig() const noexcept { return config_; } + + /*! \brief Set whether to force power of 2 scales */ + void set_force_pow_2_scales(bool force_pow_2_scales) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales, + &force_pow_2_scales, sizeof(bool)); + } + + /*! \brief Set small value to add to amax */ + void set_amax_epsilon(float amax_epsilon) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigAmaxEpsilon, + &amax_epsilon, sizeof(float)); + } + + private: + /*! \brief Wrapped NVTEQuantizationConfig. */ + NVTEQuantizationConfig config_ = nullptr; +}; + } // namespace transformer_engine #endif // __cplusplus diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index ef3d344b05..a7db5cba47 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -20,16 +20,16 @@ extern "C" { /*! \brief Cast and transpose the input. * * This function casts the input and produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * - rowwise data in `output` is the result of the cast + * - columnwise data in `output` is the transposed result of the cast. * - * \param[in] input Input tensor of shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[in] stream CUDA stream used for the operation. + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, cudaStream_t stream); +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Transpose the input. * @@ -41,25 +41,24 @@ void nvte_transpose(const NVTETensor input, NVTETensor transposed_output, cudaSt /*! \brief Cast and transpose the input. Additionally, reduce the input along the first dimension. * - * This function casts the input and produces 3 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * This function casts the input and produces 2 results: + * - `output` is the result of the cast (rowwise data) and transposed cast (columnwise data) * - `dbias` is the result of the reduction of the input along the first dimension. * * Calling this function with workspace being an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * - * \param[in] input Input tensor of shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[out] dbias Result of the reduction of the input along the - * first dimension. Shape: [H]. - * \param[out] workspace Workspace tensor. - * \param[in] stream CUDA stream used for the operation. + * \param[in] input Input tensor of shape [N, H]. + * \param[in,out] output Result of the cast and transpose. + * Shape of the rowwise data: [N, H]. + * Shape of the columnwise data: [H, N] + * \param[out] dbias Result of the reduction of the input along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. */ -void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream); +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream); /*! \brief Transpose the FP8 input. Additionally, reduce the input along the first dimension. * @@ -82,102 +81,242 @@ void nvte_fp8_transpose_dbias(const NVTETensor input, NVTETensor transposed_outp /*! \brief Cast and transpose multiple tensors. * - * This function casts each input tensor and produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. - * - * \param[in] num_tensors Number of tensors. - * \param[in] input_list List of 2D input tensors. - * \param[in,out] cast_output_list List of casted tensors. Dimensions - * match tensors in input_list. - * \param[in,out] transposed_output_list List of casted and transposed - * tensors. Dimensions are transpose - * of tensors in input_list. - * \param[in] stream CUDA stream used for the operation. + * \param[in] num_tensors Number of tensors. + * \param[in] input_list List of 2D input tensors. + * \param[in,out] output_list List of casted tensors. Dimensions + * of their rowwise data members match + * tensors in input_list. Dimensions of + * their columnwise data members are + * transposed. + * \param[in] stream CUDA stream used for the operation. */ void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, - NVTETensor* cast_output_list, NVTETensor* transposed_output_list, - cudaStream_t stream); + NVTETensor* output_list, cudaStream_t stream); -/*! \brief Compute backward of ActLU operation on the input, then cast and transpose. Additionally, - * reduce the result of the SiLU backward along the first dimension. +/*! \brief Compute backward of GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the GeLU backward along the first dimension. * - * This function produces 3 results: - * - `cast_output` is equal to `cast(dact(input))` - * - `transposed_output` is equal to `transpose(cast(dact(input)))` + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` * - `dbias` is equal to `reduce(dact(input), axis=0)` * * Calling this function with workspace being an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] input Input tensor of shape [N, H]. - * \param[in] act_input Tensor used as input to the forward of SiLU operation. + * \param[in] act_input Tensor used as input for the operation of forward activation. * Shape [N, H]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H, N]. - * \param[out] dbias Result of the reduction of the dSiLU(input) along the + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the * first dimension. Shape: [H]. * \param[out] workspace Workspace tensor. * \param[in] stream CUDA stream used for the operation. - - Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ - void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of SiLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the SiLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of the Quick GeLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Quick GeLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); +/*! \brief Compute backward of the Squared ReLU operation on the input, then cast and transpose. + * Additionally, reduce the result of the Squared ReLU backward along the first dimension. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * - `dbias` is equal to `reduce(dact(input), axis=0)` + * + * Calling this function with workspace being an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] act_input Tensor used as input for the operation of forward activation. + * Shape [N, H]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H]. + * Shape of columnwise data: [H, N]. + * \param[out] dbias Result of the reduction of the dact(input) along the + * first dimension. Shape: [H]. + * \param[out] workspace Workspace tensor. + * \param[in] stream CUDA stream used for the operation. + */ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream); -/*! \brief Compute dgeglu of the input, additionally does cast and transpose the dgeglu output. +/*! \brief Computes the gated GeLU activation of the input, additionally casts and transposes + * the output. * * This function produces 2 results: - * - `cast_output` is the result of the cast - * - `transposed_output` is the transposed result of the cast. + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` * * \param[in] input Input tensor of shape [N, H]. - * \param[in] gated_act_input Tensor used as input to the forward of GeGLU operation. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. * Shape [N, H * 2]. - * \param[in,out] cast_output Result of the cast. Shape: [N, H * 2]. - * \param[in,out] transposed_output Result of the cast and transpose. Shape: [H * 2, N]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. * \param[in] stream CUDA stream used for the operation. - - Supported activations: GeLU, SiLU, ReLU, QuickGeLU, SquaredReLU */ - void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Quick GeLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Squared ReLU activation of the input, + * additionally casts and transposes the output. + * + * This function produces 2 results: + * - rowwise data of `output` is equal to `cast(dact(input))` + * - columnwise data of `output` is equal to `transpose(cast(dact(input)))` + * + * \param[in] input Input tensor of shape [N, H]. + * \param[in] gated_act_input Tensor used as input to the forward of + * gated activation operation. + * Shape [N, H * 2]. + * \param[in,out] output Result of the cast. + * Shape of rowwise data: [N, H * 2]. + * Shape of columnwise data: [H * 2, N]. + * \param[in] stream CUDA stream used for the operation. +*/ void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream); + NVTETensor output, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/layer_norm/ln.h b/transformer_engine/common/layer_norm/ln.h deleted file mode 100644 index 13543a10aa..0000000000 --- a/transformer_engine/common/layer_norm/ln.h +++ /dev/null @@ -1,239 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ -#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" - -namespace transformer_engine { -namespace layer_norm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams { - size_t workspace_bytes; - size_t barrier_size; - - int multiprocessorCount; - cudaStream_t stream; - - Params params; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct ParamsBase { - ParamsBase() - : ctas_per_col(0), - rows(0), - cols(0), - x(nullptr), - mu(nullptr), - rs(nullptr), - gamma(nullptr), - workspace(nullptr), - barrier(nullptr), - zero_centered_gamma(false) {} - - // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. - int ctas_per_col; - // Size of CTA group. - int ctas_per_row; - - // Input is interpreted as matrix. We normalize across columns. - int rows; - int cols; - - // Common data pointers. - void *x; - void *mu; - void *rs; - void *gamma; - - // Multi-CTA workspace in gmem. - void *workspace; - - // Multi-CTA sync barriers in gmem. - int *barrier; - - // Whether gamma is centered around 0 - bool zero_centered_gamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() : ParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} - - // Output of LN FWD. - void *z; - void *beta; - float epsilon; - - // Scaling factor - void *scale; - - // AMax output - void *amax; - - // Inverse of scaling factor - void *scale_inv; - - // Whether to compute scale and amax - bool fp8_out; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct BwdParams : public ParamsBase { - BwdParams() - : ParamsBase(), - dz(nullptr), - dbeta_part(nullptr), - dgamma_part(nullptr), - dx(nullptr), - dbeta(nullptr), - dgamma(nullptr) {} - - // Input: gradient wrt. LN FWD output. - void *dz; - - // Workspace for Wgrad pre-reduction. - void *dbeta_part; - void *dgamma_part; - - // Output: Dgrad. - void *dx; - // Output: Wgrad. - void *dbeta; - void *dgamma; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TypeId {}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 0; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 1; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 2; -}; - -template <> -struct TypeId { - constexpr static uint32_t Value = 3; -}; - -template -struct Type2Key { - constexpr static uint32_t Value = TypeId::Value << S; -}; - -template -struct WeightType2Key : public Type2Key {}; - -template -struct InputType2Key : public Type2Key {}; - -template -struct OutputType2Key : public Type2Key {}; - -template -struct ComputeType2Key : public Type2Key {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Types2Key { - constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | - OutputType2Key::Value | ComputeType2Key::Value; - constexpr static inline uint64_t get(const uint64_t hidden_size) { - constexpr uint64_t type_key = Value; - return (type_key << 32) | hidden_size; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace layer_norm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_H_ diff --git a/transformer_engine/common/layer_norm/ln_api.cpp b/transformer_engine/common/layer_norm/ln_api.cpp deleted file mode 100644 index 8a40450e59..0000000000 --- a/transformer_engine/common/layer_norm/ln_api.cpp +++ /dev/null @@ -1,457 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include - -#include -#include - -#include "../common.h" -#include "ln.h" - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp16 fp16 -fp32 fp32 bf16 bf16 -bf16 fp32 bf16 fp8 - -Remarks: -Output type = Weight type -Compute always in FP32 - -*/ - -namespace transformer_engine { -namespace layer_norm { - -using namespace transformer_engine; - -// Create registries and provide runtime versions of config hash functions. - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint32_t get_type_id(DType dtype) { - if (dtype == DType::kFloat16) { - return TypeId::Value; - } else if (dtype == DType::kBFloat16) { - return TypeId::Value; - } else if (dtype == DType::kFloat32) { - return TypeId::Value; - } else if (dtype == DType::kFloat8E4M3) { - return TypeId::Value; - } else { - NVTE_ERROR("Type not supported."); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size) { - using namespace layer_norm; - uint64_t type_key = get_type_id(wtype) | (get_type_id(itype) << 2) | (get_type_id(otype) << 4) | - (get_type_id(ctype) << 6); - uint64_t launcher_key = (type_key << 32) | hidden_size; - return launcher_key; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::FwdFunction& get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.beta) && - is_aligned(params.z) && layer_norm::FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto& general_func_map = layer_norm::FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -layer_norm::BwdFunction& get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams& params) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void* ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.mu) && - is_aligned(params.rs) && is_aligned(params.gamma) && is_aligned(params.dz) && - is_aligned(params.dx) && is_aligned(params.dbeta) && is_aligned(params.dgamma) && - is_aligned(params.dbeta_part) && is_aligned(params.dgamma_part) && - layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return layer_norm::BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (layer_norm::BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto& general_func_map = layer_norm::BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -size_t product(const std::vector& shape) { - size_t ret = 1; - for (auto s : shape) { - ret *= s; - } - return ret; -} - -} // namespace layer_norm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void layernorm_fwd(const Tensor& x, // BxSxhidden_size - const Tensor& gamma, // hidden_size - const Tensor& beta, // hidden_size - const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, cudaStream_t stream, - const int multiprocessorCount, Tensor* workspace, Tensor* barrier, - const bool zero_centered_gamma) { - const auto itype = x.data.dtype; - const auto wtype = gamma.data.dtype; - const auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - const auto ctype = layer_norm::DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape == beta.data.shape); - NVTE_CHECK(hidden_size == cols); - - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(mu->data.shape == std::vector{rows}); - NVTE_CHECK(mu->data.dtype == ctype); - - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - layer_norm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - layer_norm::FwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu->data.dptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = beta.data.dptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = layer_norm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - CheckInputTensor(beta, "beta"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*mu, "mu"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, layer_norm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); - - return; -} - -void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, - const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, - Tensor* dgamma_part, Tensor* dbeta_part, cudaStream_t stream, - const int multiprocessorCount, Tensor* workspace, Tensor* barrier, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(mu.data.dtype == ctype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - auto rows = x.data.shape[0]; - auto cols = x.data.shape[1]; - - auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(mu.data.shape[0] == rows); - NVTE_CHECK(mu.data.shape == rsigma.data.shape); - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - NVTE_CHECK(dbeta->data.shape == gamma.data.shape); - NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); - - layer_norm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - layer_norm::BwdParams& params = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = mu.data.dptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = dbeta->data.dptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = dbeta_part->data.dptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - auto launcher = layer_norm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - NVTE_CHECK(dbeta_part->data.dptr == nullptr); - - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - dbeta_part->data.dtype = ctype; - dbeta_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = layer_norm::DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = layer_norm::DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(dbeta_part->data.dptr != nullptr); - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - NVTE_CHECK(dbeta_part->data.dtype == ctype); - NVTE_CHECK(dbeta_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == layer_norm::DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == layer_norm::DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(mu, "mu"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - CheckOutputTensor(*dbeta, "dbeta"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - layer_norm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); -} -} // namespace transformer_engine - -void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm_fwd); - using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - *reinterpret_cast(beta), epsilon, reinterpret_cast(z), - reinterpret_cast(mu), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor dgamma_part, - NVTETensor dbeta_part, cudaStream_t stream, const int multiprocessorCount, - NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm_bwd); - using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(mu), *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), reinterpret_cast(dx), - reinterpret_cast(dgamma), reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), - stream, multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_layernorm1p_fwd(const NVTETensor x, // BxSxhidden_size - const NVTETensor gamma, // hidden_size - const NVTETensor beta, // hidden_size - const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm1p_fwd); - using namespace transformer_engine; - layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - *reinterpret_cast(beta), epsilon, reinterpret_cast(z), - reinterpret_cast(mu), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} - -void nvte_layernorm1p_bwd(const NVTETensor dz, // BxSxhidden_size - const NVTETensor x, // BxSxhidden_size - const NVTETensor mu, // BxS, FP32! - const NVTETensor rsigma, // BxS, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, - NVTETensor dgamma_part, NVTETensor dbeta_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_layernorm1p_bwd); - using namespace transformer_engine; - layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(mu), *reinterpret_cast(rsigma), - *reinterpret_cast(gamma), reinterpret_cast(dx), - reinterpret_cast(dgamma), reinterpret_cast(dbeta), - reinterpret_cast(dgamma_part), reinterpret_cast(dbeta_part), - stream, multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} diff --git a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 17f1256910..0000000000 --- a/transformer_engine/common/layer_norm/ln_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,345 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "ln.h" -#include "ln_bwd_kernels.cuh" -#include "ln_kernel_traits.h" - -using namespace transformer_engine::layer_norm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::reduce_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); - } - - using Kernel_traits_f = layer_norm::Kernel_traits_finalize; - - auto kernel_f = &layer_norm::ln_bwd_finalize_tuned_kernel; - kernel_f<<>>( - launch_params.params); -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &ln_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } - - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &ln_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); -REGISTER_BWD_TUNED_LAUNCHER(30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); - -REGISTER_BWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); - -// Create general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu deleted file mode 100644 index 0c85f4aeb7..0000000000 --- a/transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu +++ /dev/null @@ -1,413 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "ln.h" -#include "ln_fwd_kernels.cuh" -#include "ln_kernel_traits.h" - -using namespace transformer_engine::layer_norm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::Stats::stats_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); - } -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &ln_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - ln_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1536, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); - -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(16384, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(18432, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(20480, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); -REGISTER_FWD_TUNED_LAUNCHER(25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); -REGISTER_FWD_TUNED_LAUNCHER(30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); - -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(32768, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(40960, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, fp16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16); - -// Create general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8); - -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 0683ec01ea..546f7f3403 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -1,4 +1,20 @@ { - global: *nvte*; *transformer_engine*; + global: + extern "C++" { + nvte_*; + transformer_engine::cuda::sm_count*; + transformer_engine::cuda::sm_arch*; + transformer_engine::cuda::supports_multicast*; + transformer_engine::cuda::stream_priority_range*; + transformer_engine::cuda::current_device*; + transformer_engine::cuda_driver::get_symbol*; + transformer_engine::ubuf_built_with_mpi*; + *transformer_engine::rtc*; + transformer_engine::nvte_cudnn_handle_init*; + transformer_engine::typeToSize*; + *transformer_engine::CommOverlapBase*; + *transformer_engine::CommOverlapP2PBase*; + *transformer_engine::CommOverlapCore* + }; local: *; -}; +}; \ No newline at end of file diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp new file mode 100644 index 0000000000..ddda78d951 --- /dev/null +++ b/transformer_engine/common/normalization/common.cpp @@ -0,0 +1,517 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/* #include */ + +#include "common.h" + +#include +#include +#include +#include +#include + +#include "transformer_engine/normalization.h" +#include "transformer_engine/transformer_engine.h" + +/* + +Supported Type combinations: + +input compute weights output +======================================= +fp32 fp32 fp32 fp32 +fp16 fp32 fp16 fp16 +bf16 fp32 bf16 bf16 +fp32 fp32 fp16 fp16 +fp32 fp32 bf16 bf16 +bf16 fp32 bf16 fp8 + +Remarks: +Output type = Weight type +Compute always in FP32 + +*/ + +namespace transformer_engine { +namespace normalization { + +cudnn_frontend::NormFwdPhase_t get_cudnn_forward_phase(const bool training) { + return training ? cudnn_frontend::NormFwdPhase_t::TRAINING + : cudnn_frontend::NormFwdPhase_t::INFERENCE; +} + +TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, + NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, + uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, + bool is_tuned, NVTEScalingMode mode, bool training) { + // TODO: Add scaling_mode to general_key is needed + uint64_t general_key = static_cast(itype) | (static_cast(otype) << 3) | + (static_cast(ctype) << 6) | (static_cast(wtype) << 9) | + (uint32_t(NormType) << 12) | (uint32_t(NormStage)) << 14 | + (uint32_t(NormBackend) << 16) | (uint32_t(zero_centered_gamma) << 18) | + (uint32_t(mode) << 19) | (uint32_t(training) << 22); + return std::make_tuple(general_key, batch_size, hidden_size, is_tuned); +} + +template +TeNormalizationPlan::TeNormalizationPlan( + NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, + DType ctype, const size_t batch_size, const size_t hidden_size, const size_t sm_count, + const bool zero_centered_gamma, const bool is_tuned) + : _is_layernorm(NormType == NVTE_Norm_Type::LayerNorm) { + _launch_params.multiprocessorCount = sm_count; + + auto& kernel_params = _launch_params.params; + kernel_params.rows = batch_size; + kernel_params.cols = hidden_size; + kernel_params.zero_centered_gamma = zero_centered_gamma; + if constexpr (std::is_same_v) { + kernel_params.fp8_out = is_fp8_dtype(otype); + } + // TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those + auto key = get_key(NVTE_Norm_Backend::Te, NormType, NormStage, wtype, itype, otype, ctype, 0, + hidden_size, false, is_tuned); + _kernel = KernelRegistry::getKernel(key); + + this->_build(); +} + +template <> +void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, + void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + _launch_params.stream = stream; + + auto& kernel_params = _launch_params.params; + kernel_params.workspace = workspace_dptr; + kernel_params.x = x_dptr; + kernel_params.rs = rsigma_dptr; + kernel_params.gamma = gamma_dptr; + kernel_params.z = z->data.dptr; + kernel_params.epsilon = *reinterpret_cast(eps_dptr); + kernel_params.amax = z->amax.dptr; + kernel_params.scale = z->scale.dptr; + kernel_params.scale_inv = z->scale_inv.dptr; + + if (_is_layernorm) { + kernel_params.mu = mean_dptr; + kernel_params.beta = beta_dptr; + } + + _set_workspace(); + _kernel(_launch_params, false); +} + +template <> +void TeNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, + void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + NVTE_ERROR("Backward normalization should not call the forward execute function!"); +} + +template +void TeNormalizationPlan::_build() { + _kernel(_launch_params, true); + _launch_params.alignWorkspace(); +} + +template +std::vector TeNormalizationPlan::getWorkspaceShape() const { + return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)}; +} + +template +void TeNormalizationPlan::_set_workspace() { + if (_launch_params.getTotalWorkspaceBytes() > 0) { + auto workspace_dptr = reinterpret_cast(_launch_params.params.workspace); + + if (_launch_params.barrier_bytes > 0) { + _launch_params.params.barrier = + reinterpret_cast(workspace_dptr + _launch_params.workspace_bytes); + cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes, + _launch_params.stream); + } + if constexpr (std::is_same_v) { + _launch_params.params.dgamma_part = + workspace_dptr + _launch_params.workspace_bytes + _launch_params.barrier_bytes; + if (_is_layernorm) { + _launch_params.params.dbeta_part = + reinterpret_cast(_launch_params.params.dgamma_part) + + _launch_params.dgamma_part_bytes; + } + } + } +} + +template <> +void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, + void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + NVTE_ERROR("Forward normalization should not call the backward execute function!"); +} + +template <> +void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, + void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + _launch_params.stream = stream; + + auto& kernel_params = _launch_params.params; + kernel_params.workspace = workspace_dptr; + kernel_params.x = x_dptr; + kernel_params.gamma = gamma_dptr; + kernel_params.rs = rsigma_dptr; + kernel_params.dx = dx_dptr; + kernel_params.dz = dz_dptr; + kernel_params.dgamma = dgamma_dptr; + + if (_is_layernorm) { + kernel_params.mu = mean_dptr; + kernel_params.dbeta = dbeta_dptr; + } + + _set_workspace(); + _kernel(_launch_params, false); +} + +CudnnNormalizationPlan::CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, DType ctype, + const size_t batch_size, const size_t hidden_size, + const size_t sm_count, + const bool zero_centered_gamma, + const NVTEScalingMode mode, bool training) + : _fp8_out(is_fp8_dtype(otype)), + _zero_centered(zero_centered_gamma), + _training(training), + _norm_stage(NormStage), + _norm_type(NormType) { + static_assert(CUDNN_FRONTEND_VERSION >= 10601, + "CUDNN_FRONTEND_VERSION should be at least 1.6.1!"); + + namespace fe = cudnn_frontend; + + if (is_tensor_scaling(mode)) { + _ndim_scale_block = 0; + } else { + NVTE_CHECK(mode == NVTE_MXFP8_1D_SCALING, "Unsupported scaling mode."); + _ndim_scale_block = 1; + } + + _scalar_dptr = std::make_unique(typeToSize(wtype)); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + wtype, cpp_dtype, *(reinterpret_cast(_scalar_dptr.get())) = (cpp_dtype)1.0f;); + + _handle = cudnnExecutionPlanManager::Instance().GetHandle(); + + _graph.set_io_data_type(get_cudnn_fe_dtype(itype)) + .set_intermediate_data_type(get_cudnn_fe_dtype(ctype)) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + + if (cudnnGetVersion() >= 90400) _graph.set_sm_count(sm_count); + + const auto batch_dim = static_cast(batch_size); + const auto hidden_dim = static_cast(hidden_size); + + // Create graph tensors + _x = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("X") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(itype))); + + _gamma_zero = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("gamma_zero") + .set_dim({1, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(wtype))); + if (_zero_centered) { + _scalar_offset = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("one") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(wtype)) + .set_is_pass_by_value(true)); + auto centered_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::ADD) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _gamma = _graph.pointwise(_gamma_zero, _scalar_offset, centered_options); + _gamma->set_output(false).set_data_type(get_cudnn_fe_dtype(wtype)); + } else { + _gamma = _gamma_zero; + } + + // Create graph computation nodes + if (_norm_stage == NVTE_Norm_Stage::Forward) { + _eps = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("epsilon") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype)) + .set_is_pass_by_value(true)); + if (_norm_type == NVTE_Norm_Type::LayerNorm) { + _beta = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("bias") + .set_dim({1, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim}) + .set_data_type(get_cudnn_fe_dtype(wtype))); + auto norm_options = fe::graph::Layernorm_attributes() + .set_forward_phase(get_cudnn_forward_phase(_training)) + .set_epsilon(_eps) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.layernorm(_x, _gamma, _beta, norm_options); + std::tie(_z, _mean, _rsigma) = std::make_tuple(ret[0], ret[1], ret[2]); + if (_training) _mean->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else { + auto norm_options = fe::graph::Rmsnorm_attributes() + .set_forward_phase(get_cudnn_forward_phase(_training)) + .set_epsilon(_eps) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.rmsnorm(_x, _gamma, norm_options); + std::tie(_z, _rsigma) = std::make_tuple(ret[0], ret[1]); + } + + if (_training) _rsigma->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + + const auto ZDtype = _fp8_out ? ctype : otype; + _z->set_output(!_fp8_out).set_data_type(get_cudnn_fe_dtype(ZDtype)); + + if (_fp8_out) { + if (_ndim_scale_block == 0) { // tensor_scaling + // create a scale node + _z_scale = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("z_scale") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + auto z_scale_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::MUL) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_fp8 = _graph.pointwise(_z, _z_scale, z_scale_options); + + _z_fp8->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + + // create an amax reduction node + _amax = _graph.reduction(_z, fe::graph::Reduction_attributes() + .set_mode(fe::ReductionMode_t::AMAX) + .set_compute_data_type(get_cudnn_fe_dtype(ctype))); + _amax->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)).set_dim({1, 1, 1, 1}); + _one_for_div = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("one_for_div") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype)) + .set_is_pass_by_value(true)); + auto div_options = fe::graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::DIV) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + _z_scale_inv = _graph.pointwise(_one_for_div, _z_scale, div_options); + _z_scale_inv->set_output(true).set_data_type(get_cudnn_fe_dtype(ctype)); + } else if (_ndim_scale_block == 1) { // 1d block scaling + auto z_2d = _graph.reshape(_z, fe::graph::Reshape_attributes()); + z_2d->set_dim({batch_dim, hidden_dim}); + + auto mx_quantize_row_opts = fe::graph::Block_scale_quantize_attributes() + .set_block_size(32) + .set_axis(1) + .set_transpose(false); + auto bs_row_ret = _graph.block_scale_quantize(z_2d, mx_quantize_row_opts); + std::tie(_z_mx_row, _sf_row) = std::make_tuple(bs_row_ret[0], bs_row_ret[1]); + _z_mx_row->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _sf_row->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0); //TODO + + if (_training) { + auto mx_quantize_col_opts = fe::graph::Block_scale_quantize_attributes() + .set_block_size(32) + .set_axis(0) + .set_transpose(false); + auto bs_col_ret = _graph.block_scale_quantize(z_2d, mx_quantize_col_opts); + std::tie(_z_mx_col, _sf_col) = std::make_tuple(bs_col_ret[0], bs_col_ret[1]); + _z_mx_col->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _sf_col->set_output(true).set_data_type(fe::DataType_t::FP8_E8M0); + } + } else { + NVTE_ERROR("Unsupported scaling mode."); + } + } + } else { + _dz = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("dz") + .set_dim({batch_dim, hidden_dim, 1, 1}) + .set_stride({hidden_dim, 1, hidden_dim, hidden_dim})); + _rsigma = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("inv_var") + .set_dim({batch_dim, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + _mean = _graph.tensor(fe::graph::Tensor_attributes() + .set_name("mean") + .set_dim({batch_dim, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ctype))); + if (_norm_type == NVTE_Norm_Type::LayerNorm) { + auto norm_options = fe::graph::Layernorm_backward_attributes() + .set_saved_mean_and_inv_variance(_mean, _rsigma) + .set_compute_data_type(get_cudnn_fe_dtype(ctype)); + auto ret = _graph.layernorm_backward(_dz, _x, _gamma, norm_options); + std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); + _dbeta->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + } else { + auto norm_options = + fe::graph::Rmsnorm_backward_attributes().has_dbias(false).set_compute_data_type( + get_cudnn_fe_dtype(ctype)); + auto ret = _graph.rmsnorm_backward(_dz, _x, _gamma, _rsigma, norm_options); + std::tie(_dx, _dgamma, _dbeta) = std::make_tuple(ret[0], ret[1], ret[2]); + if (_dbeta != nullptr) NVTE_ERROR("cuDNN rmsnorm dbias incorrectly returned."); + } + _dx->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + _dgamma->set_output(true).set_data_type(get_cudnn_fe_dtype(otype)); + } + // Build the graph + this->_build(); +} + +void CudnnNormalizationPlan::_build() { + NVTE_CHECK(_graph.validate().is_good()); + NVTE_CHECK(_graph.build_operation_graph(_handle).is_good()); + NVTE_CHECK(_graph + .create_execution_plans( + {cudnn_frontend::HeurMode_t::A, cudnn_frontend::HeurMode_t::FALLBACK}) + .is_good()); + NVTE_CHECK(_graph.check_support(_handle).is_good()); + NVTE_CHECK( + _graph.build_plans(_handle, cudnn_frontend::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); +} + +std::vector CudnnNormalizationPlan::getWorkspaceShape() const { + return {static_cast(_graph.get_workspace_size())}; +} + +void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, + void* mean_dptr, void* eps_dptr, void* rsigma_dptr, + void* workspace_dptr, cudaStream_t stream) { + // Binding data pointers to graph tensors + _variant_pack = {{_x, x_dptr}, {_eps, eps_dptr}}; + + if (_training) _variant_pack.insert({{_rsigma, rsigma_dptr}}); + + if (_norm_type == NVTE_Norm_Type::LayerNorm) { + _variant_pack.insert({{_beta, beta_dptr}}); + if (_training) _variant_pack.insert({{_mean, mean_dptr}}); + } + + if (_zero_centered) + _variant_pack.insert( + {{_scalar_offset, reinterpret_cast(_scalar_dptr.get())}, {_gamma_zero, gamma_dptr}}); + else + _variant_pack.insert({{_gamma, gamma_dptr}}); + + if (_fp8_out && _ndim_scale_block == 0) { + _variant_pack.insert({{_one_for_div, reinterpret_cast(_one_dptr.get())}, + {_z_scale, z->scale.dptr}, + {_z_scale_inv, z->scale_inv.dptr}, + {_amax, z->amax.dptr}, + {_z_fp8, z->data.dptr}}); + } else if (_fp8_out && _ndim_scale_block == 1) { + _variant_pack.insert({{_z_mx_row, z->data.dptr}, {_sf_row, z->scale_inv.dptr}}); + if (_training) + _variant_pack.insert( + {{_z_mx_col, z->columnwise_data.dptr}, {_sf_col, z->columnwise_scale_inv.dptr}}); + } else { + _variant_pack.insert({{_z, z->data.dptr}}); + } + + // Execute the computation + NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); + NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); +} + +void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, + void* rsigma_dptr, void* dx_dptr, void* dz_dptr, + void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) { + // Binding data pointers to graph tensors + _variant_pack = { + {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; + + if (_zero_centered) + _variant_pack.insert({{_scalar_offset, reinterpret_cast(this->_scalar_dptr.get())}, + {_gamma_zero, gamma_dptr}}); + else + _variant_pack.insert({{_gamma, gamma_dptr}}); + + // layernorm should have valid mean_dptr and beta_dptr + if (mean_dptr && dbeta_dptr) _variant_pack.insert({{_mean, mean_dptr}, {_dbeta, dbeta_dptr}}); + + // Execute the computation + NVTE_CHECK_CUDNN(cudnnSetStream(_handle, stream)); + NVTE_CHECK(_graph.execute(_handle, _variant_pack, workspace_dptr).is_good()); +} + +NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan( + NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, + DType itype, DType otype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, + const NVTEScalingMode mode, const bool training) { + const DType ctype = DType::kFloat32; + bool is_tuned = is_aligned && (batch_size % 4 == 0); + auto key = get_key(NormBackend, NormType, NormStage, wtype, itype, otype, ctype, batch_size, + hidden_size, zero_centered_gamma, is_tuned, mode, training); + + auto it = normalizationPlanMap.find(key); + if (it != normalizationPlanMap.end()) { + return it->second.get(); + } + + std::unique_ptr plan; + if (NormBackend == NVTE_Norm_Backend::Cudnn) { + plan = std::make_unique(NormType, NormStage, wtype, itype, otype, ctype, + batch_size, hidden_size, sm_count, + zero_centered_gamma, mode, training); + } else if (NormStage == NVTE_Norm_Stage::Forward) { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); + } else { + plan = std::make_unique>( + NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count, + zero_centered_gamma, is_tuned); + } + normalizationPlanMap.insert({key, std::move(plan)}); + return normalizationPlanMap[key].get(); +} + +bool& _cudnn_norm_fwd_flag() { + static bool flag = transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN"); + return flag; +} + +bool& _cudnn_norm_bwd_flag() { + static bool flag = transformer_engine::getenv("NVTE_NORM_BWD_USE_CUDNN"); + return flag; +} + +bool use_cudnn_norm_fwd() { return _cudnn_norm_fwd_flag(); } +bool use_cudnn_norm_bwd() { return _cudnn_norm_bwd_flag(); } + +} // namespace normalization +} // namespace transformer_engine + +void nvte_enable_cudnn_norm_fwd(bool enable) { + NVTE_API_CALL(nvte_enable_cudnn_norm_fwd); + transformer_engine::normalization::_cudnn_norm_fwd_flag() = enable; +} + +void nvte_enable_cudnn_norm_bwd(bool enable) { + NVTE_API_CALL(nvte_enable_cudnn_norm_bwd); + transformer_engine::normalization::_cudnn_norm_bwd_flag() = enable; +} diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h new file mode 100644 index 0000000000..ea0450f1c2 --- /dev/null +++ b/transformer_engine/common/normalization/common.h @@ -0,0 +1,387 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ +#define TRANSFORMER_ENGINE_COMMON_NORM_COMMON_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../cudnn_utils.h" +#include "../util/system.h" + +namespace transformer_engine { + +namespace normalization { + +namespace fe = cudnn_frontend; + +template +struct LaunchParams { + size_t workspace_bytes = 0; + size_t barrier_bytes = 0; + size_t dgamma_part_bytes = 0; + int multiprocessorCount; + cudaStream_t stream; + + KernelParamsType params; + + size_t getTotalWorkspaceBytes(const bool _is_layernorm = true) const { + return (workspace_bytes + barrier_bytes + size_t(_is_layernorm + 1) * dgamma_part_bytes); + } + void alignWorkspace(size_t alignment = 16) { + workspace_bytes = DIVUP(workspace_bytes, alignment) * alignment; + barrier_bytes = DIVUP(barrier_bytes, alignment) * alignment; + dgamma_part_bytes = DIVUP(dgamma_part_bytes, alignment) * alignment; + } +}; + +struct KernelParamsBase { + KernelParamsBase() + : ctas_per_col(0), + rows(0), + cols(0), + x(nullptr), + mu(nullptr), + rs(nullptr), + gamma(nullptr), + workspace(nullptr), + barrier(nullptr), + zero_centered_gamma(false) {} + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + // Size of CTA group. + int ctas_per_row; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void* x; + void* mu; + void* rs; + void* gamma; + + // Multi-CTA workspace in gmem. + void* workspace; + + // Multi-CTA sync barriers in gmem. + int* barrier; + + // Whether gamma is centered around 0 + bool zero_centered_gamma; +}; + +struct ForwardKernelParams : public KernelParamsBase { + ForwardKernelParams() + : KernelParamsBase(), z(nullptr), beta(nullptr), epsilon(0.f), fp8_out(false) {} + + // Output of LN FWD. + void* z; + void* beta; + float epsilon; + + // Scaling factor + void* scale; + int scale_byte_size; + + // Inverse of scaling factor + void* scale_inv; + + // AMax output + void* amax; + int amax_byte_size; + + // Whether to compute scale and amax + bool fp8_out; +}; + +struct BackwardKernelParams : public KernelParamsBase { + BackwardKernelParams() + : KernelParamsBase(), + dz(nullptr), + dbeta_part(nullptr), + dgamma_part(nullptr), + dx(nullptr), + dbeta(nullptr), + dgamma(nullptr) {} + + // Input: gradient wrt. LN FWD output. + void* dz; + + // Workspace for Wgrad pre-reduction. + void* dbeta_part; + void* dgamma_part; + + // Output: Dgrad. + void* dx; + // Output: Wgrad. + void* dbeta; + void* dgamma; +}; + +enum class NVTE_Norm_Backend { Te, Cudnn }; +enum class NVTE_Norm_Type { LayerNorm, RMSNorm }; +enum class NVTE_Norm_Stage { Forward, Backward }; + +using TupleKeyType = std::tuple; +struct TupleHash { + size_t operator()(const TupleKeyType& t) const { + // Generate a hash for a tuple by combining the hashes of its entries + // See: https://www.boost.org/doc/libs/1_55_0/doc/html/hash/reference.html#boost.hash_combine + size_t seed = 0; + std::hash hasher; + seed ^= hasher(std::get<0>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= hasher(std::get<1>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + seed ^= hasher(std::get<2>(t)) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + return seed; + } +}; + +// Note: the default mode here should match with the default mode with QTensor +TupleKeyType get_key(NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, + NVTE_Norm_Stage NormStage, DType wtype, DType itype, DType otype, DType ctype, + uint64_t batch_size, uint64_t hidden_size, bool zero_centered_gamma, + bool is_tuned, NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, + bool training = true); + +template +class TeNormalizationRegistry { + private: + using Function = std::function&, const bool)>; + std::unordered_map tuned_function_map; + std::unordered_map> general_function_map; + + TeNormalizationRegistry() = default; + + static TeNormalizationRegistry& getInstance() { + static TeNormalizationRegistry registry; + return registry; + } + + public: + static int registerFunction(TupleKeyType key, + void (*func)(LaunchParams&, const bool)) { + auto [general_key, batch_size, hidden_size, is_tuned] = key; + if (is_tuned) + getInstance().tuned_function_map.emplace(key, Function(func)); + else + getInstance().general_function_map[general_key].emplace(hidden_size, Function(func)); + return 0; + } + + static Function getKernel(TupleKeyType key) { + auto& instance = getInstance(); + auto [general_key, batch_size, hidden_size, is_tuned] = key; + if (is_tuned) { + auto it = instance.tuned_function_map.find(key); + if (it != instance.tuned_function_map.end()) return it->second; + } + if (instance.general_function_map.count(general_key) == 0) { + NVTE_ERROR("Unavailable kernel for this normalization config."); + } + auto& general_func_map = instance.general_function_map.at(general_key); + auto func_iter = general_func_map.lower_bound(hidden_size); + if (func_iter == general_func_map.end()) { + return general_func_map.rbegin()->second; // Hidden size is too big, need to use multi-CTA + } else { + return func_iter->second; + } + } + + TeNormalizationRegistry(const TeNormalizationRegistry&) = delete; + TeNormalizationRegistry& operator=(const TeNormalizationRegistry&) = delete; + TeNormalizationRegistry(TeNormalizationRegistry&&) = delete; + TeNormalizationRegistry& operator=(TeNormalizationRegistry&&) = delete; +}; + +class NormalizationPlanBase { + public: + virtual ~NormalizationPlanBase() = default; + virtual std::vector getWorkspaceShape() const = 0; + + virtual void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) = 0; + + virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, + void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) = 0; + + private: + virtual void _build() = 0; +}; + +template +class TeNormalizationPlan : public NormalizationPlanBase { + public: + TeNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, DType itype, + DType otype, DType ctype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_tuned); + std::vector getWorkspaceShape() const override; + + void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, + void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + private: + void _set_workspace(); + void _build(); + + using KernelRegistry = TeNormalizationRegistry; + LaunchParams _launch_params; + std::function&, const bool)> _kernel; + + const bool _is_layernorm; +}; + +class CudnnNormalizationPlan : public NormalizationPlanBase { + public: + CudnnNormalizationPlan(NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, DType wtype, + DType itype, DType otype, DType ctype, const size_t batch_size, + const size_t hidden_size, const size_t sm_count, + const bool zero_centered_gamma, const NVTEScalingMode mode, + const bool training); + + std::vector getWorkspaceShape() const override; + + void execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, void* mean_dptr, + void* eps_dptr, void* rsigma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, + void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) override; + + private: + void _build() override; + + const bool _zero_centered, _fp8_out; + int _ndim_scale_block; + const NVTE_Norm_Stage _norm_stage; + const NVTE_Norm_Type _norm_type; + std::unique_ptr _scalar_dptr; + std::unique_ptr _one_dptr = std::make_unique(1.0f); + // FWD + std::shared_ptr _x, _gamma_zero, _scalar_offset, _gamma, _beta, + _eps, _mean, _rsigma, _z, _z_scale, _one_for_div, _z_scale_inv, _amax, _z_fp8; + // MX FWD + std::shared_ptr _z_mx_row, _z_mx_col, _sf_row, _sf_col; + const bool _training; + // BWD + std::shared_ptr _dz, _dx, _dgamma, _dbeta; + + fe::graph::Graph _graph; + std::unordered_map, void*> _variant_pack; + cudnnHandle_t _handle; +}; + +class NormalizationPlanRegistry { + public: + static NormalizationPlanRegistry& getInstance() { + static thread_local NormalizationPlanRegistry instance; + return instance; + } + + NormalizationPlanBase* getNormalizationPlan( + NVTE_Norm_Backend NormBackend, NVTE_Norm_Type NormType, NVTE_Norm_Stage NormStage, + DType wtype, DType itype, DType otype, const size_t batch_size, const size_t hidden_size, + const size_t sm_count, const bool zero_centered_gamma, const bool is_aligned, + const NVTEScalingMode mode = NVTE_DELAYED_TENSOR_SCALING, const bool training = true); + + private: + NormalizationPlanRegistry() {} + NormalizationPlanRegistry(const NormalizationPlanRegistry&) = delete; + NormalizationPlanRegistry& operator=(const NormalizationPlanRegistry&) = delete; + + std::unordered_map, TupleHash> + normalizationPlanMap; +}; + +using byte = uint8_t; +using int32 = int32_t; +using fp32 = float; +using fp16 = half; +using bf16 = nv_bfloat16; +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; + +template +struct TypeToDType; + +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat32; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat16; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kBFloat16; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat8E4M3; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kFloat8E5M2; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kInt32; +}; +template <> +struct TypeToDType { + static constexpr DType value = DType::kByte; +}; + +#define IS_TUNED(x) (strcmp(#x, "tuned") == 0 ? 1 : 0) + +// TE kernels have no template for batch_size and zero_centered_gamma, thus zero out those +#define REGISTER_NORM_BASE(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, \ + CTYPE, FUNC_NAME) \ + static int \ + register_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE = \ + TeNormalizationRegistry::registerFunction( \ + (get_key(NVTE_Norm_Backend::Te, NVTE_Norm_Type::NORM_TYPE, \ + NVTE_Norm_Stage::NORM_STAGE, (TypeToDType::value), \ + (TypeToDType::value), (TypeToDType::value), \ + (TypeToDType::value), 0, HIDDEN_SIZE, 0, IS_TUNED(LAUNCH_TYPE))), \ + FUNC_NAME) + +// Alignment check +template +bool is_ptr_aligned(const Args*... ptrs) { + return ((reinterpret_cast(ptrs) % Alignment == 0) && ...); +} + +bool use_cudnn_norm_fwd(); +bool use_cudnn_norm_bwd(); + +} // namespace normalization +} // namespace transformer_engine + +#endif diff --git a/transformer_engine/common/layer_norm/ln_kernel_traits.h b/transformer_engine/common/normalization/kernel_traits.h similarity index 88% rename from transformer_engine/common/layer_norm/ln_kernel_traits.h rename to transformer_engine/common/normalization/kernel_traits.h index a72726c325..78d9212de6 100644 --- a/transformer_engine/common/layer_norm/ln_kernel_traits.h +++ b/transformer_engine/common/normalization/kernel_traits.h @@ -1,19 +1,18 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ -#ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_ -#define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_KERNEL_TRAITS_H_ +#ifndef TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_ +#define TRANSFORMER_ENGINE_COMMON_NORM_KERNEL_TRAITS_H_ #include "../common.h" #include "../utils.cuh" -//////////////////////////////////////////////////////////////////////////////////////////////////// - namespace transformer_engine { -namespace layer_norm { +namespace normalization { + template struct Kernel_traits_base { @@ -28,8 +27,6 @@ struct Kernel_traits_base { enum { THREADS_PER_WARP = 32 }; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - template +#include + +#include +#include +#include +#include +#include + +#include "../../common.h" +#include "../common.h" + +namespace transformer_engine { + +using namespace normalization; + +void layernorm_fwd(const Tensor& x, // BxSxhidden_size + const Tensor& gamma, // hidden_size + const Tensor& beta, // hidden_size + const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && + !is_block_scaling(z->scaling_mode)) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); + } + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(gamma.data.shape == beta.data.shape); + NVTE_CHECK(x.data.shape[1] == gamma.data.shape[0]); + + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(mu->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(mu->data.dtype == DType::kFloat32); + + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + CheckInputTensor(beta, "beta"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*mu, "mu"); + CheckOutputTensor(*rsigma, "rsigma"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + + if (cudnn_backend) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr, + mu->data.dptr, rsigma->data.dptr); + } + + bool training = + is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Forward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + z->data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } + + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + NVTE_CHECK( + !is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr), + "Columnwise scale_inv must be allocated for NormFwdTraining!"); + plan->execute(z, x.data.dptr, gamma.data.dptr, beta.data.dptr, mu->data.dptr, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + + // Compute FP8 transpose if required + if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { + Tensor transpose_data; + transpose_data.data = z->columnwise_data; + transpose_data.scaling_mode = z->scaling_mode; + nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), + stream); + } + + return; +} + +void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Tensor& rsigma, + const Tensor& gamma, Tensor* dx, Tensor* dgamma, Tensor* dbeta, + Tensor* workspace, const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + using namespace transformer_engine; + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(mu.data.dtype == DType::kFloat32); + NVTE_CHECK(rsigma.data.dtype == mu.data.dtype); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + + NVTE_CHECK(mu.data.shape[0] == x.data.shape[0]); + NVTE_CHECK(mu.data.shape == rsigma.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + NVTE_CHECK(dbeta->data.shape == gamma.data.shape); + NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(mu, "mu"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + CheckOutputTensor(*dbeta, "dbeta"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_bwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, + dx->data.dptr, dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::LayerNorm, NVTE_Norm_Stage::Backward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream); + } + return; +} +} // namespace transformer_engine + +void nvte_layernorm_fwd(const NVTETensor x, // BxSxhidden_size + const NVTETensor gamma, // hidden_size + const NVTETensor beta, // hidden_size + const float epsilon, NVTETensor z, NVTETensor mu, NVTETensor rsigma, + NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + NVTE_API_CALL(nvte_layernorm_fwd); + using namespace transformer_engine; + layernorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + *reinterpret_cast(beta), epsilon, reinterpret_cast(z), + reinterpret_cast(mu), reinterpret_cast(rsigma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} + +void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size + const NVTETensor x, // BxSxhidden_size + const NVTETensor mu, // BxS, FP32! + const NVTETensor rsigma, // BxS, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_layernorm_bwd); + using namespace transformer_engine; + layernorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(mu), *reinterpret_cast(rsigma), + *reinterpret_cast(gamma), reinterpret_cast(dx), + reinterpret_cast(dgamma), reinterpret_cast(dbeta), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} diff --git a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh similarity index 97% rename from transformer_engine/common/layer_norm/ln_bwd_kernels.cuh rename to transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh index dbd0025244..b68e79cd98 100644 --- a/transformer_engine/common/layer_norm/ln_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -7,16 +7,15 @@ #ifndef TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ -#include "../utils.cuh" -#include "ln.h" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace layer_norm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; @@ -119,8 +118,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( } reduce_t result = reducer.allreduce({mdy_local, mdyy_local}, sum); - mdy_local = layer_norm::Get<0>::of(result) * rn; - mdyy_local = layer_norm::Get<1>::of(result) * rn; + mdy_local = Get<0>::of(result) * rn; + mdyy_local = Get<1>::of(result) * rn; Ivec dx[LDGS]; idx = row * Ktraits::VEC_COLS + c; @@ -203,7 +202,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_tuned_kernel( template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finalize_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; @@ -323,7 +322,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void ln_bwd_finaliz template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -424,8 +423,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_bwd_general_kerne // Reduce over row reduce_t result = reducer.allreduce({mdy, mdyy}, sum); - mdy = layer_norm::Get<0>::of(result) * rn; - mdyy = layer_norm::Get<1>::of(result) * rn; + mdy = Get<0>::of(result) * rn; + mdyy = Get<1>::of(result) * rn; // Compute dx #pragma unroll @@ -507,7 +506,7 @@ template __global__ __launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_general_kernel( - layer_norm::BwdParams params) { + BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; @@ -573,7 +572,7 @@ __launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void ln_bwd_finalize_gener } } -} // namespace layer_norm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_BWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu new file mode 100644 index 0000000000..f63edfb644 --- /dev/null +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -0,0 +1,331 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../../common.h" +#include "../common.h" +#include "../kernel_traits.h" +#include "ln_bwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = + Kernel_traits_finalize; + + auto kernel_f = &ln_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &ln_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); + } + launch_params.dgamma_part_bytes = ctas_per_col * cols * sizeof(compute_t); + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &ln_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 768, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, fp16, fp32, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1024, bf16, fp32, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 1536, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2048, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 2304, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3072, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, fp16, fp32, fp16, fp32, 1, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 3840, bf16, fp32, bf16, fp32, 1, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 4096, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, fp16, fp32, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 5120, bf16, fp32, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, fp16, fp32, fp16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 6144, bf16, fp32, bf16, fp32, 1, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 8192, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp16, fp16, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, fp16, fp32, fp16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, bf16, bf16, fp32, 2, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 10240, bf16, fp32, bf16, fp32, 2, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12288, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp16, fp16, fp16, fp32, 5, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, bf16, bf16, bf16, fp32, 5, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 12800, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp32, fp32, fp32, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp16, fp16, fp16, fp32, 4, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, fp16, fp32, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, bf16, bf16, bf16, fp32, 4, 1, 4, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 15360, bf16, fp32, bf16, fp32, 4, 1, 4, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 16384, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp16, fp16, fp16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, bf16, bf16, bf16, fp32, 4, 1, 4, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 18432, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp32, fp32, fp32, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp16, fp16, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, fp16, fp32, fp16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, bf16, bf16, bf16, fp32, 4, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 20480, bf16, fp32, bf16, fp32, 4, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 24576, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp32, fp32, fp32, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp16, fp16, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, fp16, fp32, fp16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, bf16, bf16, bf16, fp32, 5, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 25600, bf16, fp32, bf16, fp32, 5, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp32, fp32, fp32, fp32, 4, 1, 8, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp16, fp16, fp16, fp32, 4, 1, 8, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, fp16, fp32, fp16, fp32, 4, 1, 8, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, bf16, bf16, bf16, fp32, 4, 1, 8, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 30720, bf16, fp32, bf16, fp32, 4, 1, 8, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 32768, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp32, fp32, fp32, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp16, fp16, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, fp16, fp32, fp16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, bf16, bf16, bf16, fp32, 4, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 40960, bf16, fp32, bf16, fp32, 4, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 49152, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, fp16, fp32, fp16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 8, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, tuned, 65536, bf16, fp32, bf16, fp32, 8, 1, 8, 16, 4); + +// Create general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu new file mode 100644 index 0000000000..9336abc26c --- /dev/null +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -0,0 +1,395 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "ln_fwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_tuned_kernel; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; + } + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &ln_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); + } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, bf16, bf16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp16, fp16, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp8e4m3, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp8e4m3, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp8e4m3, fp32, 8, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 768, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1024, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 1536, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2048, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 2304, fp32, fp32, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3072, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp32, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp16, fp16, fp16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, fp16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, bf16, bf16, bf16, fp32, 1, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 3840, fp32, fp32, bf16, fp32, 1, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 4096, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 5120, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 6144, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 8192, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, bf16, bf16, bf16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 10240, fp32, fp32, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12288, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp32, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp16, fp16, fp16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, fp16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, bf16, bf16, bf16, fp32, 2, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 12800, fp32, fp32, bf16, fp32, 2, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp32, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 15360, fp32, fp32, bf16, fp32, 2, 1, 4, 8); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 16384, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 18432, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp32, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 20480, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp16, fp16, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, fp16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, bf16, bf16, bf16, fp32, 2, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 24576, fp32, fp32, bf16, fp32, 2, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp16, fp16, fp16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, bf16, bf16, bf16, fp32, 2, 1, 4, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 25600, fp32, fp32, bf16, fp32, 4, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp32, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp16, fp16, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, fp16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, bf16, bf16, bf16, fp32, 4, 1, 4, 4); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 30720, fp32, fp32, bf16, fp32, 4, 1, 4, 4); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 32768, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp32, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 40960, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp16, fp16, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, fp16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, bf16, bf16, bf16, fp32, 4, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 49152, fp32, fp32, bf16, fp32, 4, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp32, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp16, fp16, fp16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, fp16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, bf16, bf16, bf16, fp32, 8, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, tuned, 65536, fp32, fp32, bf16, fp32, 8, 1, 4, 16); + +// Create general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 128, fp32, fp32, bf16, fp32, 4, 1, 8); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(LayerNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh similarity index 96% rename from transformer_engine/common/layer_norm/ln_fwd_kernels.cuh rename to transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index bd3741d1d1..eb2f62b4b0 100644 --- a/transformer_engine/common/layer_norm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,15 +10,16 @@ #include #include -#include "../utils.cuh" -#include "ln.h" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace layer_norm { +namespace normalization { using namespace transformer_engine; template -__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(FwdParams params) { +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( + ForwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -92,8 +93,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( stats_t s = stats.compute(xf, rn); - compute_t mu = layer_norm::Get<0>::of(s); - compute_t m2 = layer_norm::Get<1>::of(s); + compute_t mu = Get<0>::of(s); + compute_t m2 = Get<1>::of(s); if (bidn == 0 && warp_n == 0 && lane == 0) { mu_ptr[row] = mu; @@ -150,7 +151,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -315,7 +316,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } } -} // namespace layer_norm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_LAYER_NORM_LN_FWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp new file mode 100644 index 0000000000..8519fe1b64 --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -0,0 +1,186 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include "../../common.h" +#include "../common.h" +#include "transformer_engine/normalization.h" +#include "transformer_engine/transpose.h" + +namespace transformer_engine { + +using namespace normalization; + +void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, + Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && + !is_block_scaling(z->scaling_mode)) { + NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); + } + + NVTE_CHECK(x.data.shape.size() == 2); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + NVTE_CHECK(epsilon >= 0.f); + + NVTE_CHECK(z->data.shape == x.data.shape); + + NVTE_CHECK(rsigma->data.shape == std::vector{x.data.shape[0]}); + NVTE_CHECK(rsigma->data.dtype == DType::kFloat32); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(x, "x"); + CheckInputTensor(gamma, "gamma"); + + CheckOutputTensor(*z, "z"); + CheckOutputTensor(*rsigma, "rsigma"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + bool cudnn_backend = use_cudnn_norm_fwd() || is_block_scaling(z->scaling_mode); + + bool training = + is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; + + if (cudnn_backend) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr); + } + + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Forward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + z->data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } + + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + NVTE_CHECK( + !is_block_scaling(z->scaling_mode) || (!training || z->columnwise_scale_inv.dptr != nullptr), + "Columnwise scale_inv must be allocated for NormFwdTraining!"); + plan->execute(z, x.data.dptr, gamma.data.dptr, nullptr /*beta*/, nullptr /*mu*/, + reinterpret_cast(const_cast(&epsilon)), rsigma->data.dptr, + workspace->data.dptr, stream); + + // Compute FP8 transpose if required + if (z->has_columnwise_data() && is_tensor_scaling(z->scaling_mode)) { + Tensor transpose_data; + transpose_data.data = z->columnwise_data; + transpose_data.scaling_mode = z->scaling_mode; + nvte_transpose(reinterpret_cast(z), reinterpret_cast(&transpose_data), + stream); + } + + return; +} + +void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, + Tensor *dx, Tensor *dgamma, Tensor *workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(rsigma.data.dtype == DType::kFloat32); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + } + + NVTE_Norm_Backend norm_backend; + bool is_aligned = true; + if (use_cudnn_norm_bwd()) { + // TODO: add check for GPU ARCH + norm_backend = NVTE_Norm_Backend::Cudnn; + } else { + norm_backend = NVTE_Norm_Backend::Te; + is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dgamma->data.dptr); + } + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::Backward, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream); + } + return; +} + +} // namespace transformer_engine + +void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size + const NVTETensor gamma, // hidden_size + const float epsilon, NVTETensor z, NVTETensor rsigma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_fwd); + using namespace transformer_engine; + rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), + epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} + +void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size + const NVTETensor x, // Nxhidden_size + const NVTETensor rsigma, // N, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_bwd); + using namespace transformer_engine; + rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), + *reinterpret_cast(rsigma), *reinterpret_cast(gamma), + reinterpret_cast(dx), reinterpret_cast(dgamma), + reinterpret_cast(workspace), multiprocessorCount, zero_centered_gamma, + stream); +} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh similarity index 97% rename from transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh rename to transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh index 92fd850baa..5d8a5b765a 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -7,15 +7,15 @@ #ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ -#include "../utils.cuh" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace rmsnorm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_M = Ktraits::WARPS_M }; enum { WARPS_N = Ktraits::WARPS_N }; @@ -172,7 +172,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke template __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_finalize_tuned_kernel( - BwdParams params) { + BackwardKernelParams params) { using compute_t = typename Kernel_traits::compute_t; using weight_t = typename Kernel_traits::weight_t; using index_t = typename Kernel_traits::index_t; @@ -276,7 +276,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel( - BwdParams params) { + BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::ELTS_PER_LDG }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -430,8 +430,9 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_ template -__global__ __launch_bounds__( - WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel(BwdParams params) { +__global__ +__launch_bounds__(WARPS_M *WARPS_N *THREADS_PER_WARP) void rmsnorm_bwd_finalize_general_kernel( + BackwardKernelParams params) { enum { NUM_ELTS = BYTES_PER_LDG / sizeof(compute_t) }; using Wvec = Vec; using Cvec = Vec; @@ -474,7 +475,7 @@ __global__ __launch_bounds__( } } -} // namespace rmsnorm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu new file mode 100644 index 0000000000..fb5741b35b --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -0,0 +1,206 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "rmsnorm_bwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_bwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::reduce_t) * 2; + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, + stream); + } + + using Kernel_traits_f = + Kernel_traits_finalize; + + auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel; + kernel_f<<>>( + launch_params.params); +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Instantiate kernel + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_bwd_general_kernel; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, + Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); + } + launch_params.dgamma_part_bytes = + launch_params.params.ctas_per_col * launch_params.params.cols * sizeof(compute_t); + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } + + // Launch finalization kernel + constexpr uint32_t WARPS_M_FINAL = 4; + constexpr uint32_t WARPS_N_FINAL = 1; + constexpr uint32_t ELTS_N_PER_CTA_FINAL = + (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); + auto kernel_final = + &rmsnorm_bwd_finalize_general_kernel; + dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); + dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); + kernel_final<<>>(launch_params.params); +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create rmsnorm tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); + +// Create rmsnorm general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); +REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu new file mode 100644 index 0000000000..25bed95dc5 --- /dev/null +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -0,0 +1,210 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../common.h" +#include "../kernel_traits.h" +#include "rmsnorm_fwd_kernels.cuh" + +using namespace transformer_engine::normalization; + +template +void launch_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_tuned_kernel; + + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + launch_params.params.ctas_per_row = CTAS_PER_ROW; + launch_params.params.ctas_per_col = + launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; + if (Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_bytes = 2 * launch_params.params.ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * + Kernel_traits::CTAS_PER_ROW * + sizeof(typename Kernel_traits::Stats::stats_t) * 2; + } + return; + } + + if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + auto ctas_per_row = launch_params.params.ctas_per_row; + + if (ctas_per_row == 1) { + kernel<<>>( + launch_params.params); + } else { + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) + Kernel_traits::SMEM_BYTES_FWD, stream); + } +} + +template +void launch_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) + using Kernel_traits = Kernel_traits; + auto kernel = &rmsnorm_fwd_general_kernel; + auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; + + // Configure kernel params + const int rows = launch_params.params.rows; + const int cols = launch_params.params.cols; + int ctas_per_col = launch_params.params.ctas_per_col; + int ctas_per_row = launch_params.params.ctas_per_row; + if (configure_params) { + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; + ctas_per_row = ceil_div(cols, HIDDEN_SIZE); + ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); + launch_params.params.ctas_per_row = ctas_per_row; + launch_params.params.ctas_per_col = ctas_per_col; + if (launch_params.params.ctas_per_row > 1) { + launch_params.barrier_bytes = 2 * ctas_per_col * sizeof(index_t); + launch_params.workspace_bytes = + (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); + } + return; + } + + // Launch kernel + auto stream = launch_params.stream; + dim3 grid(ctas_per_row * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + if (ctas_per_row == 1) { + kernel<<>>(launch_params.params); + } else { + void *params_ = reinterpret_cast(&launch_params.params); + cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream); + } +} + +#define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ + OTYPE, CTYPE, ...) \ + namespace { \ + void \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ + LaunchParams &launch_params, const bool configure_params) { \ + launch_##LAUNCH_TYPE##_( \ + launch_params, configure_params); \ + } \ + REGISTER_NORM_BASE( \ + NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ + norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ + } // namespace + +// Create rmsnorm tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); + +// Create rmsnorm general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, fp16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 128, fp32, fp32, bf16, fp32, 4, 1, 8); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 512, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 1024, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp32, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp16, fp16, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, fp16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, bf16, bf16, bf16, fp32, 4, 1, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 2048, fp32, fp32, bf16, fp32, 4, 1, 16); + +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp32, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp16, fp16, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, fp16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, bf16, bf16, bf16, fp32, 1, 4, 16); +REGISTER_NORM_LAUNCHER(RMSNorm, Forward, general, 8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh similarity index 97% rename from transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh rename to transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index c435ae3744..c631847395 100644 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,15 +10,15 @@ #include #include -#include "../utils.cuh" +#include "../../utils.cuh" +#include "../common.h" namespace transformer_engine { -namespace rmsnorm { -using namespace transformer_engine; +namespace normalization { template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; enum { WARPS_N = Ktraits::WARPS_N }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -143,7 +143,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_kernel( - FwdParams params) { + ForwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; enum { NUM_ELTS = Ktraits::NUM_ELTS }; enum { WARPS_M = Ktraits::WARPS_M }; @@ -291,7 +291,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ } } -} // namespace rmsnorm +} // namespace normalization } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_FWD_KERNELS_CUH_ diff --git a/transformer_engine/common/nvtx.h b/transformer_engine/common/nvtx.h index 4625e0ab9d..ada7a59092 100644 --- a/transformer_engine/common/nvtx.h +++ b/transformer_engine/common/nvtx.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 2b894fbfdc..7e9e2a97f7 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ba276ad406..50a0a10b5f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -39,19 +39,51 @@ class Format(Enum): HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) -class _OverrideLinearPrecision(NamedTuple): +@dataclass(frozen=True) +class MMParams: + """for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) + apply split accumulator or not, turning it on will increase accuracy but impact gemm performance, + so only turn it on for certain gemms """ - Whether or not the execute the `fprop`, `dgrad`, and `wgrad` - GEMMs in higher precision when using FP8. + + use_split_accumulator: bool = True + + +@dataclass(frozen=True) +class QParams: + """Quantization parameters. + power_2_scale: use power of 2 scale parameter + amax_epsilon: optional minimum value of abs max + """ + + power_2_scale: bool = False + amax_epsilon: float = 0.0 + + +class Recipe: + """ + Base recipe class. """ - fprop: bool = False - dgrad: bool = False - wgrad: bool = False + def mxfp8(self): + """Whether the given recipe is MXFP8 block scaling.""" + return isinstance(self, MXFP8BlockScaling) + + def delayed(self): + """Whether the given recipe is delayed scaling.""" + return isinstance(self, DelayedScaling) + + def float8_current_scaling(self): + """Whether the given recipe is (per-tensor) current scaling.""" + return isinstance(self, Float8CurrentScaling) + + def float8_per_tensor_scaling(self): + """Whether the given recipe is per-tensor scaling.""" + return isinstance(self, (DelayedScaling, Float8CurrentScaling)) @dataclass() -class DelayedScaling: +class DelayedScaling(Recipe): """ Use the delayed scaling factor strategy. Use scale factor from previous iteration and record amax history of `amax_history_len` steps. @@ -92,9 +124,6 @@ def scaling_factor_compute(amax: Tensor, recipe: DelayedScaling) -> Tensor where `Tensor` is a framework tensor type. - override_linear_precision: Tuple(bool, bool, bool), default=(False, False, False) - Whether or not to execute the `fprop`, `dgrad`, and `wgrad` - GEMMs (respectively) in higher precision when using FP8. reduce_amax: bool, default = `True` By default, if `torch.distributed` is initialized, the `amax` value for FP8 tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` @@ -137,7 +166,6 @@ def scaling_factor_compute(amax: Tensor, fp8_format: Format = Format.HYBRID amax_history_len: int = 1024 amax_compute_algo: Union[Literal["max", "most_recent"], Callable] = "max" - override_linear_precision: _OverrideLinearPrecision = _OverrideLinearPrecision() scaling_factor_compute_algo: Optional[Callable] = None reduce_amax: bool = True fp8_dpa: bool = False @@ -145,10 +173,6 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert self.override_linear_precision in ( - (False, False, False), - (False, False, True), - ), "Only wgrad GEMM override is currently supported." if self.interval >= 0: warnings.warn( "`interval` argument is deprecated and unused. " @@ -161,7 +185,112 @@ def __repr__(self) -> str: f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " - f"wgrad_override={self.override_linear_precision.wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) + + +@dataclass() +class Float8CurrentScaling(Recipe): + """ + Use the per-tensor current scaling factor strategy. + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID + Controls the FP8 data format used during forward and backward + pass. + fp8_quant_fwd_inp: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of input tensor x + fp8_quant_fwd_weight: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of weight tensor w + fp8_quant_bwd_grad: QParams, default QParams{power_2_scale=False, amax_epsilon=0.0} + used for quantization of gradient tensor dY + fp8_gemm_fprop: MMParams, default MMParams.use_split_accumulator=False + used for calculating output y in forward pass + fp8_gemm_dgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_gemm_wgrad: MMParams, default MMParams.use_split_accumulator=True + use for calculating dgrad in backward pass + fp8_dpa: bool, default = `False` + Whether to enable FP8 dot product attention (DPA). When the model is placed in an + `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + inputs from higher precision to FP8, performs attention in FP8, and casts tensors + back to higher precision as outputs. FP8 DPA currently is only supported in the + `FusedAttention` backend. + fp8_mha: bool, default = `False` + Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting + operations mentioned above at the DPA boundaries. Currently only standard MHA modules + i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When + `fp8_mha = False, fp8_dpa = True`, a typical MHA module works as + `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. + When `fp8_mha = True, fp8_dpa = True`, it becomes + `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + + Notes + ----- + * `fp8_dpa` and `fp8_mha` are Beta features, and their API and functionality are + subject to change in future Transformer Engine releases. + """ + + fp8_format: Format = Format.HYBRID + fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) + fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + + def __repr__(self) -> str: + return ( + f"format={str(self.fp8_format).split('.')[1]}, " + f"fp8_quant_fwd_inp={self.fp8_quant_fwd_inp}, " + f"fp8_quant_fwd_weight={self.fp8_quant_fwd_weight}, " + f"fp8_quant_bwd_grad={self.fp8_quant_bwd_grad}, " + f"fp8_gemm_fprop={self.fp8_gemm_fprop}, " + f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " + f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}" + ) + + +@dataclass() +class MXFP8BlockScaling(Recipe): + """ + Use the MXFP8 scaling factor strategy. + + In this strategy, tensors are scaled in blockwise fashion. Each group + of 32 consecutive values is scaled together using their own scaling + factor. The type of the scaling factor is E8M0 (8 bits of exponent, + 0 bits of mantissa), equivalent to scaling by a power of 2. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), in this recipe the quantized tensor and its transpose + are not numerically equivalent. Due to this, when Transformer Engine + needs both the MXFP8 tensor and its transpose (e.g. to calculate both + forward and backward pass), during the quantization both versions are + computed from the high precision input to avoid double quantization + errors. + + Parameters + ---------- + fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 + Controls the FP8 data format used during forward and backward + pass. + """ + + margin: int = 0 + fp8_format: Format = Format.E4M3 + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + + def __repr__(self) -> str: + return f"margin={self.margin}, format={str(self.fp8_format).split('.')[1]}," diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu new file mode 100644 index 0000000000..3a25d71a3b --- /dev/null +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -0,0 +1,237 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "../util/vectorized_pointwise.h" + +namespace transformer_engine { +namespace { + +constexpr int amax_kernel_threads = 512; + +template +__launch_bounds__(amax_kernel_threads) __global__ + void amax_kernel(const InputType *input, float *amax, const size_t N, + const size_t num_aligned_elements) { + VectorizedLoader loader(input, N); + InputType max = 0.f; + const int warp_id = threadIdx.x / THREADS_PER_WARP; + const size_t M = num_aligned_elements; + + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { + loader.load(tid, N); +#pragma unroll + for (int i = 0; i < nvec; ++i) { + const InputType val = static_cast(loader.separate()[i]); + __builtin_assume(max >= InputType{0.f}); + if constexpr (std::is_same_v) { +#if __CUDA_ARCH__ >= 800 + max = __hmax(__habs(val), max); +#else // Turing + max = static_cast<__nv_bfloat16>( + fmaxf(fabsf(static_cast(val)), static_cast(max))); +#endif + } else if constexpr (std::is_same_v) { + max = __hmax(__habs(val), max); + } else { + max = fmaxf(fabsf(val), max); + } + } + } + + // Reduce amax over block + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + atomicMaxFloat(amax, max); + } +} + +template +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { + // Zero out amax so we can update with atomic max + cudaMemsetAsync(amax, 0, sizeof(float), stream); + + // Return immediately if tensor is empty + if (N == 0) { + return; + } + + // Figure out alignment + auto align = CheckAlignment(N, nvec, input); + size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); + + // Figure out CUDA blocks + constexpr size_t threads = amax_kernel_threads; + size_t num_blocks = DIVUP(num_aligned_elements, threads); + constexpr size_t max_blocks = 65535; + num_blocks = std::min(num_blocks, max_blocks); + + // Launch kernel + switch (align) { + case Alignment::SAME_ALIGNED: + amax_kernel + <<>>(input, amax, N, num_aligned_elements); + break; + case Alignment::SAME_UNALIGNED: + amax_kernel + <<>>(input, amax, N, num_aligned_elements); + break; + case Alignment::DIFFERENT: { + // This case is a logic error, since there is only one pointer (input) + // in the alignment check. Still safe to process without vectorization. + amax_kernel<1, true, InputType><<>>(input, amax, N, N); + break; + } + } + + // Check results + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace +} // namespace transformer_engine + +void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax); + using namespace transformer_engine; + + // Check input tensor + NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)"); + const auto &input = *reinterpret_cast(input_); + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor for amax computation must unquantized, " + "but got scaling_mode=", + to_string(input.scaling_mode)); + NVTE_CHECK(!is_fp8_dtype(input.data.dtype), + "Input tensor for amax computation must be unquantized, but got dtype=", + to_string(input.data.dtype)); + NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data"); + CheckInputTensor(input, "input_compute_amax"); + + // Check output tensor + NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); + auto &output = *reinterpret_cast(output_); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " + "but got scaling_mode=", + to_string(output.scaling_mode)); + NVTE_CHECK(output.amax.numel() == 1, + "Output tensor for amax computation has invalid amax tensor " + "(expected 1 entry, got shape=", + output.amax.shape, ")"); + NVTE_CHECK(output.amax.dptr != nullptr, + "Output tensor for amax computation has amax tensor without data"); + NVTE_CHECK(output.amax.dtype == DType::kFloat32, + "Output tensor for amax computation has invalid amax tensor " + "(expected FP32, got dtype=", + to_string(output.amax.dtype), ")"); + CheckOutputTensor(output, "output_compute_amax"); + + // Compute amax + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); + launch_amax_kernel(reinterpret_cast(input.data.dptr), + reinterpret_cast(output.amax.dptr), input.data.numel(), + stream);); // NOLINT(*) +} + +namespace transformer_engine { +namespace { + +__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, + const float max_fp8, const bool force_pow_2_scales, + const float epsilon) { + float amax = *amax_ptr; + if (amax < epsilon) { + amax = epsilon; + } + + float scale = 1.f; + + if (isinf(amax) || amax == 0.f) { + *scale_ptr = scale; + return; + } + + scale = max_fp8 / amax; + + // The amax is too small that the scale becoming infinite in FP32. In other word, + // the scale is not representable in FP32. + if (isinf(scale)) { + // use fp32 max to represent the scale + scale = std::numeric_limits::max(); + } + + if (isnan(scale)) { + scale = 1.f; + } + + if (force_pow_2_scales) { + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + + *scale_ptr = scale; +} + +} // namespace +} // namespace transformer_engine + +void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConfig config_, + cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_scale_from_amax); + using namespace transformer_engine; + + // Check output tensor + NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); + auto &output = *reinterpret_cast(output_); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Tensor must be FP8 tensor with per-tensor scaling, " + "but got scaling_mode=", + to_string(output.scaling_mode)); + NVTE_CHECK(is_fp8_dtype(output.data.dtype), + "Tensor must be FP8, but got dtype=", to_string(output.data.dtype)); + NVTE_CHECK(output.amax.numel() == 1, + "Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape, + ")"); + NVTE_CHECK(output.amax.dptr != nullptr, "Tensor has amax tensor without data"); + NVTE_CHECK(output.amax.dtype == DType::kFloat32, + "Tensor has invalid amax tensor (expected FP32, got dtype=", + to_string(output.amax.dtype), ")"); + NVTE_CHECK(output.scale.numel() == 1, + "Tensor has invalid scale tensor (expected 1 entry, got shape=", output.scale.shape, + ")"); + NVTE_CHECK(output.scale.dptr != nullptr, "Tensor has scale tensor without data"); + NVTE_CHECK(output.scale.dtype == DType::kFloat32, + "Tensor has invalid scale tensor (expected FP32, got dtype=", + to_string(output.scale.dtype), ")"); + + // Check config + NVTE_CHECK(config_ != nullptr, "Invalid config (got NULL)"); + const auto &config = *reinterpret_cast(config_); + + // Maximum FP8 value + float max_fp8 = 0.f; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, + max_fp8 = Quantized_Limits::max_norm;); + + // Update scale + compute_scale_from_amax_kernel<<<1, 1>>>(reinterpret_cast(output.amax.dptr), + reinterpret_cast(output.scale.dptr), max_fp8, + config.force_pow_2_scales, config.amax_epsilon); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/recipe/delayed_scaling.cu b/transformer_engine/common/recipe/delayed_scaling.cu index fcace6ac3d..658ce054da 100644 --- a/transformer_engine/common/recipe/delayed_scaling.cu +++ b/transformer_engine/common/recipe/delayed_scaling.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -46,7 +46,6 @@ struct AmaxParam { int num_scale = 0; float* amax_history = nullptr; float* scale = nullptr; - float* scale_inv = nullptr; }; // dummy struct for kernel_bulk's other params @@ -83,10 +82,9 @@ constexpr size_t bsize = 256; * Grid dims: num_scales x 1 x 1 */ __global__ void __launch_bounds__(bsize) - kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr, - const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr, - float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length, - size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) { + kernel(const float* amax_history_ptr, const float* scale_ptr, float* updated_amax_history_ptr, + float* updated_scale_ptr, size_t amax_history_length, size_t amax_history_stride, + AmaxComputeAlgo amax_compute_algo, float scaled_max) { const size_t tid = threadIdx.x; const size_t bid = blockIdx.x; @@ -135,7 +133,7 @@ __global__ void __launch_bounds__(bsize) } } - // Update scale and scale inverse + // Update scale if (tid == 0) { // Update scale float scale; @@ -152,15 +150,6 @@ __global__ void __launch_bounds__(bsize) scale = std::numeric_limits::max(); } updated_scale_ptr[bid] = scale; - - // Update scale inverse - float scale_inv; - if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) { - scale_inv = 1 / scale; - } else { - scale_inv = scale_inv_ptr[bid]; - } - updated_scale_inv_ptr[bid] = scale_inv; } } @@ -232,7 +221,7 @@ __global__ void __launch_bounds__(bsize) } } - // Update scale and scale inverse + // Update scale if (tid == 0) { // Computing the scaling factor requires consideration of the following scenarios: // 1. amax == 0: @@ -259,7 +248,6 @@ __global__ void __launch_bounds__(bsize) scale = std::numeric_limits::max(); } p.param[bid].scale[count] = scale; - p.param[bid].scale_inv[count] = 1 / scale; } } } @@ -268,23 +256,12 @@ __global__ void __launch_bounds__(bsize) } // namespace -void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv, - const Tensor& scale_inv_mask, Tensor* updated_amax_history_, - Tensor* updated_scale_, Tensor* updated_scale_inv_, +void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, + Tensor* updated_amax_history_, Tensor* updated_scale_, const std::string& amax_compute_algo, DType fp8_dtype, float margin, cudaStream_t stream) { auto& updated_amax_history = *updated_amax_history_; auto& updated_scale = *updated_scale_; - auto& updated_scale_inv = *updated_scale_inv_; - - // Number of elements in tensor - auto numel = [](const Tensor& tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; // Check tensors NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(), @@ -293,18 +270,9 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons const size_t num_scales = amax_history.data.shape[1]; NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ", dtype_name(amax_history.data.dtype), "."); - NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ", - numel(scale), "."); + NVTE_CHECK(scale.numel() == num_scales, "Expected ", num_scales, " elements, ", "but found ", + scale.numel(), "."); NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), "."); - if (scale_inv_mask.data.dptr != nullptr) { - NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ", - numel(scale_inv), "."); - NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32); - NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(scale_inv_mask), "."); - NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ", - dtype_name(scale_inv_mask.data.dtype), "."); - } NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ", updated_amax_history.data.shape.size(), " dims."); NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ", @@ -313,14 +281,10 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons "but found ", updated_amax_history.data.shape[1]); NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ", dtype_name(updated_amax_history.data.dtype), "."); - NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale), "."); + NVTE_CHECK(updated_scale.numel() == num_scales, "Expected ", num_scales, " elements, ", + "but found ", updated_scale.numel(), "."); NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ", dtype_name(updated_scale.data.dtype), "."); - NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ", - "but found ", numel(updated_scale_inv), "."); - NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ", - dtype_name(updated_scale_inv.data.dtype), "."); // amax value to use for updating scaling factor AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID; @@ -340,11 +304,8 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons const size_t grid_size = num_scales; amax_and_scale_update_impl::kernel<<>>( static_cast(amax_history.data.dptr), static_cast(scale.data.dptr), - static_cast(scale_inv.data.dptr), - static_cast(scale_inv_mask.data.dptr), static_cast(updated_amax_history.data.dptr), - static_cast(updated_scale.data.dptr), - static_cast(updated_scale_inv.data.dptr), amax_history_length, num_scales, + static_cast(updated_scale.data.dptr), amax_history_length, num_scales, amax_compute_algo_, scaled_max); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -352,7 +313,6 @@ void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, cons void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, std::vector amax_histories, std::vector scales, - std::vector scale_invs, const std::string& amax_compute_algo, DType fp8_dtype, float margin, cudaStream_t stream) { using namespace transformer_engine; @@ -370,15 +330,6 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, // Expected maximum value after scale is applied const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin); - // Number of elements in tensor - auto numel = [](const Tensor* tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor->data.shape) { - acc *= dim; - } - return acc; - }; - // Number of tensors in the bulk const size_t num_tensors = amax_histories.size(); size_t num_remaining_tensors = num_tensors; @@ -404,22 +355,21 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, dtype_name(amax_histories[i]->data.dtype), "."); NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ", amax_histories[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ", + NVTE_CHECK(amax_histories[i]->numel() == amax_history_length * num_scale, "Expected ", amax_history_length * num_scale, " elements, ", "but found ", - numel(amax_histories[i]), "."); + amax_histories[i]->numel(), "."); NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ", dtype_name(scales[i]->data.dtype), "."); NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(), " dims"); - NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ", - numel(scales[i]), "."); + NVTE_CHECK(scales[i]->numel() == num_scale, "Expected ", num_scale, " elements, ", "Found ", + scales[i]->numel(), "."); // amax parameters kernel_num_scales += num_scale; p.param[pi].num_scale = num_scale; p.param[pi].amax_history = static_cast(amax_histories[i]->data.dptr); p.param[pi].scale = static_cast(scales[i]->data.dptr); - p.param[pi].scale_inv = static_cast(scale_invs[i]->data.dptr); } // Launch CUDA kernel @@ -441,34 +391,30 @@ void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer, } // namespace transformer_engine void nvte_delayed_scaling_recipe_amax_and_scale_update( - const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv, - const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale, - NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, + const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history, + NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update); using namespace transformer_engine; delayed_scaling_recipe::amax_and_scale_update( *reinterpret_cast(amax_history), *reinterpret_cast(scale), - *reinterpret_cast(scale_inv), *reinterpret_cast(scale_inv_mask), reinterpret_cast(updated_amax_history), reinterpret_cast(updated_scale), - reinterpret_cast(updated_scale_inv), amax_compute_algo, - static_cast(fp8_dtype), margin, stream); + amax_compute_algo, static_cast(fp8_dtype), margin, stream); } void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( const NVTETensor amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) { + std::vector scales, const char* amax_compute_algo, NVTEDType fp8_dtype, + float margin, cudaStream_t stream) { NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction); using namespace transformer_engine; size_t num_tensors = amax_histories.size(); - std::vector t_amax_histories, t_scales, t_scale_invs; + std::vector t_amax_histories, t_scales; for (size_t i = 0; i < num_tensors; i++) { t_amax_histories.push_back(reinterpret_cast(amax_histories[i])); t_scales.push_back(reinterpret_cast(scales[i])); - t_scale_invs.push_back(reinterpret_cast(scale_invs[i])); } delayed_scaling_recipe::amax_and_scale_update_after_reduction( *reinterpret_cast(amax_reduction_buffer), t_amax_histories, t_scales, - t_scale_invs, amax_compute_algo, static_cast(fp8_dtype), margin, stream); + amax_compute_algo, static_cast(fp8_dtype), margin, stream); } diff --git a/transformer_engine/common/rmsnorm/rmsnorm.h b/transformer_engine/common/rmsnorm/rmsnorm.h deleted file mode 100644 index 8b4e1cf24e..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm.h +++ /dev/null @@ -1,89 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ -#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ - -#include - -#include -#include -#include -#include -#include - -#include "../common.h" -#include "../layer_norm/ln.h" - -namespace transformer_engine { -namespace rmsnorm { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct LaunchParams : public transformer_engine::layer_norm::LaunchParams {}; -struct FwdParams : public transformer_engine::layer_norm::FwdParams {}; -struct BwdParams : public transformer_engine::layer_norm::BwdParams {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -using FwdFunction = std::function &, const bool)>; -using BwdFunction = std::function &, const bool)>; -using FunctionKey = uint64_t; -using FwdTunedRegistry = std::unordered_map; -using BwdTunedRegistry = std::unordered_map; -using FwdGeneralRegistry = std::unordered_map>; -using BwdGeneralRegistry = std::unordered_map>; - -extern FwdTunedRegistry FWD_TUNED_FUNCS; -extern BwdTunedRegistry BWD_TUNED_FUNCS; -extern FwdGeneralRegistry FWD_GENERAL_FUNCS; -extern BwdGeneralRegistry BWD_GENERAL_FUNCS; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdTunedRegistrar { - explicit FwdTunedRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - FWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FwdGeneralRegistrar { - explicit FwdGeneralRegistrar(FwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - FWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdTunedRegistrar { - explicit BwdTunedRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(HIDDEN_SIZE); - BWD_TUNED_FUNCS.insert({key, f}); - } -}; - -////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct BwdGeneralRegistrar { - explicit BwdGeneralRegistrar(BwdFunction f) { - uint64_t key = layer_norm::Types2Key::get(0); - BWD_GENERAL_FUNCS[key].insert({HIDDEN_SIZE, f}); - } -}; - -} // namespace rmsnorm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_H_ diff --git a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/rmsnorm/rmsnorm_api.cpp deleted file mode 100644 index 9b143b2f85..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_api.cpp +++ /dev/null @@ -1,387 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "../common.h" -#include "rmsnorm.h" -#include "transformer_engine/rmsnorm.h" - -/* - -Supported Type combinations: - -input compute weights output -======================================= -fp32 fp32 fp32 fp32 -fp16 fp32 fp16 fp16 -bf16 fp32 bf16 bf16 -fp32 fp32 fp32 fp16 -fp32 fp32 fp32 bf16 -fp32 fp32 fp32 fp8 -fp16 fp32 fp16 fp8 -bf16 fp32 bf16 fp8 - -Remarks: -Input type = Weight type -Compute always in FP32 - -*/ - -namespace transformer_engine { - -namespace layer_norm { -uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size); -} - -namespace rmsnorm { - -using namespace transformer_engine; - -FwdTunedRegistry FWD_TUNED_FUNCS; -BwdTunedRegistry BWD_TUNED_FUNCS; -FwdGeneralRegistry FWD_GENERAL_FUNCS; -BwdGeneralRegistry BWD_GENERAL_FUNCS; - -FwdFunction &get_fwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::FwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.z) && FWD_TUNED_FUNCS.count(tuned_key) > 0) { - return FWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (FWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("FWD: Unsupported types."); - } - auto &general_func_map = FWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype, - const layer_norm::BwdParams ¶ms) { - // Look for tuned kernel - auto tuned_key = layer_norm::get_key(wtype, itype, otype, ctype, params.cols); - auto is_aligned = [](const void *ptr) -> bool { - // Assume vectorized memory accesses are <=16B - return reinterpret_cast(ptr) % 16 == 0; - }; - if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && - is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && - is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && - BWD_TUNED_FUNCS.count(tuned_key) > 0) { - return BWD_TUNED_FUNCS.at(tuned_key); - } - - // Pick general kernel - auto general_key = layer_norm::get_key(wtype, itype, otype, ctype, 0); - if (BWD_GENERAL_FUNCS.count(general_key) == 0) { - NVTE_ERROR("BWD: Unsupported types."); - } - auto &general_func_map = BWD_GENERAL_FUNCS.at(general_key); - auto func_iter = general_func_map.lower_bound(params.cols); - if (func_iter == general_func_map.end()) { - // Hidden size is too big, need to use multi-CTA - return general_func_map.rbegin()->second; - } else { - return func_iter->second; - } -} - -// //////////////////////////////////////////////////////////////////////////////////////////////////// - -inline size_t product(const std::vector &shape) { - return std::accumulate(shape.cbegin(), shape.cend(), size_t{1}, std::multiplies<>()); -} - -} // namespace rmsnorm - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, - Tensor *rsigma, cudaStream_t stream, const int multiprocessorCount, - Tensor *workspace, Tensor *barrier, const bool zero_centered_gamma) { - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = z->data.dtype; - const bool fp8_out = is_fp8_dtype(otype); - auto ctype = DType::kFloat32; - - NVTE_CHECK(x.data.shape.size() == 2); - - const size_t rows = x.data.shape[0]; - const size_t cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(hidden_size == cols); - NVTE_CHECK(epsilon >= 0.f); - - NVTE_CHECK(z->data.shape == x.data.shape); - - NVTE_CHECK(rsigma->data.shape == std::vector{rows}); - NVTE_CHECK(rsigma->data.dtype == ctype); - - rmsnorm::LaunchParams launch_params; - - launch_params.multiprocessorCount = multiprocessorCount; - launch_params.stream = stream; - - // Set the kernel runtime parameters. - rmsnorm::FwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma->data.dptr; - params.gamma = gamma.data.dptr; - params.beta = nullptr; - params.z = z->data.dptr; - params.epsilon = epsilon; - params.amax = z->amax.dptr; - params.scale = z->scale.dptr; - params.scale_inv = z->scale_inv.dptr; - params.fp8_out = fp8_out; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_fwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - if (launch_params.workspace_bytes == 0) { - launch_params.workspace_bytes = 1; - } - - if (workspace->data.dptr == nullptr) { - NVTE_CHECK(barrier->data.dptr == nullptr); - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(x, "x"); - CheckInputTensor(gamma, "gamma"); - - CheckOutputTensor(*z, "z"); - CheckOutputTensor(*rsigma, "rsigma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - } - - // Clear buffers - if (params.fp8_out) { - cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype), - stream); - } - if (launch_params.barrier_size > 0) { - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); - - return; -} - -void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const Tensor &gamma, - Tensor *dx, Tensor *dgamma, Tensor *dgamma_part, cudaStream_t stream, - const int multiprocessorCount, Tensor *workspace, Tensor *barrier, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - auto itype = x.data.dtype; - auto wtype = gamma.data.dtype; - auto otype = wtype; - auto ctype = DType::kFloat32; - - NVTE_CHECK(dz.data.dtype == otype); - NVTE_CHECK(rsigma.data.dtype == ctype); - - NVTE_CHECK(x.data.shape.size() == 2); - NVTE_CHECK(dz.data.shape == x.data.shape); - - const auto rows = x.data.shape[0]; - const auto cols = x.data.shape[1]; - const auto hidden_size = gamma.data.shape[0]; - - NVTE_CHECK(gamma.data.shape[0] == cols); - - NVTE_CHECK(dx->data.shape == x.data.shape); - NVTE_CHECK(dx->data.dtype == x.data.dtype); - - NVTE_CHECK(dgamma->data.shape == gamma.data.shape); - NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - - rmsnorm::LaunchParams launch_params; - launch_params.stream = stream; - launch_params.multiprocessorCount = multiprocessorCount; - - // Set the kernel runtime parameters. - rmsnorm::BwdParams ¶ms = launch_params.params; - params.rows = rows; - params.cols = cols; - params.x = x.data.dptr; - params.mu = nullptr; - params.rs = rsigma.data.dptr; - params.gamma = gamma.data.dptr; - params.dz = dz.data.dptr; - params.dx = dx->data.dptr; - params.dbeta = nullptr; - params.dgamma = dgamma->data.dptr; - params.dbeta_part = nullptr; - params.dgamma_part = dgamma_part->data.dptr; - params.zero_centered_gamma = zero_centered_gamma; - - // Request the kernel launcher. - auto launcher = rmsnorm::get_bwd_launcher(wtype, itype, otype, ctype, params); - - // Query the kernel-specific launch parameters. - launcher(launch_params, true); - - // Populate shape and dtypes for FW to allocate memory - if (dgamma_part->data.dptr == nullptr) { - dgamma_part->data.dtype = ctype; - dgamma_part->data.shape = {static_cast(launch_params.params.ctas_per_col), - hidden_size}; - - workspace->data.dtype = DType::kByte; - workspace->data.shape = {launch_params.workspace_bytes}; - - barrier->data.dtype = DType::kInt32; - barrier->data.shape = {launch_params.barrier_size}; - - return; - } else { - auto pdw_shape = - std::vector{static_cast(launch_params.params.ctas_per_col), hidden_size}; - NVTE_CHECK(dgamma_part->data.dtype == ctype); - NVTE_CHECK(dgamma_part->data.shape == pdw_shape); - } - - if (launch_params.barrier_size > 0) { - NVTE_CHECK(barrier->data.dptr != nullptr); - NVTE_CHECK(barrier->data.dtype == DType::kInt32); - NVTE_CHECK(barrier->data.shape == std::vector{launch_params.barrier_size}); - } - - if (launch_params.workspace_bytes > 0) { - NVTE_CHECK(workspace->data.dptr != nullptr); - NVTE_CHECK(workspace->data.dtype == DType::kByte); - NVTE_CHECK(workspace->data.shape == std::vector{launch_params.workspace_bytes}); - } - - // Tensor checks are delayed here in order to recover workspace sizes with null data - CheckInputTensor(dz, "dz"); - CheckInputTensor(x, "x"); - CheckInputTensor(rsigma, "rsigma"); - CheckInputTensor(gamma, "gamma"); - CheckOutputTensor(*dx, "dx"); - CheckOutputTensor(*dgamma, "dgamma"); - - if (launch_params.barrier_size > 0) { - params.workspace = workspace->data.dptr; - params.barrier = reinterpret_cast(barrier->data.dptr); - cudaMemsetAsync(params.barrier, 0, - rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype), - stream); - } - - // Launch the kernel. - launcher(launch_params, false); -} - -} // namespace transformer_engine - -void nvte_rmsnorm_fwd(const NVTETensor x, // Nxhidden_size - const NVTETensor gamma, // hidden_size - const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm_fwd); - using namespace transformer_engine; - rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), false); -} - -void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size - const NVTETensor x, // Nxhidden_size - const NVTETensor rsigma, // N, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm_bwd); - using namespace transformer_engine; - rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(rsigma), *reinterpret_cast(gamma), - reinterpret_cast(dx), reinterpret_cast(dgamma), - reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), false); -} - -void nvte_rmsnorm1p_fwd(const NVTETensor x, // Nxhidden_size - const NVTETensor gamma, // hidden_size - const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, - const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm1p_fwd); - using namespace transformer_engine; - rmsnorm_fwd(*reinterpret_cast(x), *reinterpret_cast(gamma), - epsilon, reinterpret_cast(z), reinterpret_cast(rsigma), stream, - multiprocessorCount, reinterpret_cast(workspace), - reinterpret_cast(barrier), true); -} - -void nvte_rmsnorm1p_bwd(const NVTETensor dz, // Nxhidden_size - const NVTETensor x, // Nxhidden_size - const NVTETensor rsigma, // N, FP32! - const NVTETensor gamma, // hidden_size - NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, - cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, - NVTETensor barrier) { - NVTE_API_CALL(nvte_rmsnorm1p_bwd); - using namespace transformer_engine; - rmsnorm_bwd(*reinterpret_cast(dz), *reinterpret_cast(x), - *reinterpret_cast(rsigma), *reinterpret_cast(gamma), - reinterpret_cast(dx), reinterpret_cast(dgamma), - reinterpret_cast(dgamma_part), stream, multiprocessorCount, - reinterpret_cast(workspace), reinterpret_cast(barrier), true); -} diff --git a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu deleted file mode 100644 index 3215a6a9d4..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ /dev/null @@ -1,220 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "rmsnorm.h" -#include "rmsnorm_bwd_kernels.cuh" -#include "rmsnorm_kernel_traits.h" - -using namespace transformer_engine::rmsnorm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = - rmsnorm::Kernel_traits; - auto kernel = &rmsnorm_bwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::reduce_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); - } - - using Kernel_traits_f = - Kernel_traits_finalize; - - auto kernel_f = &rmsnorm::rmsnorm_bwd_finalize_tuned_kernel; - kernel_f<<>>( - launch_params.params); -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Instantiate kernel - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_general_kernel; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(typename Kernel_traits::reduce_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } - - // Launch finalization kernel - constexpr uint32_t WARPS_M_FINAL = 4; - constexpr uint32_t WARPS_N_FINAL = 1; - constexpr uint32_t ELTS_N_PER_CTA_FINAL = - (Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL * BYTES_PER_LDG_FINAL / sizeof(compute_t)); - auto kernel_final = - &rmsnorm_bwd_finalize_general_kernel; - dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); - dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); - kernel_final<<>>(launch_params.params); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_BWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, \ - configure_params); \ - } \ - static BwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_BWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ - void rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static BwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_bwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create rmsnorm tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -REGISTER_BWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); -REGISTER_BWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); - -// Create rmsnorm general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... -// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL - -REGISTER_BWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, fp16, fp32, fp16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8, 4); -REGISTER_BWD_GENERAL_LAUNCHER(128, bf16, fp32, bf16, fp32, 4, 1, 8, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(512, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4); - -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); -REGISTER_BWD_GENERAL_LAUNCHER(4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu deleted file mode 100644 index 3c8e121540..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ /dev/null @@ -1,227 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "rmsnorm.h" -#include "rmsnorm_fwd_kernels.cuh" -#include "rmsnorm_kernel_traits.h" - -using namespace transformer_engine::rmsnorm; - -template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_tuned_kernel; - - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); - launch_params.params.ctas_per_row = CTAS_PER_ROW; - launch_params.params.ctas_per_col = - launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (Kernel_traits::CTAS_PER_ROW > 1) { - launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; - launch_params.workspace_bytes = launch_params.params.ctas_per_col * Kernel_traits::WARPS_M * - Kernel_traits::CTAS_PER_ROW * - sizeof(typename Kernel_traits::Stats::stats_t) * 2; - } - return; - } - - if (Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024) { - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - Kernel_traits::SMEM_BYTES_FWD)); - } - auto stream = launch_params.stream; - auto ctas_per_col = launch_params.params.ctas_per_col; - auto ctas_per_row = launch_params.params.ctas_per_row; - - if (ctas_per_row == 1) { - kernel<<>>( - launch_params.params); - } else { - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); - } -} - -template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) - using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_fwd_general_kernel; - auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; - - // Configure kernel params - const int rows = launch_params.params.rows; - const int cols = launch_params.params.cols; - int ctas_per_col = launch_params.params.ctas_per_col; - int ctas_per_row = launch_params.params.ctas_per_row; - if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); - const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; - ctas_per_row = ceil_div(cols, HIDDEN_SIZE); - ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); - launch_params.params.ctas_per_row = ctas_per_row; - launch_params.params.ctas_per_col = ctas_per_col; - - launch_params.barrier_size = 0; - launch_params.workspace_bytes = 0; - if (launch_params.params.ctas_per_row > 1) { - launch_params.barrier_size = 2 * ctas_per_col; - launch_params.workspace_bytes = - (ctas_per_col * WARPS_M * ctas_per_row * sizeof(compute_t) * 2); - } - return; - } - - // Launch kernel - auto stream = launch_params.stream; - dim3 grid(ctas_per_row * ctas_per_col); - dim3 block(Kernel_traits::THREADS_PER_CTA); - if (ctas_per_row == 1) { - kernel<<>>(launch_params.params); - } else { - void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define REGISTER_FWD_TUNED_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, \ - WARPS_M, WARPS_N, BYTES_PER_LDG) \ - void rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_tuned_(launch_params, configure_params); \ - } \ - static FwdTunedRegistrar \ - reg_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_tuned_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -#define REGISTER_FWD_GENERAL_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, \ - BYTES_PER_LDG) \ - void rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - LaunchParams &launch_params, const bool configure_params) { \ - launch_general_(launch_params, configure_params); \ - } \ - static FwdGeneralRegistrar \ - reg_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ - rmsnorm_fwd_general_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Create rmsnorm tuned launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(512, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(768, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp32, fp32, fp32, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, fp16, fp16, fp16, fp32, 1, 4, 1, 16); -REGISTER_FWD_TUNED_LAUNCHER(2048, bf16, bf16, bf16, fp32, 1, 4, 1, 16); - -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16); - -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16); -REGISTER_FWD_TUNED_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16); - -// Create rmsnorm general launch function and register. Macro signature: -// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, WARPS_M, WARPS_N, BYTES_PER_LDG - -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp8e4m3, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp8e4m3, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp8e4m3, fp32, 1, 4, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp16, fp16, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, fp16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, bf16, bf16, bf16, fp32, 4, 1, 8); -REGISTER_FWD_GENERAL_LAUNCHER(128, fp32, fp32, bf16, fp32, 4, 1, 8); - -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(512, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(1024, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp32, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp16, fp16, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, fp16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, bf16, bf16, bf16, fp32, 4, 1, 16); -REGISTER_FWD_GENERAL_LAUNCHER(2048, fp32, fp32, bf16, fp32, 4, 1, 16); - -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp32, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp16, fp16, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, fp16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, bf16, bf16, bf16, fp32, 1, 4, 16); -REGISTER_FWD_GENERAL_LAUNCHER(8192, fp32, fp32, bf16, fp32, 1, 4, 16); diff --git a/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h b/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h deleted file mode 100644 index 26d7da6400..0000000000 --- a/transformer_engine/common/rmsnorm/rmsnorm_kernel_traits.h +++ /dev/null @@ -1,42 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ -#define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ - -#include "../common.h" -#include "../layer_norm/ln_kernel_traits.h" -#include "../utils.cuh" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace transformer_engine { -namespace rmsnorm { - -template < - uint32_t HIDDEN_SIZE_, typename weight_t_, typename input_t_, typename output_t_, - typename compute_t_, typename index_t_, uint32_t THREADS_PER_CTA_, uint32_t BYTES_PER_LDG_, - typename Base = - layer_norm::Kernel_traits_finalize > -struct Kernel_traits_finalize : public Base {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template > -struct Kernel_traits : public Base {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace rmsnorm -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_KERNEL_TRAITS_H_ diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu new file mode 100644 index 0000000000..a0fffc783c --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -0,0 +1,338 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace { + +constexpr int TB_DIM = 32; +constexpr int NEW_SF_TILE_DIM_K = 16; +constexpr int N_SF_PER_TD_PER_TILE = 4; + +// output is in ~K-major interleaved blocks +constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr int NEW_SF_TILE_DIM_M_I32 = 32; + +template +__device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { + // inp, 4-byte chunks [0,1,2,3, 4,5,6,7, 8,9,10,11, 12,13,14,15] + // out, swapping byte to form new 4-byte chunks [0,4,8,12, 1,5,9,13, 2,6,10,14, 3,7,11,15] + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t new_regs[kVectorSize]; + int32_t* regs = reinterpret_cast(regs_vec); + +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { +#pragma unroll + for (int j = 0; j < N_SF_PER_TD_PER_TILE; j++) { + new_regs[i * N_SF_PER_TD_PER_TILE + j] = + (((regs[i + 0 * N_TILE_PER_TD] >> 8 * j) & 0xFF)) | + (((regs[i + 1 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 8) | + (((regs[i + 2 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 16) | + (((regs[i + 3 * N_TILE_PER_TD] >> 8 * j) & 0xFF) << 24); + } + } +#pragma unroll + for (int i = 0; i < kVectorSize; i++) regs[i] = new_regs[i]; +} + +template +__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // input is in M-major + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M / 4; + constexpr int SF_TILE_DIM_K_I32 = SF_TILE_DIM_K; + + const int M_i32 = M / 4; + const int K_i32 = K; + + int m_tiles_in_tb = N_TILE_PER_TD; + int k_tiles_in_tb = TB_DIM; + if (blockIdx.x == gridDim.x - 1) { + k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; + } + if (blockIdx.y == gridDim.y - 1) { + m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; + } + + const int32_t* input_i32 = reinterpret_cast(input) + + blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + + blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + int32_t* output_i32[N_TILE_PER_TD]; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + + (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + } + extern __shared__ int slm[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < m_tiles_in_tb * SF_TILE_DIM_M_I32 && + threadIdx.y < k_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // local shuffle + regs_shuffle_with_bit_shifts(regs_vec); + + // store, regs -> shared + int tM = threadIdx.x * N_SF_PER_TD; + int* slm_tile = slm + (threadIdx.y * SF_TILE_SIZE_I32 + + tM / SF_TILE_DIM_M * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int i = 0; i < N_SF_PER_TD; i++) { + /* TODO rotate_i */ + slm_tile[(tM % SF_TILE_DIM_M) / NEW_SF_TILE_DIM_M_I32 + + ((tM + i) % NEW_SF_TILE_DIM_M_I32) * NEW_SF_TILE_DIM_K_I32] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; +#pragma unroll + for (int i = 0; i < m_tiles_in_tb; i++) { + __align__(16) int4* output_v4i = reinterpret_cast(output_i32[i]); + __align__(16) int4* slm_v4i = + reinterpret_cast(slm + i * k_tiles_in_tb * SF_TILE_SIZE_I32); +#pragma unroll + for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; + j += blockDim.x * blockDim.y) { + output_v4i[j] = slm_v4i[j]; + } + } +} + +template +__device__ inline void regs_shuffle(LType* regs_vec) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + if constexpr (N_TILE_PER_TD == 1) return; + + constexpr int kVectorSize = N_SF_PER_TD_PER_TILE * N_TILE_PER_TD; + int32_t tmp[kVectorSize]; + int32_t* ptr = reinterpret_cast(regs_vec); +#pragma unroll + for (int i = 0; i < kVectorSize; i++) + tmp[i % N_TILE_PER_TD * N_SF_PER_TD_PER_TILE + i / N_TILE_PER_TD] = ptr[i]; + +#pragma unroll + for (int i = 0; i < kVectorSize; i++) ptr[i] = tmp[i]; +} + +template +__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, + const int K) { + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // input is in K-major + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + constexpr int SF_TILE_DIM_M_I32 = SF_TILE_DIM_M; + + int n_tiles_in_tb = N_TILES_IN_TB; + const int K_i32 = K / 4; + if (blockIdx.x == gridDim.x - 1) { + n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; + } + + const int* input_i32 = reinterpret_cast(input) + + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; + int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + + blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + + extern __shared__ int4 slm_v4i[]; + + // load, global -> regs + LType regs_vec[N_SF_PER_TD_PER_TILE]; + if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { +#pragma unroll + for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { + regs_vec[i] = __ldg(reinterpret_cast( + input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); + } + + // shuffle regs + regs_shuffle(regs_vec); + +// store, regs -> shared +#pragma unroll + for (int i = 0; i < N_TILE_PER_TD; i++) { + /* TODO rotate i */ + slm_v4i[(threadIdx.x * N_TILE_PER_TD + i) * SF_TILE_SIZE_I32 / 4 + threadIdx.y] = + reinterpret_cast(regs_vec)[i]; + } + } + __syncthreads(); + + // store, shared -> global + int linear_id = threadIdx.y * blockDim.x + threadIdx.x; + __align__(16) int4* output_v4i = reinterpret_cast(output_i32); +#pragma unroll + for (int i = linear_id; i < SF_TILE_SIZE_I32 * n_tiles_in_tb / 4; i += blockDim.x * blockDim.y) { + output_v4i[i] = slm_v4i[i]; + } +} + +} // namespace + +namespace transformer_engine { + +void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { + if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); + } + + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "scaling_factor_input"); + CheckInputTensor(*output, "scaling_factor_output"); + + auto& scaling_mode = input->scaling_mode; + + // 1D block scaling, row-wise or colum-wise + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + const int m = + input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; + const int k = + input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + dim3 block_size(TB_DIM, TB_DIM); + if (input->has_data()) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 2: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + case 1: + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_row_scaling_kernel + <<>>(input->scale_inv.dptr, + output->scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + if (input->has_columnwise_data()) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 2: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + case 1: + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + swizzle_col_scaling_kernel + <<>>( + input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + + // 2D block scaling + } else { + NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); + } + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("CUDA Error: %s\n", cudaGetErrorString(err)); + exit(-1); + } +} +} // namespace transformer_engine + +/* + * WIP (Phuong): + * - Opt for bank conflicts + * - Adding swizzle for 2d-block scaling. +*/ +void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_scaling_factors); + using namespace transformer_engine; + swizzle_scaling_factors(reinterpret_cast(input), reinterpret_cast(output), + stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a3b49f9fa..1f8bfca2c9 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1,76 +1,201 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include +#include +#include + #include "common.h" namespace transformer_engine { -size_t typeToSize(const transformer_engine::DType type) { +size_t typeToSize(const DType type) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, return TypeInfo::size;); // NOLINT(*) } -bool is_fp8_dtype(const transformer_engine::DType t) { - return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2; +bool is_fp8_dtype(const DType t) { return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2; } + +std::string to_string(const DType type) { + switch (type) { + case DType::kByte: + return "Byte"; + case DType::kBFloat16: + return "BFloat16"; + case DType::kFloat16: + return "Float16"; + case DType::kFloat32: + return "Float32"; + case DType::kFloat8E4M3: + return "Float8E4M3"; + case DType::kFloat8E5M2: + return "Float8E5M2"; + case DType::kFloat8E8M0: + return "Float8E8M0"; + case DType::kInt32: + return "Int32"; + case DType::kInt64: + return "Int64"; + default: + return concat_strings("Invalid type ", static_cast(type)); + } +} + +std::string to_string(const NVTEScalingMode &mode) { + switch (mode) { + case NVTE_DELAYED_TENSOR_SCALING: + return "Delayed Tensor Scaling"; + case NVTE_MXFP8_1D_SCALING: + return "MXFP8 1D Scaling"; + case NVTE_INVALID_SCALING: + return "Invalid Scaling"; + } + return "Invalid Scaling"; +} + +void CheckNoopTensor(const Tensor &t, const std::string &name) { + if (t.data.dptr != nullptr) { + NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(), + "."); + NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name, + " noop. Expected kFloat32."); + } +} + +void CheckScaleTensorShape(const Tensor &t, const std::string &name) { + NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!"); + if (is_tensor_scaling(t.scaling_mode)) { + // per-tensor scaling + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected (1), got ", + t.columnwise_scale_inv.shape, ")"); + } + } else { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { + // Need (4, 128) alignment even for e8 scaling factor + auto block_alignment = std::vector{128ul, 4ul}; + size_t expected_x, expected_y, alignment; + + if (t.has_data()) { + alignment = block_alignment[0]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(1)), alignment) * alignment; + alignment = block_alignment[1]; + expected_y = + DIVUP(DIVUP(t.flat_last_dim(), static_cast(32)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + alignment = block_alignment[1]; + expected_x = + DIVUP(DIVUP(t.flat_first_dim(), static_cast(32)), alignment) * alignment; + alignment = block_alignment[0]; + expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } + } + } } void CheckInputTensor(const Tensor &t, const std::string &name) { - const DType type = t.data.dtype; + const DType type = t.dtype(); if (is_fp8_dtype(type)) { // FP8 input needs to have scale_inv - NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale."); - NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{1}); + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor input ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Byte, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { - NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + "."); - NVTE_CHECK(t.scale_inv.dptr == nullptr, - "Scale_inv is not supported for non-FP8 input " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); } - NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!"); + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!"); + + CheckScaleTensorShape(t, name); } void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) { - const DType type = t.data.dtype; + const DType type = t.dtype(); if (is_fp8_dtype(type)) { - // FP8 output needs to have scale, amax and scale_inv - NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor."); - NVTE_CHECK(t.amax.dtype == DType::kFloat32); - NVTE_CHECK(t.amax.shape == std::vector{1}); - NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale."); - NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32); - NVTE_CHECK(t.scale_inv.shape == std::vector{1}); - NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale."); - NVTE_CHECK(t.scale.dtype == DType::kFloat32); - NVTE_CHECK(t.scale.shape == std::vector{1}); + // FP8 output needs to have scale, scale_inv and (if delayed scaling) amax + if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) { + NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ", + to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")"); + NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name, + " (expected 1 entry, got shape=", t.amax.shape, ")"); + } + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 || + t.columnwise_scale_inv.dtype == DType::kFloat8E8M0, + "FP8 scaling factor output ", name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float32 or Float8E8M0, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { - NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + "."); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + "."); - NVTE_CHECK(t.scale_inv.dptr == nullptr, - "Scale_inv is not supported for non-FP8 output " + name + "."); + NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); + NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); + NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); + NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, + "Scale_inv is not supported for non-FP8 input ", name); } if (!allow_empty) { - NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!"); + NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output ", name, " is not allocated!"); } + + CheckScaleTensorShape(t, name); } } // namespace transformer_engine -NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax, - float *scale, float *scale_inv) { +NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) { transformer_engine::Tensor *ret = new transformer_engine::Tensor; - ret->data.dptr = dptr; - ret->data.shape = std::vector(shape.data, shape.data + shape.ndim); - ret->data.dtype = static_cast(dtype); - ret->amax.dptr = amax; - ret->scale.dptr = scale; - ret->scale_inv.dptr = scale_inv; + ret->scaling_mode = scaling_mode; return ret; } @@ -81,49 +206,109 @@ void nvte_destroy_tensor(NVTETensor tensor) { } NVTEDType nvte_tensor_type(const NVTETensor tensor) { + if (tensor == nullptr) return kNVTEFloat32; return static_cast( - reinterpret_cast(tensor)->data.dtype); + reinterpret_cast(tensor)->dtype()); } NVTEShape nvte_tensor_shape(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); + } NVTEShape ret; - ret.data = t.data.shape.data(); - ret.ndim = t.data.shape.size(); + + // Determine tensor shape depending on tensor format + const auto &t = *reinterpret_cast(tensor); + switch (t.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!t.has_data() && t.has_columnwise_data()) { + // We can infer tensor shape if FP8 tensor only has FP8 data + // transpose. However, NVTEShape only contains a pointer and + // cannot store temporary data. We hack around this by caching + // the tensor shape within the empty FP8 data. + auto &shape_cache = const_cast &>(t.data.shape); + shape_cache.clear(); + if (!t.columnwise_data.shape.empty()) { + for (size_t i = 1; i < t.columnwise_data.shape.size(); i++) { + shape_cache.push_back(t.columnwise_data.shape[i]); + } + shape_cache.push_back(t.columnwise_data.shape.front()); + } + ret.data = shape_cache.data(); + ret.ndim = shape_cache.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (!t.has_data() && t.has_columnwise_data()) { + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + } else { + ret.data = t.data.shape.data(); + ret.ndim = t.data.shape.size(); + } + break; + } + default: + NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", + transformer_engine::to_string(t.scaling_mode), "\""); + } + return ret; } -size_t nvte_tensor_ndim(const NVTETensor tensor) { +NVTEShape nvte_tensor_columnwise_shape(const NVTETensor tensor) { + if (tensor == nullptr) { + NVTE_ERROR("Invalid tensor"); + } const auto &t = *reinterpret_cast(tensor); - return t.data.shape.size(); + NVTEShape ret; + ret.data = t.columnwise_data.shape.data(); + ret.ndim = t.columnwise_data.shape.size(); + return ret; } +size_t nvte_tensor_ndims(const NVTETensor tensor) { return nvte_tensor_shape(tensor).ndim; } + size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) { - const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim); - return t.data.shape[dim]; + const auto &shape = nvte_tensor_shape(tensor); + NVTE_CHECK(0 <= dim && dim < shape.ndim, "Attempted to access index ", dim, + " in a shape array with ", shape.ndim, " entries"); + return shape.data[dim]; } size_t nvte_tensor_numel(const NVTETensor tensor) { - const auto &t = *reinterpret_cast(tensor); + const auto &shape = nvte_tensor_shape(tensor); size_t numel = 1; - for (auto size : t.data.shape) { - numel *= size; + for (size_t i = 0; i < shape.ndim; i++) { + numel *= shape.data[i]; } return numel; } size_t nvte_tensor_element_size(const NVTETensor tensor) { + if (tensor == nullptr) return sizeof(float); const auto &t = *reinterpret_cast(tensor); return transformer_engine::typeToSize(t.data.dtype); } void *nvte_tensor_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); return t.data.dptr; } +void *nvte_tensor_columnwise_data(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_data.dptr; +} + float *nvte_tensor_amax(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32, "Tensor's amax must have Float32 type!"); @@ -131,6 +316,7 @@ float *nvte_tensor_amax(const NVTETensor tensor) { } float *nvte_tensor_scale(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32, "Tensor's scale must have Float32 type!"); @@ -138,12 +324,83 @@ float *nvte_tensor_scale(const NVTETensor tensor) { } float *nvte_tensor_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; const auto &t = *reinterpret_cast(tensor); - NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32, - "Tensor's inverse of scale must have Float32 type!"); return reinterpret_cast(t.scale_inv.dptr); } +void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { + if (tensor == nullptr) return nullptr; + const auto &t = *reinterpret_cast(tensor); + return t.columnwise_scale_inv.dptr; +} + +NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { + if (tensor == nullptr) return {nullptr, 0}; + const auto &t = *reinterpret_cast(tensor); + NVTEShape ret; + ret.data = t.scale_inv.shape.data(); + ret.ndim = t.scale_inv.shape.size(); + return ret; +} + +void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, + const NVTEBasicTensor *param) { + NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL."); + NVTE_CHECK(*tensor != nullptr, "Tensor is not allocated."); + auto &t = *reinterpret_cast(*tensor); + switch (param_name) { + case kNVTERowwiseData: + t.data = *param; + break; + case kNVTEColumnwiseData: + t.columnwise_data = *param; + break; + case kNVTEScale: + t.scale = *param; + break; + case kNVTEAmax: + t.amax = *param; + break; + case kNVTERowwiseScaleInv: + t.scale_inv = *param; + break; + case kNVTEColumnwiseScaleInv: + t.columnwise_scale_inv = *param; + break; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { + if (tensor == nullptr) { + return {nullptr, kNVTEFloat32, {nullptr, 0}}; + } + const auto &t = *reinterpret_cast(tensor); + switch (param_name) { + case kNVTERowwiseData: + return t.data; + case kNVTEColumnwiseData: + return t.columnwise_data; + case kNVTEScale: + return t.scale; + case kNVTEAmax: + return t.amax; + case kNVTERowwiseScaleInv: + return t.scale_inv; + case kNVTEColumnwiseScaleInv: + return t.columnwise_scale_inv; + default: + NVTE_ERROR("Unknown tensor parameter!"); + } +} + +NVTEScalingMode nvte_tensor_scaling_mode(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.scaling_mode; +} + void nvte_tensor_pack_create(NVTETensorPack *pack) { for (int i = 0; i < pack->MAX_SIZE; i++) { pack->tensors[i] = reinterpret_cast(new transformer_engine::Tensor); @@ -156,3 +413,92 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) { delete t; } } + +void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { + const auto &t = *reinterpret_cast(tensor); + // Zero out tensor data if allocated + if (t.data.dptr != nullptr) { + size_t size_in_bytes = nvte_tensor_element_size(tensor) * nvte_tensor_numel(tensor); + cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); + } + // Set amax to 0 if allocated + if (t.amax.dptr != nullptr) { + cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); + } +} + +NVTEQuantizationConfig nvte_create_quantization_config() { + return new transformer_engine::QuantizationConfig; +} + +void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, void *buf, + size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, + "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for quantization config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEQuantizationConfigForcePow2Scales: + std::memcpy(buf, &config_.force_pow_2_scales, attr_size); + break; + case kNVTEQuantizationConfigAmaxEpsilon: + std::memcpy(buf, &config_.amax_epsilon, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, + NVTEQuantizationConfigAttribute attr, const void *buf, + size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes, + "Invalid NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for quantization config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEQuantizationConfigForcePow2Scales: + std::memcpy(&config_.force_pow_2_scales, buf, attr_size); + break; + case kNVTEQuantizationConfigAmaxEpsilon: + std::memcpy(&config_.amax_epsilon, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index dd45d0a668..7f3b9fb302 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,12 +10,12 @@ #include -#include "../common.h" #include "../util/rtc.h" #include "../util/string.h" #include "../utils.cuh" +#include "cast_transpose.h" -namespace transformer_engine { +namespace transformer_engine::detail { namespace { @@ -217,159 +217,146 @@ __global__ void __launch_bounds__(block_size) cast_transpose_general_kernel( } // namespace -void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *cast_output_, - Tensor *transposed_output_, cudaStream_t stream) { - Tensor &cast_output = *cast_output_; - Tensor &transposed_output = *transposed_output_; +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream) { + Tensor &output = *output_; - // Check no-op flag - if (noop.data.dptr != nullptr) { - size_t numel = 1; - for (const auto &dim : noop.data.shape) { - numel *= dim; - } - NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); - NVTE_CHECK(noop.data.dtype == DType::kFloat32); - NVTE_CHECK(noop.data.dptr != nullptr); - } - - // Check tensor dims + CheckNoopTensor(noop, "cast_transpose_noop"); CheckInputTensor(input, "cast_transpose_input"); - CheckOutputTensor(cast_output, "cast_output"); - CheckOutputTensor(transposed_output, "transposed_output"); - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); - NVTE_CHECK(transposed_output.data.shape.size() == 2, "Transposed output must have 2 dimensions."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; - NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); - NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); - NVTE_CHECK(transposed_output.data.shape[0] == row_length, - "Wrong dimension of transposed output."); - NVTE_CHECK(transposed_output.data.shape[1] == num_rows, "Wrong dimension of transposed output."); - - // Check tensor pointers - NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); - NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated."); - NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated."); - NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype, + CheckOutputTensor(output, "cast_transpose_output"); + + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output.has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output.has_columnwise_data(), "Output columnwise is not allocated"); + + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output.data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output.data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); + NVTE_CHECK(output.flat_first_dim() == num_rows && output.flat_last_dim() == row_length, + "Invalid output dimensions (expected ", std::vector{num_rows, row_length}, + ", got ", std::vector{output.flat_first_dim(), output.flat_last_dim()}, ")"); + + // Check that cast and transposed output data matches + NVTE_CHECK(output.data.dtype == output.columnwise_data.dtype, "Cast and transposed output types must match."); - NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr, - "Cast and transposed outputs need to share amax tensor."); - NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, - "Cast and transposed outputs need to share scale tensor."); - NVTE_CHECK(cast_output.scale_inv.dptr == transposed_output.scale_inv.dptr, + NVTE_CHECK(output.scale_inv.dptr == output.columnwise_scale_inv.dptr, "Cast and transposed outputs need to share scale-inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output.data.dtype, OutputType, - constexpr const char *itype_name = TypeInfo::name; - constexpr const char *otype_name = TypeInfo::name; - constexpr size_t itype_size = sizeof(InputType); - constexpr size_t otype_size = sizeof(OutputType); - - // Choose between runtime-compiled or statically-compiled kernel - const bool aligned = - (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); - if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel - // Pick kernel config - std::vector kernel_configs; - kernel_configs.reserve(16); - const size_t sm_count = static_cast(cuda::sm_count()); - auto add_config = [&](size_t load_size, size_t store_size) { - kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, - store_size, sm_count); - }; - add_config(8, 8); - add_config(4, 8); - add_config(8, 4); - add_config(4, 4); - add_config(2, 8); - add_config(8, 2); - add_config(2, 4); - add_config(4, 2); - add_config(2, 2); - add_config(1, 8); - add_config(8, 1); - add_config(1, 4); - add_config(4, 1); - add_config(1, 2); - add_config(2, 1); - add_config(1, 1); - const auto &kernel_config = - *std::min_element(kernel_configs.begin(), kernel_configs.end()); - NVTE_CHECK(kernel_config.valid, "invalid kernel config"); - const size_t load_size = kernel_config.load_size; - const size_t store_size = kernel_config.store_size; - const size_t num_blocks = kernel_config.num_blocks; - - // Compile NVRTC kernel if needed and launch - auto &rtc_manager = rtc::KernelManager::instance(); - const std::string kernel_label = concat_strings( - "cast_transpose" - ",itype=", - itype_name, ",otype=", otype_name, ",load_size=", load_size, - ",store_size=", store_size); - if (!rtc_manager.is_compiled(kernel_label)) { - std::string code = string_code_transpose_rtc_cast_transpose_cu; - code = regex_replace(code, "__ITYPE__", itype_name); - code = regex_replace(code, "__OTYPE__", otype_name); - code = regex_replace(code, "__LOAD_SIZE__", load_size); - code = regex_replace(code, "__STORE_SIZE__", store_size); - code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); - code = regex_replace(code, "__BLOCK_SIZE__", block_size); - rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, - "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + output.dtype(), OutputType, + if (is_tensor_scaling(output.scaling_mode)) { + // delayed scaling and current scaling are two variants of per-tensor scaling + + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + constexpr size_t itype_size = sizeof(InputType); + constexpr size_t otype_size = sizeof(OutputType); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = + (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + const size_t sm_count = static_cast(cuda::sm_count()); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, itype_size, otype_size, load_size, + store_size, sm_count); + }; + add_config(8, 8); + add_config(4, 8); + add_config(8, 4); + add_config(4, 4); + add_config(2, 8); + add_config(8, 2); + add_config(2, 4); + add_config(4, 2); + add_config(2, 2); + add_config(1, 8); + add_config(8, 1); + add_config(1, 4); + add_config(4, 1); + add_config(1, 2); + add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = + *std::min_element(kernel_configs.begin(), kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings( + "cast_transpose" + ",itype=", + itype_name, ",otype=", otype_name, ",load_size=", load_size, + ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_cu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, "cast_transpose_optimized_kernel", code, + "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + } + rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; + const int num_blocks = + (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); + + cast_transpose_general_kernel + <<>>( + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(output.data.dptr), + static_cast(output.columnwise_data.dptr), + static_cast(output.scale.dptr), + static_cast(output.amax.dptr), + static_cast(output.scale_inv.dptr), row_length, num_rows); } - rtc_manager.launch(kernel_label, num_blocks, block_size, 0, stream, - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - static_cast(cast_output.scale_inv.dptr), row_length, - num_rows); - } else { // Statically-compiled general kernel - constexpr size_t load_size = 4; - constexpr size_t store_size = 4; - constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; - constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; - const int num_blocks = - (DIVUP(row_length, row_tile_size) * DIVUP(num_rows, col_tile_size)); - cast_transpose_general_kernel - <<>>( - static_cast(input.data.dptr), - reinterpret_cast(noop.data.dptr), - static_cast(cast_output.data.dptr), - static_cast(transposed_output.data.dptr), - static_cast(cast_output.scale.dptr), - static_cast(cast_output.amax.dptr), - static_cast(cast_output.scale_inv.dptr), row_length, num_rows); + } else { + NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); }); // NOLINT(*) ); // NOLINT(*) } -} // namespace transformer_engine +} // namespace transformer_engine::detail -void nvte_cast_transpose(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, cudaStream_t stream) { +void nvte_cast_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose); using namespace transformer_engine; auto noop = Tensor(); - cast_transpose(*reinterpret_cast(input), noop, - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), stream); + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), noop, + reinterpret_cast(output), stream); } -void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, - NVTETensor cast_output, NVTETensor transposed_output, +void nvte_cast_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_with_noop); using namespace transformer_engine; - cast_transpose(*reinterpret_cast(input), *reinterpret_cast(noop), - reinterpret_cast(cast_output), - reinterpret_cast(transposed_output), stream); + transformer_engine::detail::cast_transpose(*reinterpret_cast(input), + *reinterpret_cast(noop), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h new file mode 100644 index 0000000000..ed9bd5f5f7 --- /dev/null +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -0,0 +1,28 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ +#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ + +#include "../common.h" + +namespace transformer_engine::detail { + +void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream); + +template +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream); + +template +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, + cudaStream_t stream); + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index a8361d57ea..8347e117ce 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -8,18 +8,19 @@ #include #include -#include +#include +#include #include -#include "../common.h" #include "../util/math.h" #include "../util/rtc.h" #include "../util/string.h" #include "../utils.cuh" +#include "cast_transpose.h" namespace transformer_engine { -namespace { +namespace detail { // String with RTC kernel implementation #include "string_code_transpose_rtc_cast_transpose_fusion_cu.h" @@ -177,16 +178,31 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out], void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /*cast*/ Tensor *workspace, const int nvec_out) { - const size_t row_length = cast_output.data.shape[1]; - const size_t num_rows = cast_output.data.shape[0]; + const size_t row_length = cast_output.flat_last_dim(); + const size_t num_rows = cast_output.flat_first_dim(); const size_t tile_size_y = (nvec_out * THREADS_PER_WARP); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); - workspace->data.shape = {num_rows_partial_dbias, row_length}; - workspace->data.dtype = DType::kFloat32; + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } } template @@ -248,11 +264,13 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reduce_dbias_num_rows); } -template +template __global__ void __launch_bounds__(cast_transpose_num_threads) cast_transpose_fused_kernel_notaligned(const Param param, const size_t row_length, const size_t num_rows, const size_t num_tiles) { + static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); using IType = typename Param::InputType; using IType2 = typename Param::InputType2; using OType = typename Param::OutputType; @@ -373,6 +391,8 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) if constexpr (IS_DACT) { after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]) * OP(act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + after_dact[j].data.elt[k] = OP(in[current_in ^ 1][j].data.elt[k], {}); } else { after_dact[j].data.elt[k] = CType(in[current_in ^ 1][j].data.elt[k]); } @@ -449,78 +469,96 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) } static const char *ActTypeToString[] = { - "NoAct", // 0 - "Sigmoid", // 1 - "GeLU", // 2 - "QGeLU", // 3 - "SiLU", // 4 - "ReLU", // 5 - "SReLU" // 6 + "none", // 0 + "sigmoid", // 1 + "dsigmoid", // 2 + "gelu", // 3 + "dgelu", // 4 + "qgelu", // 5 + "dqgelu", // 6 + "silu", // 7 + "dsilu", // 8 + "relu", // 9 + "drelu", // 10 + "srelu", // 11 + "dsrelu" // 12 }; template -int get_dactivation_type() { - if (OP == &sigmoid) { - return 1; - } else if (OP == &dgelu) { - return 2; - } else if (OP == &dqgelu) { - return 3; - } else if (OP == &dsilu) { - return 4; - } else if (OP == &drelu) { - return 5; - } else if (OP == &dsrelu) { - return 6; - } else { - return 0; +constexpr int get_activation_type() { + constexpr decltype(OP) ActivationList[] = { + nullptr, // 0 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 + }; +#pragma unroll + for (int i = 0; i < sizeof(ActivationList) / sizeof(ActivationList[0]); ++i) { + if (OP == ActivationList[i]) { + return i; + } } + return 0; } -template -void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor *cast_output, - Tensor *transposed_output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (workspace->data.dptr != nullptr) { +void cast_transpose_fused(const Tensor &input, const Tensor *act_input, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + // Check tensors, unless querying dbias workspace + if (!IS_DBIAS || workspace->data.dptr != nullptr) { CheckInputTensor(input, "cast_transpose_fused_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - if constexpr (IS_DBIAS) CheckOutputTensor(*dbias, "dbias"); - if constexpr (IS_DACT) CheckInputTensor(act_input, "act_input"); + CheckOutputTensor(*output, "output"); + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr && dbias->has_data()); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr && act_input->has_data()); + CheckInputTensor(*act_input, "act_input"); + } } - NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - NVTE_CHECK(input.data.shape == cast_output->data.shape, - "Input and C output must have the same shape."); - const size_t row_length = input.data.shape[1]; - const size_t num_rows = input.data.shape[0]; + // Check that inputs and outputs are available + NVTE_CHECK(input.has_data(), "Input is not allocated"); + NVTE_CHECK(output->has_data(), "Output rowwise data is not allocated"); + NVTE_CHECK(output->has_columnwise_data(), "Output columnwise data is not allocated"); - NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + // Flatten tensor to 2D + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes do not match (input=", input.data.shape, + ", output=", output->data.shape); + const size_t row_length = input.flat_last_dim(); + const size_t num_rows = input.flat_first_dim(); - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); + // Check that cast and transposed output data matches + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, + "Cast and transposed output types must match."); + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, + "Cast and transposed outputs need to share scale-inverse tensor."); if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{row_length}, "Wrong shape of DBias."); } if constexpr (IS_DACT) { - NVTE_CHECK(input.data.dtype == act_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input.data.shape, "Shapes of both inputs must match."); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); } TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output->data.dtype, OutputType, using InputType2 = InputType; + output->dtype(), OutputType, using InputType2 = InputType; using Param = CTDBiasDActParam; constexpr int itype_size = sizeof(InputType); @@ -584,8 +622,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * if (!jit_compiled) { num_blocks = DIVUP(num_tiles * n_warps_per_tile, n_warps_per_block); } if constexpr (IS_DBIAS) { + // Check workspace size + populate_cast_transpose_dbias_workspace_config(*output, workspace, nvec_out); if (workspace->data.dptr == nullptr) { - populate_cast_transpose_dbias_workspace_config(*cast_output, workspace, nvec_out); return; } } @@ -631,15 +670,15 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * Param param; param.input = reinterpret_cast(input.data.dptr); - param.output_c = reinterpret_cast(cast_output->data.dptr); - param.output_t = reinterpret_cast(transposed_output->data.dptr); - param.scale_ptr = reinterpret_cast(transposed_output->scale.dptr); - param.amax = reinterpret_cast(transposed_output->amax.dptr); - param.scale_inv = reinterpret_cast(cast_output->scale_inv.dptr); + param.output_c = reinterpret_cast(output->data.dptr); + param.output_t = reinterpret_cast(output->columnwise_data.dptr); + param.scale_ptr = reinterpret_cast(output->scale.dptr); + param.amax = reinterpret_cast(output->amax.dptr); + param.scale_inv = reinterpret_cast(output->scale_inv.dptr); if constexpr (IS_DBIAS) { param.workspace = reinterpret_cast(workspace->data.dptr); } if constexpr (IS_DACT) { - param.act_input = reinterpret_cast(act_input.data.dptr); + param.act_input = reinterpret_cast(act_input->data.dptr); } // Runtime-compiled tuned kernel @@ -648,9 +687,9 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * constexpr const char *itype2_name = TypeInfo::name; constexpr const char *otype_name = TypeInfo::name; - int dActType = 0; - if constexpr (IS_DACT) { - dActType = get_dactivation_type(); + int actType = 0; + if constexpr (IS_DACT || IS_ACT) { + actType = get_activation_type(); } // Compile NVRTC kernel if needed and launch @@ -660,7 +699,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * ",itype=", itype_name, ",itype2=", itype2_name, ",otype=", otype_name, ",load_size=", load_size, ",store_size=", store_size, ",IS_DBIAS=", IS_DBIAS, - ",IS_DACT=", IS_DACT, ",dactivationType=", ActTypeToString[dActType]); + ",IS_DACT=", IS_DACT, ",IS_ACT=", IS_ACT, + ",activationType=", ActTypeToString[actType]); if (!rtc_manager.is_compiled(kernel_label)) { std::string code = string_code_transpose_rtc_cast_transpose_fusion_cu; @@ -673,7 +713,8 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * code = regex_replace(code, "__BLOCK_SIZE__", cast_transpose_num_threads); code = regex_replace(code, "__IS_DBIAS__", IS_DBIAS); code = regex_replace(code, "__IS_DACT__", IS_DACT); - code = regex_replace(code, "__DACTIVATION_TYPE__", dActType); + code = regex_replace(code, "__IS_ACT__", IS_ACT); + code = regex_replace(code, "__ACTIVATION_TYPE__", actType); rtc_manager.compile( kernel_label, "cast_transpose_fusion_kernel_optimized", code, @@ -695,11 +736,11 @@ void cast_transpose_fused(const Tensor &input, const Tensor &act_input, Tensor * NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); cudaFuncSetAttribute( - cast_transpose_fused_kernel_notaligned, + cast_transpose_fused_kernel_notaligned, cudaFuncAttributePreferredSharedMemoryCarveout, 100); - cast_transpose_fused_kernel_notaligned + cast_transpose_fused_kernel_notaligned <<>>( param, row_length, num_rows, num_tiles); } @@ -1101,43 +1142,39 @@ __global__ void __launch_bounds__(cast_transpose_num_threads) template -void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, - Tensor *cast_output, Tensor *transposed_output, +void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_input, Tensor *output, cudaStream_t stream) { CheckInputTensor(input, "dgated_act_cast_transpose_input"); CheckInputTensor(gated_act_input, "dgated_act_cast_transpose_gated_act_input"); - CheckOutputTensor(*cast_output, "dgated_act_cast_transpose_cast_output"); - CheckOutputTensor(*transposed_output, "dgated_act_cast_transpose_transposed_output"); + CheckOutputTensor(*output, "dgated_act_cast_transpose_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(gated_act_input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); + NVTE_CHECK(output->has_data() && output->has_columnwise_data(), + "Both rowwise and columnwise data need to be allocated."); + NVTE_CHECK(output->data.shape.size() == 2, "C output must have 2 dimensions."); + NVTE_CHECK(output->columnwise_data.shape.size() == 2, "T output must have 2 dimensions."); const size_t row_length = input.data.shape[1]; const size_t num_rows = input.data.shape[0]; NVTE_CHECK(gated_act_input.data.shape[0] == num_rows, "Wrong dimension of output."); NVTE_CHECK(gated_act_input.data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[0] == num_rows, "Wrong dimension of output."); - NVTE_CHECK(cast_output->data.shape[1] == row_length * 2, "Wrong dimension of output."); - NVTE_CHECK(transposed_output->data.shape[0] == row_length * 2, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); + NVTE_CHECK(output->data.shape[0] == num_rows, "Wrong dimension of output."); + NVTE_CHECK(output->data.shape[1] == row_length * 2, "Wrong dimension of output."); + NVTE_CHECK(output->columnwise_data.shape[0] == row_length * 2, "Wrong dimension of T output."); + NVTE_CHECK(output->columnwise_data.shape[1] == num_rows, "Wrong dimension of T output."); NVTE_CHECK(input.data.dtype == gated_act_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, + NVTE_CHECK(output->data.dtype == output->columnwise_data.dtype, "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - NVTE_CHECK(cast_output->scale_inv.dptr == transposed_output->scale_inv.dptr, + NVTE_CHECK(output->scale_inv.dptr == output->columnwise_scale_inv.dptr, "C and T outputs need to share scale inverse tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, InputType, + input.dtype(), InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - cast_output->data.dtype, OutputType, using InputType2 = InputType; + output->dtype(), OutputType, using InputType2 = InputType; /* dact fusion kernel uses more registers */ constexpr int desired_load_size_dact = 4; constexpr int desired_store_size_dact = 4; constexpr int itype_size = sizeof(InputType); @@ -1168,11 +1205,11 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu <<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); } else { cudaFuncSetAttribute( @@ -1184,194 +1221,193 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu <<>>( reinterpret_cast(input.data.dptr), reinterpret_cast(gated_act_input.data.dptr), - reinterpret_cast(cast_output->data.dptr), - reinterpret_cast(transposed_output->data.dptr), - reinterpret_cast(cast_output->scale.dptr), - reinterpret_cast(cast_output->amax.dptr), - reinterpret_cast(cast_output->scale_inv.dptr), row_length, num_rows, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); }); // NOLINT(*) ); // NOLINT(*) } -} // namespace + +// Explicit template instantiation +template void cast_transpose_fused( + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); +#define NVTE_INSTANTIATE_ACTIVATION(op) \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); \ + template void cast_transpose_fused>( \ + const Tensor &, const Tensor *, Tensor *, Tensor *, Tensor *, cudaStream_t); +NVTE_INSTANTIATE_ACTIVATION(relu); +NVTE_INSTANTIATE_ACTIVATION(srelu); +NVTE_INSTANTIATE_ACTIVATION(gelu); +NVTE_INSTANTIATE_ACTIVATION(qgelu); +NVTE_INSTANTIATE_ACTIVATION(silu); +#undef NVTE_INSTANTIATE_ACTIVATION + +} // namespace detail } // namespace transformer_engine using ComputeType = typename transformer_engine::fp32; -void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor cast_output, - NVTETensor transposed_output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { +void nvte_cast_transpose_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; constexpr const NVTETensor activation_input = nullptr; - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(activation_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused( + *reinterpret_cast(input), reinterpret_cast(activation_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dgelu(const NVTETensor input, const NVTETensor act_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dgelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dgelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(act_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dsilu(const NVTETensor input, const NVTETensor silu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dsilu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dsilu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(silu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(silu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_drelu(const NVTETensor input, const NVTETensor relu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_drelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &drelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(relu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(relu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor srelu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dsrelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dsrelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(srelu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(srelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_cast_transpose_dbias_dqgelu(const NVTETensor input, const NVTETensor qgelu_input, - NVTETensor cast_output, NVTETensor transposed_output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_cast_transpose_dbias_dqgelu); using namespace transformer_engine; + using namespace transformer_engine::detail; constexpr bool IS_DBIAS = true; constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; - constexpr auto dActivation = &dqgelu; - - cast_transpose_fused( - *reinterpret_cast(input), *reinterpret_cast(qgelu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - reinterpret_cast(dbias), reinterpret_cast(workspace), stream); + cast_transpose_fused>( + *reinterpret_cast(input), reinterpret_cast(qgelu_input), + reinterpret_cast(output), reinterpret_cast(dbias), + reinterpret_cast(workspace), stream); } void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dgelu; - constexpr auto Activation = &gelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, gelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor swiglu_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dsilu; - constexpr auto Activation = &silu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, silu>( *reinterpret_cast(input), *reinterpret_cast(swiglu_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &drelu; - constexpr auto Activation = &relu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, relu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dsrelu; - constexpr auto Activation = &srelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, srelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor gated_act_input, - NVTETensor cast_output, NVTETensor transposed_output, - cudaStream_t stream) { + NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu_cast_transpose); using namespace transformer_engine; + using namespace transformer_engine::detail; - constexpr auto dActivation = &dqgelu; - constexpr auto Activation = &qgelu; - - dgated_act_cast_transpose( + dgated_act_cast_transpose, qgelu>( *reinterpret_cast(input), *reinterpret_cast(gated_act_input), - reinterpret_cast(cast_output), reinterpret_cast(transposed_output), - stream); + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 4026016519..5cf316f45e 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -195,42 +195,44 @@ __global__ void __launch_bounds__(threads_per_block) } // namespace -void multi_cast_transpose(const std::vector input_list, - std::vector cast_output_list, - std::vector transposed_output_list, cudaStream_t stream) { +void multi_cast_transpose(const std::vector input_list, std::vector output_list, + cudaStream_t stream) { // Check that number of tensors is valid - NVTE_CHECK(cast_output_list.size() == input_list.size(), - "Number of input and C output tensors must match"); - NVTE_CHECK(transposed_output_list.size() == input_list.size(), - "Number of input and T output tensors must match"); + NVTE_CHECK(output_list.size() == input_list.size(), + "Number of input and output tensors must match"); if (input_list.empty()) { return; } // Check that tensor properties are valid DType itype = input_list[0]->data.dtype; - DType otype = cast_output_list[0]->data.dtype; + DType otype = output_list[0]->dtype(); for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { const auto& input = *input_list[tensor_id]; - const auto& cast_output = *cast_output_list[tensor_id]; - const auto& transposed_output = *transposed_output_list[tensor_id]; + const auto& output = *output_list[tensor_id]; CheckInputTensor(input, "multi_cast_transpose_input_" + std::to_string(tensor_id)); - CheckInputTensor(cast_output, "multi_cast_output_" + std::to_string(tensor_id)); - CheckInputTensor(transposed_output, "multi_transpose_output_" + std::to_string(tensor_id)); + CheckInputTensor(output, "multi_cast_transpose_output_" + std::to_string(tensor_id)); + //std::cout << *static_cast(output.data.dptr) << std::endl; + NVTE_CHECK(output.has_data() && output.has_columnwise_data(), + "Both rowwise and columnwise output data needs to be allocated."); NVTE_CHECK(input.data.dtype == itype, "Input tensor types do not match."); - NVTE_CHECK(cast_output.data.dtype == otype, "C output tensor types do not match."); - NVTE_CHECK(transposed_output.data.dtype == otype, "T output tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "C output tensor types do not match."); + NVTE_CHECK(output.data.dtype == otype, "T output tensor types do not match."); - NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions."); - NVTE_CHECK(cast_output.data.shape == input.data.shape, - "C output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape.size() == 2, - "T output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape[0] == input.data.shape[1], - "T output tensor shape does not match input tensor."); - NVTE_CHECK(transposed_output.data.shape[1] == input.data.shape[0], - "T output tensor shape does not match input tensor."); + NVTE_CHECK(input.data.shape.size() == 2, "Input tensor must have 2 dimensions, but shape is ", + input.data.shape); + NVTE_CHECK(output.data.shape == input.data.shape, "C output tensor shape ", output.data.shape, + "does not match input tensor shape ", input.data.shape); + NVTE_CHECK(output.columnwise_data.shape.size() == 2, "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[0] == input.data.shape[1], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); + NVTE_CHECK(output.columnwise_data.shape[1] == input.data.shape[0], "T output tensor shape ", + output.columnwise_data.shape, "does not match input tensor shape ", + input.data.shape); } // Input matrices are divided into tiles @@ -287,11 +289,11 @@ void multi_cast_transpose(const std::vector input_list, // Add tensor to kernel argument struct const int pos = kernel_args.num_tensors; kernel_args.input_list[pos] = const_cast(input_list[tensor_id]->data.dptr); - kernel_args.output_c_list[pos] = cast_output_list[tensor_id]->data.dptr; - kernel_args.output_t_list[pos] = transposed_output_list[tensor_id]->data.dptr; - kernel_args.scale_list[pos] = cast_output_list[tensor_id]->scale.dptr; - kernel_args.amax_list[pos] = cast_output_list[tensor_id]->amax.dptr; - kernel_args.scale_inv_list[pos] = cast_output_list[tensor_id]->scale_inv.dptr; + kernel_args.output_c_list[pos] = output_list[tensor_id]->data.dptr; + kernel_args.output_t_list[pos] = output_list[tensor_id]->columnwise_data.dptr; + kernel_args.scale_list[pos] = output_list[tensor_id]->scale.dptr; + kernel_args.amax_list[pos] = output_list[tensor_id]->amax.dptr; + kernel_args.scale_inv_list[pos] = output_list[tensor_id]->scale_inv.dptr; kernel_args.num_rows_list[pos] = num_rows; kernel_args.row_length_list[pos] = row_length; kernel_args.block_range[pos + 1] = kernel_args.block_range[pos] + num_tiles; @@ -327,15 +329,13 @@ void multi_cast_transpose(const std::vector input_list, } // namespace transformer_engine void nvte_multi_cast_transpose(size_t num_tensors, const NVTETensor* input_list, - NVTETensor* cast_output_list, NVTETensor* transposed_output_list, - cudaStream_t stream) { + NVTETensor* output_list, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_cast_transpose); using namespace transformer_engine; - std::vector input_list_, cast_output_list_, transposed_output_list_; + std::vector input_list_, output_list_; for (size_t i = 0; i < num_tensors; ++i) { input_list_.push_back(reinterpret_cast(const_cast(input_list[i]))); - cast_output_list_.push_back(reinterpret_cast(cast_output_list[i])); - transposed_output_list_.push_back(reinterpret_cast(transposed_output_list[i])); + output_list_.push_back(reinterpret_cast(output_list[i])); } - multi_cast_transpose(input_list_, cast_output_list_, transposed_output_list_, stream); + multi_cast_transpose(input_list_, output_list_, stream); } diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu index 07244a42e9..952d70f38b 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu index 4ba1cb4c69..34359561aa 100644 --- a/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -22,7 +22,9 @@ constexpr size_t WARPS_PER_TILE = __WARPS_PER_TILE__; constexpr size_t BLOCK_SIZE = __BLOCK_SIZE__; constexpr bool IS_DBIAS = __IS_DBIAS__; constexpr bool IS_DACT = __IS_DACT__; -constexpr size_t DACT_TYPE = __DACTIVATION_TYPE__; +constexpr bool IS_ACT = __IS_ACT__; +static_assert(!(IS_DACT && IS_ACT), "forward and backward activation are mutually exclusive"); +constexpr size_t ACT_TYPE = __ACTIVATION_TYPE__; constexpr size_t NVEC_IN = LOAD_SIZE / sizeof(IType); constexpr size_t NVEC_OUT = STORE_SIZE / sizeof(OType); @@ -33,14 +35,20 @@ using OVec = Vec; using Param = CTDBiasDActParam; using OP = CType (*)(const CType, const Empty &); -constexpr OP Activation[] = { +constexpr OP ActivationList[] = { nullptr, // 0 - &dsigmoid, // 1 - &dgelu, // 2 - &dqgelu, // 3 - &dsilu, // 4 - &drelu, // 5 - &dsrelu // 6 + &sigmoid, // 1 + &dsigmoid, // 2 + &gelu, // 3 + &dgelu, // 4 + &qgelu, // 5 + &dqgelu, // 6 + &silu, // 7 + &dsilu, // 8 + &relu, // 9 + &drelu, // 10 + &srelu, // 11 + &dsrelu // 12 }; } // namespace @@ -175,7 +183,10 @@ __global__ void __launch_bounds__(BLOCK_SIZE) if constexpr (IS_DACT) { in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]) * - Activation[DACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + ActivationList[ACT_TYPE](act_in[current_in ^ 1][j].data.elt[k], {}); + } else if constexpr (IS_ACT) { + in_cast_fp32[j].data.elt[k] = + ActivationList[ACT_TYPE](in[current_in ^ 1][j].data.elt[k], {}); } else { in_cast_fp32[j].data.elt[k] = static_cast(in[current_in ^ 1][j].data.elt[k]); } diff --git a/transformer_engine/common/transpose/rtc/transpose.cu b/transformer_engine/common/transpose/rtc/transpose.cu index 09758698f6..6d05c68106 100644 --- a/transformer_engine/common/transpose/rtc/transpose.cu +++ b/transformer_engine/common/transpose/rtc/transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 5e8ef80ae4..26740a3837 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -205,17 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); - // Number of elements in tensor - auto numel = [](const Tensor &tensor) -> size_t { - size_t acc = 1; - for (const auto &dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; - if (noop.data.dptr != nullptr) { - NVTE_CHECK(numel(noop) == 1, "Expected 1 element, ", "but found ", numel(noop), "."); + NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), "."); NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dptr != nullptr); } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index c032371940..fba3710beb 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -8,8 +8,8 @@ #include #include -#include -#include +#include +#include #include "../common.h" #include "../utils.cuh" @@ -376,8 +376,24 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ const size_t num_rows_partial_dbias = DIVUP(num_rows, tile_size_y); - workspace->data.shape = {num_rows_partial_dbias, row_length}; - workspace->data.dtype = DType::kFloat32; + if (workspace->data.dptr == nullptr) { + // Set workspace size + workspace->data.shape = {num_rows_partial_dbias, row_length}; + workspace->data.dtype = DType::kFloat32; + } else { + // Check that workspace matches expected size + const size_t workspace_size = + std::accumulate(workspace->data.shape.begin(), workspace->data.shape.end(), 1, + std::multiplies()) * + typeToSize(workspace->data.dtype); + const size_t required_size = num_rows_partial_dbias * row_length * typeToSize(DType::kFloat32); + NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", + num_rows_partial_dbias, ",", row_length, "), found ())"); + NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", + num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), + "; found dims=", workspace->data.shape, + ", dtype=", typeToSize(workspace->data.dtype), ")"); + } } template @@ -426,10 +442,9 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor constexpr int nvec_in = desired_load_size / type_size; constexpr int nvec_out = desired_store_size / type_size; - if (workspace->data.dptr == nullptr) { - populate_transpose_dbias_workspace_config(input, workspace, nvec_out); - return; - } + // Check workspace size + populate_transpose_dbias_workspace_config(input, workspace, nvec_out); + if (workspace->data.dptr == nullptr) { return; } NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index dd03afd21b..22a50025df 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -1,91 +1,147 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#include +#include +#include #include +#include +#include +#include + #include "../common.h" +#include "../transpose/cast_transpose.h" #include "../util/vectorized_pointwise.h" #include "../utils.cuh" +#include "cast_kernels.cuh" +#include "dequantize_kernels.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" + +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize); + using namespace transformer_engine; -namespace transformer_engine { + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; -namespace detail { + detail::quantize_helper(input, grad, nullptr, output, + dbias, workspace, stream); +} -struct Empty {}; +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_noop); + using namespace transformer_engine; -__device__ inline fp32 identity(fp32 value, const Empty &) { return value; } + constexpr bool IS_DBIAS = false; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr NVTETensor dbias = nullptr; + constexpr NVTETensor workspace = nullptr; + constexpr const NVTETensor grad = nullptr; -struct DequantizeParam { - const fp32 *scale_inv; -}; + detail::quantize_helper(input, grad, noop, output, + dbias, workspace, stream); +} + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; -__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr bool IS_ACT = false; + constexpr const NVTETensor activation_input = nullptr; + + detail::quantize_helper( + activation_input, input, nullptr, output, dbias, workspace, stream); } -} // namespace detail - -void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision."); - - NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, - stream);); // NOLINT(*) - ); // NOLINT(*) +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} + +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -} // namespace transformer_engine +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); +} -void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_fp8_quantize); +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); using namespace transformer_engine; - fp8_quantize(*reinterpret_cast(input), reinterpret_cast(output), - stream); + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + constexpr bool IS_ACT = false; + + detail::quantize_helper>( + activation_input, input, nullptr, output, dbias, workspace, stream); } -void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_fp8_dequantize); +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; - fp8_dequantize(*reinterpret_cast(input), reinterpret_cast(output), - stream); + detail::dequantize_helper(*reinterpret_cast(input), + reinterpret_cast(output), stream); } diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh new file mode 100644 index 0000000000..e2240ba658 --- /dev/null +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -0,0 +1,1091 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_gated_kernels.cuh + * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ + +#include +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" + +namespace transformer_engine { + +template +__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { + return DIVUP(static_cast(N), static_cast(M)) * M; +} + +namespace gated_kernels { + +constexpr size_t ALIGNMENT_SIZE = 128; +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t out_gate_mem = buff_size_aligned_out; + constexpr size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + // uint64_t *mbar = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + if (next_it < ITERATIONS) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, {}) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + + extern __shared__ char dshmem_unaligned[]; + const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); + const uint64_t dshmem_aligned_as_uint = + DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; + char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + + const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_elems_total = BUFFERS_NUM * buff_elems; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_mem = out_act_mem + out_gate_mem; + + // const size_t in_transaction_size = grad_mem + in_mem; + const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + + OType *out_act_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_rowwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + OType *out_act_colwise_sh = out_act_rowwise_sh; + OType *out_gate_colwise_sh = out_gate_rowwise_sh; + + if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); + out_gate_colwise_sh = + reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); + } + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act_rowwise = + reinterpret_cast(&tensor_map_output_act_rowwise); + const uint64_t *TMAP_output_gate_rowwise = + reinterpret_cast(&tensor_map_output_gate_rowwise); + const uint64_t *TMAP_output_act_colwise = + reinterpret_cast(&tensor_map_output_act_colwise); + const uint64_t *TMAP_output_gate_colwise = + reinterpret_cast(&tensor_map_output_gate_colwise); + + __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + const bool is_master_thread = (threadIdx.x == 0); + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + + // Prefetch data of the first stage + if (is_master_thread) { + // Initiate bulk tensor copy + // Grad + if constexpr (IS_DGATED) { + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), + TMAP_grad_in, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + } + + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), + TMAP_in_act, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, + &mbar[0]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[0]); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const int buff = it % BUFFERS_NUM; + const int next_it = it + 1; + const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; + if (next_it < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_it % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + if constexpr (IS_DGATED) { + // Grad + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + } + // Act + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + // Gate + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_it]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; + OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; + OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; + OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; + + // Assuming one iteration covers exactly 32 rows + const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; + const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; + + float after_dact_reg[BUFFER_STAGES_NUM]; + float after_dgate_reg[BUFFER_STAGES_NUM]; + float thread_Y_mx_block_amax = 0.0f; + float thread_Y_mx_block_amax_gate = 0.0f; + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; + after_dgate_reg[stage] = act_x * grad_elt; + } else { + after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + } + + if constexpr (USE_ROWWISE_SCALING) { + if constexpr (IS_DGATED) { + // dgate + float amax = fabsf(after_dgate_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_gate_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dgate_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + float amax = fabsf(after_dact_reg[stage]); + const float mx_block_X_amax = warp_reduce_max_broadcast(amax); + const e8m0_t biased_exponent_X = + float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + + out_act_rowwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal_X * after_dact_reg[stage]); + + // Only single thread writes the computed scaling factor + if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; + const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent_X; + } + } + + if constexpr (USE_COLWISE_SCALING) { + __builtin_assume(thread_Y_mx_block_amax >= 0); + __builtin_assume(thread_Y_mx_block_amax_gate >= 0); + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); + if constexpr (IS_DGATED) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + } + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool row_out_of_bounds = (row_base >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + if constexpr (IS_DGATED) { + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax_gate = + fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_gate_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dgate_reg[stage]); + } + } + // Colwise max reduction of the amax element + if (tid_Y > 0) { + stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; + } + __syncthreads(); + if (tid_Y == 0) { +#pragma unroll + for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + } + stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax + } + __syncthreads(); + + const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + + // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section + if constexpr (!USE_ROWWISE_SCALING) { + __builtin_assume(mx_block_Y_amax >= 0); + } + + const e8m0_t biased_exponent = + float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + // Only single thread writes the computed scaling factor + // Also assuming one iteration covers exactly 32 rows + if ((tid_Y == 0) && !out_of_bounds) { + const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + } + +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + out_act_colwise_sh_curr[shmem_idx] = + static_cast(scale_reciprocal * after_dact_reg[stage]); + } + } // endif USE_COLWISE_SCALING + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + + // dGeLU + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_rowwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_rowwise_sh_curr)); + } + } + + // dGeLU + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_act_colwise_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_colwise_sh_curr)); + } + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, sizeof(OType)); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem); // + mbar_mem; + + cudaFuncSetAttribute( + cast_fp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_fp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + const bool USE_ROWWISE_SCALING = output->has_data(); + const bool USE_COLWISE_SCALING = output->has_columnwise_data(); + + if (USE_ROWWISE_SCALING) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (USE_COLWISE_SCALING) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + // TODO: Make more general + const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; + const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; + size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + e8m0_t *const scales_rowwise_ptr = + USE_ROWWISE_SCALING ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, sizeof(IType)); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, + sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, + sizeof(OType)); + } + + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + 0, sizeof(OType)); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, + rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, + cols, sizeof(OType)); + } + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + const size_t buff_size_aligned_out = + DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; + + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); + // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; + + const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; + + cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "gated_act_input"); + CheckOutputTensor(*output, "gated_act_output"); + NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); + NVTE_CHECK(output->data.shape.size() == 2, "Output must have 2 dimensions."); + NVTE_CHECK(input.data.shape[0] == output->data.shape[0], + "Input shape[0] must be equal to output shape[0]."); + NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2, + "Input shape[1] must be 2x larger than output shape[1]."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), output->data.shape[0], + output->data.shape[1], {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(grad, "dgated_act_grad"); + CheckInputTensor(input, "dgated_act_input"); + CheckOutputTensor(*output, "dgated_act_output"); + NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), + "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), + ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, + "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, + "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes must match. Input shape: ", input.data.shape, + ", output shape: ", output->data.shape, "."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + if (!is_fp8_dtype(output->data.dtype) || + is_delayed_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + cudaStream_t stream) { + checkCuDriverContext(stream); + constexpr bool allow_empty = false; + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", allow_empty); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + if constexpr (IS_DGATED) { + CheckInputTensor(grad, "grad"); + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); + NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); + } + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + bool is_fp8_rowwise_output = true; + bool is_fp8_colwise_output = true; + if (output->has_data()) { + is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + if (output->has_columnwise_data()) { + is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + + const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; + + if (is_delayed_tensor_scaling(output->scaling_mode)) { + if (use_tma_kernels) { + cast_fp8_gated(grad, gated_input, output, stream); + } else { + if constexpr (IS_DGATED) { + cast_dgated(grad, gated_input, output, stream); + } else { + cast_gated(gated_input, output, stream); + } + } + } else if (is_mxfp_scaling(output->scaling_mode)) { + if (use_tma_kernels) { + cast_mxfp8_gated(grad, gated_input, output, stream); + } else { + NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", + "by 32, got input of shape ", gated_input.data.shape); + } + } else { + NVTE_ERROR("Not supported scaling mode"); + } +} +} // namespace gated_kernels + +namespace detail { + +template +void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, + cudaStream_t stream) { + using namespace gated_kernels; + Tensor grad_empty_tensor; + const Tensor &grad_tensor = + IS_DGATED ? *(reinterpret_cast(grad)) : grad_empty_tensor; + const Tensor gated_input_tensor = *reinterpret_cast(gated_input); + Tensor *output_tensor = reinterpret_cast(output); + + if (is_supported_by_CC_100()) { + quantize_gated(grad_tensor, gated_input_tensor, + output_tensor, stream); + } else { + if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { + if constexpr (IS_DGATED) { + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + } else { + cast_gated(gated_input_tensor, output_tensor, stream); + } + } else { + // MX scaling + NVTE_ERROR("Not supported by the Arch < 10.0"); + } + } +} +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh new file mode 100644 index 0000000000..ba2890ada3 --- /dev/null +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -0,0 +1,1273 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file cast_kernels.cuh + * \brief CUDA kernels to cast to/from FP8/MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ + +#include +#include +#include +#include + +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +constexpr size_t MXFP8_CHUNK_DIM_Y = 64; +constexpr size_t MXFP8_CHUNK_DIM_X = 64; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; +constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; +constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; +constexpr size_t MXFP8_BUFFERS_NUM = 2; +constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported +constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 +constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = + MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 +constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = + MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 +constexpr size_t MXFP8_BUFF_STAGES_NUM = + MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 +constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 +static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + if (noop != nullptr && noop[0] == 1.0f) return; + } + + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = + SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = + SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; + const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; + const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; + const int dbias_rowwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; + const int dbias_colwise_offset_Y = blockIdx.y; + const int dbias_colwise_block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; + const int dbias_stride = cols; + + Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_rowwise[i].clear(); + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + partial_dbias_colwise[i] = 0; + } + } + } + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) + OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + + float block_amax = 0; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; +#pragma unroll + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = + scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = + scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, + &mbar[next_iter], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec act_in; + Vec out_c; + + const int iteration_scale_rowwise_offset_Y = + scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + +#pragma unroll + for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = (row >= rows); + + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + } + + float thread_amax = 0; + float in_compute[ELEMS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in.data.elt[j]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + if (!out_of_bounds) { + partial_dbias_rowwise[chunk_X].data.elt[j] += elt; + } + } + in_compute[j] = elt; + + if constexpr (IS_ACT || IS_DACT) { + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); + const e8m0_t biased_exponent = + float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); + + // Only single thread writes the computed scaling factor + if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { + const int global_scales_offset_Y = + iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; + const int global_scales_offset_X = + scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; + const int scale_idx = + global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = exp2f_rcp(biased_exponent); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); + } + out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); + } + } + + if constexpr (USE_COLWISE_SCALING) { + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + float in_compute[SCALE_DIM_Y]; + + float amax = 0; +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const size_t row = row_base + i; + const bool row_out_of_bounds = (row >= rows); + const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + + float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if (!out_of_bounds) { + partial_dbias_colwise[chunk_X] += elt; + } + } + in_compute[i] = elt; + if constexpr (IS_ACT || IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(amax >= 0); + block_amax = fmaxf(block_amax, amax); + + const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; + const int scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = exp2f_rcp(biased_exponent); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][i][tid_colwise_X] = + static_cast(in_compute[i] * block_scale_inverse); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (USE_ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); + } + if constexpr (USE_COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); + } + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + if constexpr (IS_DBIAS) { + if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { + constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; + constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; + constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; + __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; + + if (tid_rowwise_Y > 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + partial_dbias_rowwise[c].store_to( + &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); + } + } + __syncthreads(); + + if (tid_rowwise_Y == 0) { +#pragma unroll + for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { + Vec other_row_dbias; + const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + + const int left_bound = dbias_rowwise_offset_X; + const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + +#pragma unroll + for (int i = 0; i < Y; ++i) { + other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; + } + } + + // Vectorized store when all elements are inside the boundaries + if (right_bound < cols) { + partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); + } else if (left_bound < cols && right_bound >= cols) { + // Element-by-element store when some elements cross the boundaries + const int in_bound_elts_count = cols - left_bound; + partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, + in_bound_elts_count); + } + } + } + } else { +#pragma unroll + for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { + const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; + const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; + const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; + } + } + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const int thread_offset_Y = tid_Y; + const int thread_offset_X = tid_X; + + const int dbias_offset_Y = blockIdx.y + tid_Y; + const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const int dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size * (IS_DACT ? 2 : 1); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const int chunk_offset_Y = block_offset_Y; + const int chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const int buff = iter % FP8_BUFFERS_NUM; + const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const int next_buff = next_iter % FP8_BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const int stage_offset_Y = stage; + const int shmem_offset_y = thread_offset_Y + stage_offset_Y; + const int shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const int dbias_offset_X = my_column; + const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % SHMEM_BUFFERS; + const int it_offset = iter * SHMEM_DIM; + + const int next_iter = iter + 1; + const int next_buff = next_iter % SHMEM_BUFFERS; + const int next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, + const int cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + cudaStream_t stream) { + constexpr int reduce_dbias_store_bytes = 8; // stg.64 + constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); +} + +template +static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, sizeof(OType)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void mxfp8_quantize(const Tensor &input, const Tensor *act_input, + const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + checkCuDriverContext(stream); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + const auto &input_shape = input.data.shape; + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, + MXFP8_SHMEM_DIM_X, cols, 0, sizeof(IType)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(IType)); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, + cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, + sizeof(OType)); + } + + cast_mxfp8_2D_kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} + +namespace detail { + +using Empty = transformer_engine::Empty; + +__device__ inline float identity(float value, const Empty &) { return value; } + +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace { + +static bool is_full_tile_1D_tensor(const Tensor *const t) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + return isFullTile; +} + +bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr int TMA_bytes = 16; + const int alignment_requirement = TMA_bytes / typeToSize(t->dtype()); + return cols % alignment_requirement == 0; +} + +} // namespace + +// Supported by the Arch >= 10.0 +template +void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DBIAS && !IS_DACT) { + if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_gmem_alignment) && + is_aligned_tensor_data(*output, TMA_gmem_alignment)) { + // Aligned AND FP8 + cast_fp8_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_gmem_alignment) && + is_aligned_tensor_data(*output, TMA_gmem_alignment) && + is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { + // Aligned AND FP8 (+dAct) + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + cast_fp8_2D(input, act_input, output, dbias, workspace, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize(input, act_input, noop, output, dbias, + workspace, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + } +} + +// Supported by the Arch < 10.0 +template +void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, + cudaStream_t stream) { + if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + + " on GPU with compute capability < 10.0."); + } + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + } +} + +template +void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + fp8_quantize_arch_ge_100(input, act_input, noop, output, + dbias, workspace, stream); + } else { + // Supported by the Arch < 10.0 + fp8_quantize_arch_l_100(input, act_input, noop, output, + dbias, workspace, stream); + } +} + +namespace detail { + +template +void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + const Tensor *input_tensor; + const Tensor *activation_input_tensor; + if constexpr (IS_DBIAS || IS_DACT) { + // backward - input is incoming gradient + input_tensor = reinterpret_cast(grad); + activation_input_tensor = reinterpret_cast(input); + } else { + // forward = input is activation input + input_tensor = reinterpret_cast(input); + activation_input_tensor = nullptr; + } + auto output_tensor = reinterpret_cast(output); + auto dbias_tensor = reinterpret_cast(dbias); + auto workspace_tensor = reinterpret_cast(workspace); + const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); + + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + } else if (output_tensor->has_data()) { + fp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8_quantize( + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace detail +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 797a11c43c..48fb5d77d9 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -1,11 +1,9 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ -#include - #include #include "../common.h" @@ -13,84 +11,6 @@ namespace transformer_engine { -namespace { - -/*! \brief Wrapper class for a shared library - * - * \todo Windows support - */ -class Library { - public: - explicit Library(const char *filename) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support - NVTE_ERROR("Shared library initialization is not supported with Windows"); -#else - handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); - NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - ~Library() { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support -#else - if (handle_ != nullptr) { - dlclose(handle_); - } -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - Library(const Library &) = delete; // move-only - - Library(Library &&other) noexcept { swap(*this, other); } - - Library &operator=(Library other) noexcept { - // Copy-and-swap idiom - swap(*this, other); - return *this; - } - - friend void swap(Library &first, Library &second) noexcept; - - void *get() noexcept { return handle_; } - - const void *get() const noexcept { return handle_; } - - /*! \brief Get pointer corresponding to symbol in shared library */ - void *get_symbol(const char *symbol) { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - // TODO Windows support - NVTE_ERROR("Shared library initialization is not supported with Windows"); -#else - void *ptr = dlsym(handle_, symbol); - NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); - return ptr; -#endif // _WIN32 or _WIN64 or __WINDOW__ - } - - private: - void *handle_ = nullptr; -}; - -void swap(Library &first, Library &second) noexcept { - using std::swap; - swap(first.handle_, second.handle_); -} - -/*! \brief Lazily-initialized shared library for CUDA driver */ -Library &cuda_driver_lib() { -#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) - constexpr char lib_name[] = "nvcuda.dll"; -#else - constexpr char lib_name[] = "libcuda.so.1"; -#endif - static Library lib(lib_name); - return lib; -} - -} // namespace - namespace cuda_driver { void *get_symbol(const char *symbol) { diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 9dc1114580..dcad582210 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/cuda_nvml.cpp b/transformer_engine/common/util/cuda_nvml.cpp new file mode 100644 index 0000000000..0af9cd7411 --- /dev/null +++ b/transformer_engine/common/util/cuda_nvml.cpp @@ -0,0 +1,26 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "cuda_nvml.h" + +#include "shared_lib_wrapper.h" + +namespace transformer_engine { + +namespace cuda_nvml { + +/*! \brief Lazily-initialized shared library for CUDA NVML */ +Library &cuda_nvml_lib() { + constexpr char lib_name[] = "libnvidia-ml.so.1"; + static Library lib(lib_name); + return lib; +} + +void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); } + +} // namespace cuda_nvml + +} // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_nvml.h b/transformer_engine/common/util/cuda_nvml.h new file mode 100644 index 0000000000..14131a3cdd --- /dev/null +++ b/transformer_engine/common/util/cuda_nvml.h @@ -0,0 +1,69 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ + +#include + +#include + +#include "../common.h" +#include "../util/string.h" + +namespace transformer_engine { + +namespace cuda_nvml { + +/*! \brief Get pointer corresponding to symbol in CUDA NVML library */ +void *get_symbol(const char *symbol); + +/*! \brief Call function in CUDA NVML library + * + * The CUDA NVML library (libnvidia-ml.so.1 on Linux) may be different at + * compile-time and run-time. + * + * \param[in] symbol Function name + * \param[in] args Function arguments + */ +template +inline nvmlReturn_t call(const char *symbol, ArgTs... args) { + using FuncT = nvmlReturn_t(ArgTs...); + FuncT *func = reinterpret_cast(get_symbol(symbol)); + return (*func)(args...); +} + +/*! \brief Get NVML error string + * + * \param[in] rc NVML return code + */ +inline const char *get_nvml_error_string(nvmlReturn_t rc) { + using FuncT = const char *(nvmlReturn_t); + FuncT *func = reinterpret_cast(get_symbol("nvmlErrorString")); + return (*func)(rc); +} + +} // namespace cuda_nvml + +} // namespace transformer_engine + +#define NVTE_CHECK_CUDA_NVML(expr) \ + do { \ + const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \ + if (status_NVTE_CHECK_CUDA_NVML != NVML_SUCCESS) { \ + const char *desc_NVTE_CHECK_CUDA_NVML = \ + ::transformer_engine::cuda_nvml::get_nvml_error_string(status_NVTE_CHECK_CUDA_NVML); \ + NVTE_ERROR("NVML Error: ", desc_NVTE_CHECK_CUDA_NVML); \ + } \ + } while (false) + +#define VA_ARGS(...) , ##__VA_ARGS__ +#define NVTE_CALL_CHECK_CUDA_NVML(symbol, ...) \ + do { \ + NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \ + } while (false) + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_ diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 8d2e852988..8b6bb52397 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -81,6 +81,26 @@ int sm_count(int device_id) { return cache[device_id]; } +void stream_priority_range(int *low_priority, int *high_priority, int device_id) { + static std::vector> cache(num_devices()); + static std::vector flags(num_devices()); + if (device_id < 0) { + device_id = current_device(); + } + NVTE_CHECK(0 <= device_id && device_id < num_devices(), "invalid CUDA device ID"); + auto init = [&]() { + int ori_dev = current_device(); + if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(device_id)); + int min_pri, max_pri; + NVTE_CHECK_CUDA(cudaDeviceGetStreamPriorityRange(&min_pri, &max_pri)); + if (device_id != ori_dev) NVTE_CHECK_CUDA(cudaSetDevice(ori_dev)); + cache[device_id] = std::make_pair(min_pri, max_pri); + }; + std::call_once(flags[device_id], init); + *low_priority = cache[device_id].first; + *high_priority = cache[device_id].second; +} + bool supports_multicast(int device_id) { #if CUDART_VERSION >= 12010 // NOTE: This needs to be guarded at compile time because the diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index ea1ba84772..072eacd623 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -38,6 +38,16 @@ int sm_arch(int device_id = -1); */ int sm_count(int device_id = -1); +/* \brief Minimum and maximum stream priorities supported on device + * + * \param[in] device_id CUDA device (default is current device) + * + * \param[out] low_priority Lowest priority value on device. + * + * \param[out] high_priority Highest priority value on device. + */ +void stream_priority_range(int *low_priority, int *high_priority, int device_id = -1); + /* \brief CUDA Multicast support status for device * * \param[in] device_id CUDA device (default is current device) diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh new file mode 100644 index 0000000000..e529289640 --- /dev/null +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -0,0 +1,360 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_kernels.cuh + * \brief CUDA kernels to cast from MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ + +#include +#include +#include +#include + +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/vectorized_pointwise.h" +#include "../utils.cuh" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/transpose.h" + +namespace transformer_engine { + +namespace dequantization { + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 128; +constexpr size_t BUFFERS_NUM = 2; + +constexpr size_t ELEMS_PER_THREAD = 16; +constexpr size_t BUFFER_DIM_Y = 16; // only 32 is supported +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 16 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = CHUNK_DIM_X / ELEMS_PER_THREAD; // 8 = 128 / 16 +constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X; // 128 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(ITERATIONS >= 1); + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const e8m0_t *const scales_ptr, const size_t rows, const size_t cols, + const size_t scales_stride) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + + constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 + constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + + constexpr size_t THREADS_PER_SCALE_X_ROWWISE = + DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 + constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 + + const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; + const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; + const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + + const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; + const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; + // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + + const int thread_offset_Y = tid_rowwise_Y; + const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; + // const int thread_offset_X_colwise = tid_colwise_X; + + // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned + __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + + constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; + constexpr int transaction_size = shmem_buff_size; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + if (is_master_thread) { +// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_CHUNK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); + + int parity = 0; + constexpr int iteration_zero = 0; + constexpr int buffer_zero = 0; + if (is_master_thread) { + const int chunk_stage_offset_Y = chunk_offset_Y; + const int chunk_stage_offset_X = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[buffer_zero]), + reinterpret_cast(&tensor_map_input), chunk_stage_offset_X, + chunk_stage_offset_Y, &mbar[iteration_zero]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[iteration_zero], transaction_size); + + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[iteration_zero]); + } + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const int buff = iter % BUFFERS_NUM; + const int next_iter = iter + 1; + if (next_iter < ITERATIONS) { + if (is_master_thread) { + const int next_buff = next_iter % BUFFERS_NUM; + const int chunk_it_offset_y = chunk_offset_Y + next_iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + reinterpret_cast(&in_sh[next_buff]), + reinterpret_cast(&tensor_map_input), chunk_it_offset_x, + chunk_it_offset_y, &mbar[next_iter]); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(&mbar[next_iter], transaction_size); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(&mbar[next_iter]); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + + const int scale_offset_Y = + USE_ROWWISE_SCALING ? (scales_rowwise_chunk_offset_Y + iter * BUFFER_DIM_Y + tid_rowwise_Y) + : (scales_colwise_chunk_offset_Y + (iter * BUFFER_DIM_Y) / SCALE_DIM_Y); + + const int scale_offset_X = + USE_ROWWISE_SCALING + ? (scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE) + : (scales_colwise_chunk_offset_X + tid_colwise_X); + + const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; + const e8m0_t biased_exponent = scales_ptr[scale_idx]; + const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + + if constexpr (USE_ROWWISE_SCALING) { + Vec in; + Vec out; + + const int shmem_offset_y = thread_offset_Y; + const int shmem_offset_x = thread_offset_X_rowwise; + in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); + +#pragma unroll + for (int j = 0; j < ELEMS_PER_THREAD; ++j) { + out.data.elt[j] = static_cast(block_scale * static_cast(in.data.elt[j])); + } + out.store_to(&out_sh[buff][shmem_offset_y][shmem_offset_x]); + } else { +#pragma unroll + for (int i = 0; i < BUFFER_DIM_Y; ++i) { + const float elt = static_cast(in_sh[buff][i][tid_colwise_X]); + out_sh[buff][i][tid_colwise_X] = static_cast(block_scale * elt); + } + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + detail::DequantizeParam p; + p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} + +static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + bool use_rowwise_scaling = input.has_data(); + bool use_colwise_scaling = input.has_columnwise_data(); + checkCuDriverContext(stream); + + const auto &input_shape = input.data.shape; + NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions."); + + if (use_rowwise_scaling) { + NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + } + + if (use_colwise_scaling) { + NVTE_CHECK(input.has_columnwise_data(), "Cannot dequantize tensor without columnwise data."); + NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); + } + + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // TODO: Make more general + const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; + const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); + + const size_t unpadded_scales_Y_rowwise = rows; + const size_t unpadded_scales_X_rowwise = DIVUP(cols, scale_dim_X_rowwise); + const size_t unpadded_scales_Y_colwise = DIVUP(rows, scale_dim_Y_colwise); + const size_t unpadded_scales_X_colwise = cols; + + const size_t scales_Y_rowwise = + DIVUP(unpadded_scales_Y_rowwise, scale_tensor_alignment_Y_rowwise) * + scale_tensor_alignment_Y_rowwise; + const size_t scales_X_rowwise = + DIVUP(unpadded_scales_X_rowwise, scale_tensor_alignment_X_rowwise) * + scale_tensor_alignment_X_rowwise; + const size_t scales_Y_colwise = + DIVUP(unpadded_scales_Y_colwise, scale_tensor_alignment_Y_colwise) * + scale_tensor_alignment_Y_colwise; + const size_t scales_X_colwise = + DIVUP(unpadded_scales_X_colwise, scale_tensor_alignment_X_colwise) * + scale_tensor_alignment_X_colwise; + + const e8m0_t *const scales_ptr = + use_rowwise_scaling ? reinterpret_cast(input.scale_inv.dptr) + : reinterpret_cast(input.columnwise_scale_inv.dptr); + + const size_t scales_stride = use_rowwise_scaling ? scales_X_rowwise : scales_X_colwise; + + const SimpleTensor &input_data = use_rowwise_scaling ? input.data : input.columnwise_data; + + const dim3 block(THREADS_PER_CHUNK); + const dim3 grid(chunks_X, chunks_Y); + + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_Y_colwise, SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + scale_dim_X_rowwise, SCALE_DIM_X, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input_data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(IType)); + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols, 0, sizeof(OType)); + + dequantize_mxfp8_kernel + <<>>(tensor_map_input, tensor_map_output, scales_ptr, + rows, cols, scales_stride);); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace dequantization + +namespace detail { + +void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if (is_tensor_scaling(input.scaling_mode)) { + dequantization::fp8_dequantize(input, output, stream); + } else if (is_mxfp_scaling(input.scaling_mode)) { + if (is_supported_by_CC_100()) { + dequantization::mxfp8_dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace detail + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ diff --git a/transformer_engine/common/util/handle_manager.h b/transformer_engine/common/util/handle_manager.h new file mode 100644 index 0000000000..adb2f55587 --- /dev/null +++ b/transformer_engine/common/util/handle_manager.h @@ -0,0 +1,52 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ +#define TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ + +#include + +#include "cuda_runtime.h" +#include "logging.h" + +namespace transformer_engine::detail { + +template +class HandleManager { + public: + static HandleManager& Instance() { + static thread_local HandleManager instance; + return instance; + } + + Handle GetHandle() { + static thread_local std::vector initialized(handles_.size(), false); + const int device_id = cuda::current_device(); + NVTE_CHECK(0 <= device_id && device_id < handles_.size(), "invalid CUDA device ID"); + if (!initialized[device_id]) { + Create(&(handles_[device_id])); + initialized[device_id] = true; + } + return handles_[device_id]; + } + + ~HandleManager() { + if (Destroy != nullptr) { + for (auto& handle : handles_) { + Destroy(handle); + } + } + } + + private: + HandleManager() : handles_(cuda::num_devices(), nullptr) {} + + std::vector handles_ = nullptr; +}; + +} // namespace transformer_engine::detail + +#endif // TRANSFORMER_ENGINE_UTIL_HANDLE_MANAGER_H_ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 7972db3162..10a4ec28dc 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 26204cddb8..2d425d6753 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index 017d2e6a56..e90d2de558 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh new file mode 100644 index 0000000000..a22b930ecd --- /dev/null +++ b/transformer_engine/common/util/ptx.cuh @@ -0,0 +1,300 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file ptx.cuh + * \brief BW PTX + */ + +#ifndef TRANSFORMER_ENGINE_PTX_CUH_ +#define TRANSFORMER_ENGINE_PTX_CUH_ + +#include +#include + +namespace transformer_engine { +namespace ptx { + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init +__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval +__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive +__device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) + : "memory"); +} + +__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { + asm volatile("fence.mbarrier_init.release.cluster;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.shared::cta.global" + ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr), + "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// global -> shared::cluster +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, + const uint32_t offset_y, uint64_t *mbar) { + uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + // triggers async copy, i.e. the thread continues until wait() on mbarrier + // barrier condition: + // - leader must arrive (i.e. 1 thread as set above) + // - TMA hardware substracts bytes from expect_tx counter, must reach zero + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), + "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, + const uint64_t *src_shmem, + const uint32_t size) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr), + "r"(src_shmem_ptr), "r"(size) + : "memory"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, + uint64_t *src_shmem) { + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"( + tensor_map_ptr), + "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) + : "memory"); +} + +__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { + uint32_t waitComplete; + asm volatile( + "{\n\t .reg .pred P_OUT; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P_OUT; \n" + "}" + : "=r"(waitComplete) + : "r"(mbar_ptr), "r"(parity) + : "memory"); + return static_cast(waitComplete); +} + +__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +__device__ __forceinline__ void cp_async_bulk_wait_group() { + asm volatile("cp.async.bulk.wait_group 0;"); +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} + +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { + asm volatile("cp.async.bulk.wait_group.read 0;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { + asm volatile("cp.async.bulk.wait_group.read 1;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { + asm volatile("cp.async.bulk.wait_group.read 2;"); +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { + asm volatile("cp.async.bulk.wait_group.read 4;"); +} + +// Proxy fence (bi-directional): +__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } +__device__ __forceinline__ void fence_proxy_async_shared_cta() { + asm volatile("fence.proxy.async.shared::cta;"); +} + +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +} // namespace ptx + +namespace { + +template +__forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + // Syncthreads so initialized barrier is visible to all threads. + __syncthreads(); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Destroy barrier. This invalidates the memory region of the barrier. If + // further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int iter = 0; iter < num_barriers; ++iter) { + ptx::mbarrier_invalid(&mbar[iter]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, + const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), + num_bytes, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, + const size_t chunk_Y, const size_t num_bytes, + uint64_t *barrier, const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X, + chunk_Y, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, + const size_t chunk_X1, const size_t chunk_Y1, + void *dst2, const void *src2, + const size_t chunk_X2, const size_t chunk_Y2, + const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), + chunk_X2, chunk_Y2, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +__forceinline__ __device__ void copy_2d_to_sharedx3( + void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, void *dst2, + const void *src2, const size_t chunk_X2, const size_t chunk_Y2, void *dst3, const void *src3, + const size_t chunk_X3, const size_t chunk_Y3, const size_t num_bytes, uint64_t *barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst), + reinterpret_cast(src), chunk_X1, + chunk_Y1, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst2), + reinterpret_cast(src2), + chunk_X2, chunk_Y2, barrier); + + ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(dst3), + reinterpret_cast(src3), + chunk_X3, chunk_Y3, barrier); + + // Arrive on the barrier and tell how many bytes are expected to come in. + ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +} // namespace +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_PTX_CUH_ diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 432ac815ec..b8c8df37ee 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -14,66 +14,116 @@ #include "cuda_runtime.h" -#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ - pybind11::enum_(m, "DType") \ - .value("kByte", transformer_engine::DType::kByte) \ - .value("kInt32", transformer_engine::DType::kInt32) \ - .value("kFloat32", transformer_engine::DType::kFloat32) \ - .value("kFloat16", transformer_engine::DType::kFloat16) \ - .value("kBFloat16", transformer_engine::DType::kBFloat16) \ - .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ - pybind11::enum_(m, "NVTE_Bias_Type") \ - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ - .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ - pybind11::enum_(m, "NVTE_Mask_Type") \ - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ - .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ - .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ - .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ - NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ - pybind11::enum_(m, "NVTE_QKV_Layout") \ - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \ - pybind11::enum_(m, "NVTE_Fused_Attn_Backend") \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ - pybind11::enum_(m, "CommOverlapType") \ - .value("RS", transformer_engine::CommOverlapType::RS) \ - .value("AG", transformer_engine::CommOverlapType::AG); \ - pybind11::enum_(m, "CommOverlapAlgo") \ - .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ - .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ - .value("SPLIT_PIPELINED_AG_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ - .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ - .value("SPLIT_PIPELINED_RS_P2P", \ - transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ - .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ - .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ - m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ - py::call_guard(), py::arg("device_id") = -1); \ - m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ +#define NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) \ + pybind11::enum_(m, "DType", pybind11::module_local()) \ + .value("kByte", transformer_engine::DType::kByte) \ + .value("kInt32", transformer_engine::DType::kInt32) \ + .value("kFloat32", transformer_engine::DType::kFloat32) \ + .value("kFloat16", transformer_engine::DType::kFloat16) \ + .value("kBFloat16", transformer_engine::DType::kBFloat16) \ + .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ + .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ + .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ + .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \ + .value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \ + pybind11::enum_(m, "NVTE_Mask_Type", pybind11::module_local()) \ + .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \ + .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \ + .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \ + .value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \ + .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ + .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ + NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ + .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ + .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ + .value("NVTE_THD", NVTE_QKV_Format::NVTE_THD) \ + .value("NVTE_SBHD_2BSHD", NVTE_QKV_Format::NVTE_SBHD_2BSHD) \ + .value("NVTE_BSHD_2SBHD", NVTE_QKV_Format::NVTE_BSHD_2SBHD) \ + .value("NVTE_THD_2BSHD", NVTE_QKV_Format::NVTE_THD_2BSHD) \ + .value("NVTE_THD_2SBHD", NVTE_QKV_Format::NVTE_THD_2SBHD); \ + pybind11::enum_(m, "NVTE_QKV_Layout", pybind11::module_local()) \ + .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \ + .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \ + .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) \ + .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) \ + .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) \ + .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) \ + .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) \ + .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) \ + .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) \ + .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) \ + .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) \ + .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \ + .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \ + .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \ + .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \ + .value("NVTE_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_SBHD_BSHD_BSHD) \ + .value("NVTE_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_BSHD_SBHD_SBHD) \ + .value("NVTE_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_THD_BSHD_BSHD) \ + .value("NVTE_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_THD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_BSHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_BSHD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_SBHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_SBHD_SBHD_SBHD) \ + .value("NVTE_Paged_KV_THD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_BSHD_BSHD) \ + .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD); \ + pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ + .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ + .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ + .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ + .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ + pybind11::enum_(m, "CommOverlapType", \ + pybind11::module_local()) \ + .value("RS", transformer_engine::CommOverlapType::RS) \ + .value("AG", transformer_engine::CommOverlapType::AG); \ + pybind11::enum_(m, "CommOverlapAlgo", \ + pybind11::module_local()) \ + .value("BULK_OVERLAP_AG", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_AG) \ + .value("BULK_OVERLAP_RS", transformer_engine::CommOverlapAlgo::BULK_OVERLAP_RS) \ + .value("SPLIT_PIPELINED_AG_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_AG_P2P) \ + .value("SPLIT_PIPELINED_RS", transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS) \ + .value("SPLIT_PIPELINED_RS_P2P", \ + transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ + .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ + .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + py::class_>(m, "CommOverlapCore", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapCore(); }), \ + py::call_guard()) \ + .def("is_atomic_gemm", &transformer_engine::CommOverlapCore::is_atomic_gemm, \ + py::call_guard()) \ + .def("is_p2p_overlap", &transformer_engine::CommOverlapCore::is_p2p_overlap, \ + py::call_guard()) \ + .def("is_fp8_ubuf", &transformer_engine::CommOverlapCore::is_fp8_ubuf, \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapBase", pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapBase(); }), \ + py::call_guard()); \ + py::class_, \ + transformer_engine::CommOverlapCore>(m, "CommOverlapP2PBase", \ + pybind11::module_local()) \ + .def(py::init([]() { return new transformer_engine::CommOverlapP2PBase(); }), \ + py::call_guard()); \ + m.def("device_supports_multicast", &transformer_engine::cuda::supports_multicast, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def( \ + "get_stream_priority_range", \ + [](int device_id = -1) { \ + int low_pri, high_pri; \ + transformer_engine::cuda::stream_priority_range(&low_pri, &high_pri, device_id); \ + return std::make_pair(low_pri, high_pri); \ + }, \ + py::call_guard(), py::arg("device_id") = -1); \ + m.def("ubuf_built_with_mpi", &transformer_engine::ubuf_built_with_mpi, \ py::call_guard()); #endif diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index c03654bfc5..bc286dd621 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/rtc.h b/transformer_engine/common/util/rtc.h index 2c79d038b2..820b16c206 100644 --- a/transformer_engine/common/util/rtc.h +++ b/transformer_engine/common/util/rtc.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/shared_lib_wrapper.h b/transformer_engine/common/util/shared_lib_wrapper.h new file mode 100644 index 0000000000..3ccc8239b8 --- /dev/null +++ b/transformer_engine/common/util/shared_lib_wrapper.h @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ + +#include + +namespace transformer_engine { + +/*! \brief Wrapper class for a shared library + * + * \todo Windows support + */ +class Library { + public: + explicit Library(const char *filename) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL); + NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed"); +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + ~Library() { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support +#else + if (handle_ != nullptr) { + dlclose(handle_); + } +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + Library(const Library &) = delete; // move-only + + void *get() noexcept { return handle_; } + + const void *get() const noexcept { return handle_; } + + /*! \brief Get pointer corresponding to symbol in shared library */ + void *get_symbol(const char *symbol) { +#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__) + // TODO Windows support + NVTE_ERROR("Shared library initialization is not supported with Windows"); +#else + void *ptr = dlsym(handle_, symbol); + NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library"); + return ptr; +#endif // _WIN32 or _WIN64 or __WINDOW__ + } + + private: + void *handle_ = nullptr; +}; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_ diff --git a/transformer_engine/common/util/string.h b/transformer_engine/common/util/string.h index c0a2aa1077..0064144102 100644 --- a/transformer_engine/common/util/string.h +++ b/transformer_engine/common/util/string.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -13,15 +13,34 @@ namespace transformer_engine { -/*! \brief Convert to C-style or C++-style string */ +inline const std::string &to_string_like(const std::string &val) noexcept { return val; } + +constexpr const char *to_string_like(const char *val) noexcept { return val; } + +/* \brief Convert arithmetic type to string */ template ::value>::type> inline std::string to_string_like(const T &val) { return std::to_string(val); } -inline const std::string &to_string_like(const std::string &val) noexcept { return val; } - -constexpr const char *to_string_like(const char *val) noexcept { return val; } +/* \brief Convert container to string */ +template ::value>::type, + typename = decltype(std::declval().begin())> +inline std::string to_string_like(const T &container) { + std::string str; + str.reserve(1024); // Assume strings are <1 KB + str += "("; + bool first = true; + for (const auto &val : container) { + if (!first) { + str += ","; + } + str += to_string_like(val); + first = false; + } + str += ")"; + return str; +} /*! \brief Convert arguments to strings and concatenate */ template diff --git a/transformer_engine/common/util/string_header.h.in b/transformer_engine/common/util/string_header.h.in index adbbb90d73..b9fa83a94f 100644 --- a/transformer_engine/common/util/string_header.h.in +++ b/transformer_engine/common/util/string_header.h.in @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/common/util/system.cpp b/transformer_engine/common/util/system.cpp deleted file mode 100644 index 0659061b47..0000000000 --- a/transformer_engine/common/util/system.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "../util/system.h" - -#include -#include -#include -#include -#include -#include - -#include "../common.h" - -namespace transformer_engine { - -namespace { - -template -inline typename std::enable_if::value, T>::type getenv_helper( - const char *variable, const T &default_value) { - // Implementation for numeric types - const char *env = std::getenv(variable); - if (env == nullptr || env[0] == '\0') { - return default_value; - } - T value; - std::istringstream iss(env); - iss >> value; - NVTE_CHECK(iss, "Invalid environment variable value"); - return value; -} - -template -inline typename std::enable_if::value, T>::type getenv_helper( - const char *variable, const T &default_value) { - // Implementation for string-like types - const char *env = std::getenv(variable); - if (env == nullptr || env[0] == '\0') { - return default_value; - } else { - return env; - } -} - -} // namespace - -#define NVTE_INSTANTIATE_GETENV(T, default_value) \ - template <> \ - T getenv(const char *variable, const T &default_value_) { \ - return getenv_helper(variable, default_value_); \ - } \ - template <> \ - T getenv(const char *variable) { \ - return getenv_helper(variable, default_value); \ - } -NVTE_INSTANTIATE_GETENV(bool, false); -NVTE_INSTANTIATE_GETENV(float, 0.f); -NVTE_INSTANTIATE_GETENV(double, 0.); -NVTE_INSTANTIATE_GETENV(int8_t, 0); -NVTE_INSTANTIATE_GETENV(int16_t, 0); -NVTE_INSTANTIATE_GETENV(int32_t, 0); -NVTE_INSTANTIATE_GETENV(int64_t, 0); -NVTE_INSTANTIATE_GETENV(uint8_t, 0); -NVTE_INSTANTIATE_GETENV(uint16_t, 0); -NVTE_INSTANTIATE_GETENV(uint32_t, 0); -NVTE_INSTANTIATE_GETENV(uint64_t, 0); -NVTE_INSTANTIATE_GETENV(std::string, std::string()); -NVTE_INSTANTIATE_GETENV(std::filesystem::path, std::filesystem::path()); - -bool file_exists(const std::string &path) { return static_cast(std::ifstream(path.c_str())); } - -} // namespace transformer_engine diff --git a/transformer_engine/common/util/system.h b/transformer_engine/common/util/system.h index 67626f7167..5636ab5095 100644 --- a/transformer_engine/common/util/system.h +++ b/transformer_engine/common/util/system.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -7,27 +7,96 @@ #ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_ #define TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_ +#include +#include +#include +#include +#include #include -#include "../common.h" +#include "logging.h" namespace transformer_engine { +namespace detail { + +/*! \brief Template specialization to get the env var for numeric data types */ +template +inline typename std::enable_if::value, T>::type getenv_helper( + const char *variable, const T &default_value) { + // Implementation for numeric types + const char *env = std::getenv(variable); + if (env == nullptr || env[0] == '\0') { + return default_value; + } + T value; + std::istringstream iss(env); + iss >> value; + NVTE_CHECK(iss, "Invalid environment variable value"); + return value; +} + +/*! \brief Template specialization to get the env var for string-like data types */ +template +inline typename std::enable_if::value, T>::type getenv_helper( + const char *variable, const T &default_value) { + // Implementation for string-like types + const char *env = std::getenv(variable); + if (env == nullptr || env[0] == '\0') { + return default_value; + } else { + return env; + } +} + +/*! \brief Template specialization to get the default values for different +* numeric data types +*/ +template +inline T getenv_default_value() { + return 0; +} + +/*! \brief Template specialization to get the default values for bool */ +template <> +inline bool getenv_default_value() { + return false; +} + +/*! \brief Template specialization to get the default values for string */ +template <> +inline std::string getenv_default_value() { + return std::string(); +} + +/*! \brief Template specialization to get the default values for filesystem +* path data type */ +template <> +inline std::filesystem::path getenv_default_value() { + return std::filesystem::path(); +} + +} // namespace detail + /*! \brief Get environment variable and convert to type * * If the environment variable is unset or empty, a falsy value is * returned. - */ +*/ template -T getenv(const char *variable); +inline T getenv(const char *variable) { + return detail::getenv_helper(variable, detail::getenv_default_value()); +} /*! \brief Get environment variable and convert to type */ template -T getenv(const char *variable, const T &default_value); +inline T getenv(const char *variable, const T &default_value) { + return detail::getenv_helper(variable, default_value); +} -/*! \brief Check if a file exists and can be read */ -bool file_exists(const std::string &path); +inline bool file_exists(const std::string &path) { + return std::filesystem::exists(path) && std::filesystem::is_regular_file(path); +} } // namespace transformer_engine - #endif // TRANSFORMER_ENGINE_COMMON_UTIL_SYSTEM_H_ diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 8653bf45a4..420b9ed3bb 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -44,6 +44,13 @@ class VectorizedStorage { return *this; } inline __device__ ~VectorizedStorage() {} + + /* \brief Access to separate elements. */ + inline __device__ DType *separate() { return scratch_.separate; } + + inline __device__ const DType *separate() const { return scratch_.separate; } + + inline __device__ LType &aligned() { return scratch_.aligned; } }; // Returns const LType is DType is const @@ -167,9 +174,11 @@ constexpr int unary_kernel_threads = 512; template __launch_bounds__(unary_kernel_threads) __global__ - void unary_kernel(const InputType *input, OutputType *output, const ComputeType *scale, - ComputeType *amax, ComputeType *scale_inv, Param p, const size_t N, - const size_t num_aligned_elements) { + void unary_kernel(const InputType *input, const ComputeType *noop, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, Param p, + const size_t N, const size_t num_aligned_elements) { + if (noop != nullptr && noop[0] == 1.0f) return; + VectorizedLoader loader(input, N); VectorizedStorer storer(output, N); ComputeType max = 0; @@ -322,9 +331,9 @@ Alignment CheckAlignment(const size_t lead_dim, const int nvec, const T... ptrs) template -void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, const fp32 *scale, - fp32 *amax, fp32 *scale_inv, const size_t N, const Param params, - cudaStream_t stream) { +void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, + const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, + const Param params, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -337,16 +346,16 @@ void VectorizedUnaryKernelLauncher(const InputType *input, OutputType *output, c switch (align) { case Alignment::SAME_ALIGNED: unary_kernel<<>>( - input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: unary_kernel<<>>( - input, output, scale, amax, scale_inv, params, N, num_aligned_elements); + input, noop, output, scale, amax, scale_inv, params, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize unary_kernel<1, true, fp32, Param, OP><<>>( - input, output, scale, amax, scale_inv, params, N, N); + input, noop, output, scale, amax, scale_inv, params, N, N); break; } } @@ -395,18 +404,19 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; VectorizedLoader loader0(input + id_y * n * 2, n); VectorizedLoader loader1(input + id_y * n * 2 + n, n); VectorizedStorer storer(output + id_y * n, n); - ComputeType max = 0; - ComputeType s = 1; - if constexpr (is_fp8::value) { - if (scale != nullptr) s = *scale; - } - const int warp_id = threadIdx.x / THREADS_PER_WARP; loader0.load(id_x, n); loader1.load(id_x, n); @@ -423,21 +433,20 @@ __launch_bounds__(unary_kernel_threads) __global__ storer.separate()[i] = static_cast(static_cast(temp)); } storer.store(id_x, n); - - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } - // Update scale-inverse - if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { - reciprocal(scale_inv, s); - } + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); } } } @@ -482,9 +491,17 @@ template __launch_bounds__(unary_kernel_threads) __global__ void dgated_act_kernel(const InputType *grad, const InputType *input, OutputType *output, + const ComputeType *scale, ComputeType *amax, ComputeType *scale_inv, const size_t m, const size_t n, const Param p, const size_t num_aligned_elements) { const size_t M = num_aligned_elements * m; + ComputeType max = 0; + ComputeType s = 1; + if constexpr (is_fp8::value) { + if (scale != nullptr) s = *scale; + } + const int warp_id = threadIdx.x / THREADS_PER_WARP; + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { const size_t id_x = tid % num_aligned_elements; const size_t id_y = tid / num_aligned_elements; @@ -507,12 +524,35 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgate = grad_val * Activation(gelu_in, p); + if constexpr (is_fp8::value) { + __builtin_assume(max >= 0); + max = fmaxf(fabsf(after_dgelu), max); + after_dgelu = after_dgelu * s; + max = fmaxf(fabsf(after_dgate), max); + after_dgate = after_dgate * s; + } + storer0.separate()[i] = static_cast(after_dgelu); storer1.separate()[i] = static_cast(after_dgate); } storer0.store(id_x, n); storer1.store(id_x, n); } + if constexpr (is_fp8::value) { + // Reduce amax over block + if (amax != nullptr) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); + } + } + + // Update scale-inverse + if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { + reciprocal(scale_inv, s); + } + } } template void DGatedActivationKernelLauncher(const InputType *grad, const InputType *input, - OutputType *output, const size_t m, const size_t n, - const Param &p, cudaStream_t stream) { + OutputType *output, const fp32 *scale, fp32 *amax, + fp32 *scale_inv, const size_t m, const size_t n, const Param &p, + cudaStream_t stream) { if (m != 0 && n != 0) { size_t num_aligned_elements = get_num_aligned_elements(grad, n, nvec, sizeof(InputType)); constexpr size_t threads = unary_kernel_threads; @@ -532,18 +573,19 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu switch (auto align = CheckAlignment(n, nvec, input, input + n, output, output + n)) { case Alignment::SAME_ALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: dgated_act_kernel - <<>>(grad, input, output, m, n, p, - num_aligned_elements); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, num_aligned_elements); break; case Alignment::DIFFERENT: { // If the pointers are aligned differently we cannot vectorize dgated_act_kernel<1, true, ComputeType, Param, Activation, Dactivation> - <<>>(grad, input, output, m, n, p, n); + <<>>(grad, input, output, scale, amax, scale_inv, m, n, + p, n); break; } } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 6703ce728c..227b3aaa48 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -819,6 +819,21 @@ __device__ __forceinline__ float warp_reduce_max(const float m) { return tmp; } +__forceinline__ __device__ float warp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = THREADS_PER_WARP / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero); + return val_tmp; +} + template __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int warpid) { __shared__ float staging[num_warps]; @@ -829,7 +844,7 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war staging[warpid] = my_warp_max; } __syncthreads(); - compute_t result = 0; + compute_t result = 0.f; if (warpid == 0) { const float my_max = threadIdx.x < num_warps ? staging[threadIdx.x] : 0; result = warp_reduce_max(my_max); @@ -837,6 +852,29 @@ __device__ __forceinline__ compute_t reduce_max(const compute_t m, const int war return result; } +/** + * Max reduction in subwarps + * E.g., if nvec=4, each warp processes 128 elements (32 x 4), that covers four MXFP8 scaling factors. + * To compute an actual scaling factor for 32 consequentive elements, only 8 threads need to participate, + * thus splitting the warp into 4x smaller subwarps 8-thread width. + * 'Butterfly' reduction is used inside subwarps. + */ +template +__forceinline__ __device__ float subwarp_reduce_max_broadcast(const float val) { + float val_tmp = val; +#pragma unroll + for (int offset = subwarp_width / 2; offset > 0; offset /= 2) { + const float val_other = __shfl_down_sync(0xFFFFFFFF, val_tmp, offset, subwarp_width); + __builtin_assume(val_tmp >= 0); + __builtin_assume(val_other >= 0); + val_tmp = fmaxf(val_tmp, val_other); + } + // Broadcast the amax to other threads of the subwarp from the zero subwarp lane_id + constexpr int subwarp_lane_zero = 0; + val_tmp = __shfl_sync(0xFFFFFFFF, val_tmp, subwarp_lane_zero, subwarp_width); + return val_tmp; +} + // Works only on positive values __device__ __forceinline__ void atomicMaxFloat(float *addr, const float value) { atomicMax(reinterpret_cast(addr), __float_as_int(value)); @@ -857,6 +895,79 @@ __device__ __forceinline__ void reciprocal(float *value_inv, const float *value_inv = __frcp_rn(value); } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; +using e8m0_t = uint8_t; + +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; + +template +struct Numeric_Traits; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 8; + static constexpr double maxNorm = 448; +}; + +template <> +struct Numeric_Traits { + static constexpr int maxUnbiasedExponent = 15; + static constexpr double maxNorm = 57344; +}; + +template +struct Quantized_Limits { + static constexpr int max_unbiased_exponent = Numeric_Traits::maxUnbiasedExponent; + static constexpr float max_norm = Numeric_Traits::maxNorm; + static constexpr float max_norm_rcp = 1.0 / max_norm; + static constexpr float emax = 1 << max_unbiased_exponent; + static constexpr float emax_rcp = 1.0 / emax; +}; + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); +} + } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ diff --git a/transformer_engine/common/utils.py b/transformer_engine/common/utils.py index 6fd9d141b4..a808e1571f 100644 --- a/transformer_engine/common/utils.py +++ b/transformer_engine/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """The utilities for Transformer Engine""" diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 05adbd624c..6dbe9c0e1d 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -1,17 +1,22 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Transformer Engine bindings for JAX""" # pylint: disable=wrong-import-position,wrong-import-order +import sys import logging +import importlib +import importlib.util import ctypes from importlib.metadata import version from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +_logger = logging.getLogger(__name__) + def _load_library(): """Load shared library with Transformer Engine C extensions""" @@ -30,15 +35,15 @@ def _load_library(): "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" - " transformer-engine[jax]==VERSION'" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using " + "'pip3 install transformer-engine[jax]==VERSION'" ) if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): - logging.info( - "Could not find package %s. Install transformer-engine using 'pip" - " install transformer-engine[jax]==VERSION'", + _logger.info( + "Could not find package %s. Install transformer-engine using " + "'pip3 install transformer-engine[jax]==VERSION'", module_name, ) @@ -47,13 +52,20 @@ def _load_library(): so_dir = get_te_path() / "transformer_engine" so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: - so_dir = get_te_path() - so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + try: + so_dir = get_te_path() / "transformer_engine" / "wheel_lib" + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + except StopIteration: + so_dir = get_te_path() + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) - return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + spec = importlib.util.spec_from_file_location(module_name, so_path) + solib = importlib.util.module_from_spec(spec) + sys.modules[module_name] = solib + spec.loader.exec_module(solib) -_TE_JAX_LIB_CTYPES = _load_library() +_load_library() from . import flax from .fp8 import fp8_autocast, update_collections, get_delayed_scaling from .fp8 import NVTE_FP8_COLLECTION_NAME diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 3ecc9bcd75..06629291da 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -1,20 +1,23 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX multi-head attention modules""" - +from __future__ import annotations from enum import Enum from functools import partial -from typing import Optional, Tuple +from typing import Optional, Tuple, Union +import warnings + from jax.ad_checkpoint import checkpoint_name import jax import jax.numpy as jnp +from flax.linen import make_attention_mask -from transformer_engine.transformer_engine_jax import NVTE_Bias_Type -from transformer_engine.transformer_engine_jax import NVTE_Mask_Type -from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout -from transformer_engine.transformer_engine_jax import NVTE_QKV_Format -from transformer_engine.transformer_engine_jax import nvte_get_qkv_format +from transformer_engine_jax import NVTE_Bias_Type +from transformer_engine_jax import NVTE_Mask_Type +from transformer_engine_jax import NVTE_QKV_Layout +from transformer_engine_jax import NVTE_QKV_Format +from transformer_engine_jax import nvte_get_qkv_format from . import cpp_extensions as tex @@ -46,6 +49,42 @@ class AttnMaskType(Enum): CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK + def is_causal(self): + """Returns True if the mask is a causal mask""" + return self in [ + AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_padding(self): + """Returns True if the mask includes padding""" + return self in [ + AttnMaskType.PADDING_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_bottom_right(self): + """Returns True if the causal mask is calculated from the bottom-right section""" + return self in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + +class QKVFormat(Enum): + """ + SBHD: q,k,v memory layout with [s, b, ..., h, d] + BSHD: q,k,v memory layout with [b, s, ..., h, d] + THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. + """ + + SBHD = NVTE_QKV_Format.NVTE_SBHD + BSHD = NVTE_QKV_Format.NVTE_BSHD + THD = NVTE_QKV_Format.NVTE_THD + class QKVLayout(Enum): """ @@ -66,17 +105,68 @@ class QKVLayout(Enum): THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD + def get_qkv_format(self): + """ + Return the corresponding qkv_format (BSHD, SBHD, THD) + """ + return QKVFormat(nvte_get_qkv_format(self.value)) -class QKVFormat(Enum): - """ - SBHD: q,k,v memory layout with [s, b, ..., h, d] - BSHD: q,k,v memory layout with [b, s, ..., h, d] - THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. - """ + def is_qkvpacked(self): + """ + Return True if the query, key, value is packed + """ + return self in [QKVLayout.BS3HD, QKVLayout.T3HD] - SBHD = NVTE_QKV_Format.NVTE_SBHD - BSHD = NVTE_QKV_Format.NVTE_BSHD - THD = NVTE_QKV_Format.NVTE_THD + def is_kvpacked(self): + """ + Return True if the key, value is packed + """ + return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD] + + def is_separate(self): + """ + Return True if the query, key, value are three separate tensors + """ + return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD] + + def is_thd(self): + """ + Return True if the layout belongs to THD + """ + return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] + + def to_qkvpacked(self): + """ + Return the corresponding qkvpacked format, useful when adjusting q, k, v layout + """ + qkv_format = self.get_qkv_format() + if qkv_format == QKVFormat.BSHD: + return QKVLayout.BS3HD + if qkv_format == QKVFormat.THD: + return QKVLayout.T3HD + raise ValueError(f"Unsupported {qkv_format=}") + + def to_kvpacked(self): + """ + Return the corresponding kvpacked format, useful when adjusting q, k, v layout + """ + qkv_format = self.get_qkv_format() + if qkv_format == QKVFormat.BSHD: + return QKVLayout.BSHD_BS2HD + if qkv_format == QKVFormat.THD: + return QKVLayout.THD_T2HD + raise ValueError(f"Unsupported {qkv_format=}") + + def to_separate(self): + """ + Return the corresponding separate format, useful when adjusting q, k, v layout + """ + qkv_format = self.get_qkv_format() + if qkv_format == QKVFormat.BSHD: + return QKVLayout.BSHD_BSHD_BSHD + if qkv_format == QKVFormat.THD: + return QKVLayout.THD_THD_THD + raise ValueError(f"Unsupported {qkv_format=}") class CPStrategy(Enum): @@ -92,71 +182,67 @@ class CPStrategy(Enum): RING = 2 -def get_qkv_format(qkv_layout): +class ReorderStrategy(Enum): """ - Get qkv_format from qkv_layout + Defines the tokens re-order strategy for context parallel load balancing for causal mask. + + - DualChunkSwap: This strategy splits each query into two chunks and do the mirror swap between + GPUs. This is currently used for non-THD load balance. It requires the max_seqlens be the + mulitple of 2 * cp_size. + Examples: + - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; GPU2: [8, 9, 10, 11]; GPU3: [12, 13, 14, 15]; + - After reorder: GPU0: [0, 1, 14, 15]; GPU1: [4, 5, 10, 11]; GPU2: [8, 9, 6, 7]; GPU3: [12, 13, 2, 3] + + - Striped: This strategy distributes the tokens in a striped (interleaved) manner across + the sequence. This is currently used for THD load balance. + Example: Consider 4 GPUs with seqlens=16. + - Before reorder: GPU0: [0, 1, 2, 3]; GPU1: [4, 5, 6, 7]; ...; GPU3: [12, 13, 14, 15] + - After reorder: GPU0: [0, 4, 8, 12]; GPU1: [1, 5, 9, 13]; ...; GPU3: [3, 7, 11, 15] """ - return QKVFormat(nvte_get_qkv_format(qkv_layout.value)) + + DualChunkSwap = 0 + Striped = 1 def make_swa_mask( - max_seqlen_q: int, - max_seqlen_kv: int, + segment_pos_q: jnp.ndarray, + segment_pos_kv: jnp.ndarray, window_size: Optional[Tuple[int, int]] = None, - attn_mask_type: AttnMaskType = AttnMaskType.NO_MASK, dtype: jax.typing.DTypeLike = jnp.float32, ): """ - Generate sliding window mask. `True` or `1` means keep the element. - - For `CAUSAL_BOTTOM_RIGHT_MASK` and `PADDING_CAUSAL_BOTTOM_RIGHT_MASK` mask type, - the sliding window diagonal is aligned to the bottom right corner, and for other - mask types, the top left corner. - - Parameters - ---------- - max_seqlen_q: int - Maximum sequence length for queries. - max_seqlen_kv: int - Maximum sequence length for keys and values. - window_size: Optional[Tuple[int, int]] = None - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Negative number in window size means infinity window. - `None` means no sliding window. - attn_mask_type: AttnMaskType, default = AttnMaskType.NO_MASK - dtype: jax.typing.DTypeLike, default=jnp.float32 - The mask data type. - Returns - ---------- - swa_mask: jax.numpy.tensor - Matrix with shape [max_seqlen_q, max_seqlen_kv]. Elements with value 1 are the positions - that will get attention, value 0 are the masked out positions. + Generate a sliding window mask (1 = attend, 0 = masked). + + Args: + segment_pos_q (jnp.ndarray): + Query positions within each segment. For example, a batch with segment_ids = + [[1, 1, 1, 2, 2, 2, 2, 2]] yields segment_pos = + [[0, 1, 2, 0, 1, 2, 3, 4]]. + segment_pos_kv (jnp.ndarray): + Key/value positions within each segment. + window_size (Optional[Tuple[int, int]], optional): + Sliding window size for local attention, where query at position i attends to keys + in [i - window_size[0], i + window_size[1]] inclusive. A negative number means an + infinite window; None means no sliding window. + Defaults to None. + dtype (jax.typing.DTypeLike, optional): + Mask data type. Defaults to jnp.float32. + + Returns: + jnp.ndarray: + The mask with shape [b, 1, max_seqlen_q, max_seqlen_kv]. """ - swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) - if window_size is None: - return swa_mask - bottom_right_masks = [ - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - ] - left_window, right_window = window_size - if attn_mask_type in bottom_right_masks: - if left_window < 0: - left_window = max_seqlen_kv - if right_window < 0: - right_window = max_seqlen_kv - bottom_right_shift = max_seqlen_kv - max_seqlen_q - swa_mask = jnp.triu(swa_mask, k=-left_window + bottom_right_shift) - swa_mask = jnp.tril(swa_mask, k=right_window + bottom_right_shift) + if window_size is not None: + left_window, right_window = window_size else: - if left_window < 0: - left_window = max_seqlen_q - if right_window < 0: - right_window = max_seqlen_q - swa_mask = jnp.triu(swa_mask, k=-left_window) - swa_mask = jnp.tril(swa_mask, k=right_window) - return swa_mask + left_window = right_window = jnp.inf + left_window = jnp.inf if left_window < 0 else left_window + right_window = jnp.inf if right_window < 0 else right_window + pos_q = jnp.expand_dims(segment_pos_q, axis=-1) + pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2) + inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window) + inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3) + return inv_swa_mask.astype(dtype) def canonicalize_attn_mask_type(attn_mask_type: str): @@ -212,9 +298,9 @@ def make_helper(attn_mask_type): return tex.FusedAttnHelper( q_dtype, kv_dtype, - qkv_layout.value, - attn_bias_type.value, - attn_mask_type.value, + qkv_layout, + attn_bias_type, + attn_mask_type, dropout_probability, q_num_heads, kv_num_heads, @@ -224,44 +310,312 @@ def make_helper(attn_mask_type): (-1, -1) if window_size is None else window_size, ) - if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): - return False - - return True + return make_helper(attn_mask_type).is_fused_attn_kernel_available() def _obtain_batch_and_max_seqlen(qkv, qkv_layout): - match qkv_layout: - case QKVLayout.BS3HD | QKVLayout.T3HD: - assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = q_max_seqlen - case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD: - assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = qkv[1].shape[1] - case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD: - assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = qkv[1].shape[1] - case _: - raise ValueError(f"Unsupported {qkv_layout=}") + if qkv_layout.is_qkvpacked(): + assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = q_max_seqlen + elif qkv_layout.is_kvpacked(): + assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = qkv[1].shape[1] + elif qkv_layout.is_separate(): + assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = qkv[1].shape[1] + else: + raise ValueError(f"Unsupported {qkv_layout=}") return batch, q_max_seqlen, kv_max_seqlen -def reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): +def reorder_causal_load_balancing(tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int): """Reorders a tensor for load balancing the compute of causal attention.""" - seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 - return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, False) + if strategy == ReorderStrategy.DualChunkSwap: + return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, False) + if strategy == ReorderStrategy.Striped: + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, False) + raise ValueError(f"Unsupported {strategy=}") -def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: QKVFormat): +def inverse_reorder_causal_load_balancing( + tensor, strategy: ReorderStrategy, cp_size: int, seq_dim: int +): """Inverse operation of `reorder_causal_load_balancing`.""" - seq_dim = 1 if tensor_format == QKVFormat.BSHD else 0 - return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) + if strategy == ReorderStrategy.DualChunkSwap: + return tex.attention.reorder_causal_dual_chunk_swap(tensor, cp_size, seq_dim, True) + if strategy == ReorderStrategy.Striped: + return tex.attention.reorder_causal_striped(tensor, cp_size, seq_dim, True) + raise ValueError(f"Unsupported {strategy=}") -def fused_attn( +def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq): + # bincount map with 0s + bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1)) + seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) + seqlens = seqlens_with_zero[..., 1:] + + def _find_offsets(x): + same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) + first_column = x[..., :1] != 0 + same_as_previous = jnp.hstack((first_column, same_as_previous)) + return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))( + same_as_previous + ).squeeze(-1) + + offsets = _find_offsets(segment_ids) + return seqlens, offsets + + +def _mask_to_seqlens_offset(mask, max_segments_per_seq): + assert mask.shape[1] == 1 + row_ids = mask.squeeze(axis=1).max(axis=-1) + q_seqlen, q_offset = _get_seqlens_and_offsets(row_ids, max_segments_per_seq) + col_ids = mask.squeeze(axis=1).max(axis=-2) + kv_seqlen, kv_offset = _get_seqlens_and_offsets(col_ids, max_segments_per_seq) + return q_seqlen, q_offset, kv_seqlen, kv_offset + + +def _segment_ids_pos_to_seqlens_offsets( + segment_ids_q, + segment_ids_kv, + segment_pos_q, + segment_pos_kv, + attn_mask_type, + window_size, + max_segments_per_seq, +): + # (1 = attend, 0 = masked) + segment_mask = make_attention_mask( + segment_ids_q, + segment_ids_kv, + jnp.equal, + ) + segment_mask_with_id = make_attention_mask( + segment_ids_q, + segment_ids_kv, + lambda x, y: jnp.equal(x, y) * x, + ) + attn_mask = segment_mask + if attn_mask_type.is_causal(): + causal_mask = make_attention_mask( + segment_pos_q, + segment_pos_kv, + jnp.greater_equal, + ) + attn_mask = jnp.logical_and(segment_mask, causal_mask) + + swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + attn_mask = jnp.logical_and(attn_mask, swa_mask) + + attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) + q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( + attn_mask_with_id, max_segments_per_seq + ) + return q_seqlen, kv_seqlen, q_offset, kv_offset + + +def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type): + # convert the mask to seqlens, mask doesn't support ragged offsets + if not attn_mask_type.is_padding(): + q_max_seqlen = segment_ids_q.shape[-1] + kv_max_seqlen = segment_ids_kv.shape[-1] + q_seq_lens = jnp.full_like(q_max_seqlen, q_max_seqlen, dtype=jnp.int32) + kv_seq_lens = jnp.full_like(kv_max_seqlen, kv_max_seqlen, dtype=jnp.int32) + else: + q_seq_lens = jnp.sum(segment_ids_q, axis=-1).astype(jnp.int32) + kv_seq_lens = jnp.sum(segment_ids_kv, axis=-1).astype(jnp.int32) + return q_seq_lens, kv_seq_lens + + +@jax.tree_util.register_pytree_node_class +class SequenceDescriptor: + """A class to descibe the sequences with flexible initialization. + - SequenceDescriptor.from_seqlens + For non-THD (non-packed) cases, where each batch has only 1 sequence. + - SequenceDescriptor.from_seqlens_and_offsets + For THD (packed) cases, where each batch may have not only 1 sequence. + - SequenceDescriptor.from_segment_ids_and_pos + Experimental feature for THD (packed) cases with context parallelism. + """ + + seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + seq_offsets: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + segment_ids: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + segment_pos: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + + def __init__(self, seqlens=None, seq_offsets=None, segment_ids=None, segment_pos=None): + """ + Initialize to Tuple(jnp.zeros, jnp.zeros) because the primitive only accepts pure jax array + """ + self.seqlens = (jnp.zeros(0), jnp.zeros(0)) if seqlens is None else seqlens + self.seq_offsets = (jnp.zeros(0), jnp.zeros(0)) if seq_offsets is None else seq_offsets + self.segment_ids = (jnp.zeros(0), jnp.zeros(0)) if segment_ids is None else segment_ids + self.segment_pos = (jnp.zeros(0), jnp.zeros(0)) if segment_pos is None else segment_pos + + def tree_flatten(self): + """ + Flatten method to register as a pytree node + """ + return ((self.seqlens, self.seq_offsets, self.segment_ids, self.segment_pos), None) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """ + Unflatten method to register as a pytree node + """ + del aux_data + return cls(*children) + + def get_seqlens_and_offsets( + self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq + ): + """ + Acquire the seqlens/offsets for cuDNN backend + """ + q_segment_ids, kv_segment_ids = self.segment_ids + q_segment_pos, kv_segment_pos = self.segment_pos + assert q_segment_ids.shape == q_segment_pos.shape + assert kv_segment_ids.shape == kv_segment_pos.shape + # No segment_ids/segment_pos + if q_segment_ids.size + kv_segment_ids.size == 0: + return self.seqlens, self.seq_offsets + + if qkv_layout.is_thd(): + q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + else: + q_seqlens, kv_seqlens = _segment_ids_to_seqlens( + q_segment_ids, + kv_segment_ids, + attn_mask_type, + ) + q_offsets = kv_offsets = jnp.zeros(0) + return (q_seqlens, kv_seqlens), (q_offsets, kv_offsets) + + @classmethod + def _expand_to_pair( + cls, value: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Internal helper to ensure a single value expands into a pair (q, kv). + """ + if isinstance(value, tuple): + if len(value) != 2: + raise ValueError("Input tuple must have exactly 2 elements.") + return value + + if isinstance(value, jnp.ndarray): + return value, value # Duplicate for q=kv case + + raise TypeError( + "Expected a jax.numpy.ndarray or a tuple of two jax.numpy.ndarray, " + f"but got {type(value).__name__}." + ) + + @classmethod + def from_seqlens( + cls, + seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + ) -> SequenceDescriptor: + """ + Factory method for inputs with sequence lengths only (non-THD). + Args: + seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens): + - q_seqlens (jnp.ndarray): + Sequence lengths for the query, with shape [batch]. + - kv_seqlen (jnp.ndarray): + Sequence lengths for the key and value, with shape [batch]. + Return: + A SequenceDescriptor with only seqlens initialized. + """ + q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens) + return cls(seqlens=(q_seqlens, kv_seqlens)) + + @classmethod + def from_seqlens_and_offsets( + cls, + seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + seq_offsets: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + ) -> SequenceDescriptor: + """ + Factory method for inputs with sequence lengths and offsets (THD). + Args: + seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens): + - q_seqlens (jnp.ndarray): + Sequence lengths for the query, with shape [batch, max_seqlen]. + Unused positions are padded with -1. + - kv_seqlen (jnp.ndarray): + Sequence lengths for the key and value, with shape [batch, max_seqlen]. + Unused positions are padded with -1. + seq_offsets(Tuple(jnp.ndarray, jnp.ndarray)) = (q_offsets, kv_offsets) + - q_seq_offsets (jnp.ndarray): + The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. + Unused positions are padded with -1. + - kv_seq_offsets (jnp.ndarray): + The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. + Unused positions are padded with -1. + Return: + A SequenceDescriptor with seqlens/seq_offsets initialized. + """ + q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens) + q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets) + return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets)) + + @classmethod + def from_segment_ids_and_pos( + cls, + segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> SequenceDescriptor: + """ + Experimental factory method for inputs with segment IDs and optional positions. (THD) + Args: + segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): + - q_segment_ids (jnp.ndarray): + Query segment ids start with 1, with shape [batch, max_seqlen]. + 0s are treated as paddings. + - kv_segment_ids (jnp.ndarray): + Key, value segment ids start with 1, with shape [batch, max_seqlen]. + 0s are treated as paddings. + segment_pos(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_pos, kv_segment_pos) + - q_segment_pos (jnp.ndarray): + The position inside each segment for query, with shape [batch, max_seqlen]. + - kv_segment_pos (jnp.ndarray): + The position inside each segment for key, value, with shape [batch, max_seqlen]. + Return: + A SequenceDescriptor with segment_ids/segment_pos initialized. + """ + q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) + + if segment_pos is not None: + segment_pos = cls._expand_to_pair(segment_pos) + else: + + def generate_default_pos(segment_ids): + seqlen = segment_ids.shape[-1] + return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) + + q_seg_pos = generate_default_pos(q_seg_ids) + kv_seg_pos = generate_default_pos(kv_seg_ids) + segment_pos = (q_seg_pos, kv_seg_pos) + + return cls( + segment_ids=(q_seg_ids, kv_seg_ids), + segment_pos=segment_pos, + ) + + +def _legacy_fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], mask: Optional[jnp.ndarray], @@ -296,9 +650,9 @@ def fused_attn( Intra-sequence padding is not valid. The padded tokens can only on the right-most. Otherwise the results will be wrong. seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. @@ -310,28 +664,26 @@ def fused_attn( (jnp.ndarray): The output tensor from the fused attention. """ assert ( - get_qkv_format(qkv_layout) != QKVFormat.THD + not qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format." # Check inputs qkv match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD: + case QKVLayout.BS3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + case QKVLayout.BSHD_BS2HD: assert ( len(qkv) == 2 ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + case QKVLayout.BSHD_BSHD_BSHD: assert ( len(qkv) == 3 ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + case _: + raise ValueError(f"Unknown {qkv_layout=}") # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - ]: + if not attn_mask_type.is_padding(): batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -348,10 +700,7 @@ def fused_attn( output = _fused_attn( qkv, bias, - q_seq_lens, - kv_seq_lens, - None, - None, + SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)), seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -390,79 +739,31 @@ def fused_attn_thd( context_parallel_axis: str = "", ): """ - (Experimental) Perform THD (packed) cuDNN fused attention. - - This function implements the following formula: - BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 - Args: - qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors. - It supports three formats: - - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key, - and value have the same shape (e.g., self-attention). - - `(query, kv_packed)`: For separate query and KV packed format, typically used when - query has a different shape (e.g., cross-attention). - - `(query, key, value)`: For separate query, key, and value tensors. - bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. - q_seqlen (jnp.ndarray): - Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are - padded with -1. - kv_seqlen (jnp.ndarray): - Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions - are padded with -1. - q_seq_offsets (jnp.ndarray): - The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. - Unused positions are padded with -1. - kv_seq_offsets (jnp.ndarray): - The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. - Unused positions are padded with -1. - seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. - scaling_factor (float): Scaling factor for the attention scores. - dropout_probability (float): Dropout probability to apply during attention. - is_training (bool): Flag indicating whether the model is in training mode. - max_segments_per_seq (int): - Indicating the maximum number of segments inside a sequence. This parameter is to - constrain the limit usage and need to be static during the e2e training. The XLA compile - time and memory consumption is proportional to `max_segments_per_seq`. - window_size (Optional[Tuple[int, int]]): - Sliding window size. - context_parallel_causal_load_balanced (bool): - Indicates the sequences are ordered for causal mask load balancing when running context parallelism. - context_parallel_axis (str): The name of the context parallel axis. - Returns: - (jnp.ndarray): The output tensor from the fused attention. - - Examples: - >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens - >>> b, s, h, d = 2, 4, 12, 64 - >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) - >>> # 3 segments in first seq, 2 segments in second seq - >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]]) - >>> # seq_offsets need to include the end offset of the last segments - >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]]) - >>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens, - q_seq_offsets, kv_seq_offsets, None, - AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, - QKVLayout.T3HD, 0.125, 0, True, 3) + Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor """ + warnings.warn( + "fused_attn_thd is deprecated, please use fused_attn with SequenceDescriptor", + DeprecationWarning, + ) + assert ( - get_qkv_format(qkv_layout) == QKVFormat.THD + qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." # Check inputs qkv match qkv_layout: - case NVTE_QKV_Layout.NVTE_T3HD: + case QKVLayout.T3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_THD_T2HD: + case QKVLayout.THD_T2HD: assert ( len(qkv) == 2 ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - case NVTE_QKV_Layout.NVTE_THD_THD_THD: + case QKVLayout.THD_THD_THD: assert ( len(qkv) == 3 ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + case _: + raise ValueError(f"Unknown {qkv_layout=}") batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) assert q_seq_lens.shape == (batch, q_max_seqlen) @@ -473,10 +774,9 @@ def fused_attn_thd( output = _fused_attn( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + SequenceDescriptor.from_seqlens_and_offsets( + (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets) + ), seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -494,15 +794,12 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - q_seq_lens: jnp.ndarray, - kv_seq_lens: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], - seed: jnp.ndarray, + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, qkv_layout: QKVLayout, @@ -518,10 +815,7 @@ def _fused_attn( output, _ = _fused_attn_fwd_rule( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type, attn_mask_type, @@ -541,10 +835,7 @@ def _fused_attn( def _fused_attn_fwd_rule( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type, attn_mask_type, @@ -561,14 +852,11 @@ def _fused_attn_fwd_rule( output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - qkv_layout=qkv_layout.value, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training, @@ -584,10 +872,7 @@ def _fused_attn_fwd_rule( return output, ( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, softmax_aux, rng_state, output, @@ -612,10 +897,7 @@ def _fused_attn_bwd_rule( ( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, softmax_aux, rng_state, output, @@ -627,13 +909,10 @@ def _fused_attn_bwd_rule( rng_state, output, dz, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - qkv_layout=qkv_layout.value, + sequence_descriptor, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training, @@ -645,7 +924,137 @@ def _fused_attn_bwd_rule( ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None - return grad_qkv, grad_bias, None, None, None, None, None + return ( + grad_qkv, + grad_bias, + None, + None, + ) _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) + + +def fused_attn( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int = 1, + window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", +): + """ + Perform cuDNN fused attention. + + This function implements the following formula: + BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 + Args: + qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors. + It supports three formats: + - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key, + and value have the same shape (e.g., self-attention). + - `(query, kv_packed)`: For separate query and KV packed format, typically used when + query has a different shape (e.g., cross-attention). + - `(query, key, value)`: For separate query, key, and value tensors. + bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. + sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence. + seed (Optional[jnp.ndarray]): Optional random seed for dropout. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. + scaling_factor (float): Scaling factor for the attention scores. + dropout_probability (float): Dropout probability to apply during attention. + is_training (bool): Flag indicating whether the model is in training mode. + max_segments_per_seq (int): + Indicating the maximum number of segments inside a sequence. This parameter is to + constrain the limit usage and need to be static during the e2e training. The XLA compile + time and memory consumption is proportional to `max_segments_per_seq`. + window_size (Optional[Tuple[int, int]]): + Sliding window size. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. + Returns: + (jnp.ndarray): The output tensor from the fused attention. + + Examples (non-THD, also known as non-packed): + >>> # q_segment_ids = [[1, 1, 1, 0], [1, 1, 0, 0]], 0 means padded tokens + >>> # kv_segment_ids = [[1, 0, 0, 0], [1, 1, 0, 0]], 0 means padded tokens + >>> b, s, h, d = 2, 4, 12, 64 + >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) + >>> q_seq_lens = jnp.asarray([3, 2]) + >>> kv_seq_lens = jnp.asarray([1, 2]) + >>> sequence_desc = SequenceDescriptor.from_seqlens( + seqlens=(q_seq_lens, kv_seq_lens)) + >>> out = fused_attn((qkv,), None, sequence_desc, None, + AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, + QKVLayout.BS3HD, 0.125, 0, True, 3) + + Examples (THD, also known as packed): + >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens + >>> # segment_pos = [[0, 1, 0, 0], [0, 1, 0, 1]] + >>> b, s, h, d = 2, 4, 12, 64 + >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) + >>> # 3 segments in first seq, 2 segments in second seq + >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]]) + >>> # seq_offsets need to include the end offset of the last segments + >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]]) + >>> sequence_desc = SequenceDescriptor.from_seqlens_and_offsets( + seqlens=(q_seq_lens, kv_seq_lens), + seq_offsets=(q_seq_offsets, kv_seq_offsets)) + >>> out = fused_attn((qkv,), None, sequence_desc, None, + AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, + QKVLayout.T3HD, 0.125, 0, True, 3) + """ + if sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray): + warnings.warn( + "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " + + "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.", + DeprecationWarning, + ) + if max_segments_per_seq != 1: + raise ValueError("Passing mask is only supported for non-THD case.") + return _legacy_fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + ) + output = _fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + ) + + return output diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index 579daa8e41..dfb68c113c 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Python interface for c++ extensions""" diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 44b396ad55..c9c40de7e3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1,20 +1,20 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for activation""" from typing import Tuple, Sequence, Union, Callable import operator from functools import reduce, partial +from packaging import version import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import NVTE_Activation_Type +import transformer_engine_jax +from transformer_engine_jax import NVTE_Activation_Type from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper @@ -28,6 +28,11 @@ from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = ["act_lu", "dact_lu", "act_lu_fp8"] @@ -98,7 +103,7 @@ def abstract(x_aval, *, act_enum): # pylint: disable=unused-argument assert x_shape[-2] == 2 or x_shape[-2] == 1 hidden_size = x_shape[-1] batch_shapes = x_shape[:-2] - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval out_shape = (batch_shapes) + (hidden_size,) out_aval = out_aval.update(shape=out_shape, dtype=dtype) @@ -225,7 +230,7 @@ def abstract(dz_aval, x_aval, *, act_enum): # pylint: disable=unused-argument i_hidden_size = dz_aval.shape[-1] g_hidden_size = x_aval.shape[-1] assert i_hidden_size == g_hidden_size - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval return out_aval diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6591861057..47425fe6d5 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1,13 +1,14 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for attention""" -from dataclasses import dataclass -from functools import partial, reduce, cache import operator import os -from typing import Optional, Tuple import warnings +from dataclasses import dataclass, replace +from functools import partial, reduce +from typing import Optional, Tuple +from packaging import version import jax import jax.numpy as jnp @@ -15,19 +16,18 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi - -from transformer_engine.jax.attention import CPStrategy - -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import ( - NVTE_Bias_Type, - NVTE_Mask_Type, - NVTE_QKV_Layout, - NVTE_QKV_Format, - NVTE_Fused_Attn_Backend, - nvte_get_qkv_format, + +import transformer_engine_jax +from transformer_engine_jax import NVTE_Fused_Attn_Backend +from transformer_engine.jax.attention import ( + AttnBiasType, + AttnMaskType, + QKVLayout, + QKVFormat, + CPStrategy, + SequenceDescriptor, ) + from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper from .misc import ( @@ -37,7 +37,6 @@ get_padded_spec, get_cudnn_version, is_ffi_enabled, - get_xla_flag, ) from ..sharding import ( global_mesh_resource, @@ -50,6 +49,12 @@ ) +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + + __all__ = [ "FusedAttnHelper", "fused_attn_fwd", @@ -79,9 +84,9 @@ class _FusedAttnConfig: Passes static configuration properties of fused attention. """ - attn_bias_type: NVTE_Bias_Type - attn_mask_type: NVTE_Mask_Type - qkv_layout: NVTE_QKV_Layout + attn_bias_type: AttnBiasType + attn_mask_type: AttnMaskType + qkv_layout: QKVLayout scaling_factor: float dropout_probability: float is_training: bool @@ -99,9 +104,9 @@ class FusedAttnHelper: q_dtype: jnp.dtype kv_dtype: jnp.dtype - qkv_layout: NVTE_QKV_Layout - attn_bias_type: NVTE_Bias_Type - attn_mask_type: NVTE_Mask_Type + qkv_layout: QKVLayout + attn_bias_type: AttnBiasType + attn_mask_type: AttnMaskType dropout_probability: float q_num_heads: int kv_num_heads: int @@ -119,9 +124,9 @@ def get_fused_attn_backend(self): return transformer_engine_jax.get_fused_attn_backend( jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), - self.qkv_layout, - self.attn_bias_type, - self.attn_mask_type, + self.qkv_layout.value, + self.attn_bias_type.value, + self.attn_mask_type.value, self.dropout_probability, self.q_num_heads, self.kv_num_heads, @@ -133,7 +138,6 @@ def get_fused_attn_backend(self): ) @staticmethod - @cache def is_non_deterministic_allowed(): """Check if non-deterministic kernels are allowed""" return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) @@ -141,24 +145,25 @@ def is_non_deterministic_allowed(): @staticmethod def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): """Parse qkv aval""" - match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape - kv_batch_shape = q_batch_shape - kv_max_seqlen = q_max_seqlen - num_gqa_groups = attn_heads - kv_head_dim = q_head_dim - assert nqkv == 3 - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape - assert nkv == 2 - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape - assert k_aval.shape == v_aval.shape - case _: - raise ValueError(f"Unexpected {qkv_layout=}") + if qkv_layout.get_qkv_format() == QKVFormat.SBHD: + raise NotImplementedError + if qkv_layout.is_qkvpacked(): + *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape + kv_batch_shape = q_batch_shape + kv_max_seqlen = q_max_seqlen + num_gqa_groups = attn_heads + kv_head_dim = q_head_dim + assert nqkv == 3 + elif qkv_layout.is_kvpacked(): + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape + assert nkv == 2 + elif qkv_layout.is_separate(): + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape + assert k_aval.shape == v_aval.shape, f"{k_aval.shape=} {v_aval.shape=}" + else: + raise ValueError(f"Unexpected {qkv_layout=}") assert q_batch_shape == kv_batch_shape assert q_head_dim == kv_head_dim assert q_aval.dtype == k_aval.dtype == v_aval.dtype @@ -212,9 +217,8 @@ def generate_cu_seqlen(actual_seqlen): """ Generating cumsum seqlen for a batch """ - cu_seqlen = jnp.cumsum(actual_seqlen, axis=-1) - cu_seqlen = jnp.where(actual_seqlen < 0, -1, cu_seqlen) - cu_seqlen = jnp.insert(cu_seqlen, 0, values=0, axis=-1) + actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen) + cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True) return cu_seqlen @@ -225,7 +229,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (9,) + impl_static_args = (13,) inner_primitive = None outer_primitive = None @@ -235,11 +239,15 @@ def abstract( k_aval, v_aval, bias_aval, + seed_aval, q_seqlen_or_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval, _q_seq_offsets, _k_seq_offsets, - seed_aval, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config: _FusedAttnConfig, ): @@ -250,8 +258,13 @@ def abstract( k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + assert ( + q_dtype == k_dtype == v_dtype == bias_dtype + ), f"q_dtype={q_dtype}, k_dtype={k_dtype}, v_dtype={v_dtype}, bias_dtype={bias_dtype}" + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype, ( + f"q_seqlen_or_cu_seqlen_aval={q_seqlen_or_cu_seqlen_aval}," + f" kv_seqlen_or_cu_seqlen_aval={kv_seqlen_or_cu_seqlen_aval}" + ) batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = ( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) @@ -303,7 +316,7 @@ def abstract( rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -323,9 +336,9 @@ def abstract( head_dim, config.scaling_factor, config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.qkv_layout.value, jax_dtype_to_te_dtype(q_aval.dtype), config.is_training, config.max_segments_per_seq, @@ -355,11 +368,15 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config: _FusedAttnConfig, ): @@ -374,7 +391,7 @@ def lowering( input_batch = reduce(operator.mul, batch_shape) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -388,11 +405,15 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -404,9 +425,9 @@ def lowering( max_segments_per_seq=config.max_segments_per_seq, scaling_factor=float(config.scaling_factor), dropout_probability=float(config.dropout_probability), - bias_type=int(config.attn_bias_type), - mask_type=int(config.attn_mask_type), - qkv_layout=int(config.qkv_layout), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), is_training=config.is_training, deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=config.window_size[0], @@ -418,11 +439,11 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, ] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ @@ -467,16 +488,36 @@ def impl( k, v, bias, + seed, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config: _FusedAttnConfig, ): assert FusedAttnFwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + + if config.qkv_layout.is_thd(): def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -494,20 +535,11 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match config.qkv_layout: - case NVTE_QKV_Layout.NVTE_T3HD: - kv_max_seqlen = q_max_seqlen = q.shape[-4] - kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_T2HD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-4] - kv_batch = reduce(operator.mul, k.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_THD_THD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-3] - kv_batch = reduce(operator.mul, k.shape[:-3]) + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1, f"Expected len(batch) == 1, but got {len(batch)=}" + kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 # cuDNN version < 9.3.0: @@ -518,6 +550,7 @@ def convert_to_2d(offsets, batch, max_seqlen): fill_value = 0 else: fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) @@ -525,15 +558,17 @@ def convert_to_2d(offsets, batch, max_seqlen): # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + # Gather valid q_seq_offsets, which is greater and equal to 0 # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0) - k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0) - - # Set the unused position to max size (batch * max_seqlen) + # And set the unused position to max size (batch * max_seqlen) # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] - q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets) - k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets) + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) @@ -543,11 +578,15 @@ def convert_to_2d(offsets, batch, max_seqlen): k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) return output, softmax_aux, rng_state @@ -556,7 +595,7 @@ def convert_to_2d(offsets, batch, max_seqlen): def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None - q_bdim, *_, seed_bdim = batch_dims + q_bdim, _, _, _, seed_bdim, *_ = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim return ( @@ -568,29 +607,28 @@ def batcher(batched_args, batch_dims, *, config): def infer_sharding_from_operands(config, mesh, arg_infos, result_infos): del result_infos q_spec = get_padded_spec(arg_infos[0]) - match config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - # q_spec = (...batch, q_seqlen, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) - ) - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - # q_spec = (...batch, q_seqlen, head, hidden) - # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) - ) - case _: - raise ValueError(f"Unsupported {config.qkv_layout=}") + if config.qkv_layout.is_qkvpacked(): + # q_spec = (...batch, q_seqlen, 3, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None) + ) + elif config.qkv_layout.is_kvpacked(): + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) + elif config.qkv_layout.is_separate(): + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], None) + ) + else: + raise ValueError(f"Unsupported {config.qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @@ -601,7 +639,11 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings @@ -617,7 +659,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (12,) + impl_static_args = (16,) inner_primitive = None outer_primitive = None @@ -635,6 +677,10 @@ def abstract( kv_seqlen_or_cu_seqlen_aval, _q_seq_offsets, _k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config, ): @@ -655,7 +701,7 @@ def abstract( FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, config.qkv_layout) ) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -675,9 +721,9 @@ def abstract( head_dim, config.scaling_factor, config.dropout_probability, - config.attn_bias_type, - config.attn_mask_type, - config.qkv_layout, + config.attn_bias_type.value, + config.attn_mask_type.value, + config.qkv_layout.value, jax_dtype_to_te_dtype(q_aval.dtype), config.is_training, deterministic, @@ -719,6 +765,10 @@ def lowering( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, *, config, ): @@ -733,7 +783,7 @@ def lowering( input_batch = reduce(operator.mul, batch_shape) - if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type == AttnBiasType.NO_BIAS: bias_batch = bias_heads = 0 else: *bias_batch_shape, bias_heads, _, _ = bias_aval.shape @@ -755,6 +805,10 @@ def lowering( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -766,9 +820,9 @@ def lowering( max_segments_per_seq=config.max_segments_per_seq, scaling_factor=float(config.scaling_factor), dropout_probability=float(config.dropout_probability), - bias_type=int(config.attn_bias_type), - mask_type=int(config.attn_mask_type), - qkv_layout=int(config.qkv_layout), + bias_type=int(config.attn_bias_type.value), + mask_type=int(config.attn_mask_type.value), + qkv_layout=int(config.qkv_layout.value), is_training=config.is_training, deterministic=not FusedAttnHelper.is_non_deterministic_allowed(), window_size_left=config.window_size[0], @@ -840,11 +894,31 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config, ): assert FusedAttnBwdPrimitive.inner_primitive is not None - if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + + if config.qkv_layout.is_thd(): def _fix_len_take(x, condition, fill_value=-1): x_shape = x.shape @@ -863,20 +937,11 @@ def convert_to_2d(offsets, batch, max_seqlen): ) return offsets_2d - match config.qkv_layout: - case NVTE_QKV_Layout.NVTE_T3HD: - kv_max_seqlen = q_max_seqlen = q.shape[-4] - kv_batch = q_batch = reduce(operator.mul, q.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_T2HD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-4] - kv_batch = reduce(operator.mul, k.shape[:-4]) - case NVTE_QKV_Layout.NVTE_THD_THD_THD: - q_max_seqlen = q.shape[-3] - q_batch = reduce(operator.mul, q.shape[:-3]) - kv_max_seqlen = k.shape[-3] - kv_batch = reduce(operator.mul, k.shape[:-3]) + batch, q_max_seqlen, kv_max_seqlen, *_ = FusedAttnHelper.parse_qkv_aval( + q, k, v, config.qkv_layout + ) + assert len(batch) == 1 + kv_batch = q_batch = batch[0] # Gather valid q_seqlen, which is greater than 0 # cuDNN version < 9.3.0: @@ -894,15 +959,17 @@ def convert_to_2d(offsets, batch, max_seqlen): # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + # Gather valid q_seq_offsets, which is greater and equal to 0 # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0) - k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0) - - # Set the unused position to max size (batch * max_seqlen) + # And set the unused position to max size (batch * max_seqlen) # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] - q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets) - k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets) + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) @@ -920,6 +987,10 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) return dq, dk, dv, dbias @@ -960,7 +1031,10 @@ def partition(config, mesh, arg_infos, result_infos): dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) def sharded_impl( @@ -976,6 +1050,10 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( q, @@ -990,10 +1068,14 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) global_dbias = local_dbias - if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not AttnBiasType.NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias, mesh) return local_dq, local_dk, local_dv, global_dbias @@ -1003,7 +1085,7 @@ def sharded_impl( register_primitive(FusedAttnBwdPrimitive) -def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): +def reorder_causal_dual_chunk_swap(tensor, cp_size: int, seq_dim: int, to_contiguous: bool): """Reorders a tensor for load balancing the compute of causal attention.""" if cp_size == 1: return tensor @@ -1013,7 +1095,7 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu # Need to ensure we have 2 pairs to swap for balancing between cp ranks if tensor.shape[seq_dim] % (cp_size * 2) != 0: - raise ValueError(f"{tensor.shape=} is not a multiple of {cp_size*2=}") + raise ValueError(f"{tensor.shape[seq_dim]=} is not a multiple of {cp_size*2=}") # [B, S, H, D] -> [B, 2*cp_size, S/2*cp_size, D] # [S, B, H, D] -> [2*cp_size, S/2*cp_size, B, H, D] @@ -1055,6 +1137,33 @@ def reorder_causal_load_balancing(tensor, cp_size: int, seq_dim: int, to_contigu return combined.reshape(ori_tensor_shape) +def reorder_causal_striped(tensor, cp_size: int, seq_dim: int, is_inverse: bool): + """Reorders a tensor for load balancing with striped pattern""" + origin_shape = tensor.shape + if origin_shape[seq_dim] % cp_size != 0: + raise ValueError( + "Expected origin_shape[seq_dim] is multiple of cp_size but got" + f" {origin_shape[seq_dim]=} and {cp_size=}" + ) + + if not is_inverse: + new_shape = [ + *origin_shape[:seq_dim], + *[origin_shape[seq_dim] // cp_size, cp_size], + *origin_shape[seq_dim + 1 :], + ] + else: + new_shape = [ + *origin_shape[:seq_dim], + *[cp_size, origin_shape[seq_dim] // cp_size], + *origin_shape[seq_dim + 1 :], + ] + + chunked_tensor = tensor.reshape(new_shape) + reordered_chunked_tensor = jnp.swapaxes(chunked_tensor, seq_dim, seq_dim + 1) + return reordered_chunked_tensor.reshape(origin_shape) + + @dataclass(frozen=True) class _FusedAttnCPWithAllGatherHelper: """Helper class to assist with running the all-gather strategy for CP attention.""" @@ -1066,17 +1175,17 @@ def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused attention" - allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) - if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: + if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") - allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] if self.config.attn_mask_type not in allowed_masks: raise ValueError( f"{header} only supports masking types: " @@ -1094,8 +1203,8 @@ def check_supported(self): def get_adjusted_mask(self): """Converts the mask for context parallelism.""" - if self.config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: - return NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK + if self.config.attn_mask_type == AttnMaskType.CAUSAL_MASK: + return AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK return self.config.attn_mask_type def get_step_config(self) -> _FusedAttnConfig: @@ -1122,14 +1231,13 @@ def ag(x): ) if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=True) + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=True) return x - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return ag(k), v - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return ag(k), ag(v) + if self.config.qkv_layout.is_kvpacked(): + return ag(k), v + if self.config.qkv_layout.is_separate(): + return ag(k), ag(v) return k, v # fall through @@ -1139,7 +1247,7 @@ def reduce_scatter_dkv(self, dk, dv): def rs(x): if self.config.context_parallel_load_balanced: cp_size = get_mesh_axis_size(self.config.cp_axis, self.mesh) - x = reorder_causal_load_balancing(x, cp_size, 1, to_contiguous=False) + x = reorder_causal_dual_chunk_swap(x, cp_size, 1, to_contiguous=False) return lax_paral_op( x, @@ -1150,11 +1258,10 @@ def rs(x): tiled=True, ) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return rs(dk), dv - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return rs(dk), rs(dv) + if self.config.qkv_layout.is_kvpacked(): + return rs(dk), dv + if self.config.qkv_layout.is_separate(): + return rs(dk), rs(dv) return dk, dv # fall through @@ -1191,11 +1298,10 @@ def slice_kv(self, k, v, slice_seq_len): def sliced(x): return lax.dynamic_slice_in_dim(x, 0, slice_seq_len, axis=1) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return sliced(k), v - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return sliced(k), sliced(v) + if self.config.qkv_layout.is_kvpacked(): + return sliced(k), v + if self.config.qkv_layout.is_separate(): + return sliced(k), sliced(v) return k, v # fall through @@ -1205,13 +1311,12 @@ def pad_kv(self, dk, dv, pad_seq_len): def pad(x, npad): return jnp.pad(x, npad, "constant", constant_values=0.0) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] - return pad(dk, npad), dv - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] - return pad(dk, npad), pad(dv, npad) + if self.config.qkv_layout.is_kvpacked(): + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0], [0, 0]] + return pad(dk, npad), dv + if self.config.qkv_layout.is_separate(): + npad = [[0, 0], [0, pad_seq_len], [0, 0], [0, 0]] + return pad(dk, npad), pad(dv, npad) return dk, dv # fall through @@ -1241,10 +1346,26 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) @@ -1267,7 +1388,7 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): results = [] for sub_idx in range(2): - if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked else: k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) @@ -1281,11 +1402,15 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): k_unmasked, v_unmasked, bias, + seed, q_seqlen_for_step, kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(), ) results.append((output, softmax_aux, rng_state)) @@ -1358,13 +1483,31 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. def _cross_attn_bwd( - idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen + idx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): kv_max_seqlen = k.shape[1] kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) @@ -1381,7 +1524,7 @@ def _cross_attn_bwd( results = [] for sub_idx in range(2): - if config.attn_mask_type == NVTE_Mask_Type.NVTE_NO_MASK: + if config.attn_mask_type == AttnMaskType.NO_MASK: k_unmasked, v_unmasked = k, v # full kv used for unmasked else: k_unmasked, v_unmasked = helper.slice_kv(k, v, kv_seqlens_for_rank[sub_idx]) @@ -1403,11 +1546,15 @@ def _cross_attn_bwd( kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(), ) # pad dk/dv to be unsliced shape so we can reduce scatter over all ranks. - if config.attn_mask_type != NVTE_Mask_Type.NVTE_NO_MASK: + if config.attn_mask_type != AttnMaskType.NO_MASK: pad_length = kv_max_seqlen - kv_seqlens_for_rank[sub_idx] dk_local, dv_local = helper.pad_kv(dk_local, dv_local, pad_length) @@ -1434,6 +1581,10 @@ def _cross_attn_bwd( doutput, q_seqlen, kv_seqlen, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ) for idx in range(cp_size) ] @@ -1460,39 +1611,37 @@ class _FusedAttnCPWithP2PHelper: def use_scanloop(): """Returns true if the implementation will use a scan loop for iteration.""" use_scan = bool(int(os.getenv("NVTE_FUSED_RING_ATTENTION_USE_SCAN", "1"))) - - # nvbug(4675071): Disable the HLO verifier for channel ID checks. - # A WAR was added to XLA: https://github.com/openxla/xla/pull/16779 - def truthy(val): - return val.lower() in ["1", "true"] - - x = use_scan and get_xla_flag( - "--xla_experimental_ignore_channel_id", default=False, cast=truthy - ) - return x + return use_scan def check_supported(self): """Checks if the context parallel implementation is supported by the given arguments.""" header = "Context parallel fused ring attention" - allowed_layouts = [NVTE_QKV_Layout.NVTE_BSHD_BS2HD, NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD] + if self.config.qkv_layout.is_thd(): + allowed_layouts = [QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] + else: + allowed_layouts = [QKVLayout.BSHD_BS2HD, QKVLayout.BSHD_BSHD_BSHD] + if self.config.qkv_layout not in allowed_layouts: raise ValueError( f"{header} only supports layouts:" f" {','.join(map(str, allowed_layouts))} got: {self.config.qkv_layout}" ) - if self.config.attn_bias_type != NVTE_Bias_Type.NVTE_NO_BIAS: + if self.config.attn_bias_type != AttnBiasType.NO_BIAS: raise ValueError(f"{header} does not support bias got: {self.config.attn_bias_type}") - allowed_masks = [NVTE_Mask_Type.NVTE_NO_MASK, NVTE_Mask_Type.NVTE_CAUSAL_MASK] + if self.config.qkv_layout.is_thd(): + allowed_masks = [AttnMaskType.PADDING_CAUSAL_MASK] + else: + allowed_masks = [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK] if self.config.attn_mask_type not in allowed_masks: raise ValueError( f"{header} only supports masking types: " f" {','.join(map(str, allowed_masks))} got: {self.config.attn_mask_type}" ) - if self.config.max_segments_per_seq != 1: + if not self.config.qkv_layout.is_thd() and self.config.max_segments_per_seq != 1: raise ValueError( f"{header} only supports max_segments_per_seq == 1 got:" f" {self.config.max_segments_per_seq}" @@ -1507,8 +1656,7 @@ def check_supported(self): if not self.use_scanloop(): warnings.warn( "Scan loop is disabled for fused ring attention. To enable set" - " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment and" - " add --xla_experimental_ignore_channel_id=true to XLA_FLAGS." + " NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 in your environment" ) def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: @@ -1516,7 +1664,7 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: return _FusedAttnConfig( attn_bias_type=self.config.attn_bias_type, attn_mask_type=attn_mask_type, - qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + qkv_layout=QKVLayout.BSHD_BS2HD, scaling_factor=self.config.scaling_factor, dropout_probability=self.config.dropout_probability, is_training=self.config.is_training, @@ -1529,33 +1677,38 @@ def get_step_config(self, attn_mask_type) -> _FusedAttnConfig: def stack_kv(self, k, v): """Stacks k and v tensors if not stacked.""" _not_used = jnp.zeros(0, dtype=k.dtype) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return k - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return jnp.stack([k, v], axis=2) + if self.config.qkv_layout.is_kvpacked(): + return k + if self.config.qkv_layout.is_separate(): + return jnp.stack([k, v], axis=2) return _not_used def unstack_kv(self, kv): """Un-stacks k and v tensors if not stacked.""" _not_used = jnp.zeros(0, dtype=kv.dtype) - match self.config.qkv_layout: - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: - return kv, _not_used - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: - return jnp.unstack(kv, axis=2) + if self.config.qkv_layout.is_kvpacked(): + return kv, _not_used + if self.config.qkv_layout.is_separate(): + return jnp.unstack(kv, axis=2) return _not_used, _not_used # fall through def permute_kv(self, kv, cp_perm): """Permutes kv around the ring as described by cp_perm.""" return lax_paral_op(kv, lax.ppermute, self.config.cp_axis, mesh=self.mesh, perm=cp_perm) - def correct_softmax_aux(self, softmax_aux, softmax_aux_per_step): - """Apply soft max correction after an attention step.""" - max_scale = jnp.maximum(softmax_aux, softmax_aux_per_step) - min_scale = jnp.minimum(softmax_aux, softmax_aux_per_step) - new_softmax_aux = max_scale + jnp.log(1 + jnp.exp(min_scale - max_scale)) - return new_softmax_aux + @staticmethod + def correct_output_and_softmax_aux(output, softmax_aux, partial_output, partial_softmax_aux): + """ + Corrects the output and softmax_aux tensor after each iteration of ring attention. + + See https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 for + derivation of this equation. + """ + new_out = output - jax.nn.sigmoid(partial_softmax_aux - softmax_aux).transpose( + 0, 2, 1, 3 + ) * (output - partial_output) + new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - partial_softmax_aux) + return new_out, new_aux def adjust_seqlen(self, seqlen, max_seqlen, idx): """Adjust the sequence length per step.""" @@ -1589,7 +1742,9 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) def ring_attn_fwd_impl( @@ -1597,11 +1752,15 @@ def ring_attn_fwd_impl( k, v, bias, + seed, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): _not_used = jnp.zeros(0, dtype=v.dtype) @@ -1616,10 +1775,7 @@ def ring_attn_fwd_impl( cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] - output_per_steps = jnp.zeros((cp_size, *q.shape), dtype=q.dtype) - softmax_aux_per_steps = jnp.zeros( - (cp_size, batch, head, q_max_seqlen, 1), dtype=jnp.float32 - ) + output = jnp.zeros(q.shape).astype(jnp.float32) softmax_aux = jnp.full((batch, head, q_max_seqlen, 1), -jnp.inf, dtype=jnp.float32) # RNG shape should be the shared shape. This is unused for ring attention as we do not @@ -1628,7 +1784,7 @@ def ring_attn_fwd_impl( rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): - kv, softmax_aux, output_per_steps, softmax_aux_per_steps = carry + kv, output, softmax_aux = carry # Send KV block to next step so we can overlap compute. kv_next = helper.permute_kv(kv, cp_perm) @@ -1641,17 +1797,21 @@ def mask_compute(attn_mask_type): kv, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, - helper.get_step_config(attn_mask_type), + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(attn_mask_type), ) return output_per_step, softmax_aux_per_step - causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) - no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) + causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK) + no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK) def half_kv_no_mask_compute(): q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) @@ -1662,12 +1822,16 @@ def half_kv_no_mask_compute(): kv_part, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(AttnMaskType.NO_MASK), ) return output_per_step, softmax_aux_per_step @@ -1680,12 +1844,16 @@ def half_q_no_mask_compute(): kv, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(AttnMaskType.NO_MASK), ) output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1) softmax_aux_per_step = jnp.concat( @@ -1704,7 +1872,7 @@ def skip_compute(): ) return output_per_step, softmax_aux_per_step - if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + if config.attn_mask_type == AttnMaskType.CAUSAL_MASK: # This is for nested jax.lax.cond def jax_cond_wrap(): if config.context_parallel_load_balanced: @@ -1719,25 +1887,38 @@ def jax_cond_wrap(): else: output_per_step, softmax_aux_per_step = no_mask_compute() - softmax_aux = helper.correct_softmax_aux(softmax_aux, softmax_aux_per_step) - output_per_steps = output_per_steps.at[idx].set(output_per_step) - softmax_aux_per_steps = softmax_aux_per_steps.at[idx].set(softmax_aux_per_step) + def skip_correction(output, softmax_aux, output_per_step, softmax_aux_per_step): + # No correction done here but we cast outputs to float32 and perform reduction + # in full precision. + # pylint: disable=unused-argument + return output_per_step.astype(jnp.float32), softmax_aux_per_step - return (kv_next, softmax_aux, output_per_steps, softmax_aux_per_steps) + def correction(output, softmax_aux, output_per_step, softmax_aux_per_step): + return helper.correct_output_and_softmax_aux( + output, softmax_aux, output_per_step, softmax_aux_per_step + ) + + # first step there is no correction we get initial output and stats + output, softmax_aux = lax.cond( + (idx == 0), + skip_correction, + correction, + output, + softmax_aux, + output_per_step, + softmax_aux_per_step, + ) - carry = (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) + return (kv_next, output, softmax_aux) + + carry = (kv, output, softmax_aux) if helper.use_scanloop(): carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) else: for i in range(0, cp_size): carry = scan_kv_block(i, carry) - (kv, softmax_aux, output_per_steps, softmax_aux_per_steps) = carry + (kv, output, softmax_aux) = carry - output = jnp.zeros(q.shape).astype(jnp.float32) - for idx in range(cp_size): - output = output + output_per_steps[idx].astype(jnp.float32) * jnp.exp( - softmax_aux_per_steps[idx] - softmax_aux - ).transpose(0, 2, 1, 3) output = output.astype(q.dtype) return output, softmax_aux, rng_state @@ -1789,6 +1970,10 @@ def ring_attn_bwd_impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): _not_used = jnp.zeros(0, dtype=output.dtype) @@ -1833,12 +2018,16 @@ def mask_compute(attn_mask_type): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(attn_mask_type), ) return dq_per_step, dk_dv_per_step, dbias_per_step - causal_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_CAUSAL_MASK) - no_mask_compute = partial(mask_compute, NVTE_Mask_Type.NVTE_NO_MASK) + causal_mask_compute = partial(mask_compute, AttnMaskType.CAUSAL_MASK) + no_mask_compute = partial(mask_compute, AttnMaskType.NO_MASK) def half_kv_no_mask_compute(): q_seqlen_per_step = helper.adjust_seqlen(q_seqlen, q_max_seqlen, idx) @@ -1857,7 +2046,11 @@ def half_kv_no_mask_compute(): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(AttnMaskType.NO_MASK), ) dk_dv_per_step = jnp.concat( [dk_dv_per_step, jnp.zeros_like(dk_dv_per_step)], axis=1 @@ -1891,7 +2084,11 @@ def half_q_no_mask_compute(): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(AttnMaskType.NO_MASK), ) dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1) return dq_per_step, dk_dv_per_step, dbias_per_step @@ -1899,7 +2096,7 @@ def half_q_no_mask_compute(): def skip_compute(): return jnp.zeros_like(q), jnp.zeros_like(kv), jnp.zeros_like(bias) - if config.attn_mask_type == NVTE_Mask_Type.NVTE_CAUSAL_MASK: + if config.attn_mask_type == AttnMaskType.CAUSAL_MASK: # This is for nested jax.lax.cond def jax_cond_wrap(): if config.context_parallel_load_balanced: @@ -1917,7 +2114,7 @@ def jax_cond_wrap(): kv_next, dk_dv = jnp.unstack(kv_dk_dv) dq = dq + dq_per_step dk_dv = dk_dv + dk_dv_per_step - if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not AttnBiasType.NO_BIAS: dbias = dbias + dbias_per_step return (kv_next, dq, dk_dv, dbias) @@ -1934,7 +2131,7 @@ def jax_cond_wrap(): dk_dv = helper.permute_kv(dk_dv, cp_perm) global_dbias = dbias - if config.attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: + if config.attn_bias_type is not AttnBiasType.NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) dk, dv = helper.unstack_kv(dk_dv) @@ -1946,6 +2143,271 @@ def jax_cond_wrap(): register_primitive(FusedRingAttnBwdPrimitive) +class FusedRingAttnStripedFwdPrimitive(FusedAttnFwdPrimitive): + """ + Fused Striped Ring Attention Forward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnFwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding( + mesh, PartitionSpec(get_all_mesh_axes(), None) + ) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) + out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) + + def fwd_impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + ): + if q_segment_ids.size == 0 or kv_segment_ids.size == 0: + raise ValueError("THD + ring attn only supports passing seqment_ids/pos") + + _not_used = jnp.zeros(0, dtype=v.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + if not config.qkv_layout.is_qkvpacked(): + subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked()) + else: + subblock_config = config + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + batch, q_max_seqlen, head, _ = q.shape + output = jnp.zeros(q.shape).astype(jnp.float32) + softmax_aux = jnp.zeros((batch, q_max_seqlen, head, 1), dtype=jnp.float32) + + # RNG shape should be the shared shape. This is unused for ring attention as we do not + # support dropout currently. + rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) + + def scan_kv_block(idx, carry): + kv, kv_segment_ids, kv_segment_pos, output, softmax_aux = carry + + # TODO(rewang): To check whether we need special handle for the last idx + # Send KV block to next step so we can overlap compute. + kv_next = helper.permute_kv(kv, cp_perm) + kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) + kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) + + output_per_step, softmax_aux_per_step, _ = FusedAttnFwdPrimitive.impl( + q, + kv, + _not_used, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + subblock_config, + ) + + # TODO(rewang): THD softmax_aux layout is acutally [B, S, H] + softmax_aux_per_step = softmax_aux_per_step.reshape((batch, q_max_seqlen, head, 1)) + + def skip_correction(_output, _softmax_aux, output_per_step, softmax_aux_per_step): + # No correction done here but we cast outputs to float32 and perform reduction + # in full precision. + return output_per_step.astype(jnp.float32), softmax_aux_per_step + + def correction(output, softmax_aux, output_per_step, softmax_aux_per_step): + new_out = output - jax.nn.sigmoid(softmax_aux_per_step - softmax_aux) * ( + output - output_per_step + ) + new_aux = softmax_aux - jax.nn.log_sigmoid(softmax_aux - softmax_aux_per_step) + return new_out, new_aux + + # first step there is no correction we get initial output and stats + output, softmax_aux = lax.cond( + idx == 0, + skip_correction, + correction, + output, + softmax_aux, + output_per_step, + softmax_aux_per_step, + ) + + return (kv_next, kv_segment_ids_next, kv_segment_pos_next, output, softmax_aux) + + carry = (kv, kv_segment_ids, kv_segment_pos, output, softmax_aux) + if helper.use_scanloop(): + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for i in range(0, cp_size): + carry = scan_kv_block(i, carry) + (_, _, _, output, softmax_aux) = carry + + softmax_aux = softmax_aux.reshape((batch, head, q_max_seqlen, 1)) + + return output.astype(q.dtype), softmax_aux, rng_state + + return mesh, fwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnStripedFwdPrimitive) + + +class FusedRingAttnStripedBwdPrimitive(FusedAttnBwdPrimitive): + """ + Fused Striped Ring Attention Backward Primitive + """ + + @staticmethod + def partition(config, mesh, arg_infos, result_infos): + is_context_parallel = get_mesh_axis_size(config.cp_axis, mesh) > 1 + assert ( + not is_context_parallel or config.window_size[0] == -1 + ), "Sliding window attention is not supported when context parallelism is enabled" + if not is_context_parallel: + return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) + + arg_shardings = tuple(arg.sharding for arg in arg_infos) + # dq, dk, dv, dbias sharding = q, k, v, bias sharding + out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) + + helper = _FusedAttnCPWithP2PHelper(mesh, config) + helper.check_supported() + + def bwd_impl( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + ): + + if q_segment_ids.size == 0 or kv_segment_ids.size == 0: + raise ValueError("THD + ring attn only supports passing seqment_ids/pos") + + _not_used = jnp.zeros(0, dtype=output.dtype) + + # Combine KV tensors if separate for better permute scheduling and performance. + # Eventually XLA should perform this automatically. + kv = helper.stack_kv(k, v) + if not config.qkv_layout.is_qkvpacked(): + subblock_config = replace(config, qkv_layout=config.qkv_layout.to_kvpacked()) + else: + subblock_config = config + + cp_size = get_mesh_axis_size(config.cp_axis, mesh) + cp_perm = [(i, (i + 1) % cp_size) for i in range(cp_size)] + + dq = jnp.zeros_like(q) + dkv = jnp.zeros_like(kv) + dbias = jnp.zeros_like(bias) + + def scan_kv_block(_idx, carry): + kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias = carry + + # Start communication that feeds the next iteration. + # We further combine the tensors to improve overlap. + kv_dkv = jnp.stack([kv, dkv]) + kv_dkv = helper.permute_kv(kv_dkv, cp_perm) + kv_segment_ids_next = helper.permute_kv(kv_segment_ids, cp_perm) + kv_segment_pos_next = helper.permute_kv(kv_segment_pos, cp_perm) + + def compute(): + dq_per_step, dkv_per_step, _, dbias_per_step = FusedAttnBwdPrimitive.impl( + q, + kv, + _not_used, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + config=subblock_config, + ) + return dq_per_step, dkv_per_step, dbias_per_step + + dq_per_step, dkv_per_step, dbias_per_step = compute() + + kv_next, dkv = jnp.unstack(kv_dkv) + dq += dq_per_step + dkv += dkv_per_step + if config.attn_bias_type is not AttnBiasType.NO_BIAS: + dbias = dbias + dbias_per_step + + return (kv_next, kv_segment_ids_next, kv_segment_pos_next, dq, dkv, dbias) + + carry = (kv, kv_segment_ids, kv_segment_pos, dq, dkv, dbias) + if helper.use_scanloop(): + carry = lax.fori_loop(0, cp_size, scan_kv_block, carry) + else: + for idx in range(cp_size): + carry = scan_kv_block(idx, carry) + (_, _, _, dq, dkv, dbias) = carry + + # Final permute to put gradients back to their final resting place. + dkv = helper.permute_kv(dkv, cp_perm) + + global_dbias = dbias + if config.attn_bias_type is not AttnBiasType.NO_BIAS: + global_dbias = all_reduce_sum_along_dp_fsdp(dbias, mesh) + + dk, dv = helper.unstack_kv(dkv) + return dq, dk, dv, global_dbias + + return mesh, bwd_impl, out_shardings, arg_shardings + + +register_primitive(FusedRingAttnStripedBwdPrimitive) + + def _maybe_context_parallel_axis(cp_axis: str): if not cp_axis: gmr = global_mesh_resource() @@ -1959,14 +2421,11 @@ def _maybe_context_parallel_axis(cp_axis: str): def fused_attn_fwd( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, seed: Optional[jnp.ndarray], - attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, - qkv_layout: NVTE_QKV_Layout, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, scaling_factor: float, dropout_probability: float, is_training: bool, @@ -1997,9 +2456,9 @@ def fused_attn_fwd( kv_seq_offsets (jnp.ndarray): The offsets in the sequence dim for the query, with shape [batch + 1,]. seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. @@ -2015,30 +2474,26 @@ def fused_attn_fwd( (jnp.ndarray): The output tensor from the fused attention. """ seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training) - - assert (q_seq_offsets is None) == ( - kv_seq_offsets is None - ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values." - is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD - # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) - match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - assert ( - len(qkv) == 2 - ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - assert ( - len(qkv) == 3 - ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = qkv - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + + if qkv_layout.is_qkvpacked(): + assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used, _not_used] + elif qkv_layout.is_kvpacked(): + assert ( + len(qkv) == 2 + ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used] + elif qkv_layout.is_separate(): + assert ( + len(qkv) == 3 + ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = qkv + else: + raise ValueError(f"Unknown {qkv_layout=}") + + if attn_bias_type == AttnBiasType.NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) @@ -2055,21 +2510,23 @@ def fused_attn_fwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - primative = None + primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive + primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive case CPStrategy.RING: - primative = FusedRingAttnFwdPrimitive.outer_primitive + # We must use stripe attention for THD-RING + if qkv_layout.is_thd(): + primitive = FusedRingAttnStripedFwdPrimitive.outer_primitive + else: + primitive = FusedRingAttnFwdPrimitive.outer_primitive - return primative.bind( + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + return primitive.bind( *qkv_for_primitive, bias, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, seed, + *seq_desc_flatten, config=fused_config, ) @@ -2081,13 +2538,10 @@ def fused_attn_bwd( rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, - q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], - attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, - qkv_layout: NVTE_QKV_Layout, + sequence_descriptor: SequenceDescriptor, + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, scaling_factor: float, dropout_probability: float, is_training: bool, @@ -2119,9 +2573,9 @@ def fused_attn_bwd( The offsets in the sequence dim for the query, with shape [batch + 1,]. kv_seq_offsets (jnp.ndarray): The offsets in the sequence dim for the query, with shape [batch + 1,]. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + attn_bias_type (AttnBiasType): Type of attention bias. + attn_mask_type (AttnMaskType): Type of attention mask. + qkv_layout (QKVLayout): Layout of the QKV tensors. scaling_factor (float): Scaling factor for the attention scores. dropout_probability (float): Dropout probability to apply during attention. is_training (bool): Flag indicating whether the model is in training mode. @@ -2139,31 +2593,26 @@ def fused_attn_bwd( same format as the input `qkv`. - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`. """ - - assert (q_seq_offsets is None) == ( - kv_seq_offsets is None - ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values." - is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD - # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) - match qkv_layout: - case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: - assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BS2HD | NVTE_QKV_Layout.NVTE_THD_T2HD: - assert ( - len(qkv) == 2 - ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = [*qkv, _not_used] - case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD | NVTE_QKV_Layout.NVTE_THD_THD_THD: - assert ( - len(qkv) == 3 - ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" - qkv_for_primitive = qkv - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + if qkv_layout.is_qkvpacked(): + assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used, _not_used] + elif qkv_layout.is_kvpacked(): + assert ( + len(qkv) == 2 + ), f"qkv=(query, packed_kv) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = [*qkv, _not_used] + elif qkv_layout.is_separate(): + assert ( + len(qkv) == 3 + ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" + qkv_for_primitive = qkv + else: + raise ValueError(f"Unknown {qkv_layout=}") + + if attn_bias_type == AttnBiasType.NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) @@ -2180,24 +2629,25 @@ def fused_attn_bwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - primative = None + primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive + primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive case CPStrategy.RING: - primative = FusedRingAttnBwdPrimitive.outer_primitive + if qkv_layout.is_thd(): + primitive = FusedRingAttnStripedBwdPrimitive.outer_primitive + else: + primitive = FusedRingAttnBwdPrimitive.outer_primitive - *qkv_grads, bias_grad = primative.bind( + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + *qkv_grads, bias_grad = primitive.bind( *qkv_for_primitive, bias, softmax_aux, rng_state, output, doutput, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, + *seq_desc_flatten, config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 3d88c1f078..1f148c86ab 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE base custom ops""" @@ -7,7 +7,7 @@ from abc import ABCMeta, abstractmethod from functools import partial -from jax import core +from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 1075030a0d..66b5e1c923 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -1,17 +1,22 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom call""" from dataclasses import dataclass from enum import IntEnum +from packaging import version +import jax from jax.interpreters import mlir -import jax.extend as jex - -from transformer_engine import transformer_engine_jax +import transformer_engine_jax from .misc import is_ffi_enabled +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + try: from jaxlib.hlo_helpers import custom_call except ImportError: @@ -30,11 +35,11 @@ class CustomCallAPIVersion(IntEnum): for _name, _value in transformer_engine_jax.registrations().items(): if _name.endswith("_ffi"): if is_ffi_enabled(): - jex.ffi.register_ffi_target( + ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value ) else: - jex.ffi.register_ffi_target( + ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 1f13484b98..4f65a2c3c7 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE miscellaneous for custom ops""" @@ -15,8 +15,8 @@ from jax import dtypes from jax.interpreters.mlir import dtype_to_ir_type -from transformer_engine.transformer_engine_jax import DType as TEDType -from transformer_engine import transformer_engine_jax +from transformer_engine_jax import DType as TEDType +import transformer_engine_jax from ..sharding import get_padded_spec as te_get_padded_spec diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index fd6cc09de9..ed8f5dde7a 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1,22 +1,21 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for normalization""" -from functools import partial, reduce, cache import operator import os import warnings +from functools import partial, reduce, cache +from packaging import version import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import DType as TEDType +import transformer_engine_jax from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper @@ -31,6 +30,11 @@ from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = [ "layernorm_fwd", @@ -75,14 +79,14 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): mu_rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) assert gamma_aval.size == beta_aval.size hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -96,18 +100,15 @@ def abstract(x_aval, gamma_aval, beta_aval, **kwargs): wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = out_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, mu_aval, rsigma_aval, wkspace_aval, barrier_aval + return out_aval, mu_aval, rsigma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ LayerNorm fwd outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, _, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) + out_aval, mu_aval, rsigma_aval, _ = LayerNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, mu_aval, rsigma_aval @staticmethod @@ -151,7 +152,7 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, output_type), @@ -160,9 +161,6 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, beta] operand_shapes = [x_shape, g_shape, b_shape] @@ -174,15 +172,9 @@ def lowering(ctx, x, gamma, beta, *, zero_centered_gamma, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -198,7 +190,7 @@ def impl(x, gamma, beta, zero_centered_gamma, epsilon): to describe implementation """ assert LayerNormFwdPrimitive.inner_primitive is not None - out, mu, rsigma, _, _ = LayerNormFwdPrimitive.inner_primitive.bind( + out, mu, rsigma, _ = LayerNormFwdPrimitive.inner_primitive.bind( x, gamma, beta, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) return out, mu, rsigma @@ -374,42 +366,28 @@ def abstract(dz_aval, x_aval, mu_aval, rsigma_aval, gamma_aval, **kwargs): assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_dtype == rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) - - wkspace_info, barrier_info, dgamma_part_info, dbeta_part_info = ( - transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - True, - kwargs["zero_centered_gamma"], - kwargs["epsilon"], - get_backward_sm_margin(), - ) + dx_aval = dz_aval + dgamma_aval = dbeta_aval = gamma_aval + + (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # input te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + True, + kwargs["zero_centered_gamma"], + kwargs["epsilon"], + get_backward_sm_margin(), ) wkspace_aval = dx_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = dx_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - dgamma_part_aval = dgamma_aval.update( - shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) - ) - dbeta_part_aval = dbeta_aval.update( - shape=dbeta_part_info[0], dtype=te_dtype_to_jax_dtype(dbeta_part_info[1]) - ) return ( dx_aval, dgamma_aval, dbeta_aval, wkspace_aval, - barrier_aval, - dgamma_part_aval, - dbeta_part_aval, ) @staticmethod @@ -417,9 +395,7 @@ def outer_abstract(*args, **kwargs): """ LayerNorm bwd outer primitive abstract """ - dx_aval, dgamma_aval, dbeta_aval, _, _, _, _ = LayerNormBwdPrimitive.abstract( - *args, **kwargs - ) + dx_aval, dgamma_aval, dbeta_aval, _ = LayerNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval, dbeta_aval @staticmethod @@ -470,20 +446,14 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): sm_margin = get_backward_sm_margin() - wkspace_aval, barrier_aval, dgamma_part_aval, dbeta_part_aval = ctx.avals_out[-4:] + wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_norm_descriptor( batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - dbeta_part_aval.shape, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - jax_dtype_to_te_dtype(dbeta_part_aval.dtype), zero_centered_gamma, epsilon, sm_margin, @@ -496,7 +466,7 @@ def lowering(ctx, dz, x, mu, rsigma, gamma, *, zero_centered_gamma, epsilon): @staticmethod def impl(dz, x, mu, rsigma, gamma, zero_centered_gamma, epsilon): assert LayerNormBwdPrimitive.inner_primitive is not None - dx, dgamma, dbeta, _, _, _, _ = LayerNormBwdPrimitive.inner_primitive.bind( + dx, dgamma, dbeta, _ = LayerNormBwdPrimitive.inner_primitive.bind( dz, x, mu, rsigma, gamma, zero_centered_gamma=zero_centered_gamma, epsilon=epsilon ) return dx, dgamma, dbeta @@ -624,13 +594,13 @@ def abstract(x_aval, gamma_aval, **kwargs): rsigama_dtype = jnp.float32 - out_aval = core.raise_to_shaped(x_aval) + out_aval = x_aval rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) hidden_size = gamma_aval.size assert x_aval.size % hidden_size == 0 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -644,18 +614,15 @@ def abstract(x_aval, gamma_aval, **kwargs): wkspace_aval = out_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = out_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, rsigma_aval, wkspace_aval, barrier_aval + return out_aval, rsigma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm fwd outer primitive abstract """ - out_aval, rsigma_aval, _, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs) + out_aval, rsigma_aval, _ = RmsNormFwdPrimitive.abstract(*args, **kwargs) return out_aval, rsigma_aval @staticmethod @@ -688,7 +655,7 @@ def lowering(ctx, x, gamma, *, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, x_type.element_type), @@ -696,9 +663,6 @@ def lowering(ctx, x, gamma, *, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma] operand_shapes = [x_shape, g_shape] @@ -710,15 +674,9 @@ def lowering(ctx, x, gamma, *, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -734,7 +692,7 @@ def impl(x, gamma, epsilon): to describe implementation """ assert RmsNormFwdPrimitive.inner_primitive is not None - out, rsigma, _, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) + out, rsigma, _ = RmsNormFwdPrimitive.inner_primitive.bind(x, gamma, epsilon=epsilon) return out, rsigma @staticmethod @@ -830,39 +788,31 @@ def abstract(dz_aval, x_aval, rsigma_aval, gamma_aval, **kwargs): assert rsigma_aval.shape == x_aval.shape[:-1] assert rsigma_dtype == jnp.float32 - dx_aval = core.raise_to_shaped(dz_aval) - dgamma_aval = core.raise_to_shaped(gamma_aval) - - wkspace_info, barrier_info, dgamma_part_info, _ = ( - transformer_engine_jax.get_layernorm_bwd_workspace_sizes( - x_aval.size // gamma_aval.size, # batch size - gamma_aval.size, # hidden size - jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype - jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype - False, - False, - kwargs["epsilon"], - get_backward_sm_margin(), - ) + dx_aval = dz_aval + dgamma_aval = gamma_aval + + (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( + x_aval.size // gamma_aval.size, # batch size + gamma_aval.size, # hidden size + jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype + jax_dtype_to_te_dtype(gamma_aval.dtype), # weight te_dtype + False, + False, + kwargs["epsilon"], + get_backward_sm_margin(), ) wkspace_aval = dx_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = dx_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - dgamma_part_aval = dgamma_aval.update( - shape=dgamma_part_info[0], dtype=te_dtype_to_jax_dtype(dgamma_part_info[1]) - ) - return dx_aval, dgamma_aval, wkspace_aval, barrier_aval, dgamma_part_aval + return dx_aval, dgamma_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm bwd outer primitive abstract """ - dx_aval, dgamma_aval, _, _, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs) + dx_aval, dgamma_aval, _ = RmsNormBwdPrimitive.abstract(*args, **kwargs) return dx_aval, dgamma_aval @staticmethod @@ -896,7 +846,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): hidden_size = reduce(operator.mul, g_shape) batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval, dgamma_part_aval = ctx.avals_out[-3:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(x_shape, x_type.element_type), @@ -904,12 +854,6 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), - ir.RankedTensorType.get( - dgamma_part_aval.shape, jax_dtype_to_ir_dtype(dgamma_part_aval.dtype) - ), ] operands = [dz, rsigma, x, gamma] operand_shapes = [dz_shape, rsigma_shape, x_shape, g_shape] @@ -921,15 +865,9 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - dgamma_part_aval.shape, - (0,), # no dbeta_part for RMSnorm jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - jax_dtype_to_te_dtype(dgamma_part_aval.dtype), - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -942,7 +880,7 @@ def lowering(ctx, dz, x, rsigma, gamma, *, epsilon): @staticmethod def impl(dz, x, rsigma, gamma, epsilon): assert RmsNormBwdPrimitive.inner_primitive is not None - dx, dgamma, _, _, _ = RmsNormBwdPrimitive.inner_primitive.bind( + dx, dgamma, _ = RmsNormBwdPrimitive.inner_primitive.bind( dz, x, rsigma, gamma, epsilon=epsilon ) return dx, dgamma @@ -1066,7 +1004,7 @@ def abstract( assert gamma_aval.size == beta_aval.size - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // gamma_aval.size, # batch size gamma_aval.size, # hidden size jax_dtype_to_te_dtype(x_aval.dtype), # in type @@ -1084,18 +1022,15 @@ def abstract( wkspace_aval = x_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = x_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval, barrier_aval + return out_aval, mu_aval, rsigma_aval, updated_amax_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ LayerNorm fwd (fp8 out) outer primitive abstract """ - out_aval, mu_aval, rsigma_aval, updated_amax_aval, _, _ = LayerNormFwdFp8Primitive.abstract( + out_aval, mu_aval, rsigma_aval, updated_amax_aval, _ = LayerNormFwdFp8Primitive.abstract( *args, **kwargs ) return out_aval, mu_aval, rsigma_aval, updated_amax_aval @@ -1158,7 +1093,7 @@ def lowering( batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1168,9 +1103,6 @@ def lowering( ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, beta, amax, scale, scale_inv] operand_shapes = [ @@ -1189,15 +1121,9 @@ def lowering( batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype zero_centered_gamma, epsilon, sm_margin, @@ -1215,7 +1141,7 @@ def impl(x, gamma, beta, amax, scale, scale_inv, out_dtype, zero_centered_gamma, to describe implementation """ assert LayerNormFwdFp8Primitive.inner_primitive is not None - out, mu, rsigma, updated_amax, _, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( + out, mu, rsigma, updated_amax, _ = LayerNormFwdFp8Primitive.inner_primitive.bind( x, gamma, beta, @@ -1394,7 +1320,7 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp rsigama_dtype = jnp.float32 - wkspace_info, barrier_info = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( + (wkspace_info,) = transformer_engine_jax.get_layernorm_fwd_workspace_sizes( x_aval.size // hidden_size, # batch_size hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), # in te_dtype @@ -1412,18 +1338,15 @@ def abstract(x_aval, gamma_aval, amax_aval, scale_aval, scale_inv_aval, out_dtyp wkspace_aval = x_aval.update( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) ) - barrier_aval = x_aval.update( - shape=barrier_info[0], dtype=te_dtype_to_jax_dtype(barrier_info[1]) - ) - return out_aval, rsigma_aval, amax_aval, wkspace_aval, barrier_aval + return out_aval, rsigma_aval, amax_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ RMSNorm fwd (fp8 out) outer primitive abstract """ - out_aval, rsigma_aval, amax_aval, _, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs) + out_aval, rsigma_aval, amax_aval, _ = RmsNormFwdFp8Primitive.abstract(*args, **kwargs) return out_aval, rsigma_aval, amax_aval @staticmethod @@ -1476,7 +1399,7 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_shape = out_shape[:-1] batch_size = reduce(operator.mul, x_shape) // hidden_size - wkspace_aval, barrier_aval = ctx.avals_out[-2:] + wkspace_aval = ctx.avals_out[-1] out_types = [ ir.RankedTensorType.get(out_shape, ir_out_dtype), @@ -1485,9 +1408,6 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): ir.RankedTensorType.get( wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype) ), - ir.RankedTensorType.get( - barrier_aval.shape, jax_dtype_to_ir_dtype(barrier_aval.dtype) - ), ] operands = [x, gamma, amax, scale, scale_inv] operand_shapes = [x_shape, g_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] @@ -1499,15 +1419,9 @@ def lowering(ctx, x, gamma, amax, scale, scale_inv, *, out_dtype, epsilon): batch_size, hidden_size, wkspace_aval.size, - barrier_aval.size, - (0,), # no dgamma_part in FWD pass - (0,), # no dbeta_part in BWD pass jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(gamma_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - jax_dtype_to_te_dtype(barrier_aval.dtype), - TEDType.kByte, # dummy dgamma_part te_dtype - TEDType.kByte, # dummy dbeta_part te_dtype False, # RMSNorm doesn't support zero_centered_gamma epsilon, sm_margin, @@ -1525,7 +1439,7 @@ def impl(x, gamma, amax, scale, scale_inv, out_dtype, epsilon): to describe implementation """ assert RmsNormFwdFp8Primitive.inner_primitive is not None - out, rsigma, amax, _, _ = RmsNormFwdFp8Primitive.inner_primitive.bind( + out, rsigma, amax, _ = RmsNormFwdFp8Primitive.inner_primitive.bind( x, gamma, amax, scale, scale_inv, out_dtype=out_dtype, epsilon=epsilon ) return out, rsigma, amax diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 062bbbf0fb..d944612ef5 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1,18 +1,18 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for quantization""" from typing import Tuple +from packaging import version import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import DType as TEDType +import transformer_engine_jax +from transformer_engine_jax import DType as TEDType from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper @@ -25,6 +25,11 @@ ) from ..sharding import all_reduce_max_along_all_axes_except_PP +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = ["cast_fp8"] diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index a12943f4c2..888e6a897a 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for softmax""" @@ -6,21 +6,26 @@ from functools import partial, reduce import operator import warnings +from packaging import version import jax import jax.numpy as jnp -from jax import core, dtypes +from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi -from transformer_engine import transformer_engine_jax +import transformer_engine_jax from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper from .misc import get_padded_spec, check_valid_batch_dims, jax_dtype_to_te_dtype, is_ffi_enabled from ..softmax import SoftmaxType +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = [ "scaled_softmax_fwd", @@ -126,7 +131,7 @@ def forward_abstract(logits_aval, scale_factor): assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert q_seqlen > 1 - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod @@ -237,7 +242,7 @@ def backward_abstract( assert dz_aval.shape == softmax_out_aval.shape - dx_aval = core.raise_to_shaped(dz_aval) + dx_aval = dz_aval return dx_aval @staticmethod @@ -578,7 +583,7 @@ def abstract(logits_aval, mask_aval, scale_factor): # pylint: disable=unused-ar assert mask_shape[-2] == q_seqlen assert mask_shape[-1] == k_seqlen - out_aval = core.raise_to_shaped(logits_aval) + out_aval = logits_aval return out_aval @staticmethod diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index 2338572e30..ca42126e4b 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -1,20 +1,20 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX/TE custom ops for transpose""" +import operator from functools import partial, reduce from typing import Tuple, Sequence, Union, Callable -import operator +from packaging import version import jax import jax.numpy as jnp from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi -from transformer_engine import transformer_engine_jax -from transformer_engine.transformer_engine_jax import DType as TEDType +import transformer_engine_jax +from transformer_engine_jax import DType as TEDType from .base import BasePrimitive, register_primitive from .custom_call import custom_caller, CustomCallArgsWrapper @@ -33,6 +33,11 @@ from .quantization import _jax_cast_fp8 from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +if version.parse(jax.__version__) >= version.parse("0.5.0"): + from jax import ffi # pylint: disable=ungrouped-imports +else: + from jax.extend import ffi # pylint: disable=ungrouped-imports + __all__ = [ "transpose", diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 02e6aaf9d5..6c3e2aa97d 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -81,25 +81,18 @@ struct CustomCallNormDescriptor { size_t batch_size; size_t hidden_size; size_t wkspace_size; - size_t barrier_size; - Shape dgamma_part_shape; - Shape dbeta_part_shape; DType x_dtype; DType w_dtype; DType wkspace_dtype; - DType barrier_dtype; - DType dgamma_part_dtype; - DType dbeta_part_dtype; bool zero_centered_gamma; float eps; int sm_margin; }; -pybind11::bytes PackCustomCallNormDescriptor( - size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, - const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, - DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin); +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, DType x_dtype, DType w_dtype, + DType wkspace_dtype, bool zero_centered_gamma, + float eps, int sm_margin); struct SoftmaxDescriptor { size_t batch_size; diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 9d5fb4f7b4..a5457fa032 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -1,11 +1,12 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/activation.h" #include "extensions.h" +#include "transformer_engine/cast.h" #include "transformer_engine/transpose.h" #include "xla/ffi/api/c_api.h" @@ -264,8 +265,8 @@ Error_Type DActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act auto *output = output_buf->untyped_data(); auto act_input_dims = act_input_buf.dimensions(); - auto m = product(act_input_dims, 0, act_input_dims.size() - 2); - auto n = act_input_dims.back(); + auto m = static_cast(product(act_input_dims, 0, act_input_dims.size() - 2)); + auto n = static_cast(act_input_dims.back()); auto act_len = act_input_dims.end()[-2]; auto input_shape = std::vector{m, n}; @@ -332,18 +333,27 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_ auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto dact_input_tensor = TensorWrapper(nullptr, dact_input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + // Evil hack to specify TE impl + // Note: nvte_quantize_dbias_dgelu chooses its internal impl based + // on what pointers are allocated, e.g. whether to output with + // column-wise data. However, we don't have access to any allocated + // buffers in this function. We pass a dummy pointer as a + // workaround. + int temp = 0; + + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); + auto dact_input_tensor = + TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); + auto output_tensor = TensorWrapper(); + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); + auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); TensorWrapper dummy_workspace; // For now, all dbias_dact(-s) have the same workspace size - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); @@ -384,37 +394,32 @@ void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *o auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); switch (act_enum) { case NVTE_Activation_Type::GELU: - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::SILU: - nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::RELU: - nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::QGELU: - nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; case NVTE_Activation_Type::SRELU: - nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -468,37 +473,32 @@ Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); auto act_type = static_cast(act_enum); switch (act_type) { case NVTE_Activation_Type::GELU: - nvte_cast_transpose_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::SILU: - nvte_cast_transpose_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dsilu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::RELU: - nvte_cast_transpose_dbias_drelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_drelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::QGELU: - nvte_cast_transpose_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dqgelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; case NVTE_Activation_Type::SRELU: - nvte_cast_transpose_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias_dsrelu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), workspace_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -555,29 +555,29 @@ void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *o auto act_input_tensor = TensorWrapper(act_input, act_input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); switch (act_enum) { case NVTE_Activation_Type::GEGLU: nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::REGLU: nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::QGEGLU: nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::SREGLU: nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); @@ -622,30 +622,30 @@ Error_Type DGatedActLuCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto act_type = static_cast(act_enum); switch (act_type) { case NVTE_Activation_Type::GEGLU: nvte_dgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::REGLU: nvte_dreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), - output_trans_tensor.data(), stream); + stream); break; case NVTE_Activation_Type::QGEGLU: nvte_dqgeglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; case NVTE_Activation_Type::SREGLU: nvte_dsreglu_cast_transpose(input_tensor.data(), act_input_tensor.data(), - output_tensor.data(), output_trans_tensor.data(), stream); + output_tensor.data(), stream); break; default: NVTE_ERROR("Unsupported ActivationEnum"); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 4bde10fc46..a824e5b83b 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -129,6 +129,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); + auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); @@ -164,15 +165,16 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), - nullptr); + ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, + kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); @@ -213,14 +215,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto layout_group = nvte_get_qkv_layout_group(qkv_layout); static void FusedAttnForwardImpl( - cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, - void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, - void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, - size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, - size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, - float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, - bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, + void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, + size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic, int64_t window_size_left, int64_t window_size_right) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -229,6 +231,10 @@ static void FusedAttnForwardImpl( if (is_ragged) { auto output_size = input_batch * q_max_seqlen * attn_heads * head_dim; cudaMemsetAsync(output, 0, output_size * typeToSize(dtype), stream); + + // Memset to 0xF0 for filling large negative numbers + auto softmax_aux_size = input_batch * q_max_seqlen * attn_heads; + cudaMemsetAsync(softmax_aux, 0xF0, softmax_aux_size * sizeof(float), stream); } /* Output tensors */ @@ -252,6 +258,7 @@ static void FusedAttnForwardImpl( backend, softmax_aux); /* Call the underlying NVTE API */ + auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); @@ -269,9 +276,10 @@ static void FusedAttnForwardImpl( nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; @@ -279,13 +287,13 @@ static void FusedAttnForwardImpl( auto q_tensor = TensorWrapper(q, q_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -303,11 +311,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s void *k = buffers[1]; void *v = buffers[2]; void *bias = buffers[3]; - void *q_cu_seqlens = buffers[4]; - void *kv_cu_seqlens = buffers[5]; - void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; - void *seed = buffers[8]; + void *seed = buffers[4]; + void *q_cu_seqlens = buffers[5]; + void *kv_cu_seqlens = buffers[6]; + void *q_seq_offsets = is_ragged ? buffers[7] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[8] : nullptr; /* Output buffer from XLA */ void *output = buffers[9]; @@ -316,7 +324,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s void *workspace = buffers[12]; FusedAttnForwardImpl( - stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed, + stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, @@ -354,24 +362,24 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, - Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, - Buffer_Type seed_buf, Result_Type output_buf, + Variadic_Buffer_Type _unused_args, Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf, Result_Type workspace_buf, Dictionary attrs) { FUSED_ATTN_FFI_GET_ATTRS; FusedAttnForwardImpl( stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), - bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(), - is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, - is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(), - output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), - workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, - scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, - is_training, deterministic, window_size_left, window_size_right); + bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), + softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, + head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, + mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, + window_size_right); return ffi_with_cuda_error_check(); } @@ -383,11 +391,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, .Arg() // k .Arg() // v .Arg() // bias + .Arg() // seed_buf .Arg() // q_cu_seqlens .Arg() // kv_cu_seqlens .Arg() // q_seq_offsets .Arg() // k_seq_offsets - .Arg() // seed_buf + .RemainingArgs() // _cp_aux_args unused .Ret() // output .Ret() // softmax_aux .Ret() // rng_state @@ -535,7 +544,8 @@ static void FusedAttnBackwardImpl( auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::product(qkv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), + stream); } nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 @@ -553,8 +563,9 @@ static void FusedAttnBackwardImpl( auto dq_tensor = TensorWrapper(dq, q_shape, dtype); auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::product(kv_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), + stream); } nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -576,9 +587,9 @@ static void FusedAttnBackwardImpl( auto dk_tensor = TensorWrapper(dk, k_shape, dtype); auto dv_tensor = TensorWrapper(dv, v_shape, dtype); if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::product(k_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dv, 0, transformer_engine::product(v_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); + cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); } nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -642,9 +653,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, - Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - Dictionary attrs) { + Variadic_Buffer_Type _unused_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf, + Result_Type workspace_buf, Dictionary attrs) { FUSED_ATTN_FFI_GET_ATTRS; FusedAttnBackwardImpl( @@ -677,6 +688,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Arg() // kv_cu_seqlens .Arg() // q_seq_offsets .Arg() // k_seq_offsets + .RemainingArgs() // _cp_aux_args unused .Ret() // dq .Ret() // dk .Ret() // dv diff --git a/transformer_engine/jax/csrc/extensions/cudnn.cpp b/transformer_engine/jax/csrc/extensions/cudnn.cpp index 95f505e226..19fe33b818 100644 --- a/transformer_engine/jax/csrc/extensions/cudnn.cpp +++ b/transformer_engine/jax/csrc/extensions/cudnn.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index 8b627aad35..f991aeea18 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index d886064cae..ab1d34cf5a 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index 357a5679db..b1445e5bed 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 7f6179e91c..7cb83a0f9e 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -26,5 +26,13 @@ struct Shape { std::vector MakeShapeVector(NVTEShape shape); +inline size_t product(const std::vector &shape) { + size_t ret = 1; + for (const auto &elem : shape) { + ret *= elem; + } + return ret; +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 9bd9951916..95b33708f0 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -1,11 +1,11 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#include "transformer_engine/normalization.h" + #include "extensions.h" -#include "transformer_engine/layer_norm.h" -#include "transformer_engine/rmsnorm.h" namespace transformer_engine { namespace jax { @@ -25,40 +25,36 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd auto rsigma_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; + TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; if (is_layer_norm) { auto beta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, DType::kFloat32); - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), nullptr, - num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); + nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + dummy_work_tensor.data(), num_sm, zero_centered_gamma, nullptr); } else { + // TODO(Phuong): Verify and remove this check NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); + rsigma_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, + nullptr); } auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype())); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); } void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspace_size, - size_t barrier_size, bool zero_centered_gamma, float eps, void *input, - DType in_dtype, void *weight, DType w_dtype, void *bias, void *output, - DType out_dtype, void *workspace, DType work_dtype, void *barrier, - DType barrier_dtype, void *mu, void *rsigma, float *amax, float *scale, - float *scale_inv, int sm_margin, cudaStream_t stream) { + bool zero_centered_gamma, float eps, void *input, DType in_dtype, + void *weight, DType w_dtype, void *bias, void *output, DType out_dtype, + void *workspace, DType work_dtype, void *mu, void *rsigma, float *amax, + float *scale, float *scale_inv, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; auto workspace_shape = std::vector{workspace_size}; - auto barrier_shape = std::vector{barrier_size}; auto is_layer_norm = (bias) ? true : false; auto input_tensor = TensorWrapper(input, input_shape, in_dtype); @@ -71,23 +67,21 @@ void LayerNormForwardImpl(size_t batch_size, size_t hidden_size, size_t workspac auto rsigma_tensor = TensorWrapper(rsigma, intermediates_shape, DType::kFloat32); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_fwd_func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, work_dtype); - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); if (is_layer_norm) { auto beta_tensor = TensorWrapper(bias, weight_shape, w_dtype); auto mu_tensor = TensorWrapper(mu, intermediates_shape, DType::kFloat32); - layernorm_fwd_func(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, - output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), stream, num_sm, - workspace_tensor.data(), barrier_tensor.data()); + nvte_layernorm_fwd(input_tensor.data(), gamma_tensor.data(), beta_tensor.data(), eps, + output_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + workspace_tensor.data(), num_sm, zero_centered_gamma, stream); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_fwd(input_tensor.data(), gamma_tensor.data(), eps, output_tensor.data(), - rsigma_tensor.data(), stream, num_sm, workspace_tensor.data(), - barrier_tensor.data()); + rsigma_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, + stream); } } @@ -96,20 +90,17 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff Buffer_Type *scale_buf, Buffer_Type *scale_inv_buf, Result_Type *output_buf, Result_Type *mu_buf, Result_Type *rsigma_buf, Result_Type *amax_out_buf, - Result_Type *wkspace_buf, Result_Type *barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_, - bool is_layer_norm, bool is_fp8) { + Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_, bool is_layer_norm, bool is_fp8) { auto in_dtype = convert_ffi_datatype_to_te_dtype((*x_buf).element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype((*gamma_buf).element_type()); auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); - auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); auto *input = x_buf->untyped_data(); auto *weight = gamma_buf->untyped_data(); auto *output = (*output_buf)->untyped_data(); auto *rsigma = (*rsigma_buf)->untyped_data(); auto *workspace = (*wkspace_buf)->untyped_data(); - auto *barrier = (*barrier_buf)->untyped_data(); void *bias = nullptr; void *mu = nullptr; @@ -135,17 +126,15 @@ Error_Type LayerNormForwardImplFFI(cudaStream_t stream, Buffer_Type *x_buf, Buff auto x_size = product(x_buf->dimensions()); auto gamma_size = product(gamma_buf->dimensions()); auto wkspace_size = product((*wkspace_buf)->dimensions()); - auto barrier_size = product((*barrier_buf)->dimensions()); auto hidden_size = gamma_size; auto batch_size = x_size / gamma_size; float eps = static_cast(eps_); int sm_margin = static_cast(sm_margin_); - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); return ffi_with_cuda_error_check(); } @@ -154,11 +143,10 @@ Error_Type LayerNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer Buffer_Type scale_inv_buf, Result_Type output_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type amax_out_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, bool zero_centered_gamma, double eps_, - int64_t sm_margin_) { + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, &amax_buf, &scale_buf, &scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, &amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, true, // is_layer_norm true // is_fp8 ); @@ -178,7 +166,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI .Ret() // rsigma .Ret() // amax_out .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -187,15 +174,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardFP8Handler, LayerNormForwardFP8FFI Error_Type LayerNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, bool zero_centered_gamma, double eps_, - int64_t sm_margin_) { + bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, &beta_buf, nullptr, // amax_buf nullptr, // scale_buf, nullptr, // scale_inv_buf, &output_buf, &mu_buf, &rsigma_buf, nullptr, // amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, true, // is_layer_norm false // is_fp8 ); @@ -211,7 +197,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormForwardHandler, LayerNormForwardFFI, .Ret() // mu .Ret() // rsigma .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -221,14 +206,14 @@ Error_Type RMSNormForwardFP8FFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_T Buffer_Type amax_buf, Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf, Result_Type rsigma_buf, Result_Type amax_out_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, nullptr, // beta_buf, &amax_buf, &scale_buf, &scale_inv_buf, &output_buf, nullptr, // mu_buf, - &rsigma_buf, &amax_out_buf, &wkspace_buf, &barrier_buf, - zero_centered_gamma, eps_, sm_margin_, + &rsigma_buf, &amax_out_buf, &wkspace_buf, zero_centered_gamma, + eps_, sm_margin_, false, // is_layer_norm true // is_fp8 ); @@ -246,7 +231,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI, .Ret() // rsigma .Ret() // amax_out .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -254,8 +238,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardFP8Handler, RMSNormForwardFP8FFI, Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type gamma_buf, Result_Type output_buf, Result_Type rsigma_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormForwardImplFFI(stream, &x_buf, &gamma_buf, nullptr, // beta_buf, nullptr, // amax_buf, @@ -265,7 +249,7 @@ Error_Type RMSNormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type nullptr, // mu_buf, &rsigma_buf, nullptr, // amax_out_buf, - &wkspace_buf, &barrier_buf, zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, false, // is_layer_norm false // is_fp8 ); @@ -279,7 +263,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormForwardHandler, RMSNormForwardFFI, .Ret() // output .Ret() // rsigma .Ret() // wkspace - .Ret() // barrier .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -303,50 +286,34 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid auto wgrad_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); // dummy tensor wrappers that will carry workspace size info later - TensorWrapper dummy_work_tensor, dummy_barrier_tensor; - TensorWrapper dummy_dgamma_part_tensor, dummy_dbeta_part_tensor; + TensorWrapper dummy_work_tensor; auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - // initialize dBeta information here -- layernorm will modify but RMSnorm will not - std::vector dbeta_part_shape; if (is_layer_norm) { auto mu_tensor = TensorWrapper(nullptr, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(nullptr, weight_shape, w_dtype); - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dbeta_tensor.data(), dummy_dgamma_part_tensor.data(), - dummy_dbeta_part_tensor.data(), nullptr, num_sm, dummy_work_tensor.data(), - dummy_barrier_tensor.data()); + dbeta_tensor.data(), dummy_work_tensor.data(), num_sm, zero_centered_gamma, + nullptr); - dbeta_part_shape = MakeShapeVector(dummy_dbeta_part_tensor.shape()); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), - xgrad_tensor.data(), wgrad_tensor.data(), dummy_dgamma_part_tensor.data(), - nullptr, num_sm, dummy_work_tensor.data(), dummy_barrier_tensor.data()); - - dbeta_part_shape = std::vector{0, 0}; + xgrad_tensor.data(), wgrad_tensor.data(), dummy_work_tensor.data(), num_sm, + zero_centered_gamma, nullptr); } auto work_shape = MakeShapeVector(dummy_work_tensor.shape()); - auto barrier_shape = MakeShapeVector(dummy_barrier_tensor.shape()); - auto dgamma_part_shape = MakeShapeVector(dummy_dgamma_part_tensor.shape()); - return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype()), - std::make_pair(barrier_shape, dummy_barrier_tensor.dtype()), - std::make_pair(dgamma_part_shape, dummy_dgamma_part_tensor.dtype()), - std::make_pair(dbeta_part_shape, dummy_dbeta_part_tensor.dtype())); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_work_tensor.dtype())); } void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace_size, - size_t barrier_size, Shape dgamma_part_shape, Shape dbeta_part_shape, bool zero_centered_gamma, float eps, void *input, DType in_dtype, void *weight, DType w_dtype, void *ograd, void *workspace, - DType wkspace_dtype, void *barrier, DType barrier_dtype, void *mu, - void *rsigma, void *xgrad, void *wgrad, void *dbeta, void *dgamma_part, - DType dgamma_dtype, void *dbeta_part, DType dbeta_dtype, int sm_margin, - cudaStream_t stream) { + DType wkspace_dtype, void *mu, void *rsigma, void *xgrad, void *wgrad, + void *dbeta, int sm_margin, cudaStream_t stream) { auto input_shape = std::vector{batch_size, hidden_size}; auto weight_shape = std::vector{hidden_size}; auto intermediates_shape = std::vector{batch_size}; @@ -368,28 +335,23 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace auto wgrad_tensor = TensorWrapper(wgrad, weight_shape, w_dtype); auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount() - sm_margin; - auto layernorm_bwd_func = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; auto workspace_shape = std::vector{wkspace_size}; auto workspace_tensor = TensorWrapper(workspace, workspace_shape, wkspace_dtype); - auto barrier_shape = std::vector{barrier_size}; - auto barrier_tensor = TensorWrapper(barrier, barrier_shape, barrier_dtype); - auto dgamma_part_tensor = TensorWrapper(dgamma_part, dgamma_part_shape.to_vector(), dgamma_dtype); if (is_layer_norm) { auto mu_tensor = TensorWrapper(mu, intermediates_shape, intermediates_dtype); auto dbeta_tensor = TensorWrapper(dbeta, weight_shape, w_dtype); - auto dbeta_part_tensor = TensorWrapper(dbeta_part, dbeta_part_shape.to_vector(), dbeta_dtype); - layernorm_bwd_func(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), + nvte_layernorm_bwd(dz_tensor.data(), x_tensor.data(), mu_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), xgrad_tensor.data(), wgrad_tensor.data(), - dbeta_tensor.data(), dgamma_part_tensor.data(), dbeta_part_tensor.data(), - stream, num_sm, workspace_tensor.data(), barrier_tensor.data()); + dbeta_tensor.data(), workspace_tensor.data(), num_sm, zero_centered_gamma, + stream); } else { NVTE_CHECK(!zero_centered_gamma, "rmsnorm doesn't support zero_centered_gamma."); nvte_rmsnorm_bwd(dz_tensor.data(), x_tensor.data(), rsigma_tensor.data(), gamma_tensor.data(), - xgrad_tensor.data(), wgrad_tensor.data(), dgamma_part_tensor.data(), stream, - num_sm, workspace_tensor.data(), barrier_tensor.data()); + xgrad_tensor.data(), wgrad_tensor.data(), workspace_tensor.data(), num_sm, + zero_centered_gamma, stream); } } @@ -397,15 +359,11 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu Buffer_Type *mu_buf, Buffer_Type *rsigma_buf, Buffer_Type *gamma_buf, Result_Type *xgrad_buf, Result_Type *wgrad_buf, Result_Type *dbeta_buf, - Result_Type *wkspace_buf, Result_Type *barrier_buf, - Result_Type *dgamma_part_buf, Result_Type *dbeta_part_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_, - bool is_layer_norm) { + Result_Type *wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_, bool is_layer_norm) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf->element_type()); auto wkspace_dtype = convert_ffi_datatype_to_te_dtype((*wkspace_buf)->element_type()); - auto barrier_dtype = convert_ffi_datatype_to_te_dtype((*barrier_buf)->element_type()); - auto dgamma_part_dtype = convert_ffi_datatype_to_te_dtype((*dgamma_part_buf)->element_type()); auto *ograd = dz_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data(); @@ -414,62 +372,37 @@ Error_Type LayerNormBackwardImplFFI(cudaStream_t stream, Buffer_Type *dz_buf, Bu auto *xgrad = (*xgrad_buf)->untyped_data(); auto *wgrad = (*wgrad_buf)->untyped_data(); auto *workspace = (*wkspace_buf)->untyped_data(); - auto *barrier = (*barrier_buf)->untyped_data(); - auto *dgamma_part = (*dgamma_part_buf)->untyped_data(); void *mu = nullptr; void *dbeta = nullptr; - void *dbeta_part = nullptr; - auto dbeta_part_dtype = DType::kByte; if (is_layer_norm) { mu = (*mu_buf).untyped_data(); dbeta = (*dbeta_buf)->untyped_data(); - dbeta_part = (*dbeta_part_buf)->untyped_data(); - dbeta_part_dtype = convert_ffi_datatype_to_te_dtype((*dbeta_part_buf)->element_type()); } auto x_size = product(x_buf->dimensions()); auto gamma_size = product(gamma_buf->dimensions()); auto wkspace_size = product((*wkspace_buf)->dimensions()); - auto barrier_size = product((*barrier_buf)->dimensions()); auto hidden_size = gamma_size; auto batch_size = x_size / gamma_size; - Shape dgamma_part_shape; - auto dgamma_part_dims = (*dgamma_part_buf)->dimensions(); - std::vector dgamma_parts_dims_vector(dgamma_part_dims.begin(), dgamma_part_dims.end()); - dgamma_part_shape.from_vector(dgamma_parts_dims_vector); - - Shape dbeta_part_shape; - if (is_layer_norm) { - auto dbeta_part_dims = (*dbeta_part_buf)->dimensions(); - std::vector dbeta_parts_dims_vector(dbeta_part_dims.begin(), dbeta_part_dims.end()); - dbeta_part_shape.from_vector(dbeta_parts_dims_vector); - } else { - dbeta_part_shape.from_vector({0, 0}); - } - float eps = static_cast(eps_); int sm_margin = static_cast(sm_margin_); - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); return ffi_with_cuda_error_check(); } Error_Type LayerNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, Buffer_Type mu_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type dbeta_buf, - Result_Type wkspace_buf, Result_Type barrier_buf, - Result_Type dgamma_part_buf, Result_Type dbeta_part_buf, - bool zero_centered_gamma, double eps_, int64_t sm_margin_) { + Result_Type wkspace_buf, bool zero_centered_gamma, double eps_, + int64_t sm_margin_) { return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, &mu_buf, &rsigma_buf, &gamma_buf, - &xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, &barrier_buf, - &dgamma_part_buf, &dbeta_part_buf, zero_centered_gamma, eps_, - sm_margin_, + &xgrad_buf, &wgrad_buf, &dbeta_buf, &wkspace_buf, + zero_centered_gamma, eps_, sm_margin_, true // is_layer_norm ); } @@ -486,9 +419,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, .Ret() // wgrad .Ret() // dbeta .Ret() // wkspace - .Ret() // barrier - .Ret() // dgamma_part - .Ret() // dbeta_part .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -497,15 +427,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI, Error_Type RMSNormBackwardFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, Buffer_Type rsigma_buf, Buffer_Type gamma_buf, Result_Type xgrad_buf, Result_Type wgrad_buf, Result_Type wkspace_buf, - Result_Type barrier_buf, Result_Type dgamma_part_buf, bool zero_centered_gamma, double eps_, int64_t sm_margin_) { return LayerNormBackwardImplFFI(stream, &dz_buf, &x_buf, nullptr, // mu_buf &rsigma_buf, &gamma_buf, &xgrad_buf, &wgrad_buf, nullptr, // dbeta_buf, - &wkspace_buf, &barrier_buf, &dgamma_part_buf, - nullptr, // dbeta_part_buf, - zero_centered_gamma, eps_, sm_margin_, + &wkspace_buf, zero_centered_gamma, eps_, sm_margin_, false // is_layer_norm ); } @@ -520,8 +447,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RMSNormBackwardHandler, RMSNormBackwardFFI, .Ret() // xgrad .Ret() // wgrad .Ret() // wkspace - .Ret() // barrier - .Ret() // dgamma_part .Attr("zero_centered_gamma") .Attr("eps") .Attr("sm_margin"), @@ -540,7 +465,6 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto *rsigma = buffers[8]; auto *amax_out = buffers[9]; auto *workspace = buffers[10]; - auto *barrier = buffers[11]; NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive"); @@ -548,21 +472,18 @@ void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -573,7 +494,6 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto *mu = buffers[4]; auto *rsigma = buffers[5]; auto *workspace = buffers[6]; - auto *barrier = buffers[7]; float *amax = nullptr; float *scale = nullptr; @@ -583,20 +503,17 @@ void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, s auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto out_dtype = in_dtype; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -605,15 +522,9 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - auto dbeta_part_shape = desc.dbeta_part_shape; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = desc.dbeta_part_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; @@ -627,15 +538,10 @@ void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, auto *wgrad = buffers[6]; auto *dbeta = buffers[7]; auto *workspace = buffers[8]; - auto *barrier = buffers[9]; - auto *dgamma_part = buffers[10]; - auto *dbeta_part = buffers[11]; - - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); } void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -648,7 +554,6 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto *rsigma = buffers[6]; auto *amax_out = buffers[7]; auto *workspace = buffers[8]; - auto *barrier = buffers[9]; NVTE_CHECK(amax_out == amax, "amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive."); void *bias = nullptr; @@ -658,20 +563,17 @@ void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = DType::kFloat8E4M3; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -680,7 +582,6 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto *output = buffers[2]; auto *rsigma = buffers[3]; auto *workspace = buffers[4]; - auto *barrier = buffers[5]; void *bias = nullptr; void *mu = nullptr; @@ -692,20 +593,17 @@ void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, siz auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; auto out_dtype = in_dtype; - LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, zero_centered_gamma, - eps, input, in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, - wkspace_dtype, barrier, barrier_dtype, mu, rsigma, amax, scale, scale_inv, - sm_margin, stream); + LayerNormForwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, bias, output, out_dtype, workspace, wkspace_dtype, + mu, rsigma, amax, scale, scale_inv, sm_margin, stream); } void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { @@ -716,36 +614,24 @@ void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, si auto *xgrad = buffers[4]; auto *wgrad = buffers[5]; auto *workspace = buffers[6]; - auto *barrier = buffers[7]; - auto *dgamma_part = buffers[8]; void *mu = nullptr; void *dbeta = nullptr; - void *dbeta_part = nullptr; const auto &desc = *UnpackOpaque(opaque, opaque_len); auto batch_size = desc.batch_size; auto hidden_size = desc.hidden_size; auto wkspace_size = desc.wkspace_size; - auto barrier_size = desc.barrier_size; - auto dgamma_part_shape = desc.dgamma_part_shape; - Shape dbeta_part_shape; - dbeta_part_shape.from_vector({0, 0}); auto in_dtype = desc.x_dtype; auto w_dtype = desc.w_dtype; auto wkspace_dtype = desc.wkspace_dtype; - auto barrier_dtype = desc.barrier_dtype; - auto dgamma_part_dtype = desc.dgamma_part_dtype; - auto dbeta_part_dtype = DType::kByte; auto eps = desc.eps; auto zero_centered_gamma = desc.zero_centered_gamma; auto sm_margin = desc.sm_margin; - LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, barrier_size, dgamma_part_shape, - dbeta_part_shape, zero_centered_gamma, eps, input, in_dtype, weight, - w_dtype, ograd, workspace, wkspace_dtype, barrier, barrier_dtype, mu, - rsigma, xgrad, wgrad, dbeta, dgamma_part, dgamma_part_dtype, dbeta_part, - dbeta_part_dtype, sm_margin, stream); + LayerNormBackwardImpl(batch_size, hidden_size, wkspace_size, zero_centered_gamma, eps, input, + in_dtype, weight, w_dtype, ograd, workspace, wkspace_dtype, mu, rsigma, + xgrad, wgrad, dbeta, sm_margin, stream); } } // namespace jax diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 298478603b..151a1d869a 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -32,24 +32,17 @@ pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector &shap return PackOpaque(desc); } -pybind11::bytes PackCustomCallNormDescriptor( - size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size, - const std::vector &dgamma_part_shape, const std::vector &dbeta_part_shape, - DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype, - DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin) { +pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size, + size_t wkspace_size, DType x_dtype, DType w_dtype, + DType wkspace_dtype, bool zero_centered_gamma, + float eps, int sm_margin) { CustomCallNormDescriptor desc{}; desc.batch_size = batch_size; desc.hidden_size = hidden_size; desc.wkspace_size = wkspace_size; - desc.barrier_size = barrier_size; - desc.dgamma_part_shape.from_vector(dgamma_part_shape); - desc.dbeta_part_shape.from_vector(dbeta_part_shape); desc.x_dtype = x_dtype; desc.w_dtype = w_dtype; desc.wkspace_dtype = wkspace_dtype; - desc.barrier_dtype = barrier_dtype; - desc.dgamma_part_dtype = dgamma_part_dtype; - desc.dbeta_part_dtype = dbeta_part_dtype; desc.zero_centered_gamma = zero_centered_gamma; desc.eps = eps; desc.sm_margin = sm_margin; diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 9b5c156e5d..9c92fe8b33 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -61,34 +61,43 @@ pybind11::dict Registrations() { dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); - dict["te_dact_lu_dbias_cast_transpose_ffi"] = - EncapsulateFunction(DActLuDBiasCastTransposeHandler); - dict["te_dgated_act_lu_cast_transpose_ffi"] = - EncapsulateFunction(DGatedActLuCastTransposeHandler); + dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler); + dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler); // Quantization dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax - dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); - dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); - dict["te_scaled_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler); + dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler); + dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler); + dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler); dict["te_scaled_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler); dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler); dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = - EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); + EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler); // Normalization - dict["te_layernorm_forward_ffi"] = EncapsulateFFI(LayerNormForwardHandler); - dict["te_layernorm_forward_fp8_ffi"] = EncapsulateFFI(LayerNormForwardFP8Handler); - dict["te_layernorm_backward_ffi"] = EncapsulateFFI(LayerNormBackwardHandler); - dict["te_rmsnorm_forward_ffi"] = EncapsulateFunction(RMSNormForwardHandler); - dict["te_rmsnorm_forward_fp8_ffi"] = EncapsulateFunction(RMSNormForwardFP8Handler); - dict["te_rmsnorm_backward_ffi"] = EncapsulateFunction(RMSNormBackwardHandler); + dict["te_layernorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardHandler)); + dict["te_layernorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormForwardFP8Handler)); + dict["te_layernorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(LayerNormBackwardHandler)); + dict["te_rmsnorm_forward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardHandler)); + dict["te_rmsnorm_forward_fp8_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormForwardFP8Handler)); + dict["te_rmsnorm_backward_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(RMSNormBackwardHandler)); // Attention pybind11::dict fused_attn_forward_ffi; diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index d08368657e..71d1456287 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -25,7 +25,7 @@ void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t op auto input_tensor = TensorWrapper(input, shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv); - nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -48,7 +48,7 @@ Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type a auto input_tensor = TensorWrapper(input, shape, in_dtype); auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv); - nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); return ffi_with_cuda_error_check(); } @@ -76,7 +76,7 @@ void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, desc.out_dtype); - nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -96,7 +96,7 @@ Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv); auto output_tensor = TensorWrapper(output, shape, out_dtype); - nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream); + nvte_dequantize(input_tensor.data(), output_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/csrc/extensions/softmax.cpp b/transformer_engine/jax/csrc/extensions/softmax.cpp index f54ebefcb0..1cf281e64b 100644 --- a/transformer_engine/jax/csrc/extensions/softmax.cpp +++ b/transformer_engine/jax/csrc/extensions/softmax.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/extensions/transpose.cpp b/transformer_engine/jax/csrc/extensions/transpose.cpp index 8480081a68..af347f45b2 100644 --- a/transformer_engine/jax/csrc/extensions/transpose.cpp +++ b/transformer_engine/jax/csrc/extensions/transpose.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -7,6 +7,7 @@ #include "transformer_engine/transpose.h" #include "extensions.h" +#include "transformer_engine/cast.h" #include "xla/ffi/api/ffi.h" namespace transformer_engine { @@ -89,13 +90,12 @@ void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size auto input_trans_shape = std::vector{n, m}; auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); - auto input_cast_tensor = + auto output_tensor = TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, desc.out_dtype, - amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(input_cast_trans, desc.out_dtype, input_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); - nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(), - stream); + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); } Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -131,11 +131,11 @@ Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); + + nvte_quantize(input_tensor.data(), output_tensor.data(), stream); - nvte_cast_transpose(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - stream); return ffi_with_cuda_error_check(); } @@ -159,15 +159,22 @@ pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hi auto output_trans_shape = std::vector{hidden_size, batch_size}; auto dbias_shape = std::vector{hidden_size}; - auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); - auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); - auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + // Evil hack to specify TE impl + // Note: nvte_quantize_dbias chooses its internal impl based on what + // pointers are allocated, e.g. whether to output with column-wise + // data. However, we don't have access to any allocated buffers in + // this function. We pass a dummy pointer as a workaround. + int temp = 0; + + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); + auto output_tensor = TensorWrapper(reinterpret_cast(&temp), output_shape, out_dtype); + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); + auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); TensorWrapper dummy_workspace; - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), dummy_workspace.data(), nullptr); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + dummy_workspace.data(), nullptr); auto work_shape = MakeShapeVector(dummy_workspace.shape()); return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); @@ -203,14 +210,14 @@ void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); auto output_tensor = TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace.data(), stream); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + workspace.data(), stream); } Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -253,13 +260,13 @@ Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buf auto input_tensor = TensorWrapper(input, input_shape, in_dtype); auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv); - auto output_trans_tensor = - TensorWrapper(output_trans, output_trans_shape, out_dtype, amax_out, scale, scale_inv); + output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape); + output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector{1}); auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); - nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(), - dbias_tensor.data(), workspace_tensor.data(), stream); + nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), + workspace_tensor.data(), stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/jax/csrc/utils.cu b/transformer_engine/jax/csrc/utils.cu index 8ca34013b3..2229c85165 100644 --- a/transformer_engine/jax/csrc/utils.cu +++ b/transformer_engine/jax/csrc/utils.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 32de33bac9..01d950e168 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/jax/dot.py b/transformer_engine/jax/dot.py index 8981af8b7c..826b94a983 100644 --- a/transformer_engine/jax/dot.py +++ b/transformer_engine/jax/dot.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX te modules""" @@ -25,7 +25,7 @@ def type_safe_dot_general( """ if fp8_meta_pkg is None: - kernel = jnp.asarray(kernel, x.dtype) + assert x.dtype == kernel.dtype, f"lhs dtype = {x.dtype}, rhs dtype = {kernel.dtype}" return jax.lax.dot_general(x, kernel, (contracting_dims, ((), ()))) amax_list = fp8_meta_pkg.amax_list diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 6655091caa..f386bdce22 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Transformer Engine bindings for JAX""" diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8b13c47cd4..d814c2d4df 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ @@ -8,8 +8,8 @@ import operator from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union -import jax.numpy as jnp import numpy as np +import jax.numpy as jnp from flax import linen as nn from flax.linen import partitioning as nn_partitioning from jax import lax @@ -57,19 +57,15 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga def _create_layernorm_parameters( - layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, dtype + layernorm_type, shape, scale_init, scale_axes, bias_init, bias_axes, input_dtype, dtype ): - scale = nn_partitioning.param_with_axes( - "scale", scale_init, shape, jnp.float32, axes=scale_axes - ) - scale = jnp.asarray(scale, dtype) + scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) + scale = scale.astype(input_dtype) layernorm_type = canonicalize_layernorm_type(layernorm_type) if layernorm_type == "layernorm": - bias = nn_partitioning.param_with_axes( - "ln_bias", bias_init, shape, jnp.float32, axes=bias_axes - ) - bias = jnp.asarray(bias, dtype) + bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) + bias = bias.astype(input_dtype) else: assert layernorm_type == "rmsnorm" bias = None @@ -158,15 +154,15 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp heads = inputs.shape[1] q_seqlen = inputs.shape[2] k_seqlen = inputs.shape[3] - dtype = inputs.dtype + input_dtype = inputs.dtype logits = inputs if self.softmax_type is not SoftmaxType.SCALED and is_softmax_kernel_available( - self.softmax_type, batch, heads, q_seqlen, k_seqlen, inputs.dtype + self.softmax_type, batch, heads, q_seqlen, k_seqlen, input_dtype ): if bias is not None: - logits = logits + bias.astype(dtype) + logits = logits + bias.astype(input_dtype) mask_ = mask if self.softmax_type is not SoftmaxType.SCALED_MASKED: @@ -178,25 +174,27 @@ def __call__(self, inputs: Array, mask: Array = None, bias: Array = None) -> jnp if mask is not None: attention_bias = lax.select( mask > 0, - jnp.full(mask.shape, -1e10).astype(dtype), - jnp.full(mask.shape, 0.0).astype(dtype), + jnp.full(mask.shape, -1e10), + jnp.full(mask.shape, 0.0), ) + attention_bias = attention_bias.astype(input_dtype) if bias is not None: attention_bias = _combine_biases(attention_bias, bias) if attention_bias is not None: - logits = logits + attention_bias.astype(dtype) + logits = logits + attention_bias.astype(input_dtype) # For the case that self.softmax == SoftmaxType.SCALED_UPPER_TRIANG_MASKED # and kernel is unavailable, then try on pure scaled softmax custom calls. if is_softmax_kernel_available( - SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, dtype + SoftmaxType.SCALED, batch, heads, q_seqlen, k_seqlen, input_dtype ): outputs = softmax(logits, None, self.scale_factor, SoftmaxType.SCALED) else: outputs = jax_nn.softmax(logits * self.scale_factor) + assert input_dtype == outputs.dtype return outputs @@ -260,8 +258,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 - the data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = False Indicate whether the input tensors were switched axis of batch and sequence length dimension. If set to True, the input tensors @@ -280,7 +278,8 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma + self.scale_init, + self.zero_centered_gamma, ) super().__post_init__() @@ -299,6 +298,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: outputs : jax.numpy.ndarray Output tensors. """ + input_dtype = x.dtype features = x.shape[-1] scale, ln_bias = _create_layernorm_parameters( @@ -308,9 +308,10 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: self.scale_axes, self.bias_init, self.bias_axes, + input_dtype, self.dtype, ) - return layernorm( + out = layernorm( x, scale, ln_bias, @@ -318,6 +319,8 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: zero_centered_gamma=self.zero_centered_gamma, epsilon=self.epsilon, ) + assert out.dtype == input_dtype + return out class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-methods @@ -401,7 +404,7 @@ class DenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 + dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch @@ -424,7 +427,9 @@ class DenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) super().__post_init__() @nn.compact @@ -443,25 +448,25 @@ def __call__(self, inputs: Array) -> Array: Output tensors. """ + input_dtype = inputs.dtype features = _canonicalize_tuple(self.features) axis = _canonicalize_tuple(self.axis) - inputs = jnp.asarray(inputs, self.dtype) axis = _normalize_axes(axis, inputs.ndim) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - - kernel = jnp.reshape(kernel, kernel_shape) + if not FP8Helper.is_fp8_enabled(): + kernel = kernel.astype(input_dtype) if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) else: bias = None @@ -490,11 +495,11 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - jnp.float32, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) - lora_a_kernel = lora_a_kernel.astype(self.dtype) + lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) @@ -502,10 +507,10 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=lora_b_kernel_axes, ) - lora_b_kernel = lora_b_kernel.astype(self.dtype) + lora_b_kernel = lora_b_kernel.astype(input_dtype) y += _apply_low_rank_adaptation( inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha @@ -514,6 +519,8 @@ def __call__(self, inputs: Array) -> Array: if bias is not None: bias_shape = (1,) * (y.ndim - bias.ndim) + bias.shape y += jnp.reshape(bias, bias_shape) + + assert y.dtype == input_dtype return y @@ -595,7 +602,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 + dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch @@ -633,9 +640,15 @@ class LayerNormDenseGeneral(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, + "fan_in", + "truncated_normal", + dtype=self.dtype, + ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma + self.scale_init, + self.zero_centered_gamma, ) super().__post_init__() @@ -658,6 +671,7 @@ def __call__(self, inputs: Array) -> Array: If :attr:`return_layernorm_output=False`, then this would be None. """ + input_dtype = inputs.dtype ln_output = None fuse_layernorm = ( @@ -679,6 +693,7 @@ def __call__(self, inputs: Array) -> Array: self.scale_axes, self.ln_bias_init, self.ln_bias_axes, + input_dtype, self.dtype, ) @@ -709,10 +724,10 @@ def __call__(self, inputs: Array) -> Array: kernel_shape = tuple(y.shape[ax] for ax in axis) + features kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel = nn_partitioning.param_with_axes( - "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes + "kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes ) - - kernel = jnp.reshape(kernel, kernel_shape) + if not FP8Helper.is_fp8_enabled(): + kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -755,11 +770,11 @@ def __call__(self, inputs: Array) -> Array: "lora_a_kernel", self.kernel_init, lora_a_kernel_init_shape, - jnp.float32, + self.dtype, axes=lora_a_kernel_axes, ) lora_a_kernel = jnp.reshape(lora_a_kernel, lora_a_kernel_shape) - lora_a_kernel = lora_a_kernel.astype(self.dtype) + lora_a_kernel = lora_a_kernel.astype(input_dtype) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) @@ -767,10 +782,10 @@ def __call__(self, inputs: Array) -> Array: "lora_b_kernel", nn.initializers.zeros, lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=lora_b_kernel_axes, ) - lora_b_kernel = lora_b_kernel.astype(self.dtype) + lora_b_kernel = lora_b_kernel.astype(input_dtype) z += _apply_low_rank_adaptation( y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha @@ -779,9 +794,9 @@ def __call__(self, inputs: Array) -> Array: bias = None if self.use_bias: bias = nn_partitioning.param_with_axes( - "bias", self.bias_init, features, jnp.float32, axes=self.bias_axes + "bias", self.bias_init, features, self.dtype, axes=self.bias_axes ) - bias = bias.astype(self.dtype) + bias = bias.astype(input_dtype) if bias is not None: bias_shape = (1,) * (z.ndim - bias.ndim) + bias.shape @@ -790,6 +805,7 @@ def __call__(self, inputs: Array) -> Array: if self.depth_scaling is not None: z = z / self.depth_scaling + assert z.dtype == input_dtype return z, ln_output # dense_output, layer_norm_output @@ -894,7 +910,7 @@ class LayerNormMLP(TransformerEngineBase): Optimization parameters ----------------------- - dtype : jax.numpy.dtype, default = jax.numpy.float32 + dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. transpose_batch_sequence : bool, default = True Indicate whether the input tensors were switched axis of batch @@ -935,9 +951,12 @@ class LayerNormMLP(TransformerEngineBase): def __post_init__(self): if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ) self.scale_init = _obtain_default_layernorm_scale_init_if_need( - self.scale_init, self.zero_centered_gamma + self.scale_init, + self.zero_centered_gamma, ) super().__post_init__() @@ -962,6 +981,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: If :attr:`return_layernorm_output=False`, then this would be None. """ + input_dtype = inputs.dtype ln_output = None fuse_layernorm = ( @@ -1007,6 +1027,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: self.scale_axes, self.ln_bias_init, self.ln_bias_axes, + input_dtype, self.dtype, ) @@ -1033,7 +1054,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): for _ in range(num_kernels): key, init_key = jax_random.split(key) kernels.append(self.kernel_init(init_key, *init_args)) - return jnp.stack(kernels, axis=stack_axis, dtype=jnp.float32) + return jnp.stack(kernels, axis=stack_axis, dtype=self.dtype) wi_fp8_meta_pkg = None wo_fp8_meta_pkg = None @@ -1054,10 +1075,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, kernel_1_each_shape, - jnp.float32, + self.dtype, axes=self.kernel_axes_1, ) kernel_1 = jnp.reshape(kernel_1, kernel_1_shape) + if not FP8Helper.is_fp8_enabled(): + kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] hidden_size_tuple = _canonicalize_tuple(hidden_size) kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple @@ -1066,10 +1089,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_kernel", self.kernel_init, kernel_2_param_shape, - jnp.float32, + self.dtype, axes=self.kernel_axes_2, ) kernel_2 = jnp.reshape(kernel_2, kernel_2_shape) + if not FP8Helper.is_fp8_enabled(): + kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) ffn1_ckpt_name = "ffn1" @@ -1081,15 +1106,23 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: bias_1_shape = intermediate_dim bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, bias_1_shape, jnp.float32, axes=self.bias_axes_1 + "wi_bias", + self.bias_init, + bias_1_shape, + self.dtype, + axes=self.bias_axes_1, ) - bias_1 = bias_1.astype(self.dtype) + bias_1 = bias_1.astype(input_dtype) bias_2_shape = (hidden_size,) bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, bias_2_shape, jnp.float32, axes=self.bias_axes_2 + "wo_bias", + self.bias_init, + bias_2_shape, + self.dtype, + axes=self.bias_axes_2, ) - bias_2 = bias_2.astype(self.dtype) + bias_2 = bias_2.astype(input_dtype) else: bias_1 = None bias_2 = None @@ -1156,11 +1189,11 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): num_activations, -2, wi_lora_a_kernel_init_each_shape, - jnp.float32, + self.dtype, axes=wi_lora_a_kernel_axes, ) wi_lora_a_kernel = jnp.reshape(wi_lora_a_kernel, wi_lora_a_kernel_shape) - wi_lora_a_kernel = wi_lora_a_kernel.astype(self.dtype) + wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype) wi_lora_b_kernel_shape = ( num_activations, @@ -1172,10 +1205,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wi_lora_b_kernel", nn.initializers.zeros, wi_lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=wi_lora_b_kernel_axes, ) - wi_lora_b_kernel = wi_lora_b_kernel.astype(self.dtype) + wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype) x += _apply_low_rank_adaptation( y, @@ -1189,10 +1222,14 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_1 = None if self.use_bias: bias_1 = nn_partitioning.param_with_axes( - "wi_bias", self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1 + "wi_bias", + self.bias_init, + intermediate_dim, + self.dtype, + axes=self.bias_axes_1, ) - bias_1 = bias_1.astype(self.dtype) bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape + bias_1 = bias_1.astype(input_dtype) x += jnp.reshape(bias_1, bias_1_shape) x = checkpoint_name(x, ffn1_ckpt_name) @@ -1207,6 +1244,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): z = functools.reduce(operator.mul, activations) # Remove act axis z = jnp.reshape(z, (*z.shape[:-2], -1)) + z = z.astype(input_dtype) z = nn.Dropout( rate=self.intermediate_dropout_rate, @@ -1215,6 +1253,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): )(z, deterministic=deterministic) z = with_sharding_constraint_by_logical_axes(z, self.dot_2_input_axes) + z = z.astype(input_dtype) # DenseGeneral 2 out = type_safe_dot_general( @@ -1228,10 +1267,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_a_kernel", self.kernel_init, wo_lora_a_kernel_shape, - jnp.float32, + self.dtype, axes=wo_lora_a_kernel_axes, ) - wo_lora_a_kernel = wo_lora_a_kernel.astype(self.dtype) + wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype) wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) @@ -1239,10 +1278,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): "wo_lora_b_kernel", nn.initializers.zeros, wo_lora_b_kernel_shape, - jnp.float32, + self.dtype, axes=wo_lora_b_kernel_axes, ) - wo_lora_b_kernel = wo_lora_b_kernel.astype(self.dtype) + wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype) out += _apply_low_rank_adaptation( z, @@ -1256,11 +1295,16 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_2 = None if self.use_bias: bias_2 = nn_partitioning.param_with_axes( - "wo_bias", self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2 + "wo_bias", + self.bias_init, + (hidden_size,), + self.dtype, + axes=self.bias_axes_2, ) - bias_2 = bias_2.astype(self.dtype) + bias_2 = bias_2.astype(input_dtype) out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out = checkpoint_name(out, ffn2_ckpt_name) + assert out.dtype == input_dtype return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index cb71188221..69fb74ba31 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ @@ -24,7 +24,7 @@ from .module import DenseGeneral, LayerNormDenseGeneral, LayerNormMLP from .module import LayerNorm, Softmax -from ..attention import AttnBiasType, AttnMaskType, QKVLayout +from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import fused_attn from ..softmax import SoftmaxType @@ -142,6 +142,8 @@ def __call__( assert key.shape[-2] == value.shape[-2], "k, v num_attention_heads must match." assert query.shape[-1] == key.shape[-1], "q, k head_dim must match." + input_dtype = query.dtype + if self.scale_factor is None: scale_factor = 1.0 / sqrt(query.shape[-1]) else: @@ -194,15 +196,18 @@ def __call__( if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias - def apply_swa_mask(attn_mask_type: AttnMaskType, original_mask: Array) -> Array: + def apply_swa_mask(original_mask: Array) -> Array: """Apply the sliding window mask to a given mask""" + batch = original_mask.shape[0] max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] - swa_mask = make_swa_mask(max_seqlen_q, max_seqlen_kv, self.window_size, attn_mask_type) - # In swa_mask 0 is masked out, in original_mask 1 is masked out - swa_mask = 1 - swa_mask.astype(original_mask.dtype) - swa_mask_bcast = jnp.broadcast_to(swa_mask, original_mask.shape) - new_mask = jnp.where(original_mask == 0, swa_mask_bcast, original_mask) + # TODO(rewang): Support THD format pos + pos_q = jnp.broadcast_to(jnp.arange(max_seqlen_q), (batch, max_seqlen_q)) + pos_kv = jnp.broadcast_to(jnp.arange(max_seqlen_kv), (batch, max_seqlen_kv)) + # In inv_swa_mask 0 is masked out, in original_mask 1 is masked out + inv_swa_mask = make_swa_mask(pos_q, pos_kv, self.window_size, original_mask.dtype) + swa_mask = 1 - inv_swa_mask + new_mask = jnp.where(original_mask == 0, swa_mask, original_mask) return new_mask def convert_to_softmax_type(attn_mask_type, mask): @@ -213,7 +218,7 @@ def convert_to_softmax_type(attn_mask_type, mask): if attn_mask_type == AttnMaskType.CAUSAL_MASK and self.window_size is None: mask = None if mask is not None: - mask = apply_swa_mask(attn_mask_type, mask) + mask = apply_swa_mask(mask) # Currently cuDNN backend only supports SWA for causal/padding_causal, follow this if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask @@ -230,7 +235,7 @@ def convert_to_softmax_type(attn_mask_type, mask): attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)( attn_weights, mask, bias - ).astype(self.dtype) + ).astype(input_dtype) if is_gqa: attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) @@ -240,9 +245,12 @@ def convert_to_softmax_type(attn_mask_type, mask): dropout_shape = list(attn_weights.shape) # TODO(rewang): add attention dropout broadcast dimension arguments for users keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape) - multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) + multiplier = keep.astype(input_dtype) / jnp.asarray(keep_prob, dtype=input_dtype) attn_weights = attn_weights * multiplier + assert ( + attn_weights.dtype == input_dtype + ), f"output={attn_weights.dtype}, input={input_dtype}" if self.transpose_batch_sequence: if is_gqa: return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape) @@ -250,6 +258,7 @@ def convert_to_softmax_type(attn_mask_type, mask): if is_gqa: return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape) + return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value) @@ -262,6 +271,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me scale_factor: Optional[float] = None transpose_batch_sequence: bool = False window_size: Optional[Tuple[int, int]] = None + max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" @@ -271,7 +281,7 @@ def __call__( query: Array, key: Array, value: Array, - mask: Optional[Array] = None, + sequence_descriptor: Optional[SequenceDescriptor] = None, bias: Optional[Array] = None, *, dropout_rng: Optional[PRNGKey] = None, @@ -288,8 +298,7 @@ def __call__( scale_factor = self.scale_factor del self.scale_factor - # TODO(rewang): integrate THD format - if self.qkv_layout == QKVLayout.BS3HD: + if self.qkv_layout.is_qkvpacked(): """qkvpacked format, treat query: qkvpacked tensor, shape = [..., 3, h, d] key: ignore @@ -301,7 +310,7 @@ def __call__( x = fused_attn( (qkv_packed,), bias, - mask, + sequence_descriptor, seed, attn_mask_type=self.attn_mask_type, attn_bias_type=self.attn_bias_type, @@ -310,10 +319,11 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, ) - elif self.qkv_layout == QKVLayout.BSHD_BS2HD: + elif self.qkv_layout.is_kvpacked(): """kvpacked format, treat query: query tensor, shape = [..., h, d] key: kvpacked tensor, shape = [..., 2, h, d] @@ -326,7 +336,7 @@ def __call__( x = fused_attn( (query, kv_packed), bias, - mask, + sequence_descriptor, seed, attn_mask_type=self.attn_mask_type, attn_bias_type=self.attn_bias_type, @@ -335,10 +345,11 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, ) - elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: + elif self.qkv_layout.is_separate(): if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) key = key.transpose([1, 0, 2, 3]) @@ -346,7 +357,7 @@ def __call__( x = fused_attn( (query, key, value), bias, - mask, + sequence_descriptor, seed, attn_mask_type=self.attn_mask_type, attn_bias_type=self.attn_bias_type, @@ -355,6 +366,7 @@ def __call__( dropout_probability=self.attention_dropout, is_training=not deterministic, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, ) @@ -364,6 +376,7 @@ def __call__( if self.transpose_batch_sequence: x = x.transpose([1, 0, 2, 3]) + assert x.dtype == query.dtype return x @@ -432,6 +445,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + .. note:: THD format only supports 'padding' or 'causal_padding' mask type. + attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. @@ -446,13 +461,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods qkv_layout: str, default = 'bshd_bshd_bshd' Specifies the dimensional layout format for the query, key, and value tensors in __call__(). It indicates how the inputs are processed. - Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd'}. Where + Available options: {'bs3hd', 'bshd_bs2hd', 'bshd_bshd_bshd', 't3hd', 'thd_t2hd', 'thd_thd_thd'}. * bs3hd: query tensor is treated as a qkvpacked tensor with shape = [b, s, 3, h, d]. key and value arguments in :attr:`__call__()` are ignored in this layout. * bshd_bs2hd: query tensor with shape = [b, s, h, d]. key tensor is treaded as a kvpacked tensor with shape = [b, s, 2, h, d]. `value` argument in :attr:`__call__()` is ignored. * bshd_bshd_bshd: query, key, and value are seperated with shape = [b, s, h, d]. + * t3hd/thd_t2hd/thd_thd_thd: Have the same layout as bshd series, but it allows multiple + sequences to be packed in a batch, also known as sequence packing. Explanation of denotations: @@ -471,13 +488,15 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods should be in (seqlen, batch, ...), otherwise (batch, seqlen, ...). window_size: Optional[Tuple[int, int]], default = None Sliding window size. The default value is no sliding window. + max_segments_per_seq: Optional[int], default = 1 + The maximum number of segments per sequence, also used for THD format (sequence packing). context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. Optimization parameters ----------------------- - dtype: jax.numpy.dtype, default = jax.numpy.float32 + dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. """ @@ -494,6 +513,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods scale_factor: Optional[float] = None transpose_batch_sequence: bool = True window_size: Optional[Tuple[int, int]] = None + max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" @@ -503,10 +523,11 @@ def __call__( query: Array, key: Array, value: Array, - mask: Optional[Array] = None, + sequence_descriptor: Optional[Union[SequenceDescriptor, Array]] = None, bias: Optional[Array] = None, *, deterministic: bool = False, + mask: Optional[Union[SequenceDescriptor, Array]] = None, ) -> Array: """ Parameters @@ -533,6 +554,16 @@ def __call__( outputs: jax.numpy.ndarray Output tensors. """ + input_dtype = query.dtype + + if mask is not None: + if sequence_descriptor is not None: + raise ValueError( + "sequence_descriptor and mask cannot be provided at the same time." + ) + warnings.warn("mask is deprecated, please use sequence_descriptor instead.") + sequence_descriptor = mask + del mask # For internal API, we use enum to maintain if self.attn_bias_type is None: @@ -596,16 +627,18 @@ def __call__( if not use_fused_attn: # unfused attention only supports splitted query, key, value - if qkv_layout == QKVLayout.BS3HD: + if qkv_layout.is_qkvpacked(): query, key, value = jnp.split(query, [1, 2], axis=-3) query, key, value = map( functools.partial(jnp.squeeze, axis=-3), [query, key, value] ) - elif qkv_layout == QKVLayout.BSHD_BS2HD: + elif qkv_layout.is_kvpacked(): key, value = jnp.split(key, [1], axis=-3) key, value = map(functools.partial(jnp.squeeze, axis=-3), [key, value]) else: - assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD + assert qkv_layout.is_separate() + + assert sequence_descriptor is None or isinstance(sequence_descriptor, jnp.ndarray) x = _UnfusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -616,7 +649,15 @@ def __call__( scale_factor=scale_factor, transpose_batch_sequence=self.transpose_batch_sequence, window_size=self.window_size, - )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) + )( + query, + key, + value, + sequence_descriptor, + bias, + dropout_rng=dropout_rng, + deterministic=deterministic, + ) else: x = _FusedDotProductAttention( attention_dropout=self.attention_dropout, @@ -627,10 +668,19 @@ def __call__( transpose_batch_sequence=self.transpose_batch_sequence, qkv_layout=qkv_layout, window_size=self.window_size, + max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, - )(query, key, value, mask, bias, dropout_rng=dropout_rng, deterministic=deterministic) - + )( + query, + key, + value, + sequence_descriptor, + bias, + dropout_rng=dropout_rng, + deterministic=deterministic, + ) + assert x.dtype == input_dtype, f"output_dtype={x.dtype}, input_dtype={input_dtype}" return x @@ -671,10 +721,10 @@ def alternate_impl(): sin, cos = generate_sin_cos(time_scales) x1, x2 = jnp.split(x, 2, axis=-1) - part_1 = (x1 * cos - x2 * sin).astype(x.dtype) - part_2 = (x2 * cos + x1 * sin).astype(x.dtype) + part_1 = (x1 * cos - x2 * sin).astype(dtype=x.dtype) + part_2 = (x2 * cos + x1 * sin).astype(dtype=x.dtype) - output = jnp.concatenate([part_1, part_2], axis=-1) + output = jnp.concatenate([part_1, part_2], axis=-1, dtype=x.dtype) return output def consecutive_impl(): @@ -877,8 +927,8 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods Optimization parameters ----------------------- - dtype: jax.numpy.dtype, default = jax.numpy.float32 - The data type used to allocate the initial parameters. + dtype: jax.numpy.dtype, default = jax.numpy.float32 + The data type used for computation. fuse_qkv_params: bool, default = True If set to True, this module exposes a single fused parameter for query-key-value for self-attention and key-value for @@ -973,7 +1023,9 @@ def __post_init__(self): ) if self.kernel_init is None: - self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") + self.kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", dtype=self.dtype + ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @@ -1017,6 +1069,11 @@ def __call__( Output tensors. """ + assert ( + inputs_q.dtype == inputs_kv.dtype + ), f"q.dtype = {inputs_q.dtype}, kv.dtype = {inputs_kv.dtype}" + input_dtype = inputs_q.dtype + def query_init(*args): depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) return self.kernel_init(*args) / (depth_scaling if self.scaled_query_init else 1.0) @@ -1194,8 +1251,11 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): assert ln_out is not None inputs_kv = ln_out + query = query.astype(input_dtype) key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv) + key = key.astype(input_dtype) value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv) + value = value.astype(input_dtype) query = checkpoint_name(query, "query_proj") key = checkpoint_name(key, "key_proj") value = checkpoint_name(value, "value_proj") @@ -1260,7 +1320,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): f"expected query shape {expected_shape} instead got {query.shape}." ) - cur_index = cache_index.value + cur_index = cache_index.value.astype(jnp.int32) one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype) one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape) key = cached_key.value + key * one_hot_indices @@ -1349,6 +1409,9 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): )(x) out = checkpoint_name(out, "out_proj") + assert ( + inputs_q.dtype == out.dtype + ), f"output_dtype={out.dtype}, input_dtype={inputs_q.dtype}" return out, ln_out @@ -1434,7 +1497,7 @@ def __call__(self, q_seqlen, k_seqlen, bidirectional=True): "rel_embedding", self.embedding_init, (self.num_attention_heads, self.num_buckets), - jnp.float32, + self.dtype, axes=self.embedding_axes, ) @@ -1670,10 +1733,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods def __post_init__(self): if self.mha_kernel_init is None: - self.mha_kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal") + self.mha_kernel_init = nn.initializers.variance_scaling( + 1.0, "fan_in", "normal", dtype=self.dtype + ) if self.mlp_kernel_init is None: self.mlp_kernel_init = nn.initializers.variance_scaling( - 1.0, "fan_in", "truncated_normal" + 1.0, "fan_in", "truncated_normal", dtype=self.dtype ) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads @@ -1723,6 +1788,7 @@ def __call__( outputs: jax.numpy.ndarray Output tensors. """ + input_dtype = inputs.dtype assert ( self.layer_type in TransformerLayerType ), f"layer_type should be one of TransformerLayerType, but got {self.layer_type}." @@ -1760,7 +1826,9 @@ def generate_batch_seqlen_logical_axes(is_shared_seq=None): max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, - embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"), + embedding_init=nn.initializers.variance_scaling( + 1.0, "fan_avg", "uniform", dtype=self.dtype + ), name="relpos_bias", ) else: @@ -1987,5 +2055,5 @@ def hidden_dropout(x, deterministic): dtype=self.dtype, name="output_layernorm", )(z) - + assert z.dtype == input_dtype, f"output_dtype={z.dtype}, input_dtype={input_dtype}" return z diff --git a/transformer_engine/jax/fp8.py b/transformer_engine/jax/fp8.py index 5df8ce4386..04ac6dd57d 100644 --- a/transformer_engine/jax/fp8.py +++ b/transformer_engine/jax/fp8.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ @@ -14,9 +14,9 @@ from flax.core.frozen_dict import FrozenDict from flax.linen import fp8_ops -from transformer_engine.transformer_engine_jax import DType -from transformer_engine.transformer_engine_jax import get_cublasLt_version -from transformer_engine.transformer_engine_jax import ( +from transformer_engine_jax import DType +from transformer_engine_jax import get_cublasLt_version +from transformer_engine_jax import ( get_cuda_version, get_device_compute_capability, ) @@ -354,11 +354,6 @@ def fp8_autocast( assert ( fp8_recipe.scaling_factor_compute_algo is None ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." - assert fp8_recipe.override_linear_precision == ( - False, - False, - False, - ), "DelayedScaling override_linear_precision isn't supported by TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." if mesh_resource is None: diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 4f2e83d9a2..2f120443dd 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX layernorm modules""" diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index bbf0b0f52b..c2d76c1fd3 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX MLP modules""" diff --git a/transformer_engine/jax/praxis/__init__.py b/transformer_engine/jax/praxis/__init__.py index 5be51a6d71..5352f1f53b 100644 --- a/transformer_engine/jax/praxis/__init__.py +++ b/transformer_engine/jax/praxis/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Praxis related Modules""" diff --git a/transformer_engine/jax/praxis/module.py b/transformer_engine/jax/praxis/module.py index b82c0915e4..ce407f94fc 100644 --- a/transformer_engine/jax/praxis/module.py +++ b/transformer_engine/jax/praxis/module.py @@ -1,9 +1,10 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ Praxis Modules """ +from dataclasses import field from functools import partial from typing import Callable, Iterable, Sequence, Tuple, Union @@ -74,7 +75,9 @@ class LayerNorm(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () transpose_batch_sequence: bool = False @@ -129,7 +132,9 @@ class Linear(TransformerEngineBaseLayer): out_features: int = 512 kernel_axes: Tuple[str, ...] = () use_bias: bool = True - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -174,11 +179,15 @@ class LayerNormLinear(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -237,12 +246,16 @@ class LayerNormMLP(TransformerEngineBaseLayer): zero_centered_gamma: bool = False scale_init: WeightInit = None scale_axes: Tuple[str, ...] = () - ln_bias_init: WeightInit = WeightInit.Constant(1.0) + ln_bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=1.0) + ) ln_bias_axes: Tuple[str, ...] = () kernel_axes_1: Tuple[str, ...] = () kernel_axes_2: Tuple[str, ...] = () use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) bias_axes_1: Tuple[str, ...] = () bias_axes_2: Tuple[str, ...] = () enable_low_rank_adaptation: bool = False diff --git a/transformer_engine/jax/praxis/transformer.py b/transformer_engine/jax/praxis/transformer.py index f2ac802f10..f441834355 100644 --- a/transformer_engine/jax/praxis/transformer.py +++ b/transformer_engine/jax/praxis/transformer.py @@ -1,9 +1,10 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ Praxis Modules related Transformer """ +from dataclasses import field from functools import partial from typing import Optional, Sequence, Tuple import warnings @@ -138,7 +139,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer): zero_centered_gamma: bool = False return_layernorm_output: bool = False use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) attn_mask_type: str = "causal" attn_bias_type: Optional[str] = None enable_rotary_pos_emb: bool = False @@ -275,7 +278,9 @@ class TransformerLayer(TransformerEngineBaseLayer): dropout_rng_name: str = "dropout" mlp_activations: Sequence[str] = ("relu",) use_bias: bool = False - bias_init: WeightInit = WeightInit.Constant(0.0) + bias_init: WeightInit = field( # pylint: disable=invalid-field-call + default_factory=partial(WeightInit.Constant, scale=0.0) + ) apply_residual_connection_post_layernorm: bool = False output_layernorm: bool = False float32_attention_logits: bool = False diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index c2219e3ba9..4f5cc4df20 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -37,7 +37,7 @@ from pybind11.setup_helpers import build_ext as BuildExtension os.environ["NVTE_PROJECT_BUILDING"] = "1" -CMakeBuildExtension = get_build_ext(BuildExtension) +CMakeBuildExtension = get_build_ext(BuildExtension, True) if __name__ == "__main__": diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index f2da288be5..c24e550198 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """ diff --git a/transformer_engine/jax/softmax.py b/transformer_engine/jax/softmax.py index c63ee85e5d..9b32002388 100644 --- a/transformer_engine/jax/softmax.py +++ b/transformer_engine/jax/softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """JAX softmax modules""" diff --git a/transformer_engine/paddle/MANIFEST.in b/transformer_engine/paddle/MANIFEST.in deleted file mode 100644 index 0c814f95da..0000000000 --- a/transformer_engine/paddle/MANIFEST.in +++ /dev/null @@ -1,3 +0,0 @@ -recursive-include build_tools *.* -recursive-include common_headers *.* -recursive-include csrc *.* diff --git a/transformer_engine/paddle/__init__.py b/transformer_engine/paddle/__init__.py deleted file mode 100644 index 50cf2186d6..0000000000 --- a/transformer_engine/paddle/__init__.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Transformer Engine bindings for Paddle""" - -# pylint: disable=wrong-import-position,wrong-import-order - -import logging -from importlib.metadata import version - -from transformer_engine.common import is_package_installed - - -def _load_library(): - """Load shared library with Transformer Engine C extensions""" - module_name = "transformer_engine_paddle" - - if is_package_installed(module_name): - assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`." - assert is_package_installed( - "transformer_engine_cu12" - ), "Could not find `transformer-engine-cu12`." - assert ( - version(module_name) - == version("transformer-engine") - == version("transformer-engine-cu12") - ), ( - "TransformerEngine package version mismatch. Found" - f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" - " transformer-engine[paddle]==VERSION'" - ) - - if is_package_installed("transformer-engine-cu12"): - if not is_package_installed(module_name): - logging.info( - "Could not find package %s. Install transformer-engine using 'pip" - " install transformer-engine[paddle]==VERSION'", - module_name, - ) - - from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import - - -_load_library() -from .fp8 import fp8_autocast -from .layer import ( - Linear, - LayerNorm, - LayerNormLinear, - LayerNormMLP, - FusedScaleMaskSoftmax, - DotProductAttention, - MultiHeadAttention, - TransformerLayer, - RotaryPositionEmbedding, -) -from .recompute import recompute diff --git a/transformer_engine/paddle/constants.py b/transformer_engine/paddle/constants.py deleted file mode 100644 index 69d3859b8f..0000000000 --- a/transformer_engine/paddle/constants.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Constants""" - -from enum import Enum - -import paddle - -from transformer_engine import transformer_engine_paddle as tex - - -class FP8FwdTensors(Enum): - """Used as named indices on the `scale`, `scale_inv`, - and `amax` tensors in the `FP8TensorMeta` class.""" - - GEMM1_INPUT = 0 - GEMM1_WEIGHT = 1 - GEMM1_OUTPUT = 2 - GEMM2_INPUT = 3 - GEMM2_WEIGHT = 4 - GEMM2_OUTPUT = 5 - - -class FP8BwdTensors(Enum): - """Used as named indices on the `scale`, `scale_inv`, - and `amax` tensors in the `FP8TensorMeta` class.""" - - GRAD_OUTPUT1 = 0 - GRAD_INPUT1 = 1 - GRAD_OUTPUT2 = 2 - GRAD_INPUT2 = 3 - - -""" -Map from paddle dtype to TE dtype -""" -TE_DType = { - paddle.uint8: tex.DType.kByte, - paddle.int32: tex.DType.kInt32, - paddle.float32: tex.DType.kFloat32, - paddle.float16: tex.DType.kFloat16, - paddle.bfloat16: tex.DType.kBFloat16, -} - -AttnMaskTypes = ("causal", "padding", "no_mask") - -AttnTypes = ("self", "cross") - -LayerTypes = ("encoder", "decoder") - -GemmParallelModes = ("row", "column", None) - -dist_group_type = paddle.distributed.collective.Group - -RecomputeFunctionNames = ("unpack", "backward") - -AttnBiasType = { - "no_bias": tex.NVTE_Bias_Type.NVTE_NO_BIAS, - "pre_scale_bias": tex.NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS, - "post_scale_bias": tex.NVTE_Bias_Type.NVTE_POST_SCALE_BIAS, -} - -AttnMaskType = { - "no_mask": tex.NVTE_Mask_Type.NVTE_NO_MASK, - "padding": tex.NVTE_Mask_Type.NVTE_PADDING_MASK, - "causal": tex.NVTE_Mask_Type.NVTE_CAUSAL_MASK, -} - -FusedAttnBackend = { - "F16_max512_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, - "F16_arbitrary_seqlen": tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - "No_Backend": tex.NVTE_Fused_Attn_Backend.NVTE_No_Backend, -} diff --git a/transformer_engine/paddle/cpp_extensions.py b/transformer_engine/paddle/cpp_extensions.py deleted file mode 100644 index 281be66a8c..0000000000 --- a/transformer_engine/paddle/cpp_extensions.py +++ /dev/null @@ -1,1199 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""TE FP8 extensions and GEMMs""" - -import math -from typing import Optional, Tuple, Union -import paddle -import paddle.nn.functional as F -from transformer_engine import transformer_engine_paddle as tex -from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors -from .fp8 import FP8TensorMeta, get_global_fp8_state - -BACKEND_F16m512_THREADS_PER_CTA = 128 -BACKEND_F16arb_ELTS_PER_THREADS = 16 - - -def gemm( - A: paddle.Tensor, - B: paddle.Tensor, - dtype: paddle.dtype, - workspace: paddle.Tensor, - gelu: bool = False, - gelu_input: Optional[paddle.Tensor] = None, - grad: bool = False, - accumulate: bool = False, - layout: str = "TN", - out: Optional[paddle.Tensor] = None, - out_dtype: Optional[paddle.dtype] = None, - bias: Optional[paddle.Tensor] = None, - use_bias: bool = False, -) -> Tuple[Union[paddle.Tensor, None], ...]: - """Non FP8 GEMM.""" - - assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." - transa = layout[0] == "T" - transb = layout[1] == "T" - - if out is None: - if accumulate: - out = paddle.zeros( - shape=[ - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - ], - dtype=out_dtype if out_dtype is not None else dtype, - ) - else: - out = paddle.empty( - shape=[ - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - ], - dtype=out_dtype if out_dtype is not None else dtype, - ) - - if gelu and not grad: - gelu_input = paddle.empty_like(out, dtype=dtype) - elif not gelu: - gelu_input = None - - if grad and use_bias: - grad_bias = paddle.empty(shape=[B.shape[1]], dtype=out.dtype) - else: - grad_bias = None - - bias = bias if use_bias else None - - assert ( - A.dtype == dtype and B.dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out.dtype] - if use_bias: - bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype] - else: - bias_dtype = output_dtype - - tex.te_gemm( - A, - None, - B, - None, - grad_bias if grad else bias, - out, - None, # out_scale - None, # out_amax - gelu_input, - workspace, - 0, # A_index - 0, # B_index - 0, # D_index - int(input_dtype), - int(input_dtype), - int(output_dtype), - int(bias_dtype), - transa, - transb, - grad, - workspace.shape[0], - accumulate, - False, # use_split_accumulator - 0, # math_sm_count - ) - - return out, grad_bias, gelu_input - - -def fp8_gemm( - A: paddle.Tensor, - A_scale_inv: paddle.Tensor, - A_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - A_dtype: tex.DType, - B: paddle.Tensor, - B_scale_inv: paddle.Tensor, - B_fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - B_dtype: tex.DType, - out_dtype: paddle.dtype, - workspace: paddle.Tensor, - gelu: bool = False, - accumulate: bool = False, - out: Optional[paddle.Tensor] = None, - out_index=None, - fp8_meta_tensor: FP8TensorMeta = None, - bias: Optional[paddle.Tensor] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> paddle.Tensor: - """TN layout GEMM with fp8 inputs.""" - - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_index is not None - - if out is None: - if accumulate: - out = paddle.zeros( - shape=[ - B.shape[0], - A.shape[0], - ], - dtype=out_dtype, - ) - else: - out = paddle.empty( - shape=[ - B.shape[0], - A.shape[0], - ], - dtype=out_dtype, - ) - - # Use bfloat16 as default bias_dtype - bias_dtype = paddle.bfloat16 if bias is None else bias.dtype - if gelu: - gelu_input = paddle.empty_like(out, dtype=bias_dtype) - else: - gelu_input = None - bias_dtype = TE_DType[bias_dtype] - - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype - - tex.te_gemm( - A, - A_scale_inv, - B, - B_scale_inv, - bias if use_bias else None, - out, - None if out_index is None else fp8_meta_tensor.scale, - None if out_index is None else fp8_meta_tensor.amax_history, - gelu_input, # this is pre_gelu_out - workspace, - A_fp8_tensor.value, - B_fp8_tensor.value, - 0 if out_index is None else out_index, - int(A_dtype), - int(B_dtype), - int(out_dtype), - int(bias_dtype), - True, # transa - False, # transb - False, # grad - workspace.shape[0], - accumulate, - use_split_accumulator, - 0, # math_sm_count - ) - - return out, gelu_input - - -def cast_to_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - out: Optional[paddle.Tensor] = None, -) -> paddle.Tensor: - """Cast input to FP8""" - if out is None: - out = paddle.empty( - shape=inp.shape, - dtype=paddle.uint8, - ) - else: - assert out.shape == inp.shape, "Output shape does not match input shape." - assert out.dtype == paddle.uint8, "Output should be of uint8 dtype." - - tex.cast_to_fp8( - inp, - fp8_meta_tensor.scale, - out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - return out - - -def cast_from_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - itype: tex.DType, - otype: tex.DType, -) -> paddle.Tensor: - """Cast input from FP8""" - return tex.cast_from_fp8( - inp, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(itype), - int(otype), - ) - - -def transpose( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Transpose input""" - return tex.te_transpose( - inp, - int(otype), - ) - - -def cast_transpose( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - cast_out: Optional[paddle.Tensor] = None, - transpose_out: Optional[paddle.Tensor] = None, -) -> Union[Tuple[paddle.Tensor, paddle.Tensor], None]: - """Cast + Transpose with FP8 output""" - if cast_out is None: - cast_out = paddle.empty( - shape=inp.shape, - dtype=paddle.uint8, - ) - else: - assert cast_out.shape == inp.shape, "cast_out shape does not match input shape." - assert cast_out.dtype == paddle.uint8, "cast_out should be of uint8 dtype." - - if transpose_out is None: - transpose_out = paddle.empty( - shape=[inp.shape[1], inp.shape[0]], - dtype=paddle.uint8, - ) - else: - assert transpose_out.shape == [ - inp.shape[1], - inp.shape[0], - ], "Transposed output shape does not match input shape." - assert transpose_out.dtype == paddle.uint8, "Output should be of uint8 dtype." - - tex.te_cast_transpose( - inp, - fp8_meta_tensor.scale, - cast_out, - transpose_out, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return cast_out, transpose_out - - -def cast_transpose_bgrad( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> Union[Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor], None]: - """Fused Cast + Transpose + Bias Grad""" - grad_bias, cast_out, transpose_out, _, _ = tex.te_cast_transpose_bgrad( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return grad_bias, cast_out, transpose_out - - -def te_gelu( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Non FP8 GELU""" - return tex.te_gelu( - inp, - int(otype), - ) - - -def gelu_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> paddle.Tensor: - """GELU + FP8 cast""" - out, _, _ = tex.te_gelu_fp8( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return out - - -def swiglu( - inp: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """Non FP8 SWIGLU""" - return tex.te_swiglu( - inp, - int(otype), - ) - - -def swiglu_pd( - inp: paddle.Tensor, -) -> paddle.Tensor: - """Native SWIGLU""" - gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1) - out = F.silu(gate_out) * up_out - return out - - -def swiglu_fp8( - inp: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> paddle.Tensor: - """SWIGLU + FP8 cast""" - out, _, _ = tex.te_swiglu_fp8( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return out - - -def dswiglu( - grad_output: paddle.Tensor, - swiglu_input: paddle.Tensor, - otype: tex.DType, -) -> paddle.Tensor: - """dSWIGLU""" - return tex.te_dswiglu( - grad_output, - swiglu_input, - int(otype), - ) - - -def dgelu_cast_transpose_bgrad_fp8( - grad_output: paddle.Tensor, - gelu_input: paddle.Tensor, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """ - Fused dgelu + cast / transpose / reduce the result of - the GELU backward along the first dimension - """ - cast_dgelu, transpose_dgelu, dbias, _, _ = tex.te_cast_transpose_bgrad_dgelu( - grad_output, - gelu_input, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor.value, - int(otype), - ) - - return cast_dgelu, transpose_dgelu, dbias - - -def layernorm_fwd_fp8( - inp: paddle.Tensor, - weight: paddle.Tensor, - bias: paddle.Tensor, - eps: float, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """LayerNorm with FP8 output""" - out, mu, rsigma, _, _ = tex.te_layernorm_fwd_fp8( - inp, - weight, - bias, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - eps, - fp8_tensor.value, - int(otype), - sm_margin, - zero_centered_gamma, - ) - return out, mu, rsigma - - -def layernorm_fwd( - inp: paddle.Tensor, - weight: paddle.Tensor, - bias: paddle.Tensor, - eps: float, - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 LayerNorm forward""" - return tex.te_layernorm_fwd(inp, weight, bias, eps, int(otype), sm_margin, zero_centered_gamma) - - -def layernorm_bwd( - dz: paddle.Tensor, - x: paddle.Tensor, - mu: paddle.Tensor, - rsigma: paddle.Tensor, - gamma: paddle.Tensor, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 LayerNorm backward""" - return tex.te_layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) - - -def rmsnorm_fwd( - inp: paddle.Tensor, - weight: paddle.Tensor, - eps: float, - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 RMSNorm forward""" - return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma) - - -def rmsnorm_fwd_fp8( - inp: paddle.Tensor, - weight: paddle.Tensor, - eps: float, - fp8_meta_tensor: FP8TensorMeta, - fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors], - otype: tex.DType, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """RMSNorm with FP8 output""" - out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8( - inp, - weight, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - eps, - fp8_tensor.value, - int(otype), - sm_margin, - zero_centered_gamma, - ) - return out, rsigma - - -def rmsnorm_bwd( - dz: paddle.Tensor, - x: paddle.Tensor, - rsigma: paddle.Tensor, - gamma: paddle.Tensor, - sm_margin: int = 0, - zero_centered_gamma: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Non-FP8 RMSNorm backward""" - return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) - - -def mask_to_cu_seqlens( - mask: paddle.Tensor, - need_kv: bool = False, -) -> paddle.Tensor: - """Convert mask to cu_seqlens""" - # mask shape: [b, 1, s_q, s_kv] - if get_global_fp8_state().is_cudagraph_enabled(): - raise RuntimeError("mask_to_cu_seqlens is not supported with cuda graphs.") - q_seqlen, kv_seqlen = mask.shape[2], mask.shape[3] - q_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32) - q_cu_seqlens[0] = 0 - kv_cu_seqlens = None - if need_kv: - kv_cu_seqlens = paddle.empty(shape=[mask.shape[0] + 1], dtype=paddle.int32) - kv_cu_seqlens[0] = 0 - tex.mask_to_cu_seqlens(mask, q_cu_seqlens, kv_cu_seqlens, q_seqlen, kv_seqlen, need_kv) - return q_cu_seqlens, kv_cu_seqlens - - -def fused_attn_fwd_qkvpacked( - qkv: paddle.Tensor, - cu_seqlens: paddle.Tensor, - is_training: bool, - max_seqlen: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for packed QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - - b = cu_seqlens.shape[0] - 1 - total_seqs = qkv.shape[0] * qkv.shape[1] - h = qkv.shape[3] - d = qkv.shape[4] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen, - max_seqlen, - ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert Bias.dtype == qkv.dtype, "bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen, h, d], fill_value=0, dtype=qkv.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen, h, d], dtype=qkv.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd_qkvpacked( - qkv, - cu_seqlens, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - total_seqs, - max_seqlen, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - return out, softmax_aux, rng_state - - -def fused_attn_bwd_qkvpacked( - qkv: paddle.Tensor, - cu_seqlens: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bs3hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - - b = cu_seqlens.shape[0] - 1 - total_seqs = qkv.shape[0] * qkv.shape[1] - h = qkv.shape[3] - d = qkv.shape[4] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dqkv = paddle.full(shape=qkv.shape, fill_value=0, dtype=qkv.dtype) - else: - dqkv = paddle.empty(shape=qkv.shape, dtype=qkv.dtype) - - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen, max_seqlen], dtype=qkv.dtype) - else: - dbias = None - # execute kernel - dqkv, dbias = tex.te_fused_attn_bwd_qkvpacked( - qkv, - cu_seqlens, - o, - d_o, - softmax_aux, - dqkv, - dbias, - rng_state, - b, - h, - d, - total_seqs, - max_seqlen, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - - return dqkv, dbias - - -def fused_attn_fwd_kvpacked( - q: paddle.Tensor, - kv: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bs2hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - - b = cu_seqlens_q.shape[0] - 1 - total_seqs_q = q.shape[0] * q.shape[1] - total_seqs_kv = kv.shape[0] * kv.shape[1] - h = q.shape[2] - d = q.shape[3] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen_q, - max_seqlen_kv, - ], "bias tensor must be in [1, h, max_seqlen, max_seqlen] shape." - assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - total_seqs_q, - total_seqs_kv, - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - - return out, softmax_aux, rng_state - - -def fused_attn_bwd_kvpacked( - q: paddle.Tensor, - kv: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bs2hd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - - b = cu_seqlens_q.shape[0] - 1 - total_seqs_q = q.shape[0] * q.shape[1] - total_seqs_kv = kv.shape[0] * kv.shape[1] - h = q.shape[2] - d = q.shape[3] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) - dkv = paddle.full(shape=kv.shape, fill_value=0, dtype=kv.dtype) - else: - dq = paddle.empty(shape=q.shape, dtype=q.dtype) - dkv = paddle.empty(shape=kv.shape, dtype=kv.dtype) - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = None - # execute kernel - tex.te_fused_attn_bwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - o, - d_o, - softmax_aux, - dq, - dkv, - dbias, - rng_state, - b, - h, - d, - total_seqs_q, - total_seqs_kv, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - return dq, dkv, dbias - - -def fused_attn_fwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - Bias: paddle.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bshd_bshd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", -) -> Tuple[paddle.Tensor, paddle.Tensor]: - """Fused Attention FWD for unpacked QKV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert ( - qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." - b = cu_seqlens_q.shape[0] - 1 - - h = q.shape[-2] - d = q.shape[-1] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - if bias_type != "no_bias": - assert Bias is not None, "bias tensor cannot be None when bias_type is not no_bias." - assert Bias.shape == [ - 1, - h, - max_seqlen_q, - max_seqlen_kv, - ], "bias tensor must be in [1, h, max_seqlen_q, max_seqlen_kv] shape." - assert Bias.dtype == q.dtype, "bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - rng_elts_per_thread = None - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_THREADS_PER_CTA - - # BF16/FP16 fused attention API from fmha_v2 - if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - out = paddle.full(shape=[b, max_seqlen_q, h, d], fill_value=0, dtype=q.dtype) - else: - out = paddle.empty(shape=[b, max_seqlen_q, h, d], dtype=q.dtype) - - if is_training: - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - softmax_aux = paddle.empty(shape=[b, h, max_seqlen_q, 1], dtype="float32") - else: - raise ValueError("Unsupported fused attention backend.") - else: - softmax_aux = None - - rng_state = paddle.empty( - shape=[ - 2, - ], - dtype=paddle.int64, - ) - - # execute kernel - tex.te_fused_attn_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - Bias, - out, - softmax_aux, - rng_state, - b, - h, - d, - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - rng_elts_per_thread, - ) - return out, softmax_aux, rng_state - - -def fused_attn_bwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_kv: paddle.Tensor, - rng_state: paddle.Tensor, - o: paddle.Tensor, - d_o: paddle.Tensor, - softmax_aux: paddle.Tensor, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - max_seqlen_q: int, - max_seqlen_kv: int, - qkv_dtype: tex.DType, - attn_scale: float = None, - dropout: float = 0.0, - set_zero: bool = True, - qkv_layout: str = "bshd_bshd_bshd", - bias_type: str = "no_bias", - attn_mask_type: str = "padding", - deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - """Fused Attention BWD for packed KV input""" - - assert qkv_dtype in ( - tex.DType.kBFloat16, - tex.DType.kFloat16, - ), "Only support bf16/fp16 for fused attention." - assert ( - cu_seqlens_q.shape == cu_seqlens_kv.shape - ), "cu_seqlens_q and cu_seqlens_kv must have the same shape" - assert ( - qkv_layout == "bshd_bshd_bshd" - ), "Only support bshd_bshd_bshd layout for unpacked QKV input for now." - - b = cu_seqlens_q.shape[0] - 1 - h = q.shape[-2] - d = q.shape[-1] - - if attn_scale is None: - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - set_zero = True - if set_zero: - dq = paddle.full(shape=q.shape, fill_value=0, dtype=q.dtype) - dk = paddle.full(shape=k.shape, fill_value=0, dtype=k.dtype) - dv = paddle.full(shape=v.shape, fill_value=0, dtype=v.dtype) - else: - dq = paddle.empty(shape=q.shape, dtype=q.dtype) - dk = paddle.empty(shape=k.shape, dtype=k.dtype) - dv = paddle.empty(shape=v.shape, dtype=v.dtype) - if bias_type != "no_bias": - if qkv_format == "thd": - dbias = paddle.zero(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = paddle.empty(shape=[1, h, max_seqlen_q, max_seqlen_kv], dtype=q.dtype) - else: - dbias = None - # execute kernel - tex.te_fused_attn_bwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - o, - d_o, - softmax_aux, - dq, - dk, - dv, - dbias, - rng_state, - b, - h, - d, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - qkv_layout, - bias_type, - attn_mask_type, - int(qkv_dtype), - deterministic, - ) - return dq, dk, dv, dbias - - -def scaled_softmax_forward( - inp: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled softmax forward""" - return tex.te_scaled_softmax_forward(inp, scale_factor) - - -def scaled_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled softmax backward""" - tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad - - -def scaled_masked_softmax_forward( - inp: paddle.Tensor, - mask: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled masked softmax forward""" - - return tex.te_scaled_masked_softmax_forward(inp, mask, scale_factor) - - -def scaled_masked_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled masked softmax backward""" - tex.te_scaled_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad - - -def scaled_upper_triang_masked_softmax_forward( - inp: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled upper triang masked softmax forward""" - return tex.te_scaled_upper_triang_masked_softmax_forward(inp, scale_factor) - - -def scaled_upper_triang_masked_softmax_backward( - out_grad: paddle.Tensor, - softmax_results: paddle.Tensor, - scale_factor: float, -) -> paddle.Tensor: - """scaled upper triang masked softmax backward""" - tex.te_scaled_upper_triang_masked_softmax_backward(out_grad, softmax_results, scale_factor) - return out_grad diff --git a/transformer_engine/paddle/csrc/common.cpp b/transformer_engine/paddle/csrc/common.cpp deleted file mode 100644 index 5e35a28a6b..0000000000 --- a/transformer_engine/paddle/csrc/common.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common.h" - -namespace transformer_engine { -namespace paddle_ext { - -TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, - const DType type) { - return TensorWrapper(const_cast(data_ptr), shape, type); -} - -TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type) { - return TensorWrapper(data_ptr, shape, type); -} - -TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, - void *amax_ptr, void *scale_ptr, void *scale_inv_ptr) { - return TensorWrapper(data_ptr, shape, type, reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), - reinterpret_cast(scale_inv_ptr)); -} - -TensorWrapper MakeNvteTensor(paddle::Tensor &tensor) { // NOLINT - return MakeNvteTensor(tensor.data(), GetShapeArray(tensor), Paddle2NvteDType(tensor.dtype())); -} - -TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor) { - return MakeNvteTensor(const_cast(tensor.data()), GetShapeArray(tensor), - Paddle2NvteDType(tensor.dtype())); -} - -paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, - bool init_to_zeros) { - auto size = shape.ndim; - if (size == 2 && init_to_zeros) { - return paddle::zeros({static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 2) { - return paddle::empty({static_cast(shape.data[0]), static_cast(shape.data[1])}, - Nvte2PaddleDType(type), place); - } else if (size == 1 && init_to_zeros) { - return paddle::zeros({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } else if (size == 1) { - return paddle::empty({static_cast(shape.data[0])}, Nvte2PaddleDType(type), place); - } - NVTE_CHECK(false, "Should never reach here! func: AllocateSpace"); -} - -// MHA utils -// convert QKV layout to enum -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout) { - static const std::unordered_map layout_map = { - {"sb3hd", NVTE_QKV_Layout::NVTE_SB3HD}, - {"sbh3d", NVTE_QKV_Layout::NVTE_SBH3D}, - {"sbhd_sb2hd", NVTE_QKV_Layout::NVTE_SBHD_SB2HD}, - {"sbhd_sbh2d", NVTE_QKV_Layout::NVTE_SBHD_SBH2D}, - {"sbhd_sbhd_sbhd", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD}, - {"bs3hd", NVTE_QKV_Layout::NVTE_BS3HD}, - {"bsh3d", NVTE_QKV_Layout::NVTE_BSH3D}, - {"bshd_bs2hd", NVTE_QKV_Layout::NVTE_BSHD_BS2HD}, - {"bshd_bsh2d", NVTE_QKV_Layout::NVTE_BSHD_BSH2D}, - {"bshd_bshd_bshd", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD}, - {"t3hd", NVTE_QKV_Layout::NVTE_T3HD}, - {"th3d", NVTE_QKV_Layout::NVTE_TH3D}, - {"thd_t2hd", NVTE_QKV_Layout::NVTE_THD_T2HD}, - {"thd_th2d", NVTE_QKV_Layout::NVTE_THD_TH2D}, - {"thd_thd_thd", NVTE_QKV_Layout::NVTE_THD_THD_THD}, - }; - - auto it = layout_map.find(qkv_layout); - if (it != layout_map.end()) { - return it->second; - } else { - NVTE_ERROR("Invalid QKV layout string: " + qkv_layout); - } -} - -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h deleted file mode 100644 index 6ce250432a..0000000000 --- a/transformer_engine/paddle/csrc/common.h +++ /dev/null @@ -1,186 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "common/util/logging.h" -#include "paddle/extension.h" -#include "paddle/phi/backends/all_context.h" - -namespace transformer_engine { -namespace paddle_ext { -// Paddle Tensor Utils -template -inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) { - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); -} - -template -inline void *GetDataPtr(paddle::Tensor &x, int64_t index) { // NOLINT - if (index < 0 || index >= x.numel()) { - NVTE_ERROR("Index out of bound"); - } - return reinterpret_cast(x.data() + static_cast(index)); -} - -template -inline const void *GetOptionalDataPtr(const paddle::optional &x, int64_t index) { - return x ? GetDataPtr(*x, index) : nullptr; -} - -template -inline void *GetOptionalDataPtr(paddle::optional &x, int64_t index) { // NOLINT - return x ? GetDataPtr(*x, index) : nullptr; -} - -inline const void *GetOptionalDataPtr(const paddle::optional &x) { - return x ? x->data() : nullptr; -} - -inline void *GetOptionalDataPtr(paddle::optional &x) { // NOLINT - return x ? x->data() : nullptr; -} - -inline std::vector GetShapeArray(const paddle::Tensor &x) { - std::vector shapes; - for (auto dim : x.shape()) { - shapes.push_back(static_cast(dim)); - } - return shapes; -} - -inline std::vector GetShapeArray(const paddle::optional &x) { - if (x) return GetShapeArray(x.get()); - return {0}; -} - -paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place, - bool init_to_zeros = 0); - -// DType Utils -inline paddle::DataType Nvte2PaddleDType(DType t) { - switch (t) { - case DType::kInt32: - case DType::kFloat32: - return paddle::DataType::FLOAT32; - case DType::kFloat16: - return paddle::DataType::FLOAT16; - case DType::kBFloat16: - return paddle::DataType::BFLOAT16; - case DType::kByte: - case DType::kFloat8E4M3: - case DType::kFloat8E5M2: - return paddle::DataType::UINT8; - default: - NVTE_ERROR("Invalid type"); - } -} - -inline DType Paddle2NvteDType(paddle::DataType t) { - switch (t) { - case paddle::DataType::FLOAT16: - return DType::kFloat16; - case paddle::DataType::FLOAT32: - return DType::kFloat32; - case paddle::DataType::BFLOAT16: - return DType::kBFloat16; - case paddle::DataType::BOOL: - return DType::kByte; - case paddle::DataType::UINT8: - return DType::kByte; - case paddle::DataType::INT32: - return DType::kInt32; - case paddle::DataType::INT64: - return DType::kInt64; - default: - NVTE_ERROR("Invalid type"); - } -} - -inline DType Int2NvteDType(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } -} - -// get the fused attention backend -inline NVTE_Fused_Attn_Backend get_fused_attn_backend( - const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim) { - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim, head_dim, -1, -1); - return fused_attention_backend; -} - -// CUDA Utils -class cudaDevicePropertiesManager { - public: - static cudaDevicePropertiesManager &Instance() { - static thread_local cudaDevicePropertiesManager instance; - return instance; - } - - int GetMultiProcessorCount() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.multiProcessorCount; - } - - int GetMajor() { - if (!prop_queried_) { - int device_id; - NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); - cudaGetDeviceProperties(&prop_, device_id); - prop_queried_ = true; - } - return prop_.major; - } - - private: - bool prop_queried_ = false; - cudaDeviceProp prop_; -}; - -// NVTE Tensor Utils -TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector &shape, - const DType type); -TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type); -TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector &shape, const DType type, - void *amax_ptr, void *scale_ptr, void *scale_inv_ptr); -TensorWrapper MakeNvteTensor(paddle::Tensor &tensor); // NOLINT -TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor); - -NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout); - -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu deleted file mode 100644 index 583cd0f47a..0000000000 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ /dev/null @@ -1,1794 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "common.h" -#include "common/common.h" -#include "paddle/phi/backends/gpu/cuda/cuda_graph.h" - -namespace transformer_engine { -namespace paddle_ext { - -// convert bias type to enum -NVTE_Bias_Type get_nvte_bias_type(const std::string bias_type) { - if (bias_type == "no_bias") { - return NVTE_Bias_Type::NVTE_NO_BIAS; - } else if (bias_type == "pre_scale_bias") { - return NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS; - } else if (bias_type == "post_scale_bias") { - return NVTE_Bias_Type::NVTE_POST_SCALE_BIAS; - } else { - NVTE_ERROR("Invalid bias type. \n"); - } -} - -// convert attn mask type to enum -NVTE_Mask_Type get_nvte_mask_type(const std::string mask_type) { - if (mask_type == "padding") { - return NVTE_Mask_Type::NVTE_PADDING_MASK; - } else if (mask_type == "causal") { - return NVTE_Mask_Type::NVTE_CAUSAL_MASK; - } else if (mask_type == "no_mask") { - return NVTE_Mask_Type::NVTE_NO_MASK; - } else { - NVTE_ERROR("Invalid attention mask type. \n"); - } -} - -void cast_to_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &output, // NOLINT - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), shape, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_fp8_quantize(input_cu.data(), output_cu.data(), input.stream()); -} - -std::vector cast_from_fp8(const paddle::Tensor &input, - const paddle::Tensor &scale_inv, int64_t index, - int64_t itype, int64_t otype) { - auto shape = GetShapeArray(input); - - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype))); - auto input_cu = - MakeNvteTensor(const_cast(input.data()), shape, Int2NvteDType(itype), nullptr, - nullptr, const_cast(GetDataPtr(scale_inv, index))); - auto output_cu = MakeNvteTensor(output); - - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_transpose(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[1], input.shape()[0]}, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(const_cast(input.data()), {M, N}, Int2NvteDType(otype)); - auto output_cu = MakeNvteTensor(output.data(), {N, M}, Int2NvteDType(otype)); - - nvte_transpose(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -void te_cast_transpose(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &output_cast, // NOLINT - paddle::Tensor &output_transpose, // NOLINT - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto input_cu = MakeNvteTensor(input); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(output_cast.data(), {M, N}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto output_transpose_cu = MakeNvteTensor(output_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - input.stream()); -} - -std::vector te_cast_transpose_bgrad(const paddle::Tensor &grad_output, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - auto grad_output_cast = - paddle::empty_like(grad_output, Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); - auto grad_output_transpose = - paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(Int2NvteDType(otype)), grad_output.place()); - - auto input_cu = MakeNvteTensor(grad_output); - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - auto output_cast_cu = MakeNvteTensor(grad_output_cast.data(), {M, N}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto output_transpose_cu = - MakeNvteTensor(grad_output_transpose.data(), {N, M}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); - transformer_engine::TensorWrapper workspace; - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - nvte_cast_transpose_dbias(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - dbias_cu.data(), workspace.data(), grad_output.stream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; -} - -void te_gemm(const paddle::Tensor &A, const paddle::optional &A_scale_inverse, - const paddle::Tensor &B, const paddle::optional &B_scale_inverse, - const paddle::optional &bias, paddle::Tensor &D, // NOLINT - paddle::optional &D_scale, // NOLINT - paddle::optional &D_amax, // NOLINT - paddle::optional &pre_gelu_out, paddle::Tensor &workspace, // NOLINT - int64_t A_index, int64_t B_index, int64_t D_index, int64_t A_type, int64_t B_type, - int64_t D_type, int64_t bias_type, bool transa, bool transb, bool grad, - int64_t workspace_size, bool accumulate, bool use_split_accumulator, - int64_t math_sm_count) { - auto te_A = MakeNvteTensor( - const_cast(A.data()), GetShapeArray(A), Int2NvteDType(A_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(A_scale_inverse, A_index))); - auto te_B = MakeNvteTensor( - const_cast(B.data()), GetShapeArray(B), Int2NvteDType(B_type), nullptr, nullptr, - const_cast(GetOptionalDataPtr(B_scale_inverse, B_index))); - auto te_D = MakeNvteTensor(D.data(), GetShapeArray(D), Int2NvteDType(D_type), - GetOptionalDataPtr(D_amax, D_index), - GetOptionalDataPtr(D_scale, D_index), nullptr); - - auto te_bias = MakeNvteTensor(const_cast(GetOptionalDataPtr(bias)), GetShapeArray(bias), - Int2NvteDType(bias_type)); - - DType gelu_dtype = pre_gelu_out ? Paddle2NvteDType(pre_gelu_out->dtype()) : Int2NvteDType(D_type); - auto te_pre_gelu_out = - MakeNvteTensor(GetOptionalDataPtr(pre_gelu_out), GetShapeArray(pre_gelu_out), gelu_dtype); - auto te_workspace = - MakeNvteTensor(workspace.data(), {static_cast(workspace_size)}, DType::kByte); - - nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), - transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, A.stream()); -} - -std::vector te_gelu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(DType::kByte), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(input), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_gelu(const paddle::Tensor &input, int64_t otype) { - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(input), Int2NvteDType(otype)); - - nvte_gelu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_swiglu(const paddle::Tensor &input, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype)); - - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2}, - Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto output_cu = MakeNvteTensor( - output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - - nvte_swiglu(input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input, - int64_t otype) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - - auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype())); - auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype())); - auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype())); - - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream()); - - return {output}; -} - -std::vector te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output, - const paddle::Tensor &gelu_input, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - int64_t index, int64_t otype) { - auto shape = GetShapeArray(grad_output); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t M = shape[0]; - size_t N = shape[1]; - - // DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = - paddle::empty({grad_output.shape()[1]}, grad_output.dtype(), grad_output.place()); - - auto dgelu = paddle::empty_like(grad_output, Nvte2PaddleDType(DType::kByte), grad_output.place()); - - auto dgelu_transpose = paddle::empty({grad_output.shape()[1], grad_output.shape()[0]}, - Nvte2PaddleDType(DType::kByte), grad_output.place()); - - void *amax_data = GetDataPtr(amax, index); - void *scale_data = const_cast(GetDataPtr(scale, index)); - void *scale_inv_data = GetDataPtr(scale_inv, index); - - TensorWrapper workspace; - - auto gelu_input_cu = MakeNvteTensor(gelu_input); - auto input_cu = MakeNvteTensor(grad_output); - auto cast_output_cu = MakeNvteTensor(dgelu.data(), {M, N}, Int2NvteDType(otype), amax_data, - scale_data, scale_inv_data); - auto transposed_output_cu = MakeNvteTensor(dgelu_transpose.data(), {N, M}, Int2NvteDType(otype), - amax_data, scale_data, scale_inv_data); - auto dbias_cu = MakeNvteTensor(grad_bias); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); - - // Fill workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), grad_output.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - grad_output.stream()); - - return {dgelu, dgelu_transpose, grad_bias}; -} - -std::vector te_layernorm_fwd_fp8(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &bias, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - float eps, int64_t index, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; -} - -std::vector te_layernorm_fwd(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &bias, float eps, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto mu = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto beta_cu = MakeNvteTensor(bias); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace and barrier tensors with the required config - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - - // Actual call to fwd kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {ln_out, mu, rsigma}; -} - -std::vector te_layernorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, - const paddle::Tensor &mu, const paddle::Tensor &rsigma, - const paddle::Tensor &gamma, int64_t sm_margin, - bool zero_centered_gamma) { - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - auto dbeta = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace, barrier, dgamma_part, dbeta_part; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto mu_cu = MakeNvteTensor(mu); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - auto dbeta_cu = MakeNvteTensor(dbeta); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), - num_sm - sm_margin, workspace.data(), barrier.data()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); - auto dbeta_part_data = AllocateSpace(dbeta_part.shape(), dbeta_part.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); - dbeta_part = MakeNvteTensor(dbeta_part_data.data(), dbeta_part.shape(), dbeta_part.dtype()); - - // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), dz.stream(), - num_sm - sm_margin, workspace.data(), barrier.data()); - - return {dx, dgamma, dbeta}; -} - -std::vector te_rmsnorm_fwd(const paddle::Tensor &input, - const paddle::Tensor &weight, float eps, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, input.dtype(), input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor(ln_out.data(), {N, H}, Int2NvteDType(otype)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace and barrier tensors with the required config - - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {ln_out, rsigma}; -} - -std::vector te_rmsnorm_fwd_fp8(const paddle::Tensor &input, - const paddle::Tensor &weight, - const paddle::Tensor &scale, - paddle::Tensor &amax, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - float eps, int64_t index, int64_t otype, - int64_t sm_margin, bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto shape = GetShapeArray(input); - NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions."); - - size_t N = shape[0]; - size_t H = shape[1]; - - auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place()); - auto rsigma = paddle::empty({static_cast(N)}, paddle::DataType::FLOAT32, input.place()); - auto input_cu = MakeNvteTensor(input); - auto gamma_cu = MakeNvteTensor(weight); - auto z_cu = MakeNvteTensor( - ln_out.data(), {N, H}, Int2NvteDType(otype), GetDataPtr(amax, index), - const_cast(GetDataPtr(scale, index)), GetDataPtr(scale_inv, index)); - auto rsigma_cu = MakeNvteTensor(rsigma); - TensorWrapper workspace, barrier; - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates workspace and barrier tensors with the required config - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - // Fill workspace and barrier - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), input.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), input.place(), true); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - - // Actual call to fwd kernel - nvte_rmsnorm_fwd(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - input.stream(), num_sm - sm_margin, workspace.data(), barrier.data()); - - return {ln_out, rsigma}; -} - -std::vector te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x, - const paddle::Tensor &rsigma, - const paddle::Tensor &gamma, int64_t sm_margin, - bool zero_centered_gamma) { - NVTE_CHECK(zero_centered_gamma == false, "zero_centered_gamma is not supported yet for RMSNorm."); - auto dx = paddle::empty_like(x, x.dtype(), x.place()); - auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place()); - - TensorWrapper workspace, barrier, dgamma_part; - - auto dz_cu = MakeNvteTensor(dz); - auto x_cu = MakeNvteTensor(x); - auto rsigma_cu = MakeNvteTensor(rsigma); - auto gamma_cu = MakeNvteTensor(gamma); - auto dx_cu = MakeNvteTensor(dx); - auto dgamma_cu = MakeNvteTensor(dgamma); - - auto num_sm = cudaDevicePropertiesManager::Instance().GetMultiProcessorCount(); - - // This call populates tensors with the required config. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); - - // Alloc space for Tensors. - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), x.place()); - auto barrier_data = AllocateSpace(barrier.shape(), barrier.dtype(), x.place(), true); - auto dgamma_part_data = AllocateSpace(dgamma_part.shape(), dgamma_part.dtype(), x.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - barrier = MakeNvteTensor(barrier_data.data(), barrier.shape(), barrier.dtype()); - dgamma_part = MakeNvteTensor(dgamma_part_data.data(), dgamma_part.shape(), dgamma_part.dtype()); - - // Actual call to bwd kernel. - nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), dz.stream(), num_sm - sm_margin, - workspace.data(), barrier.data()); - - return {dx, dgamma}; -} - -__global__ void set_rng_state( - [[maybe_unused]] unsigned int - identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516 - std::pair seed_offset, int64_t *rng_state_ptr) { - rng_state_ptr[0] = static_cast(seed_offset.first); - rng_state_ptr[1] = static_cast(seed_offset.second); -} - -void UpdateRandomGenerator(phi::Place place, cudaStream_t stream, int rng_elts_per_thread, - paddle::Tensor &rng_state) { - // extract random number generator seed and offset - const phi::DeviceContext *dev_ctx = - paddle::experimental::DeviceContextPool::Instance().Get(place); - - phi::Generator *gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - int64_t *rng_state_p = static_cast(rng_state.data()); -#if PADDLE_VERSION > 261 - auto state_index = gen_cuda->GetStateIndex(); - - auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { - // ensure the generator use correct state index - gen_cuda->SetStateIndex(state_index); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - params.As>(1) = seed_offset; - }; - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&set_rng_state); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); -#endif -} - -void te_fused_attn_fwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, - const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs, - int64_t max_seqlen, bool is_training, float attn_scale, - float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - UpdateRandomGenerator(QKV.place(), QKV.stream(), rng_elts_per_thread, rng_state); - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked(te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, workspace.data(), QKV.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -// fused attention BWD with packed QKV -void te_fused_attn_bwd_qkvpacked(const paddle::Tensor &QKV, const paddle::Tensor &cu_seqlens, - const paddle::Tensor &O, const paddle::Tensor &dO, - const paddle::Tensor &softmax_aux, - paddle::Tensor &dQKV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs, - int64_t max_seqlen, float attn_scale, float p_dropout, - const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, int64_t qkv_type, - bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_QKV = MakeNvteTensor(QKV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQKV = MakeNvteTensor(dQKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; // 1. softmax_aux 2. rng_state - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen), static_cast(max_seqlen)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens; - te_cu_seqlens = MakeNvteTensor(cu_seqlens.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - deterministic, workspace.data(), QKV.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), QKV.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), dummy_seq_offsets.data(), max_seqlen, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - deterministic, workspace.data(), QKV.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_fwd_kvpacked( - const paddle::Tensor &Q, const paddle::Tensor &KV, const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_kv, const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, int64_t total_seqs_kv, - int64_t max_seqlen_q, int64_t max_seqlen_kv, bool is_training, float attn_scale, - float p_dropout, const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, const int64_t qkv_type, int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor( - Q.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - te_KV = MakeNvteTensor( - KV.data(), - {static_cast(total_seqs_kv), 2, static_cast(h), static_cast(d)}, - qkv_dtype); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor( - O.data(), - {static_cast(total_seqs_q), static_cast(h), static_cast(d)}, - qkv_dtype); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - UpdateRandomGenerator(Q.place(), Q.stream(), rng_elts_per_thread, rng_state); - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), - dummy_seq_offsets.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -// fused attention BWD with packed KV -void te_fused_attn_bwd_kvpacked(const paddle::Tensor &Q, const paddle::Tensor &KV, - const paddle::Tensor &cu_seqlens_q, - const paddle::Tensor &cu_seqlens_kv, const paddle::Tensor &O, - const paddle::Tensor &dO, const paddle::Tensor &softmax_aux, - paddle::Tensor &dQ, // NOLINT - paddle::Tensor &dKV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t total_seqs_q, - int64_t total_seqs_kv, int64_t max_seqlen_q, int64_t max_seqlen_kv, - float attn_scale, float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type, bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_KV = MakeNvteTensor(KV); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dKV = MakeNvteTensor(dKV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, deterministic, workspace.data(), Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked( - te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), te_dBias.data(), te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, - -1, -1, deterministic, workspace.data(), Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_fwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, - const paddle::optional &Bias, - paddle::Tensor &O, // NOLINT - paddle::optional &softmax_aux, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, - bool is_training, float attn_scale, float p_dropout, - const std::string &qkv_layout, const std::string &bias_type, - const std::string &attn_mask_type, const int64_t qkv_type, - int64_t rng_elts_per_thread) { - if (is_training && !softmax_aux) { - NVTE_ERROR("softmax_aux must be provided when training. \n"); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_S = MakeNvteTensor(nullptr, std::vector{0}, DType::kFloat32); - te_O = MakeNvteTensor(O); - } else { // TODO: support fp8 - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - if ((bias_type != "no_bias") && Bias) { - auto bias_shape = Bias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_Bias = MakeNvteTensor(GetOptionalDataPtr(Bias), shape, DType::kFloat32); - } - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // extract random number generator seed and offset - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(Q.place()); - auto gen_cuda = dev_ctx->GetGenerator(); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - auto stream = Q.stream(); - auto rng_state_p = static_cast(rng_state.data()); -#if PADDLE_VERSION > 261 - auto state_index = gen_cuda->GetStateIndex(); - auto parameterSetter = [gen_cuda, state_index, - rng_elts_per_thread](phi::backends::gpu::gpuKernelParams ¶ms) { - // ensure the generator use correct state index - gen_cuda->SetStateIndex(state_index); - auto seed_offset = gen_cuda->IncrementOffset(rng_elts_per_thread); - params.As>(1) = seed_offset; - }; - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&set_rng_state); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - set_rng_state<<<1, 1, 0, stream>>>(id, seed_offset, rng_state_p); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - set_rng_state<<<1, 1, 0, stream>>>(0, seed_offset, rng_state_p); -#endif - - auto te_rng_state = MakeNvteTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, - p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - output_s->data.dptr = GetOptionalDataPtr(softmax_aux); - - // execute the kernel - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), dummy_seq_offsets.data(), dummy_seq_offsets.data(), - te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, - p_dropout, qkv_layout_enum, bias_type_enum, attn_mask_type_enum, -1, -1, - workspace.data(), Q.stream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -void te_fused_attn_bwd(const paddle::Tensor &Q, const paddle::Tensor &K, const paddle::Tensor &V, - const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &cu_seqlens_kv, - const paddle::Tensor &O, const paddle::Tensor &dO, - const paddle::Tensor &softmax_aux, - paddle::Tensor &dQ, // NOLINT - paddle::Tensor &dK, // NOLINT - paddle::Tensor &dV, // NOLINT - paddle::optional &dBias, // NOLINT - paddle::Tensor &rng_state, // NOLINT - int64_t b, int64_t h, int64_t d, int64_t max_seqlen_q, int64_t max_seqlen_kv, - float attn_scale, float p_dropout, const std::string &qkv_layout, - const std::string &bias_type, const std::string &attn_mask_type, - int64_t qkv_type, bool deterministic) { - TensorWrapper te_dBias; - if (bias_type != "no_bias" && dBias) { - auto bias_shape = dBias->shape(); - std::vector shape{bias_shape.begin(), bias_shape.end()}; - te_dBias = MakeNvteTensor(GetOptionalDataPtr(dBias), shape, DType::kFloat32); - } - - auto qkv_dtype = Int2NvteDType(qkv_type); - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; - if (qkv_dtype == DType::kBFloat16 || qkv_dtype == DType::kFloat16) { - // BF16 or FP16 - te_Q = MakeNvteTensor(Q); - te_K = MakeNvteTensor(K); - te_V = MakeNvteTensor(V); - te_O = MakeNvteTensor(O); - te_dO = MakeNvteTensor(dO); - te_S = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dP = MakeNvteTensor(nullptr, std::vector(0), DType::kFloat32); - te_dQ = MakeNvteTensor(dQ); - te_dK = MakeNvteTensor(dK); - te_dV = MakeNvteTensor(dV); - } else { - NVTE_ERROR("Fused attention only supports BF16/FP16 data types. \n"); - } - - // convert strings to enums - NVTE_QKV_Layout qkv_layout_enum = get_nvte_qkv_layout(qkv_layout); - NVTE_Bias_Type bias_type_enum = get_nvte_bias_type(bias_type); - NVTE_Mask_Type attn_mask_type_enum = get_nvte_mask_type(attn_mask_type); - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - nvte_aux_tensor_pack.size = 2; - auto *output_s = reinterpret_cast(nvte_aux_tensor_pack.tensors[0]); - auto *fwd_rng_state = reinterpret_cast(nvte_aux_tensor_pack.tensors[1]); - output_s->data.shape = - std::vector({static_cast(b), static_cast(h), - static_cast(max_seqlen_q), static_cast(max_seqlen_kv)}); - output_s->data.dptr = const_cast(softmax_aux.data()); - fwd_rng_state->data.shape = std::vector({2}); - fwd_rng_state->data.dptr = const_cast(rng_state.data()); - - // create cu_seqlens tensorwrappers - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = - MakeNvteTensor(cu_seqlens_q.data(), {static_cast(b + 1)}, DType::kInt32); - te_cu_seqlens_kv = - MakeNvteTensor(cu_seqlens_kv.data(), {static_cast(b + 1)}, DType::kInt32); - - // create workspace - TensorWrapper workspace; - - auto dummy_seq_offsets = TensorWrapper(nullptr, {static_cast(b + 1)}, DType::kInt32); - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); - - // allocate memory for workspace - auto workspace_data = AllocateSpace(workspace.shape(), workspace.dtype(), Q.place()); - workspace = MakeNvteTensor(workspace_data.data(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - dummy_seq_offsets.data(), dummy_seq_offsets.data(), max_seqlen_q, - max_seqlen_kv, attn_scale, p_dropout, qkv_layout_enum, bias_type_enum, - attn_mask_type_enum, -1, -1, deterministic, workspace.data(), Q.stream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); -} - -std::vector te_scaled_softmax_forward(const paddle::Tensor &input, - float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int batches = input.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; - - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor, - input.stream()); - - return {softmax_results}; -} - -void te_scaled_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); -} - -std::vector te_scaled_masked_softmax_forward(const paddle::Tensor &input, - const paddle::Tensor &mask, - float scale_factor) { - NVTE_CHECK(input.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(mask.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int batches = input.shape()[0]; - const int pad_batches = mask.shape()[0]; - const int attn_heads = input.shape()[1]; - const int query_seq_len = input.shape()[2]; - const int key_seq_len = input.shape()[3]; - - NVTE_CHECK(key_seq_len <= 4096); - NVTE_CHECK(query_seq_len > 1); - NVTE_CHECK(pad_batches == 1 || pad_batches == batches); - NVTE_CHECK(mask.shape()[1] == 1); - NVTE_CHECK(mask.shape()[2] == query_seq_len); - NVTE_CHECK(mask.shape()[3] == key_seq_len); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto mask_cu = MakeNvteTensor(mask); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_masked_softmax_forward(input_cu.data(), mask_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); - - return {softmax_results}; -} - -void te_scaled_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 4, "expected 4D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 4, "expected 4D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_softmax_backward(output_grads_cu.data(), softmax_results_cu.data(), - output_grads_cu.data(), scale_factor, softmax_results.stream()); -} - -std::vector te_scaled_upper_triang_masked_softmax_forward( - const paddle::Tensor &input, float scale_factor) { - NVTE_CHECK(input.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK( - (input.dtype() == paddle::DataType::FLOAT16) || (input.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - - const int attn_batches = input.shape()[0]; - const int seq_len = input.shape()[1]; - NVTE_CHECK(seq_len <= 2048); - - // Output - auto softmax_results = paddle::empty_like(input, input.dtype(), input.place()); - - auto input_cu = MakeNvteTensor(input); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(), softmax_results_cu.data(), - scale_factor, input.stream()); - - return {softmax_results}; -} - -void te_scaled_upper_triang_masked_softmax_backward(paddle::Tensor &output_grads, // NOLINT - const paddle::Tensor &softmax_results, - float scale_factor) { - NVTE_CHECK(output_grads.shape().size() == 3, "expected 3D tensor"); - NVTE_CHECK(softmax_results.shape().size() == 3, "expected 3D tensor"); - - NVTE_CHECK((output_grads.dtype() == paddle::DataType::FLOAT16) || - (output_grads.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK((softmax_results.dtype() == paddle::DataType::FLOAT16) || - (softmax_results.dtype() == paddle::DataType::BFLOAT16), - "Only fp16 and bf16 are supported"); - NVTE_CHECK(output_grads.shape()[1] == output_grads.shape()[2]); - - auto output_grads_cu = MakeNvteTensor(output_grads); - auto softmax_results_cu = MakeNvteTensor(softmax_results); - - // Produce gradients in place. - nvte_scaled_upper_triang_masked_softmax_backward( - output_grads_cu.data(), softmax_results_cu.data(), output_grads_cu.data(), scale_factor, - softmax_results.stream()); -} - -__global__ void UpdateFP8MetaKernel( - [[maybe_unused]] unsigned int - identifier, // This is used to relate kernel to cudaGraph nodes please refer to https://github.com/PaddlePaddle/Paddle/pull/60516 - const float *amax, const float *rolled_amax_history, const bool *non_weight_mask, - float *amax_history, float *scale, float *scale_inv, bool update_weight_scale_inv, float margin, - float fp8_max, size_t history_numel, size_t amax_numel) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - if (idx >= history_numel) { - return; - } - - amax_history[idx] = rolled_amax_history[idx]; - - if (idx < amax_numel) { - float sf = (fp8_max / amax[idx]) / powf(2.0f, margin); - float scale_reg = ((amax[idx] > 0.0f) && isfinite(amax[idx])) ? sf : scale[idx]; - scale[idx] = scale_reg; - if (update_weight_scale_inv || non_weight_mask[idx]) scale_inv[idx] = 1.0f / scale_reg; - amax_history[idx] = 0.0f; - } -} - -constexpr int BLOCK_SIZE = 512; - -void amax_and_scale_update_inplace(paddle::Tensor &amax_history, // NOLINT - paddle::Tensor &scale, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, int64_t fp8_dtype, - float margin, const std::string &amax_compute) { - auto amax_history_ = MakeNvteTensor(amax_history); - auto scale_ = MakeNvteTensor(scale); - auto scale_inv_ = MakeNvteTensor(scale_inv); - const auto non_weight_mask_ = MakeNvteTensor(non_weight_mask); - nvte_delayed_scaling_recipe_amax_and_scale_update( - amax_history_.data(), scale_.data(), scale_inv_.data(), non_weight_mask_.data(), - amax_history_.data(), scale_.data(), scale_inv_.data(), amax_compute.c_str(), - static_cast(fp8_dtype), margin, amax_history.stream()); -} - -void amax_and_scale_update_inplace_legacy( - paddle::Tensor &amax_history, // NOLINT - paddle::Tensor &scale, // NOLINT - paddle::Tensor &scale_inv, // NOLINT - const paddle::Tensor &non_weight_mask, - const paddle::optional ¤t_step_id_tensor, bool update_weight_scale_inv, - bool fwd_update, float fp8_max, float margin, const std::string &amax_compute) { -#if PADDLE_VERSION > 261 - NVTE_CHECK(amax_compute == "max" || amax_compute == "most_recent"); - - paddle::Tensor amax; - - if (amax_compute == "max") { - amax = amax_history.max({0}); - } else { - amax = amax_history.slice(0, 1); - } - - const auto rolled_amax_history = amax_history.roll({-1}, {0}); - - auto amax_history_numel = amax_history.numel(); - auto amax_numel = amax.numel(); - size_t num_blocks = (amax_history_numel + BLOCK_SIZE - 1) / BLOCK_SIZE; - - const int *current_step_id_ptr = - reinterpret_cast(GetOptionalDataPtr(current_step_id_tensor)); - auto parameterSetter = [current_step_id_ptr, - fwd_update](phi::backends::gpu::gpuKernelParams ¶ms) { - if (fwd_update) { - int current_step_id = *current_step_id_ptr; - params.As(7) = (current_step_id == 0); - } - }; - - const float *amax_ptr = amax.data(); - const float *rolled_amax_history_ptr = rolled_amax_history.data(); - const bool *non_weight_mask_ptr = non_weight_mask.data(); - float *amax_history_ptr = amax_history.data(); - float *scale_ptr = scale.data(); - float *scale_inv_ptr = scale_inv.data(); - - phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = - [=](unsigned int id) { - void *functionPtr = reinterpret_cast(&UpdateFP8MetaKernel); - cudaFunction_t cudaFunc; - PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, functionPtr)); - UpdateFP8MetaKernel<<>>( - id, amax_ptr, rolled_amax_history_ptr, non_weight_mask_ptr, amax_history_ptr, scale_ptr, - scale_inv_ptr, update_weight_scale_inv, margin, fp8_max, amax_history_numel, - amax_numel); - NVTE_CHECK_CUDA(cudaGetLastError()); - return cudaFunc; - }; - phi::backends::gpu::CUDAGraphNodeLauncher::Instance().KernelNodeLaunch(parameterSetter, - cudaKernelCallback); -#else - NVTE_ERROR( - "amax_and_scale_update_inplace_legacy is not supported in old version of PaddlePaddle\n"); -#endif -} - -void update_latest_amax_history_inplace(paddle::Tensor &history, // NOLINT - const paddle::Tensor &amax) { - // Copy amax to history[0] - NVTE_CHECK_CUDA(cudaMemcpyAsync(history.data(), amax.data(), amax.numel() * SizeOf(amax.dtype()), - cudaMemcpyDeviceToDevice, amax.stream())); -} - -__global__ __launch_bounds__(BLOCK_SIZE) void mask_to_actual_seqlens_kernel( - const bool *mask, int32_t *q_actual_seqlen, int32_t *kv_actual_seqlen, int q_seqlen, - int kv_seqlen, bool need_kv) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage q_smem; - __shared__ typename BlockReduce::TempStorage kv_smem; - unsigned int tid = threadIdx.x; - unsigned int batch_offset = blockIdx.x * q_seqlen * kv_seqlen; - - // load mask, convert to 1/0, do accumulation - int q = 0, kv = 0; - for (unsigned int q_idx = tid * kv_seqlen; q_idx < q_seqlen * kv_seqlen; - q_idx += BLOCK_SIZE * kv_seqlen) { - q += (mask[q_idx + batch_offset] ? 0 : 1); - } - - if (need_kv) { - for (unsigned int kv_idx = tid; kv_idx < kv_seqlen; kv_idx += BLOCK_SIZE) { - kv += (mask[kv_idx + batch_offset] ? 0 : 1); - } - } - __syncthreads(); - - // compute cub::BlockReduce - int q_sum, kv_sum; - q_sum = BlockReduce(q_smem).Sum(q); - if (need_kv) kv_sum = BlockReduce(kv_smem).Sum(kv); - - // write result for this block to global mem - if (tid == 0) { - q_actual_seqlen[blockIdx.x + 1] = q_sum; - if (need_kv) { - kv_actual_seqlen[blockIdx.x + 1] = kv_sum; - } - } -} - -__global__ __launch_bounds__(BLOCK_SIZE) void block_prefix_sum_inplace(int32_t *x, int n) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage smem; - // +1 to ignore the first element - int i = blockIdx.x * blockDim.x + threadIdx.x + 1; - - // load data - int32_t thread_data[1]; - thread_data[0] = i < n ? x[i] : 0; - __syncthreads(); - - // CUB block prefix sum - BlockScan(smem).InclusiveSum(thread_data, thread_data); - __syncthreads(); - - // write result - if (i < n) { - x[i] = thread_data[0]; - } -} - -void mask_to_cu_seqlens(const paddle::Tensor &mask, - paddle::Tensor &q_cu_seqlen, // NOLINT - paddle::optional &kv_cu_seqlen, // NOLINT - int q_seqlen, int kv_seqlen, bool need_kv) { - if (need_kv) { - NVTE_CHECK(GetOptionalDataPtr(kv_cu_seqlen) != nullptr, - "kv_cu_seqlen must be provided when need_kv is true"); - } - mask_to_actual_seqlens_kernel<<>>( - mask.data(), q_cu_seqlen.data(), - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), q_seqlen, kv_seqlen, need_kv); - // q_cu_seqlen shape: [bs+1], assume bs is not too large (<=512), so we can use a single block - // to do prefix sum - NVTE_CHECK(q_cu_seqlen.numel() - 1 <= BLOCK_SIZE, "batch size too large, kernel may fail"); - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>(q_cu_seqlen.data(), - q_cu_seqlen.numel()); - if (need_kv) { - block_prefix_sum_inplace<<<1, BLOCK_SIZE, 0, mask.stream()>>>( - reinterpret_cast(GetOptionalDataPtr(kv_cu_seqlen)), kv_cu_seqlen->numel()); - } -} - -} // namespace paddle_ext -} // namespace transformer_engine - -PD_BUILD_OP(te_gemm) - .Inputs({"A", paddle::Optional("A_scale_inverse"), "B", paddle::Optional("B_scale_inverse"), - paddle::Optional("bias"), "_D", paddle::Optional("_D_scale"), - paddle::Optional("_D_amax"), paddle::Optional("_pre_gelu_out"), "_workspace"}) - .Outputs({"D", paddle::Optional("D_scale"), paddle::Optional("D_amax"), - paddle::Optional("pre_gelu_out"), "workspace"}) - .Attrs({"A_index: int64_t", "B_index: int64_t", "D_index: int64_t", "A_type: int64_t", - "B_type: int64_t", "D_type: int64_t", "bias_type: int64_t", "transa: bool", - "transb: bool", "grad: bool", "workspace_size: int64_t", "accumulate: bool", - "use_split_accumulator: bool", "math_sm_count: int64_t"}) - .SetInplaceMap({{"_D", "D"}, - {paddle::Optional("_D_scale"), paddle::Optional("D_scale")}, - {paddle::Optional("_D_amax"), paddle::Optional("D_amax")}, - {paddle::Optional("_pre_gelu_out"), paddle::Optional("pre_gelu_out")}, - {"_workspace", "workspace"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gemm)); - -PD_BUILD_OP(cast_to_fp8) - .Inputs({"Input", "Scale", "_Output", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetInplaceMap({{"_Output", "Output"}, {"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_to_fp8)); - -PD_BUILD_OP(cast_from_fp8) - .Inputs({"Input", "ScaleInv"}) - .Outputs({"Output"}) - .Attrs({"index: int64_t", "itype: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::cast_from_fp8)); - -PD_BUILD_OP(te_transpose) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_transpose)); - -PD_BUILD_OP(te_cast_transpose) - .Inputs({"Input", "Scale", "_CastedOutput", "_TransposedOutput", "_Amax", "_ScaleInv"}) - .Outputs({"CastedOutput", "TransposedOutput", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_CastedOutput", "CastedOutput"}, - {"_TransposedOutput", "TransposedOutput"}, - {"_Amax", "Amax"}, - {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose)); - -PD_BUILD_OP(te_cast_transpose_bgrad) - .Inputs({"GradOutput", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"dBias", "CastedOutput", "TransposedOutput", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad)); - -PD_BUILD_OP(te_gelu_fp8) - .Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu_fp8)); - -PD_BUILD_OP(te_gelu) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu)); - -PD_BUILD_OP(te_swiglu) - .Inputs({"Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu)); - -PD_BUILD_OP(te_swiglu_fp8) - .Inputs({"Input", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu_fp8)); - -PD_BUILD_OP(te_dswiglu) - .Inputs({"Grad", "Input"}) - .Outputs({"Output"}) - .Attrs({"otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_dswiglu)); - -PD_BUILD_OP(te_cast_transpose_bgrad_dgelu) - .Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"index: int64_t", "otype: int64_t"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_cast_transpose_bgrad_dgelu)); - -PD_BUILD_OP(te_layernorm_fwd_fp8) - .Inputs({"Input", "Weight", "Bias", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "Mu", "Rsigma", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t", - "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd_fp8)); - -PD_BUILD_OP(te_layernorm_fwd) - .Inputs({"Input", "Weight", "Bias"}) - .Outputs({"Output", "Mu", "Rsigma"}) - .Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_fwd)); - -PD_BUILD_OP(te_layernorm_bwd) - .Inputs({"Dz", "X", "Mu", "Rsigma", "Gamma"}) - .Outputs({"Dx", "Dgamma", "Dbeta"}) - .Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_layernorm_bwd)); - -PD_BUILD_OP(te_rmsnorm_fwd) - .Inputs({"Input", "Weight"}) - .Outputs({"Output", "InvVariance"}) - .Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd)); - -PD_BUILD_OP(te_rmsnorm_fwd_fp8) - .Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"}) - .Outputs({"Output", "InvVariance", "Amax", "ScaleInv"}) - .SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}}) - .Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t", - "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8)); - -PD_BUILD_OP(te_rmsnorm_bwd) - .Inputs({"Dz", "X", "Rsigma", "Gamma"}) - .Outputs({"Dx", "Dgamma"}) - .Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd)); - -PD_BUILD_OP(te_fused_attn_fwd_qkvpacked) - .Inputs({"QKV", "cu_seqlens", paddle::Optional("Bias"), "_O", paddle::Optional("_softmax_aux"), - "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", - "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_qkvpacked)); - -PD_BUILD_OP(te_fused_attn_bwd_qkvpacked) - .Inputs({"QKV", "cu_seqlens", "O", "dO", "softmax_aux", "_dQKV", paddle::Optional("_dBias"), - "rng_state"}) - .Outputs({"dQKV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs: int64_t", "max_seqlen: int64_t", - "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "deterministic: bool"}) - .SetInplaceMap({{"_dQKV", "dQKV"}, {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_qkvpacked)); - -PD_BUILD_OP(te_fused_attn_fwd_kvpacked) - .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", - paddle::Optional("_softmax_aux"), "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", - "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", - "is_training: bool", "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd_kvpacked)); - -PD_BUILD_OP(te_fused_attn_bwd_kvpacked) - .Inputs({"Q", "KV", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dKV", - paddle::Optional("_dBias"), "rng_state"}) - .Outputs({"dQ", "dKV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "total_seqs_q: int64_t", - "total_seqs_kv: int64_t", "max_seqlen_q: int64_t", "max_seqlen_kv: int64_t", - "attn_scale: float", "p_dropout: float", "qkv_layout: std::string", - "bias_type: std::string", "attn_mask_type: std::string", "qkv_type: int64_t", - "deterministic: bool"}) - .SetInplaceMap({{"_dQ", "dQ"}, - {"_dKV", "dKV"}, - {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd_kvpacked)); - -PD_BUILD_OP(te_fused_attn_fwd) - .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", paddle::Optional("Bias"), "_O", - paddle::Optional("_softmax_aux"), "_rng_state"}) - .Outputs({"O", paddle::Optional("softmax_aux"), "rng_state"}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", - "max_seqlen_kv: int64_t", "is_training: bool", "attn_scale: float", "p_dropout: float", - "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t", "rng_elts_per_thread: int64_t"}) - .SetInplaceMap({{"_O", "O"}, - {paddle::Optional("_softmax_aux"), paddle::Optional("softmax_aux")}, - {"_rng_state", "rng_state"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_fwd)); - -PD_BUILD_OP(te_fused_attn_bwd) - .Inputs({"Q", "K", "V", "cu_seqlens_q", "cu_seqlens_kv", "O", "dO", "softmax_aux", "_dQ", "_dK", - "_dV", paddle::Optional("_dBias"), "rng_state"}) - .Outputs({"dQ", "dK", "dV", paddle::Optional("dBias")}) - .Attrs({"b: int64_t", "h: int64_t", "d: int64_t", "max_seqlen_q: int64_t", - "max_seqlen_kv: int64_t", "attn_scale: float", "p_dropout: float", - "qkv_layout: std::string", "bias_type: std::string", "attn_mask_type: std::string", - "qkv_type: int64_t", "deterministic: bool"}) - .SetInplaceMap({{"_dQ", "dQ"}, - {"_dK", "dK"}, - {"_dV", "dV"}, - {paddle::Optional("_dBias"), paddle::Optional("dBias")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_fused_attn_bwd)); - -PD_BUILD_OP(te_scaled_softmax_forward) - .Inputs({"input"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_forward)); - -PD_BUILD_OP(te_scaled_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_softmax_backward)); - -PD_BUILD_OP(te_scaled_masked_softmax_forward) - .Inputs({"input", "mask"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_forward)); - -PD_BUILD_OP(te_scaled_masked_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_scaled_masked_softmax_backward)); - -PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_forward) - .Inputs({"input"}) - .Outputs({"softmax_results"}) - .Attrs({"scale_factor: float"}) - .SetKernelFn( - PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_forward)); - -PD_BUILD_OP(te_scaled_upper_triang_masked_softmax_backward) - .Inputs({"out_grad_", "softmax_results"}) - .Outputs({"out_grad"}) - .Attrs({"scale_factor: float"}) - .SetInplaceMap({{"out_grad_", "out_grad"}}) - .SetKernelFn( - PD_KERNEL(transformer_engine::paddle_ext::te_scaled_upper_triang_masked_softmax_backward)); - -PD_BUILD_OP(amax_and_scale_update_inplace_legacy) - .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask", - paddle::Optional("current_step_id_tensor")}) - .Outputs({"amax_history", "scale", "scale_inv"}) - .SetInplaceMap({{"_amax_history", "amax_history"}, - {"_scale", "scale"}, - {"_scale_inv", "scale_inv"}}) - .Attrs({"update_weight_scale_inv: bool", "fwd_update: bool", "fp8_max: float", "margin: float", - "amax_compute: std::string"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace_legacy)); - -PD_BUILD_OP(amax_and_scale_update_inplace) - .Inputs({"_amax_history", "_scale", "_scale_inv", "non_weight_mask"}) - .Outputs({"amax_history", "scale", "scale_inv"}) - .SetInplaceMap({{"_amax_history", "amax_history"}, - {"_scale", "scale"}, - {"_scale_inv", "scale_inv"}}) - .Attrs({"fp8_dtype: int64_t", "margin: float", "amax_compute: std::string"}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::amax_and_scale_update_inplace)); - -PD_BUILD_OP(update_latest_amax_history_inplace) - .Inputs({"_history", "amax"}) - .Outputs({"history"}) - .SetInplaceMap({{"_history", "history"}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::update_latest_amax_history_inplace)); - -PD_BUILD_OP(mask_to_cu_seqlens) - .Inputs({"mask", "_q_cu_seqlen", paddle::Optional("_kv_cu_seqlen")}) - .Outputs({"q_cu_seqlen", paddle::Optional("kv_cu_seqlen")}) - .Attrs({"q_seqlen: int", "kv_seqlen: int", "need_kv: bool"}) - .SetInplaceMap({{"_q_cu_seqlen", "q_cu_seqlen"}, - {paddle::Optional("_kv_cu_seqlen"), paddle::Optional("kv_cu_seqlen")}}) - .SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::mask_to_cu_seqlens)); diff --git a/transformer_engine/paddle/csrc/extensions.cpp b/transformer_engine/paddle/csrc/extensions.cpp deleted file mode 100644 index 128b7e2856..0000000000 --- a/transformer_engine/paddle/csrc/extensions.cpp +++ /dev/null @@ -1,63 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common.h" - -namespace transformer_engine { -namespace paddle_ext { - -size_t get_cublasLt_version() { return cublasLtGetVersion(); } - -PYBIND11_MODULE(transformer_engine_paddle, m) { - // Misc - m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); - m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); - m.def("get_nvte_qkv_layout", &get_nvte_qkv_layout, "Get qkv layout enum by the string"); - // Data structures - py::enum_(m, "DType", py::module_local()) - .value("kByte", DType::kByte) - .value("kInt32", DType::kInt32) - .value("kFloat32", DType::kFloat32) - .value("kFloat16", DType::kFloat16) - .value("kBFloat16", DType::kBFloat16) - .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); - - py::enum_(m, "NVTE_Bias_Type") - .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) - .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) - .value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); - - py::enum_(m, "NVTE_Mask_Type") - .value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) - .value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) - .value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK); - - py::enum_(m, "NVTE_QKV_Layout") - .value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) - .value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) - .value("NVTE_SBHD_SB2HD", NVTE_QKV_Layout::NVTE_SBHD_SB2HD) - .value("NVTE_SBHD_SBH2D", NVTE_QKV_Layout::NVTE_SBHD_SBH2D) - .value("NVTE_SBHD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD) - .value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD) - .value("NVTE_BSH3D", NVTE_QKV_Layout::NVTE_BSH3D) - .value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD) - .value("NVTE_BSHD_BSH2D", NVTE_QKV_Layout::NVTE_BSHD_BSH2D) - .value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) - .value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD) - .value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) - .value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) - .value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) - .value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); - - py::enum_(m, "NVTE_Fused_Attn_Backend", py::module_local()) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) - .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) - .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) - .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); -} -} // namespace paddle_ext -} // namespace transformer_engine diff --git a/transformer_engine/paddle/distributed.py b/transformer_engine/paddle/distributed.py deleted file mode 100644 index 75630ed28e..0000000000 --- a/transformer_engine/paddle/distributed.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Methods needed for distributed training.""" - -import os -import warnings -from contextlib import contextmanager -from typing import Any, Optional, Union, Tuple - -import paddle - -import paddle.distributed.fleet.base.topology as tp -from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker -from paddle.distributed.fleet.layers.mpu import mp_ops - -try: - # This feature is not supported as of Paddle 2.6. - from paddle.distributed.fleet.meta_parallel import ( - PipelineParallelMicroStepLocations, - register_global_pipeline_parallel_hook, - ) -except ImportError: - print("Cannot find register_global_pipeline_parallel_hook !") - register_global_pipeline_parallel_hook = None - -from .constants import dist_group_type - -_weight_split_axis = { - "transformer_engine": {"row": 1, "column": 0}, - "paddle": {"row": 0, "column": 1}, -} - - -def get_tp_group_and_world_size( - tp_group: Union[dist_group_type, None], enable_tp: bool = True -) -> Tuple[Union[dist_group_type, None], int]: - """Get TP group and world size using Fleet API""" - if not (paddle.distributed.is_initialized() and enable_tp): - return None, 1 - model_parallel_group = ( - tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group() if tp_group is None else tp_group - ) - world_size = ( - tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size() - if tp_group is None - else tp_group.nranks - ) - """ - When using TP, the NCCL communication needs to be scheduled - before the GEMM for a guaranteed overlap. From the host side - in TE, the comm calls are always launched first, but to ensure - that the GEMM isn't scheduled first, the environment variable - `CUDA_DEVICE_MAX_CONNECTIONS` needs to be set to 1 to force a - single channel. - """ - num_cuda_work_queues = int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) - if num_cuda_work_queues != 1: - warnings.warn( - "To guarantee overlapping TP and SP collectives with the backward" - "GEMMs, set environment variable CUDA_DEVICE_MAX_CONNECTIONS = 1" - ) - - return model_parallel_group, world_size - - -def is_pp_enabled() -> bool: - """Check if pipeline parallel is enabled""" - if not paddle.distributed.is_initialized(): - return False - - return tp._HYBRID_PARALLEL_GROUP.get_pipe_parallel_world_size() > 1 - - -def register_pp_fwd_begin_hook(forward_begin_hook): - """Register the pp hook if register_global_pipeline_parallel_hook exist""" - if register_global_pipeline_parallel_hook is not None: - register_global_pipeline_parallel_hook( - PipelineParallelMicroStepLocations.FORWARD_BEGIN, forward_begin_hook - ) - - -@contextmanager -def track_rng_state(enable: bool, **kwargs) -> None: - """ - Applies get_rng_state_tracker().rng_state() to the context. - If not enabled, it does nothing. - """ - if enable: - with get_rng_state_tracker().rng_state(**kwargs): - yield - else: - yield - - -def set_tensor_dist_attr(tensor: paddle.Tensor, is_parallel: bool, axis: int) -> None: - """Set distributed attributes for the input tensor""" - tensor.is_distributed = is_parallel - if is_parallel: - tensor.split_axis = axis - - -def set_weight_tensor_dist_attr( - tensor: paddle.Tensor, is_parallel: bool, parallel_mode: Optional[str], backend: str -) -> None: - """Set distributed attributes for the weight tensor""" - if not is_parallel or parallel_mode is None: - return - set_tensor_dist_attr(tensor, is_parallel, axis=_weight_split_axis[backend][parallel_mode]) - - -def allreduce( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, -) -> Tuple[paddle.Tensor, Any]: - """All-reduce the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_ - - # All-reduce. - if sync_op: - output = mp_ops._mp_allreduce( - input_, - group=tp_group, - use_calc_stream=True, - use_model_parallel=True, - ) - return output, None - - wait_handle = paddle.distributed.all_reduce( - input_, - op=paddle.distributed.ReduceOp.SUM, - group=tp_group, - sync_op=False, - ) - - output = input_ - - return output, wait_handle - - -def allgather( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, - axis: int = 0, -) -> Tuple[paddle.Tensor, Any]: - """All-gather the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_, None - - parallelism = tp_group.nranks - output_shape = input_.shape - output_shape[axis] = output_shape[axis] * parallelism - output = paddle.empty(shape=output_shape, dtype=input_.dtype) - wait_handle = tp_group.process_group.all_gather_into_tensor(output, input_, sync_op) - if sync_op: - wait_handle.wait() - return output, None - return output, wait_handle - - -def reduce_scatter( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, - sync_op: bool = True, -) -> [paddle.Tensor, Any]: - """Reduce-scatter the input tensor across model parallel group.""" - - # Bypass the function if we are using only 1 GPU. - if tp_group is None or tp_group.nranks == 1: - return input_, None - - parallelism = tp_group.nranks - output_shape = input_.shape - assert input_.shape[0] % parallelism == 0, ( - f"Input sequence length {input_.shape[0]} can't be divided " - f"exactly by sequence parallelism {parallelism}" - ) - output_shape[0] = output_shape[0] // parallelism - output = paddle.empty(shape=output_shape, dtype=input_.dtype) - wait_handle = paddle.distributed.stream.reduce_scatter( - output, input_, op=paddle.distributed.ReduceOp.SUM, group=tp_group, sync_op=sync_op - ) - if sync_op: - return output, None - return output, wait_handle - - -def identity( - input_: paddle.Tensor, - tp_group: Optional[dist_group_type] = None, -) -> paddle.Tensor: - """ - Identity when forward. - Allreduce across model parallel group when backward. - """ - output = mp_ops._c_identity(input_, group=tp_group) - - return output - - -def mark_as_sequence_parallel_parameter(parameter: paddle.Tensor): - """ - Set sequence_parallel attribute to input tensor. It is used for registering allreduce - hooks in PaddleNLP sequence parallel training. - """ - setattr(parameter, "sequence_parallel", True) diff --git a/transformer_engine/paddle/fp8.py b/transformer_engine/paddle/fp8.py deleted file mode 100644 index b9b315a150..0000000000 --- a/transformer_engine/paddle/fp8.py +++ /dev/null @@ -1,370 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""FP8 utilities for TransformerEngine""" - -from contextlib import contextmanager -from typing import Tuple, Optional, Dict, Any, Union - -import numpy as np - -import paddle -from transformer_engine import transformer_engine_paddle as tex -from transformer_engine.common.recipe import DelayedScaling, Format - -from .constants import dist_group_type -from .fp8_buffer import FP8MetaFwdBuffer, FP8MetaBwdBuffer, FP8RecomputeBuffer - -__all__ = ["fp8_autocast"] - -# FP8 support -_is_fp8_available = None -_reason_for_no_fp8 = "" - - -def _check_fp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - - # Check GPU arch - arch = paddle.device.cuda.get_device_capability() - if arch >= (9, 0): # hopper and above - return True, "" - if arch < (8, 9): # pre-ada - return False, "Device compute capability 8.9 or higher required for FP8 execution." - - # Special handling for Ada - if tex.get_cublasLt_version() < 120103: - return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." - if not paddle.version.cuda(): - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - if tuple(int(v) for v in paddle.version.cuda().split(".")) < (12, 1): - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - return True, "" - - -def is_fp8_available() -> Tuple[bool, str]: - """Return if fp8 support is available""" - global _is_fp8_available, _reason_for_no_fp8 - if _is_fp8_available is None: - _is_fp8_available, _reason_for_no_fp8 = _check_fp8_support() - return _is_fp8_available, _reason_for_no_fp8 - - -class FP8State: - """Stores FP8 state""" - - def __init__(self): - self._fp8_enabled = False - self._fp8_calibration = False - self._fp8_recipe = None - self._fp8_distributed_group = None - self._is_first_fp8_module = False - self._fp8_autocast_counter = 0 - self._fp8_autocast_depth = 0 - self._fp8_recompute_enabled = False - self._use_cudagraph = False - self._fp8_fwd_buffer = FP8MetaFwdBuffer() - self._fp8_bwd_buffer = FP8MetaBwdBuffer() - self._fp8_recompute_buffer = FP8RecomputeBuffer() - - def is_fp8_enabled(self) -> bool: - """Is FP8 enabled""" - return self._fp8_enabled - - def is_fp8_calibration(self) -> bool: - """Is FP8 calibration""" - return self._fp8_calibration - - def get_fp8_recipe(self) -> DelayedScaling: - """Return the fp8 recipe""" - return self._fp8_recipe - - @staticmethod - def get_default_fp8_recipe() -> DelayedScaling: - """FP8 recipe with default args.""" - return DelayedScaling() - - def get_autocast_id(self) -> int: - """Returns the number of times of entering the `fp8_autocast` context. - as a unique ID for different training steps.""" - return self._fp8_autocast_counter - - def is_first_fp8_module(self): - """Returns `True` only the first time when called multiple - times from within the same `fp8_autocast` context. - """ - tmp = self._is_first_fp8_module - self._is_first_fp8_module = False - return tmp - - def get_fp8_group(self) -> Union[dist_group_type, None]: - """Return the fp8 group for scale/amax comm""" - return self._fp8_distributed_group - - def get_fp8_fwd_buffer(self) -> FP8MetaFwdBuffer: - """Returns global fp8 forward buffer.""" - return self._fp8_fwd_buffer - - def get_fp8_bwd_buffer(self) -> FP8MetaBwdBuffer: - """Returns global fp8 backward buffer.""" - return self._fp8_bwd_buffer - - def is_fp8_recompute_enabled(self) -> bool: - """Is FP8 recompute enabled""" - return self._fp8_recompute_enabled - - def get_fp8_recompute_buffer(self) -> FP8RecomputeBuffer: - """Returns global fp8 recompute buffer.""" - return self._fp8_recompute_buffer - - def is_cudagraph_enabled(self) -> bool: - """Is CUDAGraph enabled""" - return self._use_cudagraph - - def enable_cudagraph(self): - """Enable CUDA Graphs. Once CUDA Graphs are enabled, they cannot be disabled within the same execution context at current implementation.""" - self._use_cudagraph = True - self._fp8_fwd_buffer.enable_cudagraph() - self._fp8_bwd_buffer.enable_cudagraph() - if self._fp8_recompute_enabled: - raise RuntimeError("Currently, We do not allow recompute with cudagraph") - - def enter( - self, - enabled: bool, - calibrating: bool, - fp8_recipe: Optional[DelayedScaling], - fp8_group: Optional[dist_group_type], - ) -> None: - """Called when entering 'fp8_autocast'""" - self.saved_states = ( - self._fp8_enabled, - self._fp8_calibration, - self._fp8_recipe, - self._fp8_distributed_group, - self._is_first_fp8_module, - ) - - self._fp8_enabled = enabled - self._fp8_calibration = calibrating - self._fp8_recipe = self.get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - self._fp8_distributed_group = fp8_group - - if self._fp8_autocast_depth == 0: - self._is_first_fp8_module = True - self._fp8_autocast_counter += 1 - self._fp8_autocast_depth += 1 - - def exit(self): - """Called when exiting 'fp8_autocast'""" - # Restore saved states - ( - self._fp8_enabled, - self._fp8_calibration, - self._fp8_recipe, - self._fp8_distributed_group, - self._is_first_fp8_module, - ) = self.saved_states - - self._fp8_autocast_depth -= 1 - - if self._fp8_autocast_depth == 0: - self._fp8_fwd_buffer.finalize() - - -_global_fp8_state = FP8State() - - -def get_global_fp8_state() -> FP8State: - """Get global fp8 state""" - return _global_fp8_state - - -@contextmanager -def fp8_autocast( - enabled: bool = False, - calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, - fp8_group: Optional[dist_group_type] = None, -) -> None: - """ - Context manager for FP8 usage. - - .. code-block:: python - - with fp8_autocast(enabled=True): - out = model(inp) - - .. note:: - - Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors - with shapes where both dimensions are divisible by 16. In terms of the input to the full - Transformer network, this typically requires padding sequence length to be multiple of 16. - - .. note:: - - When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once - inside a single `fp8_autocast` region. This is unsupported behavior because the amax - reduction is handled during the exit of the `fp8_autocast` context. Calling the same - module more than once inside an `fp8_autocast` region overrides the amax tensors - before reduction can occur. - - Parameters - ---------- - enabled: bool, default = `False` - whether or not to enable fp8 - calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` - recipe used for FP8 training. - fp8_group: paddle.distributed.collective.Group, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - """ - try: - _global_fp8_state.enter(enabled, calibrating, fp8_recipe, fp8_group) - - if enabled: - fp8_available, reason_for_no_fp8 = is_fp8_available() - assert fp8_available, reason_for_no_fp8 - yield - finally: - _global_fp8_state.exit() - - -def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 - - -def amax_and_scale_update( - fp8_meta: Dict[str, Any], - fwd_update: bool, - update_weight_scale_inv: bool = True, - current_step_id_tensor: Optional[paddle.Tensor] = None, - use_cudagraph: bool = False, -) -> None: - """Updates fp8 amaxes/scales for fwd | bwd.""" - amax_compute = fp8_meta["recipe"].amax_compute_algo - sf_compute = fp8_meta["recipe"].scaling_factor_compute_algo - fp8_meta_tensor_key = "scaling_fwd" if fwd_update else "scaling_bwd" - fp8_max_key = "fp8_max_fwd" if fwd_update else "fp8_max_bwd" - - if not callable(amax_compute) and sf_compute is None: - non_weight_mask = fp8_meta[fp8_meta_tensor_key].non_weight_mask - - if use_cudagraph: - tex.amax_and_scale_update_inplace_legacy( - _amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, - _scale=fp8_meta[fp8_meta_tensor_key].scale, - _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, - non_weight_mask=non_weight_mask, - current_step_id_tensor=current_step_id_tensor, - update_weight_scale_inv=update_weight_scale_inv, - fwd_update=fwd_update, - fp8_max=fp8_meta[fp8_max_key], - margin=float(fp8_meta["recipe"].margin), - amax_compute=amax_compute, - ) - else: - if update_weight_scale_inv: - # we pass nullptr into kernel when we need to update_weight_scale_inv - non_weight_mask = paddle.empty([0]) - tex.amax_and_scale_update_inplace( - _amax_history=fp8_meta[fp8_meta_tensor_key].amax_history, - _scale=fp8_meta[fp8_meta_tensor_key].scale, - _scale_inv=fp8_meta[fp8_meta_tensor_key].scale_inv, - non_weight_mask=non_weight_mask, - fp8_dtype=int(get_fp8_te_dtype(fp8_meta["recipe"], fwd_update)), - margin=float(fp8_meta["recipe"].margin), - amax_compute=amax_compute, - ) - - else: - raise ValueError( - "We only support the fp8 recipe with 'max' or 'most_recent' " - "amax_compute_algo and default scaling_factor_compute_algo at this " - "moment." - ) - - -class FP8TensorMeta: - """Holds FP8 scaling and amax history for FP8 layers""" - - def __init__(self, is_forward: bool): - self.scale = paddle.Tensor() - self.scale_inv = paddle.Tensor() - self.amax_history = paddle.Tensor() - self.non_weight_mask = paddle.Tensor() - self.is_initialized = False - self.is_forward = is_forward - - def get_non_weight_mask(self, num_gemms: int): - """Needed for calculation of scale inverses to - preserve scale_inv when caching FP8 weights""" - if self.is_forward: - # [True, False, True]: -> [input, weight, output] - return paddle.to_tensor([True, False, True] * num_gemms) - # [True, True]: -> [grad_output, grad_input] - return paddle.to_tensor([True, True] * num_gemms) - - def prepare(self, num_gemms: int, amax_history_len: int) -> None: - """Prepare scales and amax tensors. It is called during fprop in each iteration. - If the meta tensors are not initialized yet, initialization is performed. If already - initialized, resize the meta tensors if amax_history_len has changed.""" - - if self.is_initialized: - # Handle changed amax history size. - curr_len = self.amax_history.shape[0] - num_fp8_tensors = self.amax_history.shape[1] - if amax_history_len < curr_len: - self.amax_history = self.amax_history[:amax_history_len] - elif amax_history_len > curr_len: - extra_rows = amax_history_len - curr_len - self.amax_history = paddle.concat( - [ - self.amax_history, - paddle.zeros((extra_rows, num_fp8_tensors), dtype="float32"), - ], - axis=0, - ) - return - - # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and - # 2 (grad_output and grad_input) for bwd - num_fp8_tensors = num_gemms * 3 if self.is_forward else num_gemms * 2 - - self.scale = paddle.ones(num_fp8_tensors, dtype="float32") - self.scale_inv = paddle.ones(num_fp8_tensors, dtype="float32") - self.amax_history = paddle.zeros([amax_history_len, num_fp8_tensors], dtype="float32") - self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms) - - self.is_initialized = True - - def to_numpy(self): - """Convert FP8 meta tensors to numpy.""" - assert self.is_initialized, "FP8TensorMeta is not initialized yet." - return { - "scale": self.scale.numpy(), - "scale_inv": self.scale_inv.numpy(), - "amax_history": self.amax_history.numpy(), - } - - def from_numpy(self, data: Dict[str, np.array]): - """Set FP8 meta tensors from numpy""" - self.scale = paddle.to_tensor(data["scale"]) - self.scale_inv = paddle.to_tensor(data["scale_inv"]) - self.amax_history = paddle.to_tensor(data["amax_history"]) - - num_fp8_tensors = self.scale.shape[0] - num_gemms = num_fp8_tensors // 3 if self.is_forward else num_fp8_tensors // 2 - self.non_weight_mask = self.get_non_weight_mask(num_gemms=num_gemms) - - self.is_initialized = True diff --git a/transformer_engine/paddle/fp8_buffer.py b/transformer_engine/paddle/fp8_buffer.py deleted file mode 100644 index a880ca8107..0000000000 --- a/transformer_engine/paddle/fp8_buffer.py +++ /dev/null @@ -1,350 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""FP8 meta buffer for FP8 amax reduction""" - -from abc import ABC, abstractmethod -from collections import deque -from functools import partial -import os -from typing import Dict, Any, List, Union - -import numpy as np -import paddle -from transformer_engine import transformer_engine_paddle as tex - -from .constants import dist_group_type, RecomputeFunctionNames - - -class FP8MetaBufferBase(ABC): - """ - A global buffer that holds FP8 meta for reduction across trainers. - """ - - def __init__(self): - self._global_amax = {} - self._buffer_delete_key = None - self._amax_reduce_wait_func = None - self._dp_amax_reduce_interval = None - self._contiguous_amax = None - self._use_cudagraph = False - self._dp_amax_reduce_idx = 0 - - @staticmethod - @abstractmethod - def _get_meta_tensor_key(): - """Returns scaling key in `fp8_meta`.""" - - @staticmethod - @abstractmethod - def _get_buffer_position_key(): - """Returns module position key in `fp8_meta`.""" - - @staticmethod - @abstractmethod - def _get_autocast_key(): - """Returns autocast id key in `fp8_meta`.""" - - def _get_amax_buffer_key(self, fp8_meta: Dict[str, Any]) -> str: - """Return a key in `_global_amax` for the AMAX storage.""" - return f"AMAX_{fp8_meta[self._get_autocast_key()]}" - - def _execute_deletion(self) -> None: - """Delete the key from global amax buffer.""" - if self._buffer_delete_key is not None and self._buffer_delete_key in self._global_amax: - del self._global_amax[self._buffer_delete_key] - - def _wait_handle_and_split( - self, - contiguous_amax: paddle.Tensor, - chunk_sizes: List[int], - amax_buffer_key: str, - wait_handle: Union[bool, None], - ) -> None: - """Wait for amax reduction to finish and then copy reduced amax to buffer""" - if wait_handle is not None: - wait_handle.wait() - if self._use_cudagraph: - splited_list = list(contiguous_amax.split(chunk_sizes)) - for amax, split in zip(self._global_amax[amax_buffer_key], splited_list): - amax.copy_(split, False) - else: - self._global_amax[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) - - def _global_amax_reduction( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """Concatenate, reduce, and split amaxes in the global buffer.""" - - def _reduce_tensor_across_group_op_max(tensor, group, sync_op): - if paddle.distributed.is_initialized(): - wait_handle = paddle.distributed.all_reduce( - tensor, - op=paddle.distributed.ReduceOp.MAX, - group=group, - sync_op=sync_op, - ) - return wait_handle - return None - - amax_buffer_key = self._get_amax_buffer_key(fp8_meta) - # Key already deleted. - if amax_buffer_key not in self._global_amax: - return None - - # Reduce AMAX in DP-domain at an interval. - if self._dp_amax_reduce_interval is None: - self._dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1")) - - tp_amax_reduce = False - reduce_group = -1 # Set value that will raise error if not set. `None` is a valid group. - if self._dp_amax_reduce_idx == 0: - reduce_group = fp8_meta["fp8_group"] - else: - tp_amax_reduce = True - self._dp_amax_reduce_idx = (self._dp_amax_reduce_idx + 1) % self._dp_amax_reduce_interval - - if tp_amax_reduce: - if tp_size > 1: - reduce_group = tp_group - else: - return None - - chunk_sizes = [x.shape[0] for x in self._global_amax[amax_buffer_key]] - if self._use_cudagraph: - # we need to ensure the _contiguous_amax is address-stable under cudagraph - if self._contiguous_amax is None: - self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key]) - else: - self._contiguous_amax.copy_( - paddle.concat(self._global_amax[amax_buffer_key]), False - ) - else: - self._contiguous_amax = paddle.concat(self._global_amax[amax_buffer_key]) - - wait_handle = _reduce_tensor_across_group_op_max( - self._contiguous_amax, - reduce_group, - not fp8_meta["async_amax_reduction"], - ) - - if wait_handle is not None and self._use_cudagraph: - # we need to ensure record/wait does not cross the boundary of the graph - wait_handle.wait() - wait_handle = None - - return partial( - self._wait_handle_and_split, - self._contiguous_amax, - chunk_sizes, - amax_buffer_key, - wait_handle, - ) - - def add_amax(self, fp8_meta: Dict[str, Any]) -> None: - """Append `amax_history` to global buffer.""" - buffer_key = self._get_amax_buffer_key(fp8_meta) - fp8_meta_tensor_key = self._get_meta_tensor_key() - buffer_position_key = self._get_buffer_position_key() - - if buffer_key not in self._global_amax: - self._global_amax[buffer_key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - else: - self._global_amax[buffer_key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - - if buffer_position_key not in fp8_meta: - fp8_meta[buffer_position_key] = len(self._global_amax[buffer_key]) - 1 - - # Catch incorrect fp8_autocast usage. - assert fp8_meta[buffer_position_key] == len(self._global_amax[buffer_key]) - 1, ( - "Same module is being invoked more than once inside an `fp8_autocast` " - "region when using FP8 with amax reduction. This behavior is currently " - "unsupported. For more details and correct usage, please see " - "https://github.com/NVIDIA/TransformerEngine/pull/93." - ) - - def copy_amax_from_buffer(self, fp8_meta: Dict[str, Any]) -> None: - """Populate current amax with the correct location from buffer.""" - fp8_meta_tensor_key = self._get_meta_tensor_key() - buffer_position_key = self._get_buffer_position_key() - if buffer_position_key not in fp8_meta: - return - - amax_buffer_key = self._get_amax_buffer_key(fp8_meta) - assert amax_buffer_key in self._global_amax, "TE internal error." - - # Copy amax to amax_history[0] - tex.update_latest_amax_history_inplace( - _history=fp8_meta[fp8_meta_tensor_key].amax_history, - amax=self._global_amax[amax_buffer_key][fp8_meta[buffer_position_key]], - ) - - def set_for_deletion(self, fp8_meta: Dict[str, Any]) -> None: - """Delete this amax key from global buffer during autocast end.""" - if self._get_autocast_key() not in fp8_meta: - return - self._buffer_delete_key = self._get_amax_buffer_key(fp8_meta) - - def get_amax_reduce_handle(self) -> Union[bool, None]: - """Return AMAX reduction wait handle.""" - return self._amax_reduce_handle - - def wait(self) -> None: - """Wait for reduced amax to be available in buffer.""" - if self._amax_reduce_wait_func is not None: - self._amax_reduce_wait_func() # pylint: disable=not-callable - self._amax_reduce_wait_func = None - - def to_numpy(self) -> Dict[str, List[np.array]]: - """Convert to numpy arrays""" - out = {} - for k, v in self._global_amax.items(): - out[k] = [tensor.numpy() for tensor in v] - return out - - def from_numpy(self, buffer: Dict[str, np.array]) -> None: - """Set buffer values from numpy arrays""" - for k, v in buffer.items(): - self._global_amax[k] = [paddle.to_tensor(arr) for arr in v] - - def enable_cudagraph(self): - """Enable CUDA Graphs.""" - self._use_cudagraph = True - - -class FP8MetaFwdBuffer(FP8MetaBufferBase): - """FP8Meta Buffer for forward""" - - @staticmethod - def _get_meta_tensor_key() -> str: - """Returns scaling key in `fp8_meta`.""" - return "scaling_fwd" - - @staticmethod - def _get_buffer_position_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "global_fp8_buffer_pos_fwd" - - @staticmethod - def _get_autocast_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "autocast_id_fwd" - - def set_for_amax_reduction( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """Sets up the function to call during autocast exit.""" - self._amax_global_reduce_func = partial( - self._global_amax_reduction, - fp8_meta, - tp_group, - tp_size, - ) - - def finalize(self) -> None: - """ - Called at FP8 autocast end. - Performs AMAX reduction and delete unused buffer entries. - """ - if hasattr(self, "_amax_global_reduce_func") and callable(self._amax_global_reduce_func): - self._amax_reduce_wait_func = self._amax_global_reduce_func() - self._execute_deletion() - - -class FP8MetaBwdBuffer(FP8MetaBufferBase): - """FP8Meta Buffer for backward""" - - @staticmethod - def _get_meta_tensor_key() -> str: - """Returns scaling key in `fp8_meta`.""" - return "scaling_bwd" - - @staticmethod - def _get_buffer_position_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "global_fp8_buffer_pos_bwd" - - @staticmethod - def _get_autocast_key() -> str: - """Returns module position key in `fp8_meta`.""" - return "autocast_id_bwd" - - def finalize( - self, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - ) -> None: - """ - Called at FP8 autocast end in backward. - Performs AMAX reduction and delete unused buffer entries. - """ - self._amax_reduce_wait_func = self._global_amax_reduction( - fp8_meta, tp_group, tp_size - ) # _wait_handle_and_split - self._execute_deletion() - - -class FP8RecomputeBuffer: - """Buffer used to hold FP8 meta tensors for recompute""" - - def __init__(self): - self._global_amax = [] - - @staticmethod - def get_buffer_position_key(): - """Returns the key (in fp8_meta) for recompute buffer position""" - return "recompute_buffer_pos" - - def stash_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: - """Stash the scaling factors and amaxes for recompute""" - buffer_position_key = self.get_buffer_position_key() - - to_copy = [ - fp8_meta["scaling_fwd"].amax_history.clone(), - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ] - - if buffer_position_key in fp8_meta: - self._global_amax[fp8_meta[buffer_position_key]].append(to_copy) - else: - self._global_amax.append(deque()) - self._global_amax[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(self._global_amax) - 1 - - def retrieve_fp8_meta_tensors(self, fp8_meta: Dict[str, Any]) -> None: - """Switch to the previously saved scaling factors and amaxes""" - # Store updated amaxes and scales from phase 1 post forward. - fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale - fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv - - # Retrieve stashed amaxes and scales from phase 1 pre forward. - buffer_position_key = self.get_buffer_position_key() - stashed_fp8_meta = self._global_amax[fp8_meta[buffer_position_key]].popleft() - - # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] - - @staticmethod - def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: - """Restore latest scaling factors and amaxes after recompute forward run.""" - assert "updated_amax_history_fwd" in fp8_meta, ( - "Recompute internal error." - " If you are not using recompute, please check if" - " the forward function is called from one of these functions: " - f"{RecomputeFunctionNames}. If so, consider change the function name " - "or set NVTE_DISABLE_RECOMPUTE=1." - ) - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] diff --git a/transformer_engine/paddle/layer/__init__.py b/transformer_engine/paddle/layer/__init__.py deleted file mode 100644 index 58eb6a7c56..0000000000 --- a/transformer_engine/paddle/layer/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Layer level Paddle APIs""" - -from .attention import DotProductAttention, MultiHeadAttention, RotaryPositionEmbedding -from .layernorm import LayerNorm -from .layernorm_linear import LayerNormLinear -from .layernorm_mlp import LayerNormMLP -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax -from .transformer import TransformerLayer diff --git a/transformer_engine/paddle/layer/attention.py b/transformer_engine/paddle/layer/attention.py deleted file mode 100644 index 3ff5a42ff5..0000000000 --- a/transformer_engine/paddle/layer/attention.py +++ /dev/null @@ -1,1161 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Attntion API""" - -import math -import os -import warnings -from typing import Optional, Tuple, Union - -import paddle -import paddle.nn.functional as F - -try: - from paddle.incubate.nn.functional import fused_rotary_position_embedding -except ImportError: - fused_rotary_position_embedding = None -from transformer_engine import transformer_engine_paddle as tex - -from .layernorm_linear import LayerNormLinear -from .linear import Linear -from .softmax import FusedScaleMaskSoftmax -from ..constants import ( - AttnTypes, - TE_DType, - AttnBiasType, - AttnMaskType, - FusedAttnBackend, - dist_group_type, -) -from ..cpp_extensions import ( - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, - fused_attn_fwd, - fused_attn_bwd, - mask_to_cu_seqlens, -) -from ..distributed import get_tp_group_and_world_size, track_rng_state -from ..utils import attention_mask_func, divide -from ..recompute import recompute - -__all__ = ["DotProductAttention", "MultiHeadAttention", "RotaryPositionEmbedding"] - - -def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: - """ - Used to repeat the key and value states for GQA. - The hidden states go from (batch, seqlen, num_gqa_groups, head_size) - to (batch, seqlen, num_heads, head_size) - """ - batch, seqlen, num_gqa_groups, head_size = hidden_states.shape - if n_rep == 1: - return hidden_states - - hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) - return hidden_states.reshape([batch, seqlen, num_gqa_groups * n_rep, head_size]) - - -class RotaryPositionEmbedding(paddle.nn.Layer): - """ - Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. - """ - - def __init__( - self, - dim: int, - max_position_embeddings: int, - ): - """ - Parameters - ---------- - dim: int - rotary embedding dimension - max_position_embeddings: int - max_position_embeddings before position interpolation - """ - super().__init__() - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.inv_freq = 1.0 / ( - 10000 ** (paddle.cast(paddle.arange(0, dim, 2), dtype="float32") / self.dim) - ) - self._set_cos_sin_cache(seq_len=max_position_embeddings) - - def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - # [seq_len] - t = paddle.arange(seq_len, dtype="float32") - # [seq_len, dim/2] - freqs = paddle.einsum("i,j->ij", t, self.inv_freq) - # [seq_len, dim] - emb = paddle.concat([freqs, freqs], axis=-1) - # [1, seqlen, 1, dim] - self.cos_cached = emb.cos()[None, :, None, :] - self.sin_cached = emb.sin()[None, :, None, :] - - def forward(self, max_seq_len: int): - """ - Create rotary position embedding frequencies - - Parameters - ---------- - max_seq_len: int - sequence length of a sample - """ - cos = self.cos_cached[:, :, :max_seq_len, ...] - sin = self.sin_cached[:, :, :max_seq_len, ...] - return (cos, sin) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return paddle.concat([-x2, x1], axis=-1) # shape is the same as x - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): - """Applies rotary positional embedding to the input.""" - - if position_ids is None: - # Note: Only for LlamaForCausalLMPipe model pretraining - cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] - else: - cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] - sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class FusedAttnFuncPackedQKV(paddle.autograd.PyLayer): - """Function for FusedAttention with packed QKV input""" - - @staticmethod - def forward( - ctx, - qkv, - cu_seqlens, - attn_bias, - max_seqlen, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with packed QKV input""" - out, softmax_aux, rng_state = fused_attn_fwd_qkvpacked( - qkv, - cu_seqlens, - is_training, - max_seqlen, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(qkv, out, cu_seqlens, rng_state, softmax_aux) - ctx.max_seqlen = max_seqlen - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with packed QKV input""" - qkv, out, cu_seqlens, rng_state, softmax_aux = ctx.saved_tensor() - dqkv, *rest = fused_attn_bwd_qkvpacked( - qkv, - cu_seqlens, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - - # if no_bias, return dqkv - if ctx.attn_bias_type == "no_bias": - return (dqkv, None) - # else, return (dqkv, dbias) - return (dqkv, None, rest[0]) - - -class FusedAttnFuncPackedKV(paddle.autograd.PyLayer): - """Function for FusedAttention with packed KV input""" - - @staticmethod - def forward( - ctx, - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - attn_bias, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with packed KV input""" - out, softmax_aux, rng_state = fused_attn_fwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - is_training, - max_seqlen_q, - max_seqlen_kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with packed KV input""" - q, kv, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dkv, *rest = fused_attn_bwd_kvpacked( - q, - kv, - cu_seqlens_q, - cu_seqlens_kv, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - - # if no_bias, return dq, dkv - if ctx.attn_bias_type == "no_bias": - return (dq, dkv, None, None) - # else, return (dq, dkv, dbias) - return (dq, dkv, None, None, rest[0]) - - -class FusedAttnFunc(paddle.autograd.PyLayer): - """Function for FusedAttention with separate Q, K, V tensors""" - - @staticmethod - def forward( - ctx, - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - attn_bias, - max_seqlen_q, - max_seqlen_kv, - attn_scale, - qkv_dtype, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - is_training, - deterministic, - fused_attention_backend, - ): - """Forward function for FusedAttention with separate Q, K, V tensors""" - out, softmax_aux, rng_state = fused_attn_fwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - is_training, - max_seqlen_q, - max_seqlen_kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - attn_scale, - dropout_p, - set_zero, - qkv_layout, - attn_bias_type, - attn_mask_type, - ) - - ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux) - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.set_zero = set_zero - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.deterministic = deterministic - ctx.fused_attention_backend = fused_attention_backend - - return out - - @staticmethod - def backward(ctx, d_out): - """Backward function for FusedAttention with separate Q, K, V tensors""" - q, k, v, out, cu_seqlens_q, cu_seqlens_kv, rng_state, softmax_aux = ctx.saved_tensor() - dq, dk, dv, *rest = fused_attn_bwd( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_kv, - rng_state, - out, - d_out, - softmax_aux, - ctx.fused_attention_backend, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.qkv_dtype, - ctx.attn_scale, - ctx.dropout_p, - ctx.set_zero, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.deterministic, - ) - # if no_bias, return dq, dk, dv - if ctx.attn_bias_type == "no_bias": - return (dq, dk, dv, None, None) - # else, return (dq, dk, dv, dbias) - return (dq, dk, dv, None, None, rest[0]) - - -class DotProductAttention(paddle.nn.Layer): - """ - Allows the model to jointly attend to information from different - representation subspaces as described in the paper: - `Attention Is All You Need `_. - - .. note:: - - Argument :attr:`attention_mask` will be ignored in the `forward` call when - :attr:`attn_mask_type` is set to `"causal"`. - - .. warning:: - - Fused attention backward uses a non-deterministic algorithm when workspace - optimization is not enabled. To use a deterministic algorithm, set the - environment variable :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` - - Parameters - ---------- - num_attention_heads: int - number of attention heads in the transformer layer. - kv_channels: int - number of channels in the key and value tensors. - num_gqa_groups : Optional[int] = None - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the queries. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. - attention_type: {'self', 'cross'}, default = `self` - type of attention operation. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. - """ - - def __init__( - self, - num_attention_heads: int, - kv_channels: int, - num_gqa_groups: Optional[int] = None, - attention_dropout: float = 0.1, - attn_mask_type: str = "causal", - attention_type: str = "self", - tp_size: int = 1, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.attn_mask_type = attn_mask_type - self.attention_dropout = attention_dropout - self.attention_type = attention_type - self.qkv_layout = "bshd_bshd_bshd" - self.hidden_size_per_attention_head = kv_channels - self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - self.tp_size = tp_size - self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // tp_size) - self.num_queries_per_key_value = num_attention_heads // self.num_gqa_groups - - self.backend = backend - - self.use_fused_attention = bool(int(os.getenv("NVTE_FUSED_ATTN", "1"))) - - self.deterministic = not bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) - - # To use the workspace optimization path for determinism, please - # set NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT=1 for cuDNN >=8.9.5 and <9.0.0, - # and set NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 for cuDNN >=9.0.0. - cudnn_version = paddle.get_cudnn_version() - if 8905 <= cudnn_version < 9000: - if self.deterministic: - # workspace optimization path is deterministic - os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" - - # CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT - # - unset: enables workspace optimization when required workspace is <= 256MB - # or when bias gradient needs to be computed - # - n: enables workspace optimization when required workspace is <= n bytes - # - -1: enables workspace optimization always - # - 0: disables workspace optimization always - if "NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT" in os.environ: - if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "0": - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "0" - if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": - os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" - - if not self.use_fused_attention and backend == "transformer_engine": - warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - self.backend = "paddle" - - if self.backend != "transformer_engine": - self.scale_mask_softmax = FusedScaleMaskSoftmax( - attn_mask_type, attention_mask_func, backend=self.backend - ) - - def forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - ) -> paddle.Tensor: - """ - Dot Product Attention Layer. - - .. note:: - - Argument :attr:`attention_mask` will be ignored when :attr:`attn_mask_type` - is set to `"causal"`. - - - Parameters - ---------- - query_layer : paddle.Tensor - Query tensor. - key_layer : paddle.Tensor - Key tensor. - value_layer : paddle.Tensor - Value tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. - core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to use the fast path to set output tensors to 0 or not. - """ - - backend = self.backend - - assert key_layer.shape == value_layer.shape, "Keys and values must have the same shape!" - assert ( - key_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" - - if backend == "transformer_engine": - max_s_q = query_layer.shape[1] - max_s_kv = max_s_q if self.attention_type == "self" else key_layer.shape[1] - self.fused_attention_backend = tex.get_fused_attn_backend( - TE_DType[query_layer.dtype], - TE_DType[query_layer.dtype], - tex.get_nvte_qkv_layout(self.qkv_layout), - AttnBiasType[core_attention_bias_type], - AttnMaskType[self.attn_mask_type], - self.attention_dropout, - query_layer.shape[-2], - key_layer.shape[-2] if key_layer is not None else query_layer.shape[-2], - max_s_q, - max_s_kv, - query_layer.shape[-1], - ) - - is_backend_avail = self.fused_attention_backend in [ - FusedAttnBackend["F16_max512_seqlen"], - FusedAttnBackend["F16_arbitrary_seqlen"], - ] - if is_backend_avail and self.use_fused_attention: - return self._te_forward( - query_layer, - key_layer, - value_layer, - attention_mask, - core_attention_bias_type, - core_attention_bias, - set_zero, - ) - warnings.warn("Fused attention is not enabled, falling back to Paddle backend") - backend = "paddle" - self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.attn_mask_type, attention_mask_func, backend=backend - ) - if backend == "paddle": - if core_attention_bias_type != "no_bias": - warnings.warn( - "Paddle backend dot product attention does not support bias yet. " - "Bias will be ignored." - ) - return self._pd_forward(query_layer, key_layer, value_layer, attention_mask) - raise AttributeError(f"Backend {backend} is not supported.") - - def _te_forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - ) -> paddle.Tensor: - - if self.attention_type == "self": - # self attention - q: [b, s, h, d] kv: None - assert ( - len(query_layer.shape) == 4 - and len(key_layer.shape) == 4 - and len(value_layer.shape) == 4 - ), "q,k,v shape must be [b, s, h, d] for dot product self attention" - max_seqlen = query_layer.shape[1] - if self.attn_mask_type == "causal" or attention_mask is None: - cu_seqlens = paddle.arange( - 0, - (query_layer.shape[0] + 1) * query_layer.shape[1], - step=query_layer.shape[1], - dtype="int32", - ) - else: - cu_seqlens, _ = mask_to_cu_seqlens(attention_mask, need_kv=False) - qkv_dtype = TE_DType[query_layer.dtype] - - output = FusedAttnFunc.apply( - query_layer, - key_layer, - value_layer, - cu_seqlens, - cu_seqlens, - core_attention_bias, - max_seqlen, - max_seqlen, - 1.0 / self.norm_factor, - qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, - self.qkv_layout, - core_attention_bias_type, - self.attn_mask_type, - self.training, - self.deterministic, - self.fused_attention_backend, - ) - elif self.attention_type == "cross": - # cross attention - q: [b, s_q, h, d] k,v: [b, s_kv, h, d] - assert ( - len(query_layer.shape) == 4 - and len(key_layer.shape) == 4 - and len(value_layer.shape) == 4 - ), ( - "query shape must be [b, s_q, h, d] and key shape must be [b, s_kv, h, d]" - "for dot product cross attention" - ) - assert attention_mask is not None, "attention_mask must be provided for cross attention" - max_seqlen_q = query_layer.shape[1] - max_seqlen_kv = key_layer.shape[1] - cu_seqlens_q, cu_seqlens_kv = mask_to_cu_seqlens(attention_mask, need_kv=True) - qkv_dtype = TE_DType[query_layer.dtype] - output = FusedAttnFunc.apply( - query_layer, - key_layer, - value_layer, - cu_seqlens_q, - cu_seqlens_kv, - core_attention_bias, - max_seqlen_q, - max_seqlen_kv, - 1.0 / self.norm_factor, - qkv_dtype, - self.attention_dropout if self.training else 0.0, - set_zero, - self.qkv_layout, - core_attention_bias_type, - self.attn_mask_type, - self.training, - self.deterministic, - self.fused_attention_backend, - ) - else: - raise ValueError("attention_type must be one of ['self', 'cross']") - return output - - def _pd_forward( - self, - query_layer: paddle.Tensor, - key_layer: paddle.Tensor, - value_layer: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - ) -> paddle.Tensor: - - q = query_layer - k = repeat_kv(key_layer, self.num_queries_per_key_value) - v = repeat_kv(value_layer, self.num_queries_per_key_value) - - q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) - k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) - v = paddle.transpose(x=v, perm=[0, 2, 1, 3]) - - product = paddle.matmul(x=q * (1.0 / self.norm_factor), y=k, transpose_y=True) - attention_probs = self.scale_mask_softmax(product, attention_mask, scale=None) - - if self.attention_dropout > 0: - attention_probs = F.dropout( - attention_probs, - self.attention_dropout, - training=self.training, - ) - - out = paddle.matmul(attention_probs, v) - out = paddle.transpose(out, perm=[0, 2, 1, 3]) # [b, s, h, d] - # out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) - return out - - -class MultiHeadAttention(paddle.nn.Layer): - """ - Multi-head Attention (MHA), including Query, - Key, Value and Output projection. - - Parameters - ---------- - hidden_size: int - hidden size of the model. - num_attention_heads: int - number of attention heads. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - layernorm_epsilon: float, default = 1e-5 - epsilon to use in the layer norm operations. - weight_attr: Union[paddle.ParamAttr, None], default = `None` - paddle.ParamAttr object for the weight parameter. - bias_attr: Union[paddle.ParamAttr, None, bool], default = `None` - paddle.ParamAttr object for the bias parameter. - attn_mask_type: {'causal', 'padding', 'no_mask'}, default = `causal` - type of attention mask passed into softmax operation. - params_dtype: Optional[paddle.dtype], default = `None` - data type for the weights and biases. - return_layernorm_output: bool, default = `False` - whether to return the output of the layernorm operation. - input_layernorm: bool, default = `False` - whether to apply layernorm to the input. - attention_type: {'self', 'cross'}, default = `self` - type of attention operation. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - zero_centered_gamma: bool, default = `False` - whether to zero initialize the gamma of the layernorm operation. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for attention operation. If set to 'paddle', a framework - only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel - whereas PROJ and FC2 is used as Row Parallel as described - `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - num_gqa_groups : int, default = `None` - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the querys. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - rng_state_name : str, default = `local_seed` - Controls the rng state used for dropout on attention probs. The - specified rng should be set different seeds for different TP ranks. - It will be ignored if `set_parallel_mode` is False. The specified - name should be registered through - `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() - .add(rng_state_name, seed)`. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - hidden_size: int, - num_attention_heads: int, - attention_dropout: float = 0.1, - layernorm_epsilon: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - max_sequence_length: Optional[int] = None, - attn_mask_type: str = "causal", - params_dtype: Optional[paddle.dtype] = None, - return_layernorm_output: bool = False, - input_layernorm: bool = False, - attention_type: str = "self", - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False, - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - num_gqa_groups: Optional[int] = None, - fuse_wgrad_accumulation: bool = False, - rng_state_name: str = "local_seed", - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.input_layernorm = input_layernorm - self.attention_type = attention_type - self.return_layernorm_output = return_layernorm_output - self.params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype - self.max_sequence_length = max_sequence_length - self.weight_attr = weight_attr - self.bias_attr = bias_attr - self.attn_mask_type = attn_mask_type - - assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" - - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.sequence_parallel = self.tensor_parallel and sequence_parallel - self.hidden_size_per_attention_head = hidden_size // num_attention_heads - self.num_attention_heads = num_attention_heads - self.set_parallel_mode = set_parallel_mode - self.rng_state_name = rng_state_name - self.backend = backend - - self.num_attention_heads_per_partition = divide(self.num_attention_heads, self.tp_size) - self.num_gqa_groups = num_attention_heads if num_gqa_groups is None else num_gqa_groups - assert ( - self.num_attention_heads % self.num_gqa_groups == 0 - ), "The number of attention heads must be divisible by the number of GQA groups!" - assert ( - self.num_gqa_groups % self.tp_size == 0 - ), "The number of GQA groups must be divisible by tensor parallel size!" - self.num_gqa_groups_per_partition = int(self.num_gqa_groups // self.tp_size) - self.hidden_size_kv = int(hidden_size * self.num_gqa_groups // self.num_attention_heads) - qkv_parallel_mode = "column" if set_parallel_mode else None - - if self.attention_type == "self": - if self.input_layernorm: - self.layernorm_qkv = LayerNormLinear( - hidden_size, - hidden_size + 2 * self.hidden_size_kv, - eps=layernorm_epsilon, - weight_attr=self.weight_attr, - bias_attr=self.bias_attr, - return_layernorm_output=return_layernorm_output, - normalization=normalization, - zero_centered_gamma=zero_centered_gamma, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - else: - self.qkv = Linear( - hidden_size, - hidden_size + 2 * self.hidden_size_kv, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - else: # cross attention - if self.input_layernorm: - self.layernorm_query = LayerNormLinear( - hidden_size, - hidden_size, - eps=layernorm_epsilon, - weight_attr=self.weight_attr, - bias_attr=self.bias_attr, - return_layernorm_output=return_layernorm_output, - normalization=normalization, - zero_centered_gamma=zero_centered_gamma, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - else: - self.query_layer = Linear( - hidden_size, - hidden_size, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - self.key_value = Linear( - hidden_size, - 2 * self.hidden_size_kv, - self.weight_attr, - self.bias_attr, - parallel_mode=qkv_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - # Attention. - self.core_attention = DotProductAttention( - self.num_attention_heads, - self.hidden_size_per_attention_head, - self.num_gqa_groups, - attention_dropout, - attn_mask_type=attn_mask_type, - attention_type=self.attention_type, - tp_size=self.tp_size, - backend=self.backend, - ) - - # Linear - self.proj = Linear( - hidden_size, - hidden_size, - self.weight_attr, - self.bias_attr, - parallel_mode="row" if set_parallel_mode else None, - sequence_parallel=self.sequence_parallel, - tp_group=self.tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=self.backend, - ) - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - encoder_output: Optional[paddle.Tensor] = None, - rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - recompute_core_attention: bool = False, - is_first_microbatch: Optional[bool] = None, - ) -> Tuple[Union[paddle.Tensor, None], ...]: - """ - MultiHeadAttention Layer. - - Parameters - ---------- - hidden_states : paddle.Tensor - Input tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out softmax input when not using attention. - encoder_output : Optional[paddle.Tensor], default = `None` - Output of the encoder layer. - rotary_pos_emb: Tuple[paddle.Tensor, paddle.Tensor], default = `None` - Embeddings for query and key tensors for applying rotary position - embedding. By default no input embedding is applied. - core_attention_bias_type: str, default = `no_bias` - only support no_bias type currently, {`no_bias`} - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to use the fast path to set output tensors to 0 or not. - recompute_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed - during the backward pass in order to save memory that would - otherwise be occupied to store the forward activations until - backprop. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - - if self.attn_mask_type != "causal" and attention_mask is not None: - assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" - - input_dim = len(hidden_states.shape) - if input_dim == 2: - # hidden_states: [b * s_q, hidden_size] - # need to get max_seq_len from attention_mask - assert self.max_sequence_length is not None, "max_sequence_length must be provided" - max_seq_len = self.max_sequence_length - elif input_dim == 3: - # hidden_states: [b, s_q, hidden_size] - max_seq_len = hidden_states.shape[1] - else: - raise ValueError(f"hidden_states should have 2 or 3 dimensions, got {input_dim}.") - - layernorm_output = None - if self.attention_type == "self": - if self.input_layernorm: - layernorm_qkv_outputs = self.layernorm_qkv( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - if self.return_layernorm_output: - mixed_qkv_layer, layernorm_output = layernorm_qkv_outputs - else: - mixed_qkv_layer = layernorm_qkv_outputs - else: - mixed_qkv_layer = self.qkv( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - - num_queries_per_key_value = ( - self.num_attention_heads_per_partition // self.num_gqa_groups_per_partition - ) - - # [b, s_q, hidden_size+2*hidden_size_kv] --> [b, s_q, (h/ng+2), ng, d] - mixed_qkv_layer = mixed_qkv_layer.reshape( - shape=[ - -1, - max_seq_len, - (num_queries_per_key_value + 2), - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - # [b, s_q, (h/ng+2), ng, d] - # --> [b, s_q, (h/ng), ng, d] [b, s_q, 1, ng, d] [b, s_q, 1, ng, d] - query_layer, key_layer, value_layer = paddle.split( - mixed_qkv_layer, - num_or_sections=(num_queries_per_key_value, 1, 1), - axis=2, - ) - - # query: -> [b, s, h, d] - # key, value: -> [b, s, ng, d] - query_layer, key_layer, value_layer = ( - x.reshape(shape=[x.shape[0], x.shape[1], -1, self.hidden_size_per_attention_head]) - for x in (query_layer, key_layer, value_layer) - ) - - else: # cross attention - mixed_kv_layer = self.key_value( - encoder_output, - is_first_microbatch=is_first_microbatch, - ) - # [b, s_kv, 2 * hidden_size] --> [b, s_kv, 2, num_heads, head_size] - mixed_kv_layer = mixed_kv_layer.reshape( - shape=[ - 0, - 0, - 2 * self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - # [b, s_kv, 2 * ng, head_size] - # --> 2 [b, s_kv, ng, head_size] - key_layer, value_layer = paddle.split( - mixed_kv_layer, - num_or_sections=2, - axis=2, - ) - - if self.input_layernorm: - layernorm_query_outputs = self.layernorm_query( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - if self.return_layernorm_output: - query_layer, layernorm_output = layernorm_query_outputs - else: - query_layer = layernorm_query_outputs - else: - query_layer = self.query_layer( - hidden_states, - is_first_microbatch=is_first_microbatch, - ) - - # [b, s, hidden_size] --> [b, s, h, d] - query_layer = query_layer.reshape( - shape=[ - -1, - max_seq_len, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ] - ) - - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - if fused_rotary_position_embedding is None: - query_layer, key_layer = apply_rotary_pos_emb( - query_layer, key_layer, q_pos_emb, k_pos_emb - ) - else: - query_layer, key_layer, _ = fused_rotary_position_embedding( - query_layer, - key_layer, - v=None, - sin=k_pos_emb, - cos=q_pos_emb, - position_ids=None, - use_neox_rotary_style=False, - ) - - with track_rng_state(enable=self.tensor_parallel, name=self.rng_state_name): - if recompute_core_attention: - context_layer = recompute( - self.core_attention, - query_layer, - key_layer, - value_layer, - attention_mask, - core_attention_bias_type, - core_attention_bias, - set_zero, - use_reentrant=False, - ) - else: - context_layer = self.core_attention( - query_layer=query_layer, - key_layer=key_layer, - value_layer=value_layer, - attention_mask=attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - ) - - if input_dim == 3: - context_layer = paddle.reshape( - context_layer, [-1, max_seq_len, context_layer.shape[2] * context_layer.shape[3]] - ) - else: # input_dim == 2 - context_layer = paddle.reshape( - context_layer, [-1, context_layer.shape[2] * context_layer.shape[3]] - ) - - # Output. [b, s, hidden] - attention_output = self.proj(context_layer, is_first_microbatch=is_first_microbatch) - - if self.input_layernorm and self.return_layernorm_output: - return attention_output, layernorm_output - return attention_output diff --git a/transformer_engine/paddle/layer/base.py b/transformer_engine/paddle/layer/base.py deleted file mode 100644 index adbd1ce269..0000000000 --- a/transformer_engine/paddle/layer/base.py +++ /dev/null @@ -1,571 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Base modules and utilities for TransformerEngine Paddle API""" - -from abc import ABC, abstractmethod -from contextlib import contextmanager -import os -import pickle -from typing import Generator, Dict, Tuple, Union, Any, List, Optional - -import numpy as np - -import paddle - -try: - from paddle.base import core - from paddle.base.framework import _dygraph_tracer -except ImportError: - from paddle.fluid import core - from paddle.fluid.framework import _dygraph_tracer - -from ..constants import FP8FwdTensors, FP8BwdTensors, dist_group_type -from ..cpp_extensions import cast_transpose, cast_transpose_bgrad, cast_to_fp8, transpose -from ..fp8 import ( - FP8State, - FP8TensorMeta, - amax_and_scale_update, - get_global_fp8_state, - get_fp8_te_dtype, -) -from ..distributed import allgather, register_pp_fwd_begin_hook, is_pp_enabled -from ..profile import nvtx_range -from ..recompute import is_in_recompute_phase -from ..fp8_buffer import FP8RecomputeBuffer - -_2X_ACC_FPROP = False -_2X_ACC_DGRAD = True -_2X_ACC_WGRAD = True -_cublas_workspace = None - - -def get_cublas_workspace_size_bytes() -> None: - """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if paddle.device.cuda.get_device_capability()[0] >= 9: - return 33_554_432 - return 4_194_304 - - -def get_workspace() -> paddle.Tensor: - """Returns workspace for cublas.""" - global _cublas_workspace - if _cublas_workspace is None: - _cublas_workspace = paddle.empty( - [get_cublas_workspace_size_bytes()], - dtype="uint8", - ) - return _cublas_workspace - - -class TransformerEngineBaseLayer(paddle.nn.Layer, ABC): - """Base TE Layer.""" - - def __init__(self) -> None: - super().__init__() - assert "gpu" in paddle.device.get_device(), "TransformerEngine needs CUDA." - self.fp8_initialized = False - self.fp8_enabled = False - self.fp8_calibration = False - self.fp8_meta = {} - self.fp8_meta["fp8_checkpoint"] = False - self.fp8_meta["fp8_group"] = None - self.fp8_meta["recipe"] = FP8State.get_default_fp8_recipe() - self.fp8_meta["scaling_fwd"] = FP8TensorMeta(is_forward=True) - self.fp8_meta["scaling_bwd"] = FP8TensorMeta(is_forward=False) - self.tp_group = None - self.tp_size = 1 - self.sequence_parallel = False - self.fp8_meta["autocast_id_fwd_stack"] = [] - self.fp8_meta["async_amax_reduction"] = bool( - int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0")) - ) - # weights that stored in fp16 would be cast into fp8 every first microstep - self.fp8_weights = [] - self.fp8_weight_cache = {} - self.registered_pp_start_callback = False - self.current_step_id = None - - def set_activation_dtype(self, inp: paddle.Tensor) -> None: - """Get activation data type for AMP.""" - tracer = _dygraph_tracer() - if tracer and tracer._amp_level != core.AmpLevel.O0: - # Set activation_dtype to the Paddle AMP dtype if under 'paddle.amp.auto_cast' context - if tracer._amp_dtype == "float32": - self.activation_dtype = paddle.float32 - elif tracer._amp_dtype == "bfloat16": - self.activation_dtype = paddle.bfloat16 - elif tracer._amp_dtype == "float16": - self.activation_dtype = paddle.float16 - else: - raise RuntimeError(f"AMP format {tracer._amp_dtype} is not supported.") - else: - # If not under paddle.amp.auto_cast, set activation_dtype to the input dtype. - # Also, make sure the parameters match the input dtype. - - # Skip the check if activation_dtype is already set and if activation_dtype - # matches input dtype. If they do not match, e.g, when user switch from AMP - # training to normal training, activation_dtype will still be updated. - if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype: - return - - dtype = inp.dtype - - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) - - self.activation_dtype = dtype - - # This routine is shared across FP8 and FP8_calibration paths so should not actually - # assume FP8 execution. - def fp8_init(self, num_gemms: int = 1) -> None: - """Initialize fp8 related metadata and tensors during fprop.""" - global_fp8_state = get_global_fp8_state() - self.fp8_enabled = global_fp8_state.is_fp8_enabled() - self.fp8_calibration = global_fp8_state.is_fp8_calibration() - self.fp8_meta["fp8_checkpoint"] = self.fp8_enabled or self.fp8_calibration - - if self.fp8_enabled or self.fp8_calibration: - # FP8 init has already been run and recipe is the same, don't do anything. - if ( - self.fp8_initialized - and global_fp8_state.get_fp8_recipe() == self.fp8_meta["recipe"] - ): - return - - # Set FP8, recipe, and other FP8 metadata - self.fp8_meta["recipe"] = global_fp8_state.get_fp8_recipe() - self.fp8_meta["fp8_group"] = global_fp8_state.get_fp8_group() - - # Set FP8_MAX per tensor according to recipe - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd - - # Allocate scales and amaxes - amax_history_len = self.fp8_meta["recipe"].amax_history_len - self.fp8_meta["scaling_fwd"].prepare(num_gemms, amax_history_len) - self.fp8_meta["scaling_bwd"].prepare(num_gemms, amax_history_len) - self.fp8_initialized = True - else: - # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False - return - - def set_fp8_weights(self) -> None: - """Initializes FP8 weights for the module""" - if not self.fp8_enabled: - return - - for i, weight in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - if ( - weight_cast_key in self.fp8_weight_cache - and self.fp8_weight_cache[weight_cast_key].shape == weight.shape - ): - return - - self.fp8_weight_cache[weight_cast_key] = paddle.empty( - shape=weight.shape, - dtype=paddle.uint8, - ) - - self.fp8_weight_cache[weight_transpose_key] = paddle.empty( - shape=[weight.shape[1], weight.shape[0]], - dtype=paddle.uint8, - ) - - def _get_fp8_state(self) -> paddle.Tensor: - """Dump FP8 state to paddle.Tensor.""" - state = None - if self.fp8_meta["fp8_checkpoint"]: - state = {} - state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy() - state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy() - state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy() - state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy() - # Store other pickelable values. - extra = {} - for k, v in self.fp8_meta.items(): - if isinstance(v, (bool, int, float, str)): - extra[k] = v - state["extra_fp8_variables"] = extra - - state_serialized = pickle.dumps(state) - state_tensor = paddle.to_tensor(np.frombuffer(state_serialized, dtype=np.uint8)) - - return state_tensor - - @paddle.no_grad() - def state_dict( - self, - destination=None, - include_sublayers=True, - structured_name_prefix="", - use_hook=True, - ): - """Save FP8 State when checkpointing.""" - st = super().state_dict( - destination=destination, - include_sublayers=include_sublayers, - structured_name_prefix=structured_name_prefix, - use_hook=use_hook, - ) - st["fp8_state"] = self._get_fp8_state() - return st - - def _set_fp8_state(self, state: paddle.Tensor) -> None: - """Load previous state.""" - if state is None: - return - - state = pickle.loads(state.numpy().tobytes()) - if state is None: - return - - # Load fp8 meta tensors. - self.fp8_meta["scaling_fwd"].from_numpy(state["scaling_fwd"]) - self.fp8_meta["scaling_bwd"].from_numpy(state["scaling_bwd"]) - - # Restore global FP8 buffer states. - global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() - global_fp8_bwd_buffer = get_global_fp8_state().get_fp8_bwd_buffer() - global_fp8_fwd_buffer.from_numpy(state["global_fp8_fwd_buffer"]) - global_fp8_bwd_buffer.from_numpy(state["global_fp8_bwd_buffer"]) - - # Load extra items. - self.fp8_meta.update(state["extra_fp8_variables"]) - self.fp8_meta["recipe"].amax_history_len = self.fp8_meta["scaling_fwd"].amax_history.shape[ - 0 - ] - recompute_buffer_pos_key = FP8RecomputeBuffer.get_buffer_position_key() - if recompute_buffer_pos_key in self.fp8_meta: - del self.fp8_meta[recompute_buffer_pos_key] - - @paddle.no_grad() - def set_state_dict(self, state_dict, use_structured_name=True): - """Restore FP8 State from checkpoint.""" - fp8_state_tensor = state_dict.pop("fp8_state") - self._set_fp8_state(fp8_state_tensor) - - return super().set_state_dict(state_dict) - - @contextmanager - def prepare_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Union[bool, None], - num_gemms: int = 1, - ) -> Generator[paddle.Tensor, None, None]: - """Checks and prep for FWD. - The context manager is needed because there isn't a way for a module to know - if it's the last FP8 module in the forward autocast. It is useful - to setup the forward aggregated amax reduction for every module - just in case. The autocast exit will pick up the most recent one. - """ - - if self.fp8_enabled and is_in_recompute_phase(): - global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer() - global_recompute_buffer.retrieve_fp8_meta_tensors(self.fp8_meta) - else: - self.set_activation_dtype(inp) - self.fp8_init(num_gemms=num_gemms) - - # Create persistent tensors for fp8 weights and their transposes - # only when fp8 weight caching is used. - if is_first_microbatch is not None: - self.set_fp8_weights() - - if self.fp8_enabled and self.sequence_parallel: - assert self.fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) - - update_weight_scale_inv = is_first_microbatch is None or is_first_microbatch - - # Previous iteration was grad_enabled - if self.fp8_meta.get("update_amax_and_scale_fwd", False): - global_fp8_fwd_buffer = get_global_fp8_state().get_fp8_fwd_buffer() - global_fp8_fwd_buffer.wait() - # Register PP forward begin hook when CUDAGraph is enabled. - # NOTE(tizheng): register_pp_fwd_begin_hook prevents layer parameters from being freed - # when the layer object is deleted. Need to find a better way. - if get_global_fp8_state().is_cudagraph_enabled() and self.current_step_id is None: - self.current_step_id = paddle.to_tensor( - [1], dtype=paddle.int32, place=paddle.CPUPlace() - ) - - def current_step_id_callback( - step_id=None, **kwargs - ): # pylint: disable=unused-argument - self.current_step_id.copy_( - paddle.to_tensor( - [step_id], dtype=paddle.int32, place=paddle.CPUPlace() - ), - True, - ) - - if is_pp_enabled(): - register_pp_fwd_begin_hook(current_step_id_callback) - - if self.fp8_meta["recipe"].reduce_amax: - global_fp8_fwd_buffer.copy_amax_from_buffer(self.fp8_meta) - amax_and_scale_update( - self.fp8_meta, - fwd_update=True, - update_weight_scale_inv=update_weight_scale_inv, - current_step_id_tensor=self.current_step_id, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - global_fp8_fwd_buffer.set_for_deletion(self.fp8_meta) - else: - amax_and_scale_update( - self.fp8_meta, - fwd_update=True, - update_weight_scale_inv=update_weight_scale_inv, - current_step_id_tensor=self.current_step_id, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - - if self.fp8_enabled and self.training: - # Setup for amax reduction - if self.fp8_meta["recipe"].reduce_amax: - global_fp8_state = get_global_fp8_state() - self.fp8_meta["first_module"] = global_fp8_state.is_first_fp8_module() - self.fp8_meta["autocast_id_fwd"] = global_fp8_state.get_autocast_id() - self.fp8_meta["autocast_id_fwd_stack"].append(self.fp8_meta["autocast_id_fwd"]) - self.fp8_meta["update_amax_and_scale_fwd"] = True - else: - self.fp8_meta["update_amax_and_scale_fwd"] = False - - # Activation recomputation is used and this is the first forward phase. - if ( - self.fp8_enabled - and self.training - and get_global_fp8_state().is_fp8_recompute_enabled() - ): - global_recompute_buffer = get_global_fp8_state().get_fp8_recompute_buffer() - global_recompute_buffer.stash_fp8_meta_tensors(self.fp8_meta) - - with nvtx_range(self.__class__.__name__ + " forward"): - yield inp - - if self.fp8_enabled and is_in_recompute_phase(): - FP8RecomputeBuffer.restore_fp8_meta_tensors(self.fp8_meta) - return - - if self.fp8_enabled and self.training and self.fp8_meta["recipe"].reduce_amax: - global_fp8_state = get_global_fp8_state() - global_fp8_fwd_buffer = global_fp8_state.get_fp8_fwd_buffer() - global_fp8_fwd_buffer.add_amax(self.fp8_meta) - global_fp8_fwd_buffer.set_for_amax_reduction( - self.fp8_meta, - self.tp_group, - self.tp_size, - ) - - @staticmethod - @contextmanager - def prepare_backward( - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - tp_group: dist_group_type, - tp_size: int, - name: str = "", - ) -> Generator[None, None, None]: - """Checks and prep for BWD.""" - if fp8_enabled: - global_fp8_state = get_global_fp8_state() - global_fp8_bwd_buffer = global_fp8_state.get_fp8_bwd_buffer() - global_fp8_bwd_buffer.wait() - - if fp8_meta["recipe"].reduce_amax: - global_fp8_bwd_buffer.copy_amax_from_buffer(fp8_meta) - amax_and_scale_update( - fp8_meta, - fwd_update=False, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - global_fp8_bwd_buffer.set_for_deletion(fp8_meta) - - # Get new backward key. - fp8_meta["autocast_id_bwd"] = fp8_meta["autocast_id_fwd_stack"].pop(0) - else: - amax_and_scale_update( - fp8_meta, - fwd_update=False, - use_cudagraph=get_global_fp8_state().is_cudagraph_enabled(), - ) - - with nvtx_range(name + " backward"): - yield - - if fp8_enabled and fp8_meta["recipe"].reduce_amax: - global_fp8_bwd_buffer.add_amax(fp8_meta) - if fp8_meta["first_module"]: - global_fp8_bwd_buffer.finalize(fp8_meta, tp_group, tp_size) - - @staticmethod - def grad_output_preprocess( - ctx, grad_output: paddle.Tensor, row_parallel_mode: bool - ) -> Tuple[Union[paddle.Tensor, None], ...]: - """Utility function for backward. - Returns tuple in order (all optional/None based on training precion/recipe): - R1: gathered `grad_output` in higher precision. - R2: gathered `grad_output` in FP8. - R3: R2 transposed. - R4: bias gradient on R1. - """ - grad_output_mat = grad_output.reshape((-1, grad_output.shape[-1])) - gather_grad_output = row_parallel_mode and ctx.sequence_parallel - - # No-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8_enabled: - if gather_grad_output: - grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group) - return grad_output_mat, None, None, None - - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - if gather_grad_output: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather - if ctx.use_bias: - bgrad = grad_output_mat.sum(axis=0) - else: - bgrad = None - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - grad_output_c, _ = allgather(grad_output_c, ctx.tp_group) - grad_output_t = transpose(grad_output_c, fp8_dtype_backward) - - return grad_output_mat, grad_output_c, grad_output_t, bgrad - - # FP8 case with gather and non-FP8 wgrad - grad_output_mat, _ = allgather(grad_output_mat, ctx.tp_group) - - # FP8 case without gather: cast, transpose, bgrad fused - if ctx.use_bias: - bgrad, grad_output_c, grad_output_t = cast_transpose_bgrad( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - grad_output_c, grad_output_t = cast_transpose( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - grad_output_t = None - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - bgrad = None - return grad_output_mat, grad_output_c, grad_output_t, bgrad - - @abstractmethod - def forward(self): - """Needs override.""" - - def get_fp8_weights_scratchpad_and_cast( - self, - is_first_microbatch: Union[bool, None], - ) -> List[Optional[paddle.Tensor]]: - """ - Fetch the fp8 weight tensor placeholders if they exist (when - `is_first_microbatch` is not `None`) - """ - if not self.fp8_enabled or is_first_microbatch is None: - return [None, None] * len(self.fp8_weights) - - out_list = [] - for i, _ in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - assert ( - weight_cast_key in self.fp8_weight_cache - ), "TE internal error: fp8 weight buffer is not found" - - weight_fp8 = self.fp8_weight_cache[weight_cast_key] - weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key] - - # Disable fp8 weight cache - # is_first_microbatch is None -> we cast the weights into fp8 every micro step - # Enalbe fp8 weight cache - # is_first_microbatch == true -> we cast the weights into fp8 every micro step - - out_list.extend([weight_fp8, weight_t_fp8]) - - # is cudagraph is enabled we cast the weight before the pp pipe - # we only register the callback once - if get_global_fp8_state().is_cudagraph_enabled() and ( - not self.registered_pp_start_callback and is_pp_enabled() - ): - - fp8_dtype_forward = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) - - def cast_callback(step_id=None, **kwargs): # pylint: disable=unused-argument - update_fp8_weights = step_id == 0 - - for i, weight in enumerate(self.fp8_weights, start=1): - weight_cast_key = f"weight{i}_fp8" - weight_transpose_key = f"weight{i}_t_fp8" - - assert ( - weight_cast_key in self.fp8_weight_cache - ), "TE internal error: fp8 weight buffer is not found" - - weight_fp8 = self.fp8_weight_cache[weight_cast_key] - weight_t_fp8 = self.fp8_weight_cache[weight_transpose_key] - - if paddle.is_grad_enabled(): - if update_fp8_weights: - cast_transpose( - weight, - self.fp8_meta["scaling_fwd"], - ( - FP8FwdTensors.GEMM1_WEIGHT - if i == 1 - else FP8FwdTensors.GEMM2_WEIGHT - ), - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) - else: - if update_fp8_weights: - cast_to_fp8( - weight, - self.fp8_meta["scaling_fwd"], - ( - FP8FwdTensors.GEMM1_WEIGHT - if i == 1 - else FP8FwdTensors.GEMM2_WEIGHT - ), - fp8_dtype_forward, - out=weight_fp8, - ) - - cast_callback(0 if is_first_microbatch else 1) - register_pp_fwd_begin_hook(cast_callback) - self.registered_pp_start_callback = True - return out_list diff --git a/transformer_engine/paddle/layer/layernorm.py b/transformer_engine/paddle/layer/layernorm.py deleted file mode 100644 index 208e39ea03..0000000000 --- a/transformer_engine/paddle/layer/layernorm.py +++ /dev/null @@ -1,197 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Linear API""" - -import os -from typing import Union, Tuple - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from ..constants import TE_DType -from ..cpp_extensions import layernorm_fwd, layernorm_bwd -from ..distributed import mark_as_sequence_parallel_parameter - -__all__ = ["LayerNorm"] - - -class _LayerNorm(paddle.autograd.PyLayer): - """TE Non-FP8 LayerNorm""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: paddle.Tensor, - eps: float, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "LayerNorm not possible" - inputmat = inp.reshape((-1, in_features)) - - ln_out, mu, rsigma = layernorm_fwd( - inputmat, - ln_weight, - ln_bias, - eps, - TE_DType[inp.dtype], - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - ctx.save_for_backward(inputmat, ln_weight, mu, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.requires_dx = not inp.stop_gradient - ctx.requires_dw = not ln_weight.stop_gradient - ctx.requires_dbias = not ln_bias.stop_gradient - return ln_out.reshape(inp.shape) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - inputmat, ln_weight, mu, rsigma = ctx.saved_tensor() - d_ln_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma, dbeta = layernorm_bwd( - d_ln_out, inputmat, mu, rsigma, ln_weight, ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma - ) - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, - dgamma if ctx.requires_dw else None, - dbeta if ctx.requires_dbias else None, - ) - - -class LayerNorm(paddle.nn.Layer): - r""" - Applies Layer Normalization over a mini-batch of inputs as described in - the paper `Layer Normalization `__ - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * \gamma + \beta - - :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of - size :attr:`hidden_size` - - Parameters - ---------- - hidden_size : int - size of each input sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for softmax operation. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - zero_centered_gamma: bool = False, - sequence_parallel: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.sequence_parallel = sequence_parallel - self.backend = backend - self._dtype = self._helper.get_default_dtype() - - self._weight_attr = weight_attr - if not self._weight_attr: - self._weight_attr = paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ) - - self._bias_attr = bias_attr - if self._bias_attr is False: - self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0), trainable=False) - - self.weight = self.create_parameter( - shape=[hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - - self.bias = self.create_parameter( - shape=[hidden_size], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.weight) - mark_as_sequence_parallel_parameter(self.bias) - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor: - """LayerNorm FWD""" - return _LayerNorm.apply( - inp, - self.weight, - self.bias, - self.eps, - self.fwd_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - ) - - def _pd_forward( - self, - inp: paddle.Tensor, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - return F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.weight, - bias=self.bias, - epsilon=self.eps, - ) - - def forward(self, *args, **kwargs): - """forward""" - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_linear.py b/transformer_engine/paddle/layer/layernorm_linear.py deleted file mode 100644 index c39ad29957..0000000000 --- a/transformer_engine/paddle/layer/layernorm_linear.py +++ /dev/null @@ -1,721 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""LayerNormLinear API""" - -import warnings -import os -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from ..cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, - layernorm_fwd, - layernorm_fwd_fp8, - layernorm_bwd, - rmsnorm_fwd_fp8, - rmsnorm_fwd, - rmsnorm_bwd, -) - -from .base import TransformerEngineBaseLayer -from .linear import _linear_fwd, _linear_bwd -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type -from ..distributed import ( - allreduce, - get_tp_group_and_world_size, - identity, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - save_for_backward_allow_none, - saved_tensor_allow_none, -) - -__all__ = ["LayerNormLinear"] - - -def _apply_normalization_fwd( - normalization: str, - inputmat: paddle.Tensor, - norm_weight: paddle.Tensor, - norm_bias: Union[paddle.Tensor, None], - out_fp8_index: FP8FwdTensors, - eps: float, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_norm_output: bool, - fwd_norm_sm_margin: int, - zero_centered_gamma: bool, -): - """Performs LayerNorm + FP8_Cast for FP8 path. LayerNorm only for BF16 path""" - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - if normalization == "RMSNorm": - assert norm_bias is None, "RMSNorm does not support bias!" - norm_weight = cast_if_needed_inplace(norm_weight, activation_dtype) - if norm_bias is not None: - norm_bias = cast_if_needed_inplace(norm_bias, activation_dtype) - - norm_kwargs = { - "inp": inputmat, - "weight": norm_weight, - "eps": eps, - "otype": TE_DType[activation_dtype], - "sm_margin": fwd_norm_sm_margin, - "zero_centered_gamma": zero_centered_gamma, - } - - fwd_normalization_funcs = { - ("LayerNorm", True, True): layernorm_fwd, - ("LayerNorm", True, False): layernorm_fwd_fp8, - ("LayerNorm", False, True): layernorm_fwd, - ("LayerNorm", False, False): layernorm_fwd, - ("RMSNorm", True, True): rmsnorm_fwd, - ("RMSNorm", True, False): rmsnorm_fwd_fp8, - ("RMSNorm", False, True): rmsnorm_fwd, - ("RMSNorm", False, False): rmsnorm_fwd, - } - - if normalization == "LayerNorm": - norm_kwargs["bias"] = norm_bias - norm_fwd_func = fwd_normalization_funcs[(normalization, fp8_enabled, return_norm_output)] - - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if not return_norm_output: - fp8_kwargs = { - "fp8_meta_tensor": fp8_meta["scaling_fwd"], - "fp8_tensor": out_fp8_index, - "otype": fp8_dtype_forward, - } - norm_kwargs.update(fp8_kwargs) - - out_tuple = norm_fwd_func(**norm_kwargs) - - if normalization == "LayerNorm": - norm_out_return, mu, rsigma = out_tuple - else: # RMSNorm - norm_out_return, rsigma = out_tuple - mu = None - - if fp8_enabled and return_norm_output: - norm_out = cast_to_fp8( - norm_out_return, - fp8_meta["scaling_fwd"], - out_fp8_index, - fp8_dtype_forward, - ) - else: - norm_out = norm_out_return - - return ( - norm_out_return, - norm_out, - mu, - rsigma, - ) - - -def _apply_normalization_bwd( - normalization: str, - inputmat: paddle.Tensor, - dgrad: paddle.Tensor, - norm_weight: paddle.Tensor, - mu: Union[paddle.Tensor, None], - rsigma: paddle.Tensor, - grad_norm_out_return: paddle.Tensor, - return_norm_output: bool, - bwd_norm_sm_margin: int, - zero_centered_gamma: bool, -): - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - if normalization == "RMSNorm": - assert mu is None, "RMSNorm does not support bias!" - # LayerNorm gradient - d_norm_out = dgrad.reshape(inputmat.shape) - # Residual gradient - if return_norm_output: - d_norm_out = d_norm_out + grad_norm_out_return.reshape(d_norm_out.shape) - - norm_bwd_func = layernorm_bwd if normalization == "LayerNorm" else rmsnorm_bwd - norm_bwd_kwargs = { - "dz": d_norm_out, - "x": inputmat, - "rsigma": rsigma, - "gamma": norm_weight, - "sm_margin": bwd_norm_sm_margin, - "zero_centered_gamma": zero_centered_gamma, - } - if normalization == "LayerNorm": - norm_bwd_kwargs["mu"] = mu - - out_tuple = norm_bwd_func(**norm_bwd_kwargs) - if normalization == "LayerNorm": - dxmat, dgamma, dbeta = out_tuple - else: # RMSNorm - dxmat, dgamma = out_tuple - dbeta = None - - return dxmat, dgamma, dbeta - - -class _LayerNormLinear(paddle.autograd.PyLayer): - """TE implementation of LayerNormLinear""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: Union[paddle.Tensor, None], - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - bias: Union[paddle.Tensor, None], - use_bias: bool, - eps: float, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_layernorm_output: bool, - is_grad_enabled: bool, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - normalization: str, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: - if normalization == "RMSNorm": - assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm - assert ln_bias is not None, "LayerNorm requires bias!" - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(weight) - - # LayerNorm Fwd + FP8 Cast - ( - ln_out_return, - ln_out, - mu, - rsigma, - ) = _apply_normalization_fwd( - normalization, - inputmat, - ln_weight, - ln_bias, - FP8FwdTensors.GEMM1_INPUT, - eps, - fp8_enabled, - fp8_meta, - activation_dtype, - return_layernorm_output, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - # Linear Fwd - out, weight_t_fp8 = _linear_fwd( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_fp8, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - bias, - use_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - - if is_grad_enabled: - save_for_backward_allow_none( - ctx, - inputmat, - ln_weight, - mu, - rsigma, - weight, - weight_t_fp8 if fp8_enabled else None, - ln_out, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - - ctx.activation_dtype = activation_dtype - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - ctx.return_layernorm_output = return_layernorm_output - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.parallel_mode = parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_wgrad = not weight.stop_gradient - ctx.requires_bgrad = use_bias and not bias.stop_gradient - ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient - ctx.requires_ln_wgrad = not ln_weight.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.has_ln_bias = ln_bias is not None - ctx.normalization = normalization - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - out = out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) - - if return_layernorm_output: - return out, ln_out_return.reshape(inp.shape) - return out - - @staticmethod - def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, ...] - ) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormLinear" - ): - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - ln_weight, - mu, - rsigma, - weight, - weight_t_fp8, - ln_out, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ( - grad_output, - grad_output_c, - grad_output_t, - bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess( - ctx, grad_outputs[0], ctx.parallel_mode == "row" - ) - - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - # Prepare ln_out for Linear bwd - linear_inputmat = ln_out - if ctx.fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - if ctx.requires_wgrad and ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - linear_inputmat = cast_from_fp8( - ln_out, - ctx.fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - - # Linear Bwd - dgrad, wgrad, bgrad_ = _linear_bwd( - linear_inputmat, - None, # inputmat_t will be automatically computed if not provided - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.requires_bgrad, - ctx.fp8_enabled, - ctx.fp8_meta, - True, # Always compute dgrad to feed into LayerNorm bwd - ctx.requires_wgrad, - ctx.activation_dtype, - ctx.parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - if not ctx.fp8_enabled: - # bgrad is fused with gemm for non-FP8 path - bgrad = bgrad_ - - # LayerNorm Bwd - dxmat, dgamma, dbeta = _apply_normalization_bwd( - ctx.normalization, - inputmat, - dgrad, - ln_weight, - mu, - rsigma, - grad_outputs[1] if ctx.return_layernorm_output else None, - ctx.return_layernorm_output, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - - bgrad = bgrad if ctx.requires_bgrad else None - bgrad_out = (bgrad,) if ctx.use_bias else () - dbeta = dbeta if ctx.requires_ln_bgrad else None - dbeta_out = (dbeta,) if ctx.has_ln_bias else () - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - weight_cache_grad = (None, None) - - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - wgrad = None - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, - dgamma if ctx.requires_ln_wgrad else None, - *dbeta_out, - wgrad if ctx.requires_wgrad else None, - *weight_cache_grad, - *bgrad_out, - ) - - -class LayerNormLinear(TransformerEngineBaseLayer): - r""" - Applies layer normalization followed by linear transformation to the incoming data. - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward - together with the output of the linear transformation. - Example use case: residual connection for transformer module is - taken post layernorm. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - parallel_mode : {None, 'Column', 'Row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - """ - - def __init__( - self, - in_features: int, - out_features: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - normalization: str = "LayerNorm", - return_layernorm_output: bool = False, - zero_centered_gamma: bool = False, - parallel_mode: Optional[str] = None, - sequence_parallel: bool = False, - tp_group: Union[dist_group_type, None] = None, - fuse_wgrad_accumulation: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.eps = eps - self.normalization = normalization - assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" - self.return_layernorm_output = return_layernorm_output - self.zero_centered_gamma = zero_centered_gamma - self.backend = backend - - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=parallel_mode is not None - ) - self.tensor_parallel = self.tp_size > 1 - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - # LayerNorm weights - self.ln_weight = self.create_parameter( - shape=[self.in_features], - attr=paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ), - dtype=self._dtype, - is_bias=False, - ) - if self.normalization != "RMSNorm": - self.ln_bias = self.create_parameter( - shape=[self.in_features], - attr=paddle.ParamAttr(initializer=Constant(value=0.0)), - dtype=self._dtype, - is_bias=True, - ) - else: - self.ln_bias = None - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.ln_weight) - if self.ln_bias is not None: - mark_as_sequence_parallel_parameter(self.ln_bias) - - # Initialize Linear weight parameter - with track_rng_state(enable=self.tensor_parallel): - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=( - [self.out_features, self.in_features] - if self.backend == "transformer_engine" - else [self.in_features, self.out_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.weight, self.tensor_parallel, self.parallel_mode, self.backend - ) - self.fp8_weights.append(self.weight) - - # Initialize Linear bias parameter - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if self.has_bias: - self.bias = self.create_parameter( - shape=[self.out_features], - attr=( - self._bias_attr - if not use_default_bias - else paddle.ParamAttr(initializer=Constant(value=0.0)) - ), - dtype=self._dtype, - is_bias=True, - ) - if parallel_mode == "column": - set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) - if parallel_mode == "row" and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.bias) - else: - self.bias = None - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: - """ - Apply layer normalization to the input followed by a linear transformation. - """ - - with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - - out = _LayerNormLinear.apply( - inp, - self.ln_weight, - self.ln_bias, - self.weight, - weight_fp8, - weight_t_fp8, - self.bias if self.gemm_bias_fused_add else None, - self.has_bias and self.gemm_bias_fused_add, - self.eps, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - self.return_layernorm_output, - paddle.is_grad_enabled(), - self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - ) - - if self.return_layernorm_output: - out, ln_out = out - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) - - if self.return_layernorm_output: - return out, ln_out - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - - if self.normalization == "RMSNorm": - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps, - ) - - if self.parallel_mode == "column" and self.tensor_parallel: - norm_out = identity(norm_out, self.tp_group) - out = F.linear(norm_out, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == "row" and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.bias if self.bias is not None else out - if self.return_layernorm_output: - return out, norm_out - return out - - def forward(self, *args, **kwargs): - """ - Apply layer normalization to the input followed by a linear transformation. - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/layernorm_mlp.py b/transformer_engine/paddle/layer/layernorm_mlp.py deleted file mode 100644 index 32f837183c..0000000000 --- a/transformer_engine/paddle/layer/layernorm_mlp.py +++ /dev/null @@ -1,1010 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""LayerNormMLP API""" - -import os -import warnings -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from .base import TransformerEngineBaseLayer -from .layernorm_linear import _apply_normalization_fwd, _apply_normalization_bwd -from .linear import _linear_fwd_fp8, _linear_fwd_non_fp8, _linear_bwd_fp8, _linear_bwd_non_fp8 -from ..constants import TE_DType, FP8FwdTensors, FP8BwdTensors, dist_group_type -from ..cpp_extensions import ( - cast_from_fp8, - gelu_fp8, - swiglu_fp8, - swiglu, - dswiglu, - cast_transpose_bgrad, - dgelu_cast_transpose_bgrad_fp8, -) -from ..distributed import ( - allreduce, - get_tp_group_and_world_size, - identity, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - get_paddle_act_func, - save_for_backward_allow_none, - saved_tensor_allow_none, -) - -__all__ = ["LayerNormMLP"] - - -def _mlp_forward( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - fc1_weight: paddle.Tensor, - fc1_weight_fp8: Optional[paddle.Tensor], - fc1_weight_t_fp8: Optional[paddle.Tensor], - fc1_weight_fp8_index: FP8FwdTensors, - fc1_bias: Union[paddle.Tensor, None], - use_fc1_bias: bool, - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT - fc2_weight: paddle.Tensor, - fc2_weight_fp8: Optional[paddle.Tensor], - fc2_weight_t_fp8: Optional[paddle.Tensor], - fc2_weight_fp8_index: FP8FwdTensors, - fc2_bias: Union[paddle.Tensor, None], - use_fc2_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - activation: str, - is_grad_enabled: bool, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_first_microbatch: bool, -): - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fc1_out, fc1_weight_t_fp8 = _linear_fwd_fp8( - inputmat, - inputmat_fp8_index, - fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - fc1_weight_fp8_index, - fc1_bias, - use_fc1_bias, - fp8_meta, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - if activation == "gelu": - gelu_out = gelu_fp8( - fc1_out, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - ) - elif activation == "swiglu": - gelu_out = swiglu_fp8( - fc1_out, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - ) - else: - raise NotImplementedError("Activation type " + activation + " is not supported!") - - fc2_out, fc2_weight_t_fp8 = _linear_fwd_fp8( - gelu_out, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - fc2_weight_fp8_index, - fc2_bias, - use_fc2_bias, - fp8_meta, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - else: - fc1_outputs = _linear_fwd_non_fp8( - inputmat, - inputmat_fp8_index, - fc1_weight, - fc1_weight_fp8_index, - fc1_bias, - use_fc1_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - activation=activation, - ) - - if activation == "gelu": - fc1_out, gelu_out = fc1_outputs - elif activation == "swiglu": - fc1_out = fc1_outputs - gelu_out = swiglu(fc1_out, TE_DType[activation_dtype]) - else: - raise NotImplementedError("Activation type " + activation + " is not supported!") - - fc2_out = _linear_fwd_non_fp8( - gelu_out, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_fp8_index, - fc2_bias, - use_fc2_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - ) - return ( - fc1_out, - gelu_out, - fc2_out, - fc1_weight_t_fp8 if fp8_enabled else None, - fc2_weight_t_fp8 if fp8_enabled else None, - ) - - -def _mlp_backward( - fc1_input: paddle.Tensor, # ln_out, BF16 / FP8 - fc1_input_fp8_index: FP8FwdTensors, - fc1_weight: paddle.Tensor, - fc1_weight_t_fp8: paddle.Tensor, - fc1_weight_fp8_index: FP8FwdTensors, - fc1_grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT2 - requires_fc1_wgrad: bool, - requires_fc1_bgrad: bool, - fc1_out: paddle.Tensor, - fc2_input: paddle.Tensor, # gelu_out - fc2_input_fp8_index: FP8FwdTensors, # FP8FwdTensors.GEMM2_INPUT - fc2_weight: paddle.Tensor, - fc2_weight_t_fp8: paddle.Tensor, - fc2_weight_fp8_index: FP8FwdTensors, - requires_fc2_wgrad: bool, - requires_fc2_bgrad: bool, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, # FP8BwdTensors.GRAD_OUTPUT1 - fwd_scale_inverses: paddle.Tensor, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - activation_dtype: paddle.dtype, - activation: str, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad, - ) = ( - None, - None, - None, - None, - None, - ) - - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - # FC2 Bwd - fp8_wgrad = not fp8_meta["recipe"].override_linear_precision.wgrad - if requires_fc2_wgrad and not fp8_wgrad: - fc2_input = cast_from_fp8( - fc2_input, - fp8_meta["scaling_fwd"], - fc2_input_fp8_index, - fp8_dtype_forward, - TE_DType[activation_dtype], - ) - - fc2_dgrad, fc2_wgrad = _linear_bwd_fp8( - fc2_input, - None, - fc2_input_fp8_index, - fc2_weight, - fc2_weight_t_fp8, - fc2_weight_fp8_index, - grad_output, - grad_output_c, - grad_output_t, - grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - True, - requires_fc2_wgrad, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - dgelu_t = None - fc1_bgrad_ = None - if activation == "gelu": - # GELU Bwd - dgelu, dgelu_t, fc1_bgrad_ = dgelu_cast_transpose_bgrad_fp8( - fc2_dgrad, - fc1_out, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - ) - elif activation == "swiglu": - dgelu = dswiglu(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - fc1_bgrad_, dgelu, dgelu_t = cast_transpose_bgrad( - dgelu, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - ) - - if requires_fc1_bgrad: - fc1_bgrad = fc1_bgrad_ - - # FC1 Bwd - dgelu_no_fp8 = None - if requires_fc1_wgrad and not fp8_wgrad: - # TODO(tizheng) Paddle lacks fused dgelu_bgrad OP. Cast from dgrad(fp8) instead. - dgelu_no_fp8 = cast_from_fp8( - dgelu, - fp8_meta["scaling_bwd"], - fc1_grad_output_fp8_index, - fp8_dtype_backward, - TE_DType[activation_dtype], - ) - fc1_input = cast_from_fp8( - fc1_input, - fp8_meta["scaling_fwd"], - fc1_input_fp8_index, - fp8_dtype_forward, - TE_DType[activation_dtype], - ) - - fc1_dgrad, fc1_wgrad = _linear_bwd_fp8( - fc1_input, - None, - fc1_input_fp8_index, - fc1_weight, - fc1_weight_t_fp8, - fc1_weight_fp8_index, - dgelu_no_fp8, - dgelu, - dgelu_t, - fc1_grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - requires_dgrad, - requires_fc1_wgrad, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - else: - dgelu, fc2_wgrad, fc2_bgrad = _linear_bwd_non_fp8( - fc2_input, - fc2_weight, - grad_output, - requires_fc2_bgrad, - True, - requires_fc2_wgrad, - activation_dtype, - "row" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - gelu_input=fc1_out, - activation=activation, - ) - - if activation == "swiglu": - dgelu = dswiglu(dgelu, fc1_out, TE_DType[dgelu.dtype]) - - fc1_dgrad, fc1_wgrad, fc1_bgrad = _linear_bwd_non_fp8( - fc1_input, - fc1_weight, - dgelu, - requires_fc1_bgrad, - requires_dgrad, - requires_fc1_wgrad, - activation_dtype, - "column" if set_parallel_mode else None, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - return ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad, - ) - - -class _LayerNormMLP(paddle.autograd.PyLayer): - """TE implementation of LayerNormMLP""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - ln_weight: paddle.Tensor, - ln_bias: Union[paddle.Tensor, None], - fc1_weight: paddle.Tensor, - fc1_weight_fp8: Optional[paddle.Tensor], - fc1_weight_t_fp8: Optional[paddle.Tensor], - fc1_bias: Union[paddle.Tensor, None], - use_fc1_bias: bool, - fc2_weight: paddle.Tensor, - fc2_weight_fp8: Optional[paddle.Tensor], - fc2_weight_t_fp8: Optional[paddle.Tensor], - fc2_bias: Union[paddle.Tensor, None], - use_fc2_bias: bool, - eps: float, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - return_layernorm_output: bool, - is_grad_enabled: bool, - fwd_ln_sm_margin: int, - bwd_ln_sm_margin: int, - zero_centered_gamma: bool, - normalization: str, - activation: str, - set_parallel_mode: bool, - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - ) -> Union[Tuple[paddle.Tensor, ...], paddle.Tensor]: - if normalization == "RMSNorm": - assert ln_bias is None, "RMSNorm does not support bias!" - else: # LayerNorm - assert ln_bias is not None, "LayerNorm requires bias!" - # Make sure input dimensions are compatible - in_features = ln_weight.shape[0] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(fc1_weight) - assert_dim_for_fp8_forward_exec(fc2_weight) - - # only support gelu for now - assert activation in ["gelu", "swiglu"], "Only gelu and swiglu are supported for now" - - # LayerNorm Fwd + FP8 Cast - ( - ln_out_return, - ln_out, - mu, - rsigma, - ) = _apply_normalization_fwd( - normalization, - inputmat, - ln_weight, - ln_bias, - FP8FwdTensors.GEMM1_INPUT, - eps, - fp8_enabled, - fp8_meta, - activation_dtype, - return_layernorm_output, - fwd_ln_sm_margin, - zero_centered_gamma, - ) - - ( - fc1_out, - gelu_out, - fc2_out, - fc1_weight_t_fp8, - fc2_weight_t_fp8, - ) = _mlp_forward( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - fc1_bias, - use_fc1_bias, - FP8FwdTensors.GEMM2_INPUT, - fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - FP8FwdTensors.GEMM2_WEIGHT, - fc2_bias, - use_fc2_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - activation, - is_grad_enabled, - set_parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_first_microbatch, - ) - - if is_grad_enabled: - save_for_backward_allow_none( - ctx, - inputmat, - ln_weight, - mu, - rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight, - fc1_weight_t_fp8, - fc2_weight, - fc2_weight_t_fp8, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - ctx.activation_dtype = activation_dtype - ctx.activation = activation - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_fc1_bias = use_fc1_bias - ctx.use_fc2_bias = use_fc2_bias - ctx.inp_shape = inp.shape - ctx.return_layernorm_output = return_layernorm_output - ctx.bwd_ln_sm_margin = bwd_ln_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.set_parallel_mode = set_parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_fc1_wgrad = not fc1_weight.stop_gradient - ctx.requires_fc2_wgrad = not fc2_weight.stop_gradient - ctx.requires_fc1_bgrad = use_fc1_bias and not fc1_bias.stop_gradient - ctx.requires_fc2_bgrad = use_fc2_bias and not fc2_bias.stop_gradient - ctx.requires_ln_bgrad = ln_bias is not None and not ln_bias.stop_gradient - ctx.requires_ln_wgrad = not ln_weight.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.has_ln_bias = ln_bias is not None - ctx.normalization = normalization - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - fc2_out = fc2_out.reshape((-1, *inp.shape[1:-1], fc2_out.shape[-1])) - - if return_layernorm_output: - return fc2_out, ln_out_return.reshape(inp.shape) - return fc2_out - - @staticmethod - def backward( - ctx, *grad_outputs: Tuple[paddle.Tensor, ...] - ) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_LayerNormMLP" - ): - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - ln_weight, - mu, - rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight, - fc1_weight_t_fp8, - fc2_weight, - fc2_weight_t_fp8, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess - ( - grad_output, - grad_output_c, - grad_output_t, - fc2_bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess(ctx, grad_outputs[0], True) - - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - ( - fc1_dgrad, - fc1_wgrad, - fc1_bgrad, - fc2_wgrad, - fc2_bgrad_, - ) = _mlp_backward( - ln_out, - FP8FwdTensors.GEMM1_INPUT, - fc1_weight, - fc1_weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - FP8BwdTensors.GRAD_OUTPUT2, - ctx.requires_fc1_wgrad, - ctx.requires_fc1_bgrad, - fc1_out, - gelu_out, - FP8FwdTensors.GEMM2_INPUT, - fc2_weight, - fc2_weight_t_fp8, - FP8FwdTensors.GEMM2_WEIGHT, - ctx.requires_fc2_wgrad, - ctx.requires_fc2_bgrad, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.fp8_enabled, - ctx.fp8_meta, - True, - ctx.activation_dtype, - ctx.activation, - ctx.set_parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - if not ctx.fp8_enabled: - # fc2_bias is fused with gemm for non-FP8 path - fc2_bgrad = fc2_bgrad_ - - # LayerNorm Bwd - dxmat, dgamma, dbeta = _apply_normalization_bwd( - ctx.normalization, - inputmat, - fc1_dgrad, - ln_weight, - mu, - rsigma, - grad_outputs[1] if ctx.return_layernorm_output else None, - ctx.return_layernorm_output, - ctx.bwd_ln_sm_margin, - ctx.zero_centered_gamma, - ) - - fc1_bgrad = fc1_bgrad if ctx.requires_fc1_bgrad else None - fc2_bgrad = fc2_bgrad if ctx.requires_fc2_bgrad else None - fc1_bgrad_out = (fc1_bgrad,) if ctx.use_fc1_bias else () - fc2_bgrad_out = (fc2_bgrad,) if ctx.use_fc2_bias else () - dbeta = dbeta if ctx.requires_ln_bgrad else None - dbeta_out = (dbeta,) if ctx.has_ln_bias else () - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - fc1_weight_cache_grad = () - fc2_weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - fc1_weight_cache_grad = (None, None) - fc2_weight_cache_grad = (None, None) - - if ctx.requires_fc1_wgrad and ctx.fuse_wgrad_accumulation: - fc1_wgrad = None - if ctx.requires_fc2_wgrad and ctx.fuse_wgrad_accumulation: - fc2_wgrad = None - - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dgrad else None, - dgamma if ctx.requires_ln_wgrad else None, - *dbeta_out, - fc1_wgrad if ctx.requires_fc1_wgrad else None, - *fc1_weight_cache_grad, - *fc1_bgrad_out, - fc2_wgrad if ctx.requires_fc2_wgrad else None, - *fc2_weight_cache_grad, - *fc2_bgrad_out, - ) - - -class LayerNormMLP(TransformerEngineBaseLayer): - r""" - Applies layer normalization on the input followed by the MLP module, consisting of - 2 successive linear transformations, separated by the GeLU activation. - - Parameters - ---------- - hidden_size : int - size of each input sample. - ffn_hidden_size : int - intermediate size to which input samples are projected. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - normalization : { 'LayerNorm', 'RMSNorm' }, default = 'LayerNorm' - type of normalization applied. - activation : str, default = 'gelu' - activation function used. - Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'. - return_layernorm_output : bool, default = `False` - if set to `True`, output of layernorm is returned from the forward - together with the output of the linear transformation. - Example use case: residual connection for transformer module - is taken post layernorm. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, FC1 is used as Column Parallel and FC2 is used as Row - Parallel as described `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : paddle.distributed.collective.Group, default = `None` - tensor parallel process group. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - """ - - def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - normalization: str = "LayerNorm", - activation: str = "gelu", - return_layernorm_output: bool = False, - zero_centered_gamma: bool = False, - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - fuse_wgrad_accumulation: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.hidden_size = hidden_size - self.ffn_hidden_size = ffn_hidden_size - self.eps = eps - self.normalization = normalization - assert normalization in ["LayerNorm", "RMSNorm"], "Normalization type not supported" - self.activation = activation - self.return_layernorm_output = return_layernorm_output - self.zero_centered_gamma = zero_centered_gamma - self.backend = backend - - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.set_parallel_mode = set_parallel_mode - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - if self.set_parallel_mode: - self.size_per_partition = divide(self.ffn_hidden_size, self.tp_size) - else: - self.size_per_partition = self.ffn_hidden_size - - # LayerNorm weights - self.ln_weight = self.create_parameter( - shape=[self.hidden_size], - attr=paddle.ParamAttr( - initializer=Constant(value=0.0 if self.zero_centered_gamma else 1.0) - ), - dtype=self._dtype, - is_bias=False, - ) - - if self.normalization != "RMSNorm": - self.ln_bias = self.create_parameter( - shape=[self.hidden_size], - attr=paddle.ParamAttr(initializer=Constant(value=0.0)), - dtype=self._dtype, - is_bias=True, - ) - else: - self.ln_bias = None - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.ln_weight) - if self.ln_bias is not None: - mark_as_sequence_parallel_parameter(self.ln_bias) - - # FC1 weights - if self.activation in ["swiglu"]: - fc1_output_features = self.size_per_partition * 2 - else: - fc1_output_features = self.size_per_partition - - with track_rng_state(enable=self.tensor_parallel): - self.fc1_weight = self.create_parameter( - shape=( - [fc1_output_features, self.hidden_size] - if self.backend == "transformer_engine" - else [self.hidden_size, fc1_output_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.fc1_weight, self.tensor_parallel, parallel_mode="column", backend=self.backend - ) - self.fp8_weights.append(self.fc1_weight) - - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if use_default_bias: - self._bias_attr = paddle.ParamAttr(initializer=Constant(value=0.0)) - - if self.has_bias: - self.fc1_bias = self.create_parameter( - shape=[fc1_output_features], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - set_tensor_dist_attr(self.fc1_bias, self.tensor_parallel, axis=0) - else: - self.fc1_bias = None - - # FC2 weights - self.fc2_weight = self.create_parameter( - shape=( - [self.hidden_size, self.size_per_partition] - if self.backend == "transformer_engine" - else [self.size_per_partition, self.hidden_size] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.fc2_weight, self.tensor_parallel, parallel_mode="row", backend=self.backend - ) - self.fp8_weights.append(self.fc2_weight) - - if self.has_bias: - self.fc2_bias = self.create_parameter( - shape=[self.hidden_size], - attr=self._bias_attr, - dtype=self._dtype, - is_bias=True, - ) - if self.set_parallel_mode and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.fc2_bias) - else: - self.fc2_bias = None - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.set_parallel_mode and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - # These many SMs are subtracted from the total SM count when calling forward - # and backward LayerNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with LN. - self.fwd_ln_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: - """ - Apply layer normalization to the input followed by a linear transformation. - """ - - with self.prepare_forward(inp, num_gemms=2, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - fc1_weight_fp8, fc1_weight_t_fp8, fc2_weight_fp8, fc2_weight_t_fp8 = ( - self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - ) - - out = _LayerNormMLP.apply( - inp, - self.ln_weight, - self.ln_bias, - self.fc1_weight, - fc1_weight_fp8, - fc1_weight_t_fp8, - self.fc1_bias, - self.has_bias, - self.fc2_weight, - fc2_weight_fp8, - fc2_weight_t_fp8, - self.fc2_bias, - self.has_bias, - self.eps, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - self.return_layernorm_output, - paddle.is_grad_enabled(), - self.fwd_ln_sm_margin if paddle.is_grad_enabled() else self.inf_ln_sm_margin, - self.bwd_ln_sm_margin, - self.zero_centered_gamma, - self.normalization, - self.activation, - self.set_parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - ) - - if self.return_layernorm_output: - out, ln_out = out - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.fc2_bias, self.activation_dtype) - - if self.return_layernorm_output: - return out, ln_out - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support LayerNorm with zero-centered scale." - ) - - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - - if self.normalization == "RMSNorm": - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - norm_out = inp * norm * self.ln_weight - else: # LayerNorm - norm_out = F.layer_norm( - x=inp, - normalized_shape=inp.shape[-1], - weight=self.ln_weight, - bias=self.ln_bias, - epsilon=self.eps, - ) - if self.set_parallel_mode and self.tensor_parallel: - norm_out = identity(norm_out, self.tp_group) - fc1_out = F.linear(norm_out, self.fc1_weight, self.fc1_bias) - act_func = get_paddle_act_func(self.activation) - act_out = act_func(fc1_out) - out = F.linear( - act_out, self.fc2_weight, self.fc2_bias if self.gemm_bias_fused_add else None - ) - if self.set_parallel_mode and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.fc2_bias if self.fc2_bias is not None else out - if self.return_layernorm_output: - return out, norm_out - return out - - def forward(self, *args, **kwargs): - """ - Apply layer normalization to the input followed by a feedforward network (MLP Block). - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/linear.py b/transformer_engine/paddle/layer/linear.py deleted file mode 100644 index af35955a1c..0000000000 --- a/transformer_engine/paddle/layer/linear.py +++ /dev/null @@ -1,919 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Linear API""" - -import warnings -from typing import Union, Tuple, Dict, Any, Optional - -import paddle -import paddle.nn.functional as F -from paddle.nn.initializer import Constant - -from .base import ( - TransformerEngineBaseLayer, - get_workspace, - _2X_ACC_FPROP, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, -) - -from ..constants import FP8FwdTensors, FP8BwdTensors, GemmParallelModes, dist_group_type -from ..cpp_extensions import gemm, fp8_gemm, cast_to_fp8, cast_transpose, transpose -from ..distributed import ( - allgather, - allreduce, - get_tp_group_and_world_size, - identity, - reduce_scatter, - track_rng_state, - set_tensor_dist_attr, - set_weight_tensor_dist_attr, - mark_as_sequence_parallel_parameter, -) -from ..fp8 import get_fp8_te_dtype, get_global_fp8_state -from ..utils import ( - assert_dim_for_fp8_forward_exec, - cast_if_needed, - cast_if_needed_inplace, - divide, - get_bias_dtype, - save_for_backward_allow_none, - saved_tensor_allow_none, - clear_tensor_data, -) - -__all__ = ["Linear"] - - -def _linear_fwd_fp8( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_grad_enabled: bool, - is_first_microbatch: bool = None, -): - """FP8 path of Linear Fwd""" - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - bias_dtype = get_bias_dtype(activation_dtype) - bias = cast_if_needed(bias, bias_dtype) - - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = allgather(inputmat, tp_group) - else: - inputmat_total = inputmat - - if not get_global_fp8_state().is_cudagraph_enabled(): - # if cuda graph is not enabled, we cast the weight here - update_fp8_weights = is_first_microbatch is None or is_first_microbatch - if is_grad_enabled: - if update_fp8_weights: - weight_fp8, weight_t_fp8 = cast_transpose( - weight, - fp8_meta["scaling_fwd"], - weight_fp8_index, - fp8_dtype_forward, - cast_out=weight_fp8, - transpose_out=weight_t_fp8, - ) - else: - weight_t_fp8 = None - if update_fp8_weights: - weight_fp8 = cast_to_fp8( - weight, - fp8_meta["scaling_fwd"], - weight_fp8_index, - fp8_dtype_forward, - out=weight_fp8, - ) - - out, _ = fp8_gemm( - weight_fp8, - fp8_meta["scaling_fwd"].scale_inv, - weight_fp8_index, - fp8_dtype_forward, - inputmat_total, - fp8_meta["scaling_fwd"].scale_inv, - inputmat_fp8_index, - fp8_dtype_forward, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ) - - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - - return out, weight_t_fp8 - - -def _linear_fwd_non_fp8( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - activation: str = "", -): - """Non-FP8 path of Linear Fwd""" - - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = allgather(inputmat, tp_group) - else: - inputmat_total = inputmat - - # Layer parameters are initialized as float32 dtype by default. - # Cast the parameters to activation_dtype if the current dtype - # does not match activation_dtype. The casting is inplace, so it - # only needs to performed once throughout the traing process. - weight = cast_if_needed_inplace(weight, activation_dtype) - bias = cast_if_needed_inplace(bias, activation_dtype) - - if fp8_calibration: - # amax of input - fp8_meta["scaling_fwd"].amax_history[0, inputmat_fp8_index.value] = paddle.max( - paddle.abs(inputmat_total) - ).item() - # amax of weight - fp8_meta["scaling_fwd"].amax_history[0, weight_fp8_index.value] = paddle.max( - paddle.abs(weight) - ).item() - fp8_meta["update_amax_and_scale_fwd"] = True - - outputs = gemm( - weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - gelu=(activation == "gelu"), - ) - - if activation == "gelu": - gelu_out, _, out = outputs - return out, gelu_out - - out, _, _ = outputs - - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - return out - - -def _linear_fwd( - inputmat: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - weight_fp8_index: FP8FwdTensors, - bias: paddle.Tensor, - use_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - is_grad_enabled: bool, - is_first_microbatch: bool = None, - gather_output: bool = False, -): - if fp8_enabled: - out, weight_t_fp8 = _linear_fwd_fp8( - inputmat, - inputmat_fp8_index, - weight, - weight_fp8, - weight_t_fp8, - weight_fp8_index, - bias, - use_bias, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - ) - else: - out = _linear_fwd_non_fp8( - inputmat, - inputmat_fp8_index, - weight, - weight_fp8_index, - bias, - use_bias, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - ) - if gather_output and tensor_parallel and parallel_mode == "column": - out, _ = allgather(out, tp_group, axis=-1) - - return ( - out, - weight_t_fp8 if fp8_enabled else None, - ) - - -def _linear_bwd_fp8( - inputmat: paddle.Tensor, - inputmat_t: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_t_fp8: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, - fwd_scale_inverses: paddle.Tensor, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - dgrad, wgrad, handle = None, None, None - - # Overlap input AG with dgrad - inputmat_total = None - inputmat_t_total = None - if requires_wgrad and parallel_mode == "column" and sequence_parallel: - inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad) - else: - inputmat_total = inputmat - inputmat_t_total = inputmat_t - - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - if requires_dgrad: - dgrad, _ = fp8_gemm( - weight_t_fp8, - fwd_scale_inverses, - weight_fp8_index, - fp8_dtype_forward, - grad_output_c, - fp8_meta["scaling_bwd"].scale_inv, - grad_output_fp8_index, - fp8_dtype_backward, - activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ) - clear_tensor_data(grad_output_c) - - # Overlap dgrad-RS/AR with wgrad - if parallel_mode == "column" and sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False) - elif parallel_mode == "column" and tensor_parallel: - dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) - - if requires_wgrad: - if not fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmat_t_total is None: - inputmat_t_total = transpose(inputmat_total, fp8_dtype_backward) - clear_tensor_data(inputmat_total) - - wgrad, _ = fp8_gemm( - inputmat_t_total, - fwd_scale_inverses, - inputmat_fp8_index, - fp8_dtype_forward, - grad_output_t, - fp8_meta["scaling_bwd"].scale_inv, - grad_output_fp8_index, - fp8_dtype_backward, - "float32" if fuse_wgrad_accumulation else activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ) - clear_tensor_data(inputmat_t_total, grad_output_t) - else: - wgrad, _, _ = gemm( - inputmat_total, - grad_output, - activation_dtype, - get_workspace(), - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - layout="NT", - out=weight.main_grad if fuse_wgrad_accumulation else None, - out_dtype="float32" if fuse_wgrad_accumulation else None, - ) - clear_tensor_data(inputmat_total) - - if fuse_wgrad_accumulation: - weight.main_grad = wgrad - - if parallel_mode == "column" and tensor_parallel and handle is not None: - handle.wait() - if parallel_mode == "column" and sequence_parallel: - handle.wait() - - return dgrad, wgrad - - -def _linear_bwd_non_fp8( - inputmat: paddle.Tensor, - weight: paddle.Tensor, - grad_output: paddle.Tensor, - requires_bgrad: bool, - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, - gelu_input: Union[paddle.Tensor, None] = None, - activation: str = "", -): - """ - Performs Linear Backward. Optionally, fuses GELU backward and dbias. - """ - dgrad, wgrad, bgrad, handle = None, None, None, None - - # Overlap input AG with dgrad - inputmat_total = None - if requires_wgrad and parallel_mode == "column" and sequence_parallel: - inputmat_total, handle = allgather(inputmat, tp_group, sync_op=not requires_dgrad) - else: - inputmat_total = inputmat - - if requires_dgrad: - dgrad, _, _ = gemm( - weight, - grad_output, - activation_dtype, - get_workspace(), - layout="NN", - gelu=(activation == "gelu"), - gelu_input=gelu_input, - grad=True, - ) - # Overlap dgrad-RS/AR with wgrad - if parallel_mode == "column" and sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter(dgrad, tp_group, sync_op=False) - elif parallel_mode == "column" and tensor_parallel: - dgrad, handle = allreduce(dgrad, tp_group, sync_op=False) - - if requires_wgrad: - wgrad, bgrad, _ = gemm( - inputmat_total, - grad_output, - activation_dtype, - get_workspace(), - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - layout="NT", - out=weight.main_grad if fuse_wgrad_accumulation else None, - out_dtype="float32" if fuse_wgrad_accumulation else None, - use_bias=requires_bgrad, - ) - if fuse_wgrad_accumulation: - weight.main_grad = wgrad - - elif requires_bgrad: - bgrad = grad_output.sum(axis=0) - if parallel_mode == "column" and tensor_parallel and handle is not None: - handle.wait() - if parallel_mode == "column" and sequence_parallel and handle is not None: - handle.wait() - - return dgrad, wgrad, bgrad - - -def _linear_bwd( - inputmat: paddle.Tensor, - inputmat_t: paddle.Tensor, - inputmat_fp8_index: FP8FwdTensors, - weight: paddle.Tensor, - weight_t_fp8: paddle.Tensor, - weight_fp8_index: FP8FwdTensors, - grad_output: paddle.Tensor, - grad_output_c: paddle.Tensor, - grad_output_t: paddle.Tensor, - grad_output_fp8_index: FP8BwdTensors, - fwd_scale_inverses: paddle.Tensor, - requires_bgrad: bool, - fp8_enabled: bool, - fp8_meta: Dict[str, Any], - requires_dgrad: bool, - requires_wgrad: bool, - activation_dtype: paddle.dtype, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - fuse_wgrad_accumulation: bool, - accumulate_wgrad_into_param_main_grad: bool, -): - dgrad, wgrad, bgrad = None, None, None - if fp8_enabled: - dgrad, wgrad = _linear_bwd_fp8( - inputmat, - inputmat_t, - inputmat_fp8_index, - weight, - weight_t_fp8, - weight_fp8_index, - grad_output, - grad_output_c, - grad_output_t, - grad_output_fp8_index, - fwd_scale_inverses, - fp8_meta, - requires_dgrad, - requires_wgrad, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - else: - dgrad, wgrad, bgrad = _linear_bwd_non_fp8( - inputmat, - weight, - grad_output, - requires_bgrad, - requires_dgrad, - requires_wgrad, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad=accumulate_wgrad_into_param_main_grad, - ) - return dgrad, wgrad, bgrad - - -class _Linear(paddle.autograd.PyLayer): - """TE implementation of Linear""" - - @staticmethod - def forward( - ctx, - weight: paddle.Tensor, - weight_fp8: Optional[paddle.Tensor], - weight_t_fp8: Optional[paddle.Tensor], - inp: paddle.Tensor, - bias: paddle.Tensor, - use_bias: bool, - fp8_enabled: bool, - fp8_calibration: bool, - fp8_meta: Dict[str, Any], - activation_dtype: paddle.dtype, - is_grad_enabled: bool, - parallel_mode: Union[str, None], - tensor_parallel: bool, - sequence_parallel: bool, - tp_group: Union[dist_group_type, None], - tp_size: int, - fuse_wgrad_accumulation: bool, - is_first_microbatch: bool, - gather_output: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = weight.shape[-1] - assert inp.shape[-1] == in_features, "GEMM not possible" - inputmat = inp.reshape((-1, in_features)) - if fp8_enabled: - assert_dim_for_fp8_forward_exec(inputmat) - assert_dim_for_fp8_forward_exec(weight) - - inputmat_no_fp8 = inputmat - - # FP8 casting - inputmat_t = None - if fp8_enabled: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and not sequence_parallel - ): - inputmat, inputmat_t = cast_transpose( - inputmat, - fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - else: - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - - # GEMM Fwd - out, weight_t_fp8 = _linear_fwd( - inputmat, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_fp8, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - bias, - use_bias, - fp8_enabled, - fp8_calibration, - fp8_meta, - activation_dtype, - parallel_mode, - tensor_parallel, - sequence_parallel, - tp_group, - is_grad_enabled, - is_first_microbatch, - gather_output, - ) - - if is_grad_enabled: - saved_inputmat = None - if fp8_enabled and sequence_parallel: - saved_inputmat = inputmat - else: - saved_inputmat = inputmat_no_fp8 - save_for_backward_allow_none( - ctx, - saved_inputmat, - inputmat_t, - weight, - weight_t_fp8 if fp8_enabled else None, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8_enabled else None, - ) - ctx.activation_dtype = activation_dtype - ctx.fp8_enabled = fp8_enabled - ctx.fp8_meta = fp8_meta - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - ctx.parallel_mode = parallel_mode - ctx.tensor_parallel = tensor_parallel - ctx.sequence_parallel = sequence_parallel - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.requires_dgrad = not inp.stop_gradient - ctx.requires_wgrad = not weight.stop_gradient - ctx.requires_bgrad = use_bias and not bias.stop_gradient - ctx.is_first_microbatch = is_first_microbatch - ctx.reduce_scatter_output = gather_output - - return out.reshape((-1, *inp.shape[1:-1], out.shape[-1])) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - with TransformerEngineBaseLayer.prepare_backward( - ctx.fp8_enabled, ctx.fp8_meta, ctx.tp_group, ctx.tp_size, name="_Linear" - ): - - ( # pylint: disable=unbalanced-tuple-unpacking - inputmat, - inputmat_t, - weight, - weight_t_fp8, - fwd_scale_inverses, - ) = saved_tensor_allow_none(ctx) - - ( - grad_output, - grad_output_c, - grad_output_t, - bgrad, - ) = TransformerEngineBaseLayer.grad_output_preprocess( - ctx, grad_output, ctx.parallel_mode == "row" - ) - if ctx.is_first_microbatch is not None: - accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch - ) - else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - - dgrad, wgrad, bgrad_ = _linear_bwd( - inputmat, - inputmat_t, - FP8FwdTensors.GEMM1_INPUT, - weight, - weight_t_fp8, - FP8FwdTensors.GEMM1_WEIGHT, - grad_output, - grad_output_c, - grad_output_t, - FP8BwdTensors.GRAD_OUTPUT1, - fwd_scale_inverses, - ctx.requires_bgrad, - ctx.fp8_enabled, - ctx.fp8_meta, - ctx.requires_dgrad, - ctx.requires_wgrad, - ctx.activation_dtype, - ctx.parallel_mode, - ctx.tensor_parallel, - ctx.sequence_parallel, - ctx.tp_group, - ctx.fuse_wgrad_accumulation, - accumulate_wgrad_into_param_main_grad, - ) - - if not ctx.fp8_enabled: - # bgrad is fused with gemm for non-FP8 path - bgrad = bgrad_ - - if ctx.reduce_scatter_output: - wgrad, _ = reduce_scatter(wgrad, ctx.tp_group) - bgrad, _ = reduce_scatter(bgrad, ctx.tp_group) - - if not ctx.fp8_enabled or ctx.is_first_microbatch is None: - weight_cache_grad = () - else: - # weight_fp8 and weight_t_fp8 are stop_gradient tensors - weight_cache_grad = (None, None) - - dgrad_return = dgrad.reshape(ctx.inp_shape) if ctx.requires_dgrad else None - if not ctx.use_bias: - bgrad_return = () - elif ctx.requires_bgrad: - bgrad_return = (bgrad,) - else: - bgrad_return = (None,) - - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - wgrad = None - - return ( - wgrad if ctx.requires_wgrad else None, - *weight_cache_grad, - dgrad_return, - *bgrad_return, - ) - - -class Linear(TransformerEngineBaseLayer): - """ - Applies a linear transformation to the incoming data :math:`y = xA^T + b` - - Parameters - ---------- - in_features : int - size of each input sample. - out_features : int - size of each output sample. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - parallel_mode : {None, 'Column', 'Row'}, default = `None` - used to decide whether this Linear layer is Column Parallel Linear or Row - Parallel Linear as described `here `_. - When set to `None`, no communication is performed. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - in_features: int, - out_features: int, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - parallel_mode: Optional[str] = None, - sequence_parallel: bool = False, - tp_group: Union[dist_group_type, None] = None, - fuse_wgrad_accumulation: bool = False, - gather_output: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.backend = backend - self._weight_attr = weight_attr - self._bias_attr = bias_attr - self._dtype = self._helper.get_default_dtype() - self.gather_output = gather_output - - # Set parallel configs - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=parallel_mode is not None - ) - self.tensor_parallel = self.tp_size > 1 - self.parallel_mode = parallel_mode - assert ( - self.parallel_mode in GemmParallelModes - ), f"parallel_mode {parallel_mode} not supported" - - if self.parallel_mode == "column": - self.out_features = divide(self.out_features, self.tp_size) - elif self.parallel_mode == "row": - self.in_features = divide(self.in_features, self.tp_size) - - self.sequence_parallel = self.tensor_parallel and sequence_parallel - - self.fuse_wgrad_accumulation = fuse_wgrad_accumulation - - # Initialize weight parameter - with track_rng_state(enable=self.tensor_parallel): - # TE linear weight is in column major - self.weight = self.create_parameter( - shape=( - [self.out_features, self.in_features] - if self.backend == "transformer_engine" - else [self.in_features, self.out_features] - ), - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - set_weight_tensor_dist_attr( - self.weight, self.tensor_parallel, self.parallel_mode, self.backend - ) - - # Initialize bias parameter - self.has_bias = self._bias_attr is not False - use_default_bias = self._bias_attr is None or self._bias_attr is True - if self.has_bias: - self.bias = self.create_parameter( - shape=[self.out_features], - attr=( - self._bias_attr - if not use_default_bias - else paddle.ParamAttr(initializer=Constant(value=0.0)) - ), - dtype=self._dtype, - is_bias=True, - ) - if parallel_mode == "column": - set_tensor_dist_attr(self.bias, self.tensor_parallel, axis=0) - if parallel_mode == "row" and self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.bias) - else: - self.bias = None - - self.fp8_weights.append(self.weight) - - # For RPL, bias has to be added after TP collectives - # So it cannot be fused with the GEMM - if self.parallel_mode == "row" and self.tensor_parallel and self.has_bias: - self.gemm_bias_fused_add = False - else: - self.gemm_bias_fused_add = True - - def _te_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """ - Apply the linear transformation to the input. - """ - with self.prepare_forward(inp, is_first_microbatch=is_first_microbatch) as inp: - # Layer input should be casted outside PyLayer, as performing - # inplace cast to input tensors may cause problems when used - # together with Paddle native layers. - inp = cast_if_needed(inp, self.activation_dtype) - - # Get persistent fp8 weight buffer. None if buffer does not exist. - weight_fp8, weight_t_fp8 = self.get_fp8_weights_scratchpad_and_cast(is_first_microbatch) - - out = _Linear.apply( - self.weight, - weight_fp8, - weight_t_fp8, - inp, - self.bias if self.gemm_bias_fused_add else None, - self.has_bias and self.gemm_bias_fused_add, - self.fp8_enabled, - self.fp8_calibration, - self.fp8_meta, - self.activation_dtype, - paddle.is_grad_enabled(), - self.parallel_mode, - self.tensor_parallel, - self.sequence_parallel, - self.tp_group, - self.tp_size, - self.fuse_wgrad_accumulation, - is_first_microbatch, - self.gather_output, - ) - - if not self.gemm_bias_fused_add: - out = out + cast_if_needed_inplace(self.bias, self.activation_dtype) - - return out - - def _pd_forward( - self, - inp: paddle.Tensor, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """Calls Paddle OP""" - if is_first_microbatch is not None: - warnings.warn( - "`is_first_microbatch` is not supported for paddle backend and is ignored." - ) - if self.parallel_mode == "column" and self.tensor_parallel: - inp = identity(inp, self.tp_group) - out = F.linear(inp, self.weight, self.bias if self.gemm_bias_fused_add else None) - if self.parallel_mode == "row" and self.tensor_parallel: - out, _ = allreduce(out, self.tp_group) - out = out + self.bias if self.bias is not None else out - return out - - def forward(self, *args, **kwargs): - """ - Apply the linear transformation to the input. - - Parameters - ---------- - inp : paddle.Tensor - Input tensor. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} is not supported.") diff --git a/transformer_engine/paddle/layer/rmsnorm.py b/transformer_engine/paddle/layer/rmsnorm.py deleted file mode 100644 index 1afc3d9759..0000000000 --- a/transformer_engine/paddle/layer/rmsnorm.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""RMSNorm API""" -import os -from typing import Union, Tuple - -import paddle -from paddle.nn.initializer import Constant - -from ..constants import TE_DType -from ..cpp_extensions import rmsnorm_fwd, rmsnorm_bwd -from ..distributed import mark_as_sequence_parallel_parameter - -__all__ = ["RMSNorm"] - - -class _RMSNorm(paddle.autograd.PyLayer): - """functional RMSNorm""" - - @staticmethod - def forward( - ctx, - inp: paddle.Tensor, - rmsnorm_weight: paddle.Tensor, - eps: float, - fwd_rmsnorm_sm_margin: int, - bwd_rmsnorm_sm_margin: int, - zero_centered_gamma: bool, - ) -> paddle.Tensor: - # Make sure input dimensions are compatible - in_features = rmsnorm_weight.shape[0] - assert inp.shape[-1] == in_features, "RMSNorm not possible" - inputmat = inp.reshape((-1, in_features)) - - rmsnorm_out, rsigma = rmsnorm_fwd( - inputmat, - rmsnorm_weight, - eps, - TE_DType[inp.dtype], - fwd_rmsnorm_sm_margin, - zero_centered_gamma, - ) - - ctx.save_for_backward(inputmat, rmsnorm_weight, rsigma) - ctx.inp_shape = inp.shape - ctx.bwd_rmsnorm_sm_margin = bwd_rmsnorm_sm_margin - ctx.zero_centered_gamma = zero_centered_gamma - ctx.requires_dx = not inp.stop_gradient - ctx.requires_dw = not rmsnorm_weight.stop_gradient - - return rmsnorm_out.reshape(inp.shape) - - @staticmethod - def backward(ctx, grad_output: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - inputmat, rmsnorm_weight, rsigma = ctx.saved_tensor() - d_rmsnorm_out = grad_output.reshape(inputmat.shape) - dxmat, dgamma = rmsnorm_bwd( - d_rmsnorm_out, - inputmat, - rsigma, - rmsnorm_weight, - ctx.bwd_rmsnorm_sm_margin, - ctx.zero_centered_gamma, - ) - return ( - dxmat.reshape(ctx.inp_shape) if ctx.requires_dx else None, - dgamma if ctx.requires_dw else None, - ) - - -class RMSNorm(paddle.nn.Layer): - r""" - Applies Root Mean Square Layer Normalization over a mini-batch of inputs as described in - the paper `Root Mean Square Layer Normalization `__ - - .. math:: - y = \frac{x}{RMS_\varepsilon(x)} * \gamma - - where - - .. math:: - RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^nx_i^2 + \varepsilon} - - :math:`\gamma` is a learnable affine transform parameter of size :attr:`hidden_size` - - Parameters - ---------- - hidden_size : int - size of each input sample. - eps : float, default = 1e-5 - a value added to the denominator of layer normalization for numerical stability. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in RMSNorm is initialized to 0 and - the RMSNorm formula changes to - - .. math:: - y = \frac{x}{RMS(x) + \varepsilon} * (1 + \gamma) - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - backend to use for rmsnorm operation. - - Parallelism parameters - ---------------------- - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - """ - - def __init__( - self, - hidden_size: int, - eps: float = 1e-5, - weight_attr: Union[paddle.ParamAttr, None] = None, - zero_centered_gamma: bool = False, - sequence_parallel: bool = False, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - self.eps = eps - self.zero_centered_gamma = zero_centered_gamma - self.sequence_parallel = sequence_parallel - self.backend = backend - self._dtype = self._helper.get_default_dtype() - - self._weight_attr = weight_attr - if not self._weight_attr: - self._weight_attr = paddle.ParamAttr(initializer=Constant(1.0)) - - self.weight = self.create_parameter( - shape=[hidden_size], - attr=self._weight_attr, - dtype=self._dtype, - is_bias=False, - ) - - if self.sequence_parallel: - mark_as_sequence_parallel_parameter(self.weight) - - # These many SMs are subtracted from the total SM count when calling forward - # and backward RMSNorm C APIs. These envvars can be used to prevent the LN - # kernels from using all SMs in the device. This is useful for cases such as - # communication overlap with RMSNorm. - self.fwd_rmsnorm_sm_margin = int(os.getenv("NVTE_FWD_LAYERNORM_SM_MARGIN", "0")) - self.bwd_rmsnorm_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) - - def _te_forward(self, inp: paddle.Tensor) -> paddle.Tensor: - return _RMSNorm.apply( - inp, - self.weight, - self.eps, - self.fwd_rmsnorm_sm_margin, - self.bwd_rmsnorm_sm_margin, - self.zero_centered_gamma, - ) - - def _pd_forward( - self, - inp: paddle.Tensor, - ) -> paddle.Tensor: - if self.zero_centered_gamma: - raise NotImplementedError( - "Paddle backend does not support RMSNorm with zero_centered_gamma." - ) - norm = paddle.rsqrt(paddle.mean(inp**2, axis=-1, keepdim=True) + self.eps) - y = inp * norm * self.weight - return y - - def forward(self, *args, **kwargs): - if self.backend == "transformer_engine": - return self._te_forward(*args, **kwargs) - if self.backend == "paddle": - return self._pd_forward(*args, **kwargs) - raise AttributeError(f"Backend {self.backend} not supported.") diff --git a/transformer_engine/paddle/layer/softmax.py b/transformer_engine/paddle/layer/softmax.py deleted file mode 100644 index 11549364fe..0000000000 --- a/transformer_engine/paddle/layer/softmax.py +++ /dev/null @@ -1,254 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Fused scaled masked softmax functions""" - -import os -import warnings -from typing import Callable, Tuple, Union, Optional - -import paddle - -from transformer_engine.paddle.cpp_extensions import ( - scaled_upper_triang_masked_softmax_forward, - scaled_upper_triang_masked_softmax_backward, - scaled_masked_softmax_forward, - scaled_masked_softmax_backward, - scaled_softmax_forward, - scaled_softmax_backward, -) - - -__all__ = ["FusedScaleMaskSoftmax"] - - -THREADS_PER_WARP = 32 -THREADS_PER_BLOCK = 128 - - -_default_causal_mask = {} - - -def _get_default_causal_mask(seqlen: int) -> paddle.Tensor: - """Return the causal upper triangular mask for softmax input""" - if seqlen not in _default_causal_mask: - _default_causal_mask[seqlen] = paddle.triu(paddle.ones((seqlen, seqlen)), diagonal=1).cast( - "bool" - ) - return _default_causal_mask[seqlen] - - -class ScaledUpperTriangMaskedSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledUpperTriangMaskedSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledUpperTriangMaskedSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - input_grads = scaled_upper_triang_masked_softmax_backward( - output_grads, softmax_results, scale_t[0] - ) - - return input_grads, None - - -class ScaledMaskedSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, mask: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledMaskedSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - - softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledMaskedSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - - input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class ScaledSoftmax(paddle.autograd.PyLayer): - """ - Fused operation which performs following two operations in sequence - 1. Scale the tensor. - 2. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs: paddle.Tensor, scale: float) -> paddle.Tensor: - """ScaledSoftmax fwd""" - scale_t = paddle.Tensor([scale]) - - softmax_results = scaled_softmax_forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads: paddle.Tensor) -> Tuple[Union[paddle.Tensor, None], ...]: - """ScaledSoftmax bwd""" - softmax_results, scale_t = ctx.saved_tensor() - - input_grads = scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class FusedScaleMaskSoftmax(paddle.nn.Layer): - """ - Scaled and masked softmax module for paddle with fused optimizations. - - Parameters - ---------- - attn_mask_type : str, default = `causal` - type of attention mask, can be 'causal', 'padding', or 'no_mask'. - mask_func : callable - custom callable for applying the mask to the softmax input. - `masked_input=mask_func(inp, mask)`. - softmax_in_fp32 : bool, default = True - perform softmax computation in fp32. - layernorm_epsilon : float, default = 1e-5 - a value added to the denominator of layer normalization - for numerical stability. - backend: {'transformer_engine', 'paddle'}, default = `transformer_engine` - backend to use for operation. - """ - - def __init__( - self, - attn_mask_type: str, - mask_func: Callable, - softmax_in_fp32: bool = True, - backend: str = "transformer_engine", - ) -> None: - super().__init__() - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1"))) - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.backend = backend - - def forward( - self, - inp: paddle.Tensor, - mask: paddle.Tensor, - scale: Optional[float] = None, - ) -> paddle.Tensor: - """FusedScaleMaskSoftmax fprop""" - # [batch_size, num_heads, s_q, s_kv] - assert inp.dim() == 4 - self.input_is_fp16 = inp.dtype == paddle.float16 - self.input_is_bf16 = inp.dtype == paddle.bfloat16 - self.input_in_16bit_float = self.input_is_fp16 or self.input_is_bf16 - - assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" - - if self.backend == "transformer_engine" and not self.is_kernel_available(*inp.shape): - warnings.warn( - "fused kernel is not available for this input shape, fall back to paddle backend" - ) - self.backend = "paddle" - - if self.backend == "transformer_engine": - return self._te_forward(inp, mask, scale) - if self.backend == "paddle": - return self._pd_forward(inp, mask, scale) - raise AttributeError(f"Backend {self.backend} is not supported.") - - def is_kernel_available(self, b: int, h: int, s_q: int, s_kv: int) -> bool: - """Check FusedScaleMaskSoftmax kernel availability based on size""" - attn_batches = b * h - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_16bit_float # input must be fp16 - and 16 < s_kv <= 4096 # s_kv must be 16 ~ 2048 - and s_q % 4 == 0 # s_q must be a multiple of 4 - and attn_batches % 4 == 0 # b * h must be a multiple of 4 - ): - if 0 <= s_kv <= 4096: - batch_per_block = self.get_batch_per_block(int(s_kv)) - - if self.attn_mask_type == "causal": - if attn_batches % batch_per_block == 0: - return True - else: - if s_q % batch_per_block == 0: - return True - return False - - def _te_forward( - self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None - ) -> paddle.Tensor: - """Fused masked softmax kernel""" - b, h, s_q, s_kv = inp.size() - scale = 1.0 if scale is None else scale - - if self.attn_mask_type == "causal": - assert s_q == s_kv, "causal mask is only for self attention" - - # input is 3D tensor (attn_batches, s_q, s_kv) - inp = inp.reshape((-1, s_q, s_kv)) - probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale) - return probs.reshape((b, h, s_q, s_kv)) - # input is 4D tensor (b, h, s_q, s_kv) - if mask is not None: - return ScaledMaskedSoftmax.apply(inp, mask, scale) - return ScaledSoftmax.apply(inp, scale) - - def _pd_forward( - self, inp: paddle.Tensor, mask: paddle.Tensor, scale: Optional[float] = None - ) -> paddle.Tensor: - """Call Paddle OP""" - if self.input_in_16bit_float and self.softmax_in_fp32: - inp = paddle.cast(inp, "float32") - - if scale is not None: - inp = inp * scale - - if self.attn_mask_type == "causal": - mask = _get_default_causal_mask(inp.shape[2]) - - mask_output = self.mask_func(inp, mask) if mask is not None else inp - probs = paddle.nn.functional.softmax(mask_output, axis=-1) - - if self.input_in_16bit_float and self.softmax_in_fp32: - if self.input_is_fp16: - probs = paddle.cast(probs, "float16") - else: - probs = paddle.cast(probs, "bfloat16") - - return probs - - @staticmethod - def get_batch_per_block(key_seq_len: int) -> int: - """Softmax utility""" - pow2 = 1 << (key_seq_len - 1).bit_length() - warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP - batches_per_warp = 2 if pow2 <= 128 else 1 - warps_per_block = THREADS_PER_BLOCK // warp_size - batches_per_block = warps_per_block * batches_per_warp - return batches_per_block diff --git a/transformer_engine/paddle/layer/transformer.py b/transformer_engine/paddle/layer/transformer.py deleted file mode 100644 index 4a9c2c38dc..0000000000 --- a/transformer_engine/paddle/layer/transformer.py +++ /dev/null @@ -1,375 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Transformer""" - -from typing import Optional, Tuple, Union -import warnings - -import paddle -from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd - -from .layernorm_mlp import LayerNormMLP -from .layernorm import LayerNorm -from .attention import MultiHeadAttention -from ..constants import AttnMaskTypes, LayerTypes, dist_group_type -from ..distributed import get_tp_group_and_world_size, track_rng_state - - -class TransformerLayer(paddle.nn.Layer): - r""" - TransformerLayer is made up of an attention block and a feedforward network (MLP). - This standard layer is based on the paper "Attention Is All You Need". - - Parameters - ---------- - hidden_size : int - size of each input sample. - ffn_hidden_size : int - intermediate size to which input samples are projected. - num_attention_heads : int - number of attention heads in the transformer layer. - num_gqa_groups : Optional[int], default = `None` - number of GQA groups in the transformer layer. - Grouped Query Attention is described in - `this paper `_. - This only affects the keys and values, not the queries. - GQA-1 is equivalent to Multi-Query Attention - (`MQA `_), while GQA-H - is equivalent to MHA, i.e. `num_gqa_groups = num_attention_heads`. - layernorm_epsilon : float, default = 1e-5 - a value added to the denominator of layer normalization - for numerical stability. - hidden_dropout: float, default = 0.1 - dropout probability for the dropout op after FC2 layer. - attention_dropout: float, default = 0.1 - dropout probability for the dropout op during multi-head attention. - weight_attr: Union[paddle.ParamAttr, None], default = None - optional `paddle.ParamAttr` for weight. - bias_attr: Union[paddle.ParamAttr, None, bool], default = None - optional `paddle.ParamAttr` for bias. - self_attn_mask_type: {'causal', 'padding'}, default = `causal` - type of attention mask passed into softmax operation. - apply_residual_connection_post_layernorm : bool, default = `False` - if set to `True`, residual connections are taken - from the output of layer norm (default is taken - from input of layer norm) - output_layernorm: bool, default = `False` - if set to `True`, layer normalization is applied on the output side, - after the final dropout-add. default behavior is to apply layer - normalization on the input side, before the QKV transformation. - layer_type: {'encoder', 'decoder'}, default = `encoder` - if set to `decoder`, an additional cross-attn block is added after self-attn. - This can be used for structures like `T5` Transformer in conjunction with the - `encoder` option. - normalization: {'LayerNorm', 'RMSNorm'}, default = `LayerNorm` - zero_centered_gamma : bool, default = 'False' - if set to 'True', gamma parameter in LayerNorm is initialized to 0 and - the LayerNorm formula changes to - - .. math:: - y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \varepsilon}} * - (1 + \gamma) + \beta - activation : str, default = 'gelu' - Type of activation used in MLP block. - Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'. - - params_dtype : paddle.dtype, default = `paddle.get_default_dtype()` - it controls the type used to allocate the initial parameters. Useful when - the model is trained with lower precision and the original FP32 parameters - would not fit in GPU memory. - backend: {'transformer_engine', 'paddle'}, default = 'transformer_engine' - if set to 'paddle', a framework only no-FP8 path is executed with limited optimization. - - Parallelism parameters - ---------------------- - set_parallel_mode : bool, default = `False` - if set to `True`, QKV and FC1 layers are used as Column Parallel - whereas PROJ and FC2 is used as Row Parallel as described - `here `_. - sequence_parallel : bool, default = `False` - if set to `True`, uses sequence parallelism. - tp_group : ProcessGroup, default = `None` - tensor parallel process group. - attention_dropout_rng_state_name : str, default = `local_seed` - Controls the rng state used for dropout on attention probs. The - specified rng should be set different seeds for different TP ranks. - It will be ignored if `set_parallel_mode` is False. - hidden_dropout_rng_state_name : str, default = `global_seed` - Controls the rng state used for dropout on hidden states. The - specified rng should be given the same seeds for different TP - ranks. It will be ignored if `set_parallel_mode` is False. The - specified name should be registered through - `paddle.distributed.fleet.meta_parallel.get_rng_state_tracker() - .add(rng_state_name, seed)`. - - Optimization parameters - ----------------------- - fuse_wgrad_accumulation : bool, default = 'False' - if set to `True`, enables fusing of creation and accumulation of - the weight gradient. When enabled, it is assumed that the weights - have an additional `main_grad` attribute (used instead of the - regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. - - """ - - def __init__( - self, - hidden_size: int, - ffn_hidden_size: int, - num_attention_heads: int, - num_gqa_groups: Optional[int] = None, - layernorm_epsilon: float = 1e-5, - hidden_dropout: float = 0.1, - attention_dropout: float = 0.1, - weight_attr: Union[paddle.ParamAttr, None] = None, - bias_attr: Union[paddle.ParamAttr, None, bool] = None, - max_sequence_length: Optional[int] = None, - self_attn_mask_type: str = "causal", - params_dtype: Optional[paddle.dtype] = None, - apply_residual_connection_post_layernorm: bool = False, - output_layernorm: bool = False, - layer_type: str = "encoder", - normalization: str = "LayerNorm", - zero_centered_gamma: bool = False, - activation: str = "gelu", - set_parallel_mode: bool = False, - sequence_parallel: bool = False, - tp_group: Optional[dist_group_type] = None, - fuse_wgrad_accumulation: bool = False, - attention_dropout_rng_state_name: str = "local_seed", - hidden_dropout_rng_state_name: str = "global_seed", - backend: str = "transformer_engine", - ) -> None: - super().__init__() - - params_dtype = paddle.get_default_dtype() if params_dtype is None else params_dtype - self.output_layernorm = output_layernorm - self.layer_type = layer_type - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm - self.self_attn_mask_type = self_attn_mask_type - self.set_parallel_mode = set_parallel_mode - self.tp_group, self.tp_size = get_tp_group_and_world_size( - tp_group, enable_tp=set_parallel_mode - ) - self.tensor_parallel = self.tp_size > 1 - self.sequence_parallel = self.tensor_parallel and sequence_parallel - self.hidden_dropout_rng_state_name = hidden_dropout_rng_state_name - # SP needs local seed for hidden dropout - if self.sequence_parallel and self.hidden_dropout_rng_state_name == "global_seed": - warnings.warn( - "RNG state for hidden dropout needs to be different across TP ranks. " - "Forcing hidden_dropout_rng_state_name to 'local_seed'" - ) - self.hidden_dropout_rng_state_name = "local_seed" - - assert ( - self_attn_mask_type in AttnMaskTypes - ), f"self_attn_mask_type {self_attn_mask_type} not supported" - assert layer_type in LayerTypes, f"layer_type {layer_type} not supported" - - attention_args = ( - hidden_size, - num_attention_heads, - attention_dropout, - layernorm_epsilon, - weight_attr, - bias_attr, - ) - common_attention_kwargs = { - "params_dtype": params_dtype, - "return_layernorm_output": apply_residual_connection_post_layernorm, - "normalization": normalization, - "zero_centered_gamma": zero_centered_gamma, - "set_parallel_mode": set_parallel_mode, - "sequence_parallel": self.sequence_parallel, - "max_sequence_length": max_sequence_length, - "tp_group": tp_group, - "num_gqa_groups": num_gqa_groups, - "fuse_wgrad_accumulation": fuse_wgrad_accumulation, - "rng_state_name": attention_dropout_rng_state_name, - "backend": backend, - } - - self.self_attention = MultiHeadAttention( - *attention_args, - **common_attention_kwargs, - attn_mask_type=self_attn_mask_type, - input_layernorm=not output_layernorm, - attention_type="self", - ) - - if layer_type == "decoder": - self.inter_attention = MultiHeadAttention( - *attention_args, - **common_attention_kwargs, - attn_mask_type="padding", - input_layernorm=True, - attention_type="cross", - ) - - self.layernorm_mlp = LayerNormMLP( - hidden_size, - ffn_hidden_size, - eps=layernorm_epsilon, - weight_attr=weight_attr, - bias_attr=bias_attr, - normalization=normalization, - activation=activation, - return_layernorm_output=apply_residual_connection_post_layernorm, - zero_centered_gamma=zero_centered_gamma, - set_parallel_mode=set_parallel_mode, - sequence_parallel=self.sequence_parallel, - tp_group=tp_group, - fuse_wgrad_accumulation=fuse_wgrad_accumulation, - backend=backend, - ) - - self.hidden_dropout = hidden_dropout - - if self.output_layernorm: - self.layernorm = LayerNorm( - hidden_size, - layernorm_epsilon, - weight_attr, - bias_attr, - zero_centered_gamma=zero_centered_gamma, - sequence_parallel=self.sequence_parallel, - backend=backend, - ) - - self.fused_dropout_add1 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - if self.layer_type == "decoder": - self.fused_dropout_add2 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - self.fused_dropout_add3 = FusedDropoutAdd(self.hidden_dropout, mode="upscale_in_train") - - def forward( - self, - hidden_states: paddle.Tensor, - attention_mask: Optional[paddle.Tensor] = None, - encoder_output: Optional[paddle.Tensor] = None, - enc_dec_attn_mask: Optional[paddle.Tensor] = None, - rotary_pos_emb: Optional[Tuple[paddle.Tensor, paddle.Tensor]] = None, - core_attention_bias_type: str = "no_bias", - core_attention_bias: Optional[paddle.Tensor] = None, - set_zero: bool = True, - recompute_core_attention: bool = False, - is_first_microbatch: Optional[bool] = None, - ) -> paddle.Tensor: - """ - Transformer Layer: attention block and a feedforward network (MLP) - - .. note:: - - Argument :attr:`attention_mask` will be ignored when :attr:`self_attn_mask_type` - is set to `"causal"`. - - Parameters - ---------- - hidden_states : paddle.Tensor - Input tensor. - attention_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out self-attention softmax input. - encoder_output : Optional[paddle.Tensor], default = `None` - Output of the encoder block to be fed into the decoder block if using - `layer_type="decoder"`. - enc_dec_attn_mask : Optional[paddle.Tensor], default = `None` - Boolean tensor used to mask out inter-attention softmax input if using - `layer_type="decoder"`. - rotary_pos_emb : Optional[Tuple[paddle.Tensor, paddle.Tensor]], default = `None` - Embeddings for query and key tensors for applying rotary position - embedding. By default no input embedding is applied - core_attention_bias_type: str, default = `no_bias` - core_attention_bias: Optional[paddle.Tensor], default = `None` - Bias tensor for Q * K.T - set_zero: bool, default = `True` - Whether to set output tensors to 0 or not before use. - recompute_core_attention: bool, default = `False` - If true, forward activations for core attention are recomputed - during the backward pass in order to save memory that would - otherwise be occupied to store the forward activations until - backprop. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - """ - - if self.self_attn_mask_type != "causal" and attention_mask is not None: - assert attention_mask.dtype == paddle.bool, "Attention mask must be a boolean tensor" - - assert core_attention_bias_type in ["no_bias"], ( - "Only no_bias is supported currently, " - f"but receive core_attention_bias_type = {core_attention_bias_type}" - ) - - # Self attention. - self_attention_outputs = self.self_attention( - hidden_states, - attention_mask, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - rotary_pos_emb=rotary_pos_emb, - recompute_core_attention=recompute_core_attention, - is_first_microbatch=is_first_microbatch, - ) - - if self.apply_residual_connection_post_layernorm and not self.output_layernorm: - attention_output, residual = self_attention_outputs - else: - attention_output = self_attention_outputs - residual = hidden_states - - # dropoout add. - with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): - bda_output = self.fused_dropout_add1(attention_output, residual) - - # Cross attention. - if self.layer_type == "decoder": - inter_attention_outputs = self.inter_attention( - bda_output, - enc_dec_attn_mask, - encoder_output=encoder_output, - core_attention_bias_type=core_attention_bias_type, - core_attention_bias=core_attention_bias, - set_zero=set_zero, - recompute_core_attention=recompute_core_attention, - is_first_microbatch=is_first_microbatch, - ) - if self.apply_residual_connection_post_layernorm: - attention_output, residual = inter_attention_outputs - else: - attention_output = inter_attention_outputs - residual = bda_output - - with track_rng_state( - enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name - ): - bda_output = self.fused_dropout_add2(attention_output, residual) - - # MLP. - mlp_outputs = self.layernorm_mlp(bda_output, is_first_microbatch=is_first_microbatch) - if self.apply_residual_connection_post_layernorm: - mlp_output, residual = mlp_outputs - else: - mlp_output = mlp_outputs - residual = bda_output - - # dropoout add. - with track_rng_state(enable=self.tensor_parallel, name=self.hidden_dropout_rng_state_name): - output = self.fused_dropout_add3(mlp_output, residual) - - # For BERT like architectures. - if self.output_layernorm: - output = self.layernorm(output) - - # output: [b, s, hidden] - return output diff --git a/transformer_engine/paddle/profile.py b/transformer_engine/paddle/profile.py deleted file mode 100644 index 67d9afcb6f..0000000000 --- a/transformer_engine/paddle/profile.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utils for profiling""" - -from contextlib import contextmanager - -try: - from paddle.base import core -except ImportError: - from paddle.fluid import core - - -@contextmanager -def nvtx_range(msg): - """Context to insert NVTX""" - core.nvprof_nvtx_push(msg) - yield - core.nvprof_nvtx_pop() diff --git a/transformer_engine/paddle/recompute.py b/transformer_engine/paddle/recompute.py deleted file mode 100644 index 1d64ad0de0..0000000000 --- a/transformer_engine/paddle/recompute.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Methods needed for recompute.""" - -import os -import inspect - -from paddle.distributed import fleet - -from .constants import RecomputeFunctionNames -from .fp8 import get_global_fp8_state - - -__all__ = ["recompute"] - - -_DISABLE_RECOMPUTE = int(os.getenv("NVTE_DISABLE_RECOMPUTE", "0")) - - -def is_in_recompute_phase(): - """Inspect call stack to determine if this is called from - backward phase. Paddle has two recompute methods: - (1) Use RecomputeFunction. The recomputed function is called from `RecomputeFunction.backward`; - (2) Use paddle.autograd.saved_tensors_hooks. The recompute function is called from `unpack`.""" - if _DISABLE_RECOMPUTE: - return False - frame = inspect.currentframe().f_back - while frame: - if frame.f_code.co_name in RecomputeFunctionNames: - return True - frame = frame.f_back - return False - - -def recompute(function, *args, **kwargs): - """ - This is a wrapper of paddle.distributed.fleet.utils.recompute. It provides necessary - state information for fp8 layers. - - Parameters - ---------- - function: Callable - paddle module used to run the forward and backward passes using - the specified :attr:`args` and :attr:`kwargs`. - args : tuple - tuple of torch tensors for inputs to :attr:`function`. - kwargs : dict - dictionary of string keys for keyword arguments to :attr:`function`. - """ - assert ( - not _DISABLE_RECOMPUTE - ), f"Recompute is disabled. Got NVTE_DISABLE_RECOMPUTE={_DISABLE_RECOMPUTE}." - - global_fp8_state = get_global_fp8_state() - - try: - global_fp8_state._fp8_recompute_enabled = True - outputs = fleet.utils.recompute(function, *args, **kwargs) - finally: - global_fp8_state._fp8_recompute_enabled = False - - return outputs diff --git a/transformer_engine/paddle/setup.py b/transformer_engine/paddle/setup.py deleted file mode 100644 index 5b1d1a1e04..0000000000 --- a/transformer_engine/paddle/setup.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Installation script for TE paddle-paddle extensions.""" - -# pylint: disable=wrong-import-position,wrong-import-order - -import sys -import os -import shutil -from pathlib import Path - -import setuptools -from paddle.utils.cpp_extension import BuildExtension - -try: - import paddle # pylint: disable=unused-import -except ImportError as e: - raise RuntimeError("This package needs Paddle Paddle to build.") from e - - -current_file_path = Path(__file__).parent.resolve() -build_tools_dir = current_file_path.parent.parent / "build_tools" -if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir): - build_tools_copy = current_file_path / "build_tools" - if build_tools_copy.exists(): - shutil.rmtree(build_tools_copy) - shutil.copytree(build_tools_dir, build_tools_copy) - - -from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers -from build_tools.te_version import te_version -from build_tools.paddle import setup_paddle_extension - - -os.environ["NVTE_PROJECT_BUILDING"] = "1" -CMakeBuildExtension = get_build_ext(BuildExtension) - - -if __name__ == "__main__": - # Extensions - common_headers_dir = "common_headers" - copy_common_headers(current_file_path.parent, str(current_file_path / common_headers_dir)) - ext_modules = [ - setup_paddle_extension( - "csrc", current_file_path / "csrc", current_file_path / common_headers_dir - ) - ] - - # Configure package - setuptools.setup( - name="transformer_engine_paddle", - version=te_version(), - description="Transformer acceleration library - Paddle Paddle Lib", - ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=["paddlepaddle-gpu>=2.6.1"], - tests_require=["numpy"], - ) - if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): - shutil.rmtree(common_headers_dir) - shutil.rmtree("build_tools") diff --git a/transformer_engine/paddle/utils.py b/transformer_engine/paddle/utils.py deleted file mode 100644 index 7b9aabbf5a..0000000000 --- a/transformer_engine/paddle/utils.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Utility functions for Transformer Engine modules""" - -from typing import Optional, Tuple, Union - -import paddle -import paddle.nn.functional as F -from .cpp_extensions import swiglu_pd - - -def cast_if_needed( - tensor: Union[paddle.Tensor, None], dtype: paddle.dtype -) -> Union[paddle.Tensor, None]: - """Cast tensor to dtype""" - return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype) - - -def cast_if_needed_inplace( - tensor: Union[paddle.Tensor, None], dtype: paddle.dtype -) -> Union[paddle.Tensor, None]: - """Cast tensor to dtype (inplace), not to be used on layer inputs""" - return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype) - - -def check_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> bool: - """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. - """ - return not tensor.shape[0] % 8 and not tensor.shape[1] % 16 - - -def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None: - """For fp8 fprop (TN layout), inputs and weights must be such - that dim0 is divisible by 8 and dim1 is divisible by 16. - """ - # single tensor check so it's clear which tensor is triggering the assertion - assert check_dim_for_fp8_forward_exec(tensor), ( - "Tensor dimensions are not compatible for FP8 execution: " - f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)" - ) - - -def get_bias_dtype(activation_dtype: paddle.dtype): - """Get bias dtype given activation_dtype""" - return paddle.bfloat16 if activation_dtype == paddle.float32 else activation_dtype - - -def get_paddle_act_func(activation): - """Get paddle activation function""" - funcs = { - "gelu": F.gelu, - "relu": F.relu, - "silu": F.silu, - "swiglu": swiglu_pd, - } - if activation not in funcs: - raise "Activation type " + activation + " is not supported." - return funcs[activation] - - -def attention_mask_func( - attention_scores: paddle.Tensor, attention_mask: paddle.Tensor -) -> paddle.Tensor: - """Get attention mask""" - - def _masked_fill(x, mask, value): - y = paddle.full(x.shape, value, x.dtype) - return paddle.where(mask, y, x) - - attention_scores = _masked_fill(attention_scores, attention_mask, -10000.0) - return attention_scores - - -def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor: - """Convert mask to cu_seqlens""" - assert "bool" in str(mask.dtype), "mask must be bool dtype" - assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]" - q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype="int32") - q_cu_seqlens = paddle.cumsum(q_actual_seqlens) - q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0) - if not need_kv: - return q_cu_seqlens, None - kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype="int32") - kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens) - kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0) - return q_cu_seqlens, kv_cu_seqlens - - -def divide(numerator: int, denominator: int) -> int: - """Ensure that numerator is divisible by the denominator and return - the division value.""" - assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" - return numerator // denominator - - -def save_for_backward_allow_none(ctx, *args) -> None: - """Save tensors for backward. Args could be None""" - indices_mapping = [] - tensors_to_save = [] - for x in args: - if isinstance(x, paddle.Tensor): - indices_mapping.append(len(tensors_to_save)) - tensors_to_save.append(x) - elif x is None: - indices_mapping.append(-1) - else: - raise ValueError(f"Type {type(x)} is not allowed.") - - ctx._indices_mapping = indices_mapping - ctx.save_for_backward(*tensors_to_save) - - -def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]: - """Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx.""" - assert hasattr( - ctx, "_indices_mapping" - ), "`saved_tensor_allow_none` must be used with `save_for_backward_allow_none` in pair." - - indices_mapping = ctx._indices_mapping - outputs = [] - saved_tensors = ctx.saved_tensor() - - for index in indices_mapping: - if index < 0: - outputs.append(None) - else: - outputs.append(saved_tensors[index]) - - return tuple(outputs) - - -def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None: - """ - Free tensor buffer - """ - - def can_free(t): - return ( - t is not None - and isinstance(t, paddle.Tensor) - and t._is_initialized() - and t.inplace_version == 0 - ) - - for t in tensors: - if can_free(t): - t._clear_dataptr() diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 781f9d42fd..5f20dbff85 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -7,15 +7,26 @@ # pylint: disable=wrong-import-position,wrong-import-order import logging +import functools +import sys import importlib import importlib.util -import sys -import torch from importlib.metadata import version +from packaging.version import Version as PkgVersion + +import torch from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +_logger = logging.getLogger(__name__) + + +@functools.lru_cache(maxsize=None) +def torch_version() -> tuple[int, ...]: + """Get PyTorch version""" + return PkgVersion(str(torch.__version__)).release + def _load_library(): """Load shared library with Transformer Engine C extensions""" @@ -34,15 +45,15 @@ def _load_library(): "TransformerEngine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install" - " transformer-engine[pytorch]==VERSION'" + f" v{version('transformer-engine-cu12')}. Install transformer-engine using " + "'pip3 install transformer-engine[pytorch]==VERSION'" ) if is_package_installed("transformer-engine-cu12"): if not is_package_installed(module_name): - logging.info( - "Could not find package %s. Install transformer-engine using 'pip" - " install transformer-engine[pytorch]==VERSION'", + _logger.info( + "Could not find package %s. Install transformer-engine using " + "'pip3 install transformer-engine[pytorch]==VERSION'", module_name, ) @@ -51,8 +62,12 @@ def _load_library(): so_dir = get_te_path() / "transformer_engine" so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) except StopIteration: - so_dir = get_te_path() - so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + try: + so_dir = get_te_path() / "transformer_engine" / "wheel_lib" + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) + except StopIteration: + so_dir = get_te_path() + so_path = next(so_dir.glob(f"{module_name}.*.{extension}")) spec = importlib.util.spec_from_file_location(module_name, so_path) solib = importlib.util.module_from_spec(spec) @@ -60,6 +75,9 @@ def _load_library(): spec.loader.exec_module(solib) +assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." + + _load_library() from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear @@ -71,33 +89,25 @@ def _load_library(): from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub from transformer_engine.pytorch.attention import DotProductAttention -from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import MultiheadAttention +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams from transformer_engine.pytorch.transformer import TransformerLayer -from transformer_engine.pytorch.permutation import moe_permute, moe_unpermute +from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_permute_with_probs, + moe_unpermute, + moe_sort_chunks_by_index, + moe_sort_chunks_by_index_with_probs, +) from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.fp8 import fp8_model_init from transformer_engine.pytorch.graph import make_graphed_callables -from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context from transformer_engine.pytorch import ops from transformer_engine.pytorch import optimizers - -# Register custom op symbolic ONNX functions -from transformer_engine.pytorch.te_onnx_extensions import ( - onnx_cast_to_fp8, - onnx_cast_to_fp8_noalloc, - onnx_cast_from_fp8, - onnx_fp8_gelu, - onnx_fp8_relu, - onnx_te_gemm, - onnx_layernorm_fwd_fp8, - onnx_layernorm_fwd, - onnx_rmsnorm_fwd, - onnx_rmsnorm_fwd_fp8, -) +from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy try: torch._dynamo.config.error_on_nested_jit_trace = False diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 6b153fd3c1..afb6b92f04 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,41 +12,24 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings import logging -import functools -from dataclasses import dataclass, fields import numpy as np from packaging.version import Version as PkgVersion import torch -import torch.nn.functional as F import transformer_engine_torch as tex -import transformer_engine as te -from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.cpp_extensions import ( - cast_to_fp8, - cast_from_fp8, +from transformer_engine.pytorch.utils import ( + get_cudnn_version, + nvtx_range_pop, + nvtx_range_push, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import ( - fused_attn_fwd_qkvpacked, - fused_attn_bwd_qkvpacked, - fused_attn_fwd_kvpacked, - fused_attn_bwd_kvpacked, fused_attn_fwd, fused_attn_bwd, - QKVLayout, - AttnBiasType, - AttnMaskType, FusedAttnBackend, META_QKV, - META_DQKV, META_O, - META_DO, - META_S, - META_DP, - META_O_CP, - META_DQKV_CP, ) from transformer_engine.pytorch.fp8 import ( FP8GlobalStateManager, @@ -54,6 +37,7 @@ get_fp8_torch_dtype, ) from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( @@ -82,222 +66,103 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.jit import jit_fuser, no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.tensor.quantized_tensor import ( + QuantizedTensor, + prepare_for_saving, + restore_from_saved, +) +# Import attention utils +import transformer_engine.pytorch.dot_product_attention.utils as dpa_utils +from transformer_engine.pytorch.dot_product_attention.utils import FlashAttentionUtils as fa_utils +from transformer_engine.pytorch.dot_product_attention.utils import AttentionLogging as attn_log +from transformer_engine.pytorch.dot_product_attention.rope import apply_rotary_pos_emb -# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 -_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) -# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 -_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) -_log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL -_log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} -_log_level = _log_levels[_log_level if _log_level in [0, 1, 2] else 2] -_formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") -_stream_handler = logging.StreamHandler() -_stream_handler.setFormatter(_formatter) -fa_logger = logging.getLogger() -fa_logger.setLevel(_log_level) -if not fa_logger.hasHandlers(): - fa_logger.addHandler(_stream_handler) - - -@functools.lru_cache(maxsize=None) -def _get_supported_versions(version_min, version_max): - return ">= " + str(version_min) + ", " + "<= " + str(version_max) - - -_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) -_NVTE_FUSED_ATTN = int(os.getenv("NVTE_FUSED_ATTN", "1")) -_NVTE_UNFUSED_ATTN = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - -# Detect flash-attn v2 in the environment -_flash_attn_is_installed = False -_flash_attn_version = PkgVersion("0") -_flash_attn_version_required = PkgVersion("2.1.1") -_flash_attn_max_version = PkgVersion("2.6.3") -_flash_attn_2_plus = False -_flash_attn_2_1_plus = False -_flash_attn_2_3_plus = False -_flash_attn_2_4_plus = False -_flash_attn_2_4_1_plus = False -_flash_attn_2_5_7_plus = False -_flash_attn_2_6_0_plus = False +# Setup Attention Logging +attn_log.setup_logging() + +# Global vars for flash attn v2 and v3 imports +flash_attn_cuda_bwd = None flash_attn_func = None flash_attn_varlen_func = None -flash_attn_varlen_fwd = None -flash_attn_varlen_bwd = None -flash_attn_cuda_bwd = None - +_flash_attn_fwd = None +_flash_attn_bwd = None +_flash_attn_varlen_fwd = None +_flash_attn_varlen_bwd = None try: - _flash_attn_version = PkgVersion(get_pkg_version("flash-attn")) + fa_utils.version = PkgVersion(get_pkg_version("flash-attn")) except PackageNotFoundError: - if torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN: - fa_logger.debug( - "flash-attn v2 is not installed. To use, please install it by" - """ "pip install flash-attn".""", - ) + pass # only print warning if use_flash_attention_2 = True in get_attention_backend else: - if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: + if torch.cuda.is_available() and get_device_compute_capability() >= (10, 0): + if fa_utils.version_required_blackwell <= fa_utils.version <= fa_utils.max_version: + fa_utils.is_installed = True + elif fa_utils.version_required <= fa_utils.version <= fa_utils.max_version: + fa_utils.is_installed = True + + if fa_utils.is_installed: + from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func + from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd + from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd from flash_attn.flash_attn_interface import ( - _flash_attn_varlen_forward as flash_attn_varlen_fwd, + _flash_attn_varlen_forward as _flash_attn_varlen_fwd, ) from flash_attn.flash_attn_interface import ( - _flash_attn_varlen_backward as flash_attn_varlen_bwd, + _flash_attn_varlen_backward as _flash_attn_varlen_bwd, ) - from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd - _flash_attn_is_installed = True - _flash_attn_2_plus = _flash_attn_version >= PkgVersion("2") - _flash_attn_2_1_plus = _flash_attn_version >= PkgVersion("2.1") - _flash_attn_2_3_plus = _flash_attn_version >= PkgVersion("2.3") - _flash_attn_2_4_plus = _flash_attn_version >= PkgVersion("2.4") - _flash_attn_2_4_1_plus = _flash_attn_version >= PkgVersion("2.4.1") - _flash_attn_2_5_7_plus = _flash_attn_version >= PkgVersion("2.5.7") - _flash_attn_2_6_0_plus = _flash_attn_version >= PkgVersion("2.6.0") + # Setup Flash attention utils + fa_utils.set_flash_attention_version() elif ( - torch.cuda.is_available() and get_device_compute_capability() >= (8, 0) and _NVTE_FLASH_ATTN + torch.cuda.is_available() + and get_device_compute_capability() >= (8, 0) + and dpa_utils._NVTE_FLASH_ATTN ): - fa_logger.warning( + attn_log.fa_logger.warning( "Supported flash-attn versions are %s. Found flash-attn %s.", - _get_supported_versions( - _flash_attn_version_required, - _flash_attn_max_version, + dpa_utils._get_supported_versions( + ( + fa_utils.version_required + if get_device_compute_capability() < (10, 0) + else fa_utils.version_required_blackwell + ), + fa_utils.max_version, ), - _flash_attn_version, + fa_utils.version, ) - -# Detect flash-attn v3 in the environment -# This section will be removed when FA3 is released as a regular FA package, -# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0 -_flash_attn_3_is_installed = False -_flash_attn_3_version = PkgVersion("0") -_flash_attn_3_0_0_beta = False -_use_flash_attn_3 = False -_flash_attn_3_installation_steps = """\ -(1) pip install "git+https://github.com/Dao-AILab/flash-attention.git#egg=flashattn-hopper&subdirectory=hopper" -(2) python_path=`python -c "import site; print(site.getsitepackages()[0])"` -(3) mkdir -p $python_path/flashattn_hopper -(4) wget -P $python_path/flashattn_hopper https://raw.githubusercontent.com/Dao-AILab/flash-attention/main/hopper/flash_attn_interface.py""" try: - _flash_attn_3_version = PkgVersion(get_pkg_version("flashattn-hopper")) + fa_utils.fa3_version = PkgVersion(get_pkg_version("flash-attn-3")) except PackageNotFoundError: - if torch.cuda.is_available() and get_device_compute_capability() >= (9, 0) and _NVTE_FLASH_ATTN: - fa_logger.debug( - "flash-attn v3 is not installed. To use, please install it by \n%s", - _flash_attn_3_installation_steps, - ) + pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: - from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flashattn_hopper.flash_attn_interface import ( + from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_3.flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_forward as flash_attn_varlen_fwd_v3, - ) - from flashattn_hopper.flash_attn_interface import ( - _flash_attn_varlen_backward as flash_attn_varlen_bwd_v3, + from flash_attn_3.flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 - _flash_attn_3_is_installed = True - _flash_attn_3_0_0_beta = PkgVersion("3.0.0b") < _flash_attn_3_version < PkgVersion("3.0.0") - _use_flash_attn_3 = True + fa_utils.set_flash_attention_3_params() +# Global vars for available attention backends and ALiBi cache _attention_backends = { "attention_params": None, "use_flash_attention": None, + "flash_attention_backend": None, "use_fused_attention": None, "fused_attention_backend": None, "use_unfused_attention": None, "backend_selection_requires_update": False, } - -@dataclass(eq=True) -class AttentionParams: - """ - Attention parameters used to determine which backend to be used. - - Parameters - ---------- - qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor` - Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}. - qkv_dtype: torch.dtype, default = `torch.bfloat16` - Data type of query/key/value tensors. - qkv_layout: str, default = "sbh3d" - Query/key/value tensor memory layout. - batch_size: int, default = 1 - Batch size. - num_heads: int, default = 16 - Number of attention heads in the query tensor. - num_gqa_groups: int, default = 16 - Number of attention heads in key and value tensors. - max_seqlen_q: int, default = 128 - Maximum sequence length of the query tensor. - max_seqlen_kv: int, default = 128 - Maximum sequence length of the key and value tensors. - head_dim_qk: int, default = 64 - The size of each attention head in query and key tensors. - head_dim_v: int, default = 64 - The size of each attention head in the value tensor. - attn_mask_type: str, default = `no_mask` - Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, - `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} - window_size: Tuple[int, int], default = None - Sliding window attention size. - alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` - Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. - core_attention_bias_type: str, default = `no_bias` - Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}. - core_attention_bias_shape: str, default = `1hss` - Attention bias shape, {`1hss`, `b1ss`, `bhss`}. - core_attention_bias_requires_grad: bool, default = `True` - Whether attention bias requires gradient. - pad_between_seqs: bool, default = `False` - Whether there is padding between sequences in a batch. - This only applies to `qkv_format=thd`. - attention_dropout: float, default = 0.0 - Attention dropout. - context_parallel: bool, default = `False` - Whether context parallelism is used or not. - deterministic: bool, default = `False` - Whether to run `DotProductAttention` with determinism or not. - is_training: bool, default = `True` - Whether in training mode (`True`) or inference mode (`False`) - fp8: bool, default = `False` - Whether `DotProductAttention` is in an `fp8_autocast` region. - fp8_meta: Optional[Dict[str Any]], default = `None` - The FP8 metadata tensor of `DotProductAttention`. - """ - - qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor - qkv_dtype: torch.dtype = torch.bfloat16 - qkv_layout: str = "sbh3d" - batch_size: int = 1 - num_heads: int = 16 - num_gqa_groups: int = 16 - max_seqlen_q: int = 128 - max_seqlen_kv: int = 128 - head_dim_qk: int = 64 - head_dim_v: int = 64 - attn_mask_type: str = "no_mask" - window_size: Union[Tuple[int, int], None] = None - alibi_slopes_shape: Union[torch.Size, List, None] = None - core_attention_bias_type: str = "no_bias" - core_attention_bias_shape: str = "1hss" - core_attention_bias_requires_grad: bool = True - pad_between_seqs: bool = False - attention_dropout: float = 0.0 - context_parallel: bool = False - deterministic: bool = False - is_training: bool = True - fp8: bool = False - fp8_meta: Union[Dict[str, Any], None] = None - - _alibi_cache = { "_num_heads": None, "_alibi_slopes": None, @@ -309,8 +174,7 @@ class AttentionParams: "_alibi_bias_require_update": False, } - -__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"] +__all__ = ["DotProductAttention", "MultiheadAttention"] def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor: @@ -318,1119 +182,6 @@ def maybe_contiguous(tensor: torch.Tensor) -> torch.Tensor: return tensor.contiguous() if tensor.stride(-1) != 1 else tensor -def get_attention_backend( - attention_params: AttentionParams = None, -): - """ - Select the appropriate attention backend/sub-backend based on user input and runtime environment. - - Parameters - ---------- - See `AttentionParams`. - - Returns - ---------- - use_flash_attention: bool - Whether the `FlashAttention` backend has been selected. - use_fused_attention: bool - Whether the `FusedAttention` backend has been selected. - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`. - use_unfused_attention: bool - Whether the `UnfusedDotProductAttention` backend has been selected. - available_backends: List[bool] - All available backends that could support the provided input. A list of Booleans - in the form of [use_flash_attention, use_fused_attention, use_unfused_attention]. - """ - qkv_type = attention_params.qkv_type - qkv_dtype = attention_params.qkv_dtype - qkv_layout = attention_params.qkv_layout - batch_size = attention_params.batch_size - num_heads = attention_params.num_heads - num_gqa_groups = attention_params.num_gqa_groups - max_seqlen_q = attention_params.max_seqlen_q - max_seqlen_kv = attention_params.max_seqlen_kv - head_dim_qk = attention_params.head_dim_qk - head_dim_v = attention_params.head_dim_v - attn_mask_type = attention_params.attn_mask_type - window_size = attention_params.window_size - alibi_slopes_shape = attention_params.alibi_slopes_shape - core_attention_bias_type = attention_params.core_attention_bias_type - core_attention_bias_shape = attention_params.core_attention_bias_shape - core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad - pad_between_seqs = attention_params.pad_between_seqs - attention_dropout = attention_params.attention_dropout - context_parallel = attention_params.context_parallel - deterministic = attention_params.deterministic - is_training = attention_params.is_training - fp8 = attention_params.fp8 - fp8_meta = attention_params.fp8_meta - - # Run config - logger = logging.getLogger("DotProductAttention") - logger.setLevel(_log_level) - if not logger.hasHandlers(): - logger.addHandler(_stream_handler) - device_compute_capability = get_device_compute_capability() - cudnn_version = get_cudnn_version() - run_config = { - "transformer_engine_version": te.__version__, - "compute_capability": "sm" - + str(10 * device_compute_capability[0] + device_compute_capability[1]), - "flash_attn_version": ( - str(_flash_attn_version) if _flash_attn_is_installed else "not installed" - ), - "flash_attn_3_version": ( - str(_flash_attn_3_version) if _flash_attn_3_is_installed else "not installed" - ), - "cudnn_version": ".".join([str(i) for i in cudnn_version]), - } - attention_params_dict = { - field.name: getattr(attention_params, field.name) for field in fields(attention_params) - } - run_config.update(attention_params_dict) - if fp8: - run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - logger.debug("Running with config=%s", run_config) - - # The following sections check if `FlashAttention` supports the provided attention params, - # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is - # necessary for performance/functionality, a warning will be issued to prompt users to - # install an appropriate FA version. - global _flash_attn_version_required, _flash_attn_max_version, _use_flash_attn_3 - - # Filter: Environment variables - use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) - use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) - use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) - if not use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention due to NVTE_FLASH_ATTN=0") - if not use_fused_attention: - logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") - if not use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") - - # Filter: ONNX mode - if is_in_onnx_export_mode(): - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention due to ONNX mode") - use_flash_attention = False - if use_fused_attention: - logger.debug("Disabling FusedAttention due to ONNX mode") - use_fused_attention = False - - # Filter: Compute capability - if device_compute_capability < (8, 0): - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as it requires compute capability sm80+") - use_flash_attention = False - if use_fused_attention: - logger.debug("Disabling FusedAttention as it requires compute capability sm80+") - use_fused_attention = False - if device_compute_capability < (9, 0): - if use_flash_attention and _flash_attn_3_is_installed: - logger.debug("Disabling FlashAttention 3 as it requires compute capability sm90+") - _use_flash_attn_3 = False - - # Filter: Data type - if qkv_dtype not in [torch.bfloat16, torch.float16] or qkv_type not in [ - torch.Tensor, - Float8Tensor, - ]: - if use_flash_attention and _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", - qkv_dtype, - ) - use_flash_attention = False - if use_fused_attention: - logger.debug( - "Disabling FusedAttention due to unsupported QKV data type. " - "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. " - "Found: qkv_dtype = %s.", - qkv_dtype, - ) - use_fused_attention = False - - # Filter: Execution type - if fp8 and fp8_meta["recipe"].fp8_dpa: - if use_flash_attention and not _use_flash_attn_3: - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as FlashAttention 2 does not support FP8") - use_flash_attention = False - if use_flash_attention and _use_flash_attn_3 and is_training: - logger.debug( - "Disabling FlashAttention as FlashAttention 3 does not support FP8 training" - ) - use_flash_attention = False - if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8") - use_unfused_attention = False - - # Filter: Head dimension - if use_flash_attention and head_dim_qk != head_dim_v: - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention as it does not support MLA.") - use_flash_attention = False - if use_flash_attention and ( - head_dim_qk > 256 - or head_dim_qk % 8 != 0 - or (head_dim_qk > 192 and device_compute_capability not in ((8, 0), (9, 0))) - ): - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention due to unsupported head_dim_qk and head_dim_v. " - "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " - "head_dim_qk <= 256 (>192 requires sm80/90). " - "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", - head_dim_qk, - head_dim_v, - ".".join([str(i) for i in device_compute_capability]), - ) - use_flash_attention = False - qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") - if use_fused_attention and head_dim_qk != head_dim_v and qkv_layout_group != "hd_hd_hd": - logger.debug( - "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", - qkv_layout, - ) - use_fused_attention = False - - # Filter: QKV layout - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - if qkv_format == "thd": - if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") - use_unfused_attention = False - if use_flash_attention and pad_between_seqs: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " - "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" - ) - use_flash_attention = False - - # Filter: Dropout - if attention_dropout != 0.0 and use_flash_attention and _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for dropout") - _use_flash_attn_3 = False - - # Filter: Context parallelism - # qkv_format | attn_mask_type | attn_bias_type | supported backends - # ---------------------------------------------------------------------------------------------------- - # bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention - # | no_mask, causal | | - # | cross-attention: | | - # | no_mask | | - # thd | self-attention: | no_bias | FlashAttention, FusedAttention - # | padding, padding_causal | | if no padding between sequences, - # | cross-attention: | | FusedAttention - # | padding | | if there is padding between sequences - # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v. - if context_parallel and use_unfused_attention: - logger.debug( - "Disabling UnfusedDotProductAttention as it does not support context parallelism" - ) - use_unfused_attention = False - if context_parallel and use_flash_attention: - if fp8 and fp8_meta["recipe"].fp8_dpa: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with FP8" - ) - use_flash_attention = False - if "bottom_right" in attn_mask_type: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with" - " causal_bottom_right masking" - ) - use_flash_attention = False - elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with" - " causal masking for cross-attention" - ) - use_flash_attention = False - elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with bias" - " type of %s", - core_attention_bias_type, - ) - use_flash_attention = False - elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - if _flash_attn_is_installed: - logger.debug( - "Disabling FlashAttention as it does not support context parallelism with" - " attention bias for THD format" - ) - use_flash_attention = False - - if context_parallel and use_fused_attention: - if "bottom_right" in attn_mask_type: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with" - " causal_bottom_right masking" - ) - use_fused_attention = False - elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with causal" - " masking for cross-attention" - ) - use_fused_attention = False - elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with bias type" - " of %s", - core_attention_bias_type, - ) - use_fused_attention = False - elif qkv_format == "thd" and core_attention_bias_type != "no_bias": - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with attention" - " bias for THD format" - ) - use_fused_attention = False - elif head_dim_qk != head_dim_v: - logger.debug( - "Disabling FusedAttention as it does not support context parallelism with MLA" - ) - use_fused_attention = False - - # Filter: Attention mask - # attn_mask_type | attention_mask | supported backends - # ---------------------------------------------------------------------------------------- - # no_mask | None | All - # padding | | All - # self-attention | One tensor in shape [b, 1, 1, sq] | - # cross-attention | Tuple of two tensors in shapes | - # | [b, 1, 1, sq] and [b, 1, 1, skv] | - # causal | None | - # self-attention | | All - # cross-attention | | FusedAttention, UnfusedDotProductAttention - # padding_causal | Same as "padding" | - # self-attention | | All - # cross-attention | | FusedAttention, UnfusedDotProductAttention - # causal_bottom_right | None | All - # padding_causal_bottom_right | Same as "padding" | - # self-attention | | All - # cross-attention | | FlashAttention, UnfusedDotProductAttention - # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention - # | [b, h, sq, skv] | - if attn_mask_type == "arbitrary": - if use_flash_attention and _flash_attn_is_installed: - logger.debug("Disabling FlashAttention for arbitrary mask") - use_flash_attention = False - if use_fused_attention: - logger.debug("Disabling FusedAttention for arbitrary mask") - use_fused_attention = False - if ( - use_flash_attention - and _use_flash_attn_3 - and attn_mask_type in ["causal", "padding_causal"] - and max_seqlen_q != max_seqlen_kv - ): - logger.warning( - "Disabling FlashAttention 3 as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - _use_flash_attn_3 = False - if ( - use_flash_attention - and attn_mask_type in ["causal", "padding_causal"] - and max_seqlen_q != max_seqlen_kv - ): - if _flash_attn_2_1_plus: - logger.warning( - "Disabling FlashAttention as it only supports bottom-right-diagonal " - "causal mask since flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False - if not _flash_attn_is_installed: - _flash_attn_max_version = PkgVersion("2.1") - if ( - use_flash_attention - and attn_mask_type in ["causal_bottom_right", "padding_causal_bottom_right"] - and max_seqlen_q != max_seqlen_kv - ): - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.1") - elif not _flash_attn_2_1_plus and not _use_flash_attn_3: - logger.warning( - "Disabling FlashAttention as it only supports top-left-diagonal " - "causal mask before flash-attn 2.1. See " - "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" - ) - use_flash_attention = False - if ( - use_flash_attention - and _use_flash_attn_3 - and fp8 - and fp8_meta["recipe"].fp8_dpa - and "padding" in attn_mask_type - ): - logger.debug("Disabling FlashAttention 3 for FP8 and padding masks") - _use_flash_attn_3 = False - - # Filter: Sliding window attention - # backend | window_size | diagonal alignment - # --------------------------------------------------------------------------------- - # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right - # FusedAttention | (-1, 0) or (>=0, 0) | top left - # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; - # | | converts window_size to an 'arbitrary' mask - if window_size is None: - window_size = check_set_window_size(attn_mask_type, window_size) - else: - if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention" - " for FP8" - ) - use_fused_attention = False - elif window_size[1] != 0 or attention_dropout != 0.0 or qkv_format == "thd": - logger.debug( - "Disabling FusedAttention as it only supports sliding window attention " - "with causal mask, no dropout, and qkv_format = bshd/sbhd" - ) - use_fused_attention = False - elif max_seqlen_q != max_seqlen_kv and attn_mask_type in [ - "no_mask", - "padding", - "causal_bottom_right", - "padding_causal_bottom_right", - ]: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s for cross-attention", - attn_mask_type, - ) - use_fused_attention = False - elif "padding" in attn_mask_type: - logger.debug( - "Disabling FusedAttention as it does not support sliding window attention " - "with attn_mask_type = %s", - attn_mask_type, - ) - use_fused_attention = False - if use_flash_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): - if _use_flash_attn_3: - logger.debug( - "Disabling FlashAttention 3 as it does not support sliding window attention" - ) - _use_flash_attn_3 = False - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.3") - elif not _flash_attn_2_3_plus: - logger.debug( - "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" - ) - use_flash_attention = False - - # Filter: Attention bias - # backend | bias types | ALiBi diagonal alignment - # --------------------------------------------------------------------------------- - # FlashAttention | no_bias, alibi/alibi_slopes | bottom right - # FusedAttention | no_bias, post_scale_bias | - # | alibi/alibi_slopes | top left, - # | | bottom_right (converts to a 'post_scale_bias' bias) - # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | - # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias - if use_flash_attention and core_attention_bias_type == "alibi": - if _use_flash_attn_3: - logger.debug("Disabling FlashAttention 3 for ALiBi") - _use_flash_attn_3 = False - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.4") - elif not _flash_attn_2_4_plus: - logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") - use_flash_attention = False - - if use_flash_attention and ( - core_attention_bias_type not in ["no_bias", "alibi"] - or core_attention_bias_shape is not None - ): - if _flash_attn_is_installed: - logger.debug("Disabling FlashAttention for pre/post_scale_bias") - use_flash_attention = False - - fu_core_attention_bias_type = core_attention_bias_type - fu_core_attention_bias_shape = core_attention_bias_shape - fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad - if ( - use_fused_attention - and core_attention_bias_type == "alibi" - and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) - ): - fu_core_attention_bias_type = "post_scale_bias" - fu_core_attention_bias_requires_grad = False - if alibi_slopes_shape is None: - fu_core_attention_bias_shape = "1hss" - elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: - fu_core_attention_bias_shape = "1hss" - elif ( - len(alibi_slopes_shape) == 2 - and alibi_slopes_shape[0] == batch_size - and alibi_slopes_shape[1] == num_heads - ): - fu_core_attention_bias_shape = "bhss" - - if ( - use_fused_attention - and fu_core_attention_bias_type == "post_scale_bias" - and fu_core_attention_bias_shape != "1hss" - ): - if fu_core_attention_bias_requires_grad: - # remove this line when cuDNN adds bwd support for - # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] - logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") - use_fused_attention = False - else: - # max512 backend will only support [1, h, s, s] - os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - - # Filter: cuDNN support - fused_attention_backend = None - if use_fused_attention: - q_type = TE_DType[qkv_dtype] - kv_type = q_type - if fp8 and fp8_meta["recipe"].fp8_dpa: - q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - kv_type = q_type - fused_attention_backend = tex.get_fused_attn_backend( - q_type, - kv_type, - QKVLayout[qkv_layout], - AttnBiasType[fu_core_attention_bias_type], - AttnMaskType[attn_mask_type], - attention_dropout, - num_heads, - num_gqa_groups, - max_seqlen_q, - max_seqlen_kv, - head_dim_qk, - head_dim_v, - window_size[0], - window_size[1], - ) - if fused_attention_backend == FusedAttnBackend["No_Backend"]: - logger.debug("Disabling FusedAttention as no backend supports the provided input") - use_fused_attention = False - fused_attention_backend = None - if ( - use_fused_attention - and window_size is not None - and window_size[0] != -1 - and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "slidng window attention", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None - if ( - use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - and fu_core_attention_bias_type == "post_scale_bias" - and fu_core_attention_bias_shape != "1hss" - ): - logger.debug( - "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in" - " [1, H, S, S] shape" - ) - use_fused_attention = False - fused_attention_backend = None - - # Filter: Determinism - # backend | deterministic - # --------------------------------------------- - # FlashAttention | - # flash-attn >=2.0, <2.4.1 | no - # flash-attn >=2.4.1 | yes - # FusedAttention | - # sub-backend 0 | yes - # sub-backend 1 | workspace optimization path and sm90+: yes; - # | otherwise: no - # sub-backend 2 | no - # UnfusedDotProductAttention | yes - if use_flash_attention and deterministic: - if not _flash_attn_is_installed: - _flash_attn_version_required = PkgVersion("2.4.1") - elif not _flash_attn_2_4_1_plus and not _use_flash_attn_3: - logger.warning( - "Disabling FlashAttention as version <2.4.1 does not support deterministic " - "execution. To use FlashAttention with deterministic behavior, " - "please install flash-attn >= 2.4.1." - ) - use_flash_attention = False - if use_fused_attention and deterministic: - if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons") - use_fused_attention = False - if ( - fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - and is_training - and ( - device_compute_capability < (9, 0) - or core_attention_bias_requires_grad - or cudnn_version < (8, 9, 5) - ) - ): - logger.debug("Disabling FusedAttention for determinism reasons") - use_fused_attention = False - - # All available backends - available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] - - # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. - # When `FusedAttention` does not support the provided attention params, and `FlashAttention` - # does, we recommend users to install flash-attn if not installed already. - if not use_fused_attention and use_flash_attention and not _flash_attn_is_installed: - logger.warning( - "flash-attn may provide important feature support or performance improvement." - " Please install flash-attn %s.", - _get_supported_versions( - _flash_attn_version_required, - _flash_attn_max_version, - ), - ) - if use_flash_attention and not _flash_attn_is_installed: - use_flash_attention = False - available_backends[0] = False - - logger.debug( - "Available backends = {FlashAttention=%s, FusedAttention=%s%s," - " UnfusedDotProductAttention=%s}", - bool(available_backends[0]), - bool(available_backends[1]), - ( - f" (sub-backend {int(fused_attention_backend)})" - if fused_attention_backend is not None - else "" - ), - bool(available_backends[2]), - ) - - # Select FusedAttention for performance - if ( - use_flash_attention - and use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - ): - if device_compute_capability == (9, 0): - logger.debug( - "Disabling FlashAttention to give FusedAttention preference on Hopper+ " - "for performance reasons" - ) - use_flash_attention = False - if ( - use_flash_attention - and use_fused_attention - and fused_attention_backend == FusedAttnBackend["FP8"] - and _use_flash_attn_3 - ): - logger.debug( - "Disabling FlashAttention 3 to give FusedAttention preference for performance reasons " - "in FP8 execution" - ) - use_flash_attention = False - - # Selected backend - if use_flash_attention: - use_fused_attention = False - use_unfused_attention = False - elif use_fused_attention: - use_unfused_attention = False - selected_backend = "NoBackend" - if use_flash_attention: - selected_backend = "FlashAttention" - elif use_fused_attention: - selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" - elif use_unfused_attention: - selected_backend = "UnfusedDotProductAttention" - logger.debug("Selected backend = %s", selected_backend) - - global _attention_backends - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - - return ( - use_flash_attention, - use_fused_attention, - fused_attention_backend, - use_unfused_attention, - available_backends, - ) - - -class InferenceParams: # pylint: disable=too-few-public-methods - """ - Inference parameters that are passed to the main model in order - to efficiently calculate and store the context during inference. - - Parameters - ---------- - max_batch_size : int - maximum batch size during inference. - max_sequence_length : int - maximum sequence length during inference. - """ - - def __init__(self, max_batch_size, max_sequence_length): - self.max_sequence_length = max_sequence_length - self.max_batch_size = max_batch_size - self.sequence_len_offset = 0 - self.batch_size_offset = 0 - self.key_value_memory_dict = {} - - def swap_key_value_dict(self, batch_indices): - """ - Reorders the KV cache using the specified batch indices. - - Parameters - ---------- - batch_indices : List[int] - Sequence of indices to reorder along the batch dimensions of - the KV cache. Must have a length equal to the batch size. - """ - if len(self.key_value_memory_dict) == 0: - raise ValueError("should not swap when dict in empty") - - for layer_number, inference_memory in self.key_value_memory_dict.items(): - inference_key_memory, inference_value_memory = inference_memory - assert ( - len(batch_indices) == inference_key_memory.shape[1] - ) # make sure batch size is the same - new_inference_key_memory = inference_key_memory[:, batch_indices] - new_inference_value_memory = inference_value_memory[:, batch_indices] - self.key_value_memory_dict[layer_number] = ( - new_inference_key_memory, - new_inference_value_memory, - ) - - -@torch.no_grad() -def get_swa_mask( - window_size: Tuple[int, int], - max_seqlen_q: int, - max_seqlen_kv: int, - attn_mask_type: str = "no_mask", - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, -) -> torch.Tensor: - """ - Convert sliding window `window_size` to an equivalent "`arbitrary`" mask. - For "`causal`" mask type, the sliding window diagonal is aligned to the top left corner, - and for other mask types, the bottom right corner. - - Parameters - ---------- - window_size: Tuple[int, int] - Sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. Both `causal` and `causal_bottom_right` masks - map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on - `attn_mask_type`. - max_seqlen_q: int - Maximum sequence length for queries. - max_seqlen_kv: int - Maximum sequence length for keys and values. - attn_mask_type: str, default = `no_mask` - Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", - "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]], - default = `None` - Boolean tensor(s) used to mask out attention softmax input. - - Returns - ---------- - attention_mask: torch.Tensor - Combined `attention_mask` (input) and sliding window attention mask. - The shape is [max_seqlen_q, max_seqlen_kv] when input `attention_mask` is None; - else, the same shape as input `attention_mask`. - """ - mask = torch.ones(max_seqlen_q, max_seqlen_kv, dtype=torch.bool, device="cuda") - if attn_mask_type in ["causal"]: - left = window_size[0] if window_size[0] != -1 else max_seqlen_q - right = window_size[1] if window_size[1] != -1 else max_seqlen_q - mask_upper = torch.triu(mask, diagonal=-left) - mask_lower = torch.tril(mask_upper, diagonal=right) - else: - left = window_size[0] if window_size[0] != -1 else max_seqlen_kv - right = window_size[1] if window_size[1] != -1 else max_seqlen_kv - mask_upper = torch.triu(mask, diagonal=max_seqlen_kv - max_seqlen_q - left) - mask_lower = torch.tril(mask_upper, diagonal=max_seqlen_kv - max_seqlen_q + right) - attn_mask_type = "arbitrary" - mask = mask_lower.logical_not() - if attention_mask is not None: - mask = torch.logical_and(attention_mask, mask) - return attn_mask_type, mask - - -@torch.no_grad() -def get_alibi( - num_heads: int, - max_seqlen_q: int, - max_seqlen_kv: int, - actual_seqlens_q: Optional[torch.Tensor] = None, - actual_seqlens_kv: Optional[torch.Tensor] = None, - alibi_slopes: Optional[torch.Tensor] = None, - bias_dtype: Optional[torch.dtype] = None, - bottom_right_alignment: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Parameters - ---------- - num_heads: int - Number of heads. - max_seqlen_q: int - Maximum sequence length for queries. - max_seqlen_kv: int - Maximum sequence length for keys and values. - actual_seqlens_q: Optional[torch.Tensor], default = `None` - Actual sequence lengths for queries, in shape [batch_size]. - actual_seqlens_kv: Optional[torch.Tensor], default = `None` - Actual sequence lengths for keys and values, in shape [batch_size]. - alibi_slopes: Optional[torch.Tensor], default = `None` - Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. - bias_dtype: Optional[torch.dtype], default = `None` - Dtype of the generated ALiBi bias. If None, use torch.float32. - bottom_right_alignment: bool, default = `True` - Whether to align the diagonal of the ALiBi bias to the bottom right corner of - the matrix (`True`) or top left (`False`). - - Returns - ---------- - alibi_slopes: torch.Tensor - ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. - alibi_bias: torch.Tensor - ALiBi bias in FP32 or `bias_dtype`. Its shape is - (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, - and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or - (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in - [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and - `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. - """ - global _alibi_cache - if _alibi_cache["_alibi_slopes_require_update"]: - if alibi_slopes is not None: - _alibi_cache["_alibi_slopes"] = alibi_slopes - else: - n = 2 ** math.floor(math.log2(num_heads)) - m_0 = 2.0 ** (-8.0 / n) - m = torch.pow(m_0, torch.arange(1, 1 + n)) - - if n < num_heads: - m_hat_0 = 2.0 ** (-4.0 / n) - m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) - m = torch.cat([m, m_hat]) - - _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda") - _alibi_cache["_num_heads"] = num_heads - _alibi_cache["_alibi_slopes_require_update"] = False - - if _alibi_cache["_alibi_bias_require_update"]: - assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!" - if _alibi_cache["_alibi_slopes"].dim() == 1: - slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) - elif _alibi_cache["_alibi_slopes"].dim() == 2: - slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) - else: - raise ValueError("ALiBi slopes cannot exceed 2 dimensions.") - - bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv - ) - if actual_seqlens_q is None and actual_seqlens_kv is None: - if bottom_right_alignment: - bias = bias + max_seqlen_kv - max_seqlen_q - elif actual_seqlens_q is not None and actual_seqlens_kv is not None: - batch_size = actual_seqlens_q.shape[0] - bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - if bottom_right_alignment: - bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) - else: - assert ( - False - ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" - bias = bias.abs().mul(-1) - bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) - _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv - _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment - bias_dtype = torch.float32 if bias_dtype is None else bias_dtype - _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") - _alibi_cache["_alibi_bias_require_update"] = False - - return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"] - - -def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: - """ - Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 - tensor of shape [batch_size + 1] containing the cumulative sequence lengths of - the samples in a batch. - """ - mask = mask.squeeze(1).squeeze(1) - reduced_mask = mask.logical_not().sum(dim=1) - cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") - cu_seqlens = torch.cat((zero, cu_seqlens)) - - return cu_seqlens - - -def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 - tensor of shape [batch_size + 1] containing the cumulative sequence lengths of - the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1] - containing the indices for the valid tokens. - """ - mask = mask.squeeze(1).squeeze(1) - bs, seqlen = mask.shape - - reduced_mask = mask.logical_not().sum(dim=1) - cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") - cu_seqlens = torch.cat((zero, cu_seqlens)) - - mask = mask.reshape(-1) - indices = mask.logical_not().nonzero() - indices = indices.unsqueeze(-1) - - num_nonzeros = indices.shape[0] - pad_amount = bs * seqlen - num_nonzeros - indices = F.pad( - input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen) - ) - - return cu_seqlens, indices - - -def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: - """ - Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32 - tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for - the valid tokens in a batch. - """ - bs = len(cu_seqlens) - 1 - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)] - indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda") - - num_nonzeros = indices.shape[0] - pad_amount = bs * max_seqlen - num_nonzeros - indices = F.pad( - input=indices, - pad=(0, 0, 0, 0, 0, pad_amount), - mode="constant", - value=float(bs * max_seqlen), - ) - - return indices - - -_cu_seqlens_cache = {} - - -def _get_full_cu_seqlens( - batch_size: int, - max_seqlen: int, - device: torch.device, -) -> torch.Tensor: - """Cumulative sequence lengths in full data batch - - All sequences in batch have the maximum sequence length. - - """ - global _cu_seqlens_cache - if (batch_size, max_seqlen) not in _cu_seqlens_cache: - _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( - 0, - (batch_size + 1) * max_seqlen, - step=max_seqlen, - dtype=torch.int32, - device=device, - ) - return _cu_seqlens_cache[(batch_size, max_seqlen)] - - -@torch.compile -def pack_tensor( - indices: torch.Tensor, - tensor: torch.Tensor, -) -> torch.Tensor: - """ - Packs the given tensor using the `indices`. - """ - padding_indice = torch.zeros( - 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device - ) - indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) - if isinstance(tensor, Float8Tensor): - tensor_data = torch.cat((tensor._data, padding_indice), dim=0) - - packed = Float8Tensor.make_like(tensor, data=torch.gather(tensor_data, 0, indices)) - else: - tensor = torch.cat((tensor, padding_indice), dim=0) - - packed = torch.gather(tensor, 0, indices) - return packed - - -@torch.compile -def pack_2_tensors( - indices: torch.Tensor, - t1: torch.Tensor, - t2: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Packs the given 2 tensors using the `indices`. - """ - t1_packed = pack_tensor(indices, t1) - t2_packed = pack_tensor(indices, t2) - return t1_packed, t2_packed - - -@torch.compile -def pack_3_tensors( - indices: torch.Tensor, - t1: torch.Tensor, - t2: torch.Tensor, - t3: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Packs the given 3 tensors using the `indices`. - """ - t1_packed = pack_tensor(indices, t1) - t2_packed = pack_tensor(indices, t2) - t3_packed = pack_tensor(indices, t3) - return t1_packed, t2_packed, t3_packed - - -@torch.compile -def unpack_tensor( - indices: torch.Tensor, - dim0: int, - tensor: torch.Tensor, -) -> torch.Tensor: - """ - Inverse of `pack_tensor`. - """ - indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) - unpacked = torch.zeros( - dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device - ) - if isinstance(tensor, Float8Tensor): - unpacked.scatter_(0, indices, tensor._data) - unpacked = Float8Tensor.make_like(tensor, data=unpacked[0:-1, :, :]) - else: - unpacked.scatter_(0, indices, tensor) - unpacked = unpacked[0:-1, :, :] - return unpacked - - -@torch.compile -def unpack_2_tensors( - indices: torch.Tensor, - dim0: int, - t1: torch.Tensor, - t2: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Inverse of `pack_2_tensors`. - """ - t1_unpacked = unpack_tensor(indices, dim0, t1) - t2_unpacked = unpack_tensor(indices, dim0, t2) - return t1_unpacked, t2_unpacked - - -@torch.compile -def unpack_3_tensors( - indices: torch.Tensor, - dim0: int, - t1: torch.Tensor, - t2: torch.Tensor, - t3: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Inverse of `pack_3_tensors`. - """ - t1_unpacked = unpack_tensor(indices, dim0, t1) - t2_unpacked = unpack_tensor(indices, dim0, t2) - t3_unpacked = unpack_tensor(indices, dim0, t3) - return t1_unpacked, t2_unpacked, t3_unpacked - - -class PackTensors(torch.autograd.Function): - """ - Autograd function to pack tensors. - """ - - @staticmethod - def forward( - ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...] - ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: - # pylint: disable=missing-function-docstring - assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." - ctx.save_for_backward(indices) - ctx.dim0 = tensors[0].shape[0] - if len(tensors) == 1: - return pack_tensor(indices, *tensors) - if len(tensors) == 2: - return pack_2_tensors(indices, *tensors) - return pack_3_tensors(indices, *tensors) - - @staticmethod - def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): - # pylint: disable=missing-function-docstring - (indices,) = ctx.saved_tensors - if len(grad_outputs) == 1: - return None, unpack_tensor(indices, ctx.dim0, *grad_outputs) - if len(grad_outputs) == 2: - return None, *unpack_2_tensors(indices, ctx.dim0, *grad_outputs) - return None, *unpack_3_tensors(indices, ctx.dim0, *grad_outputs) - - -class UnpackTensor(torch.autograd.Function): - """ - Autograd function to unpack a tensor. - """ - - @staticmethod - def forward( - ctx, - indices: torch.Tensor, - dim0: int, - tensor: torch.Tensor, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - ctx.save_for_backward(indices) - return unpack_tensor(indices, dim0, tensor) - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - (indices,) = ctx.saved_tensors - return None, None, pack_tensor(indices, grad_output) - - def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm ): @@ -1473,24 +224,52 @@ def flash_attn_p2p_communicate( return send_recv_reqs +@jit_fuser +def flash_attn_fwd_out_correction_init( + out_init_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_init_step: torch.Tensor, + seq_dim: int, +): + """Merge partial outputs of the first step in Attention with context parallelism""" + softmax_lse_corrected_exp = torch.exp(softmax_lse_init_step - softmax_lse).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_init_step * softmax_lse_corrected_exp + return out_corrected.to(out_init_step.dtype) + + @jit_fuser def flash_attn_fwd_out_correction( out: torch.Tensor, out_per_step: torch.Tensor, softmax_lse: torch.Tensor, softmax_lse_per_step: torch.Tensor, - movedim_src: int, - movedim_dst: int, + seq_dim: int, ): """Merge partial outputs of each step in Attention with context parallelism""" - softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim( - movedim_src, movedim_dst - ) + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse).movedim(2, seq_dim) softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) out_corrected = out_per_step * softmax_lse_corrected_exp out.add_(out_corrected) +@jit_fuser +def flash_attn_fwd_second_half_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Merge second half of partial outputs of each step in Attention with context parallelism""" + out_ = out.select(seq_dim, 1) + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1)[..., 1, :] + softmax_lse_corrected_exp = torch.exp(softmax_lse_per_step - softmax_lse_).movedim(2, seq_dim) + softmax_lse_corrected_exp = softmax_lse_corrected_exp.unsqueeze(-1) + out_corrected = out_per_step * softmax_lse_corrected_exp + out_.add_(out_corrected) + + @jit_fuser def flash_attn_fwd_softmax_lse_correction( softmax_lse: torch.Tensor, @@ -1499,10 +278,23 @@ def flash_attn_fwd_softmax_lse_correction( """Merge softmax stats of each step in Attention with context parallelism""" max_scale = torch.max(softmax_lse, softmax_lse_per_step) min_scale = torch.min(softmax_lse, softmax_lse_per_step) - new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) softmax_lse.copy_(new_scale) +@jit_fuser +def flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, +): + """Merge second half of softmax stats of each step in Attention with context parallelism""" + softmax_lse_ = softmax_lse[..., 1, :] + max_scale = torch.max(softmax_lse_, softmax_lse_per_step) + min_scale = torch.min(softmax_lse_, softmax_lse_per_step) + new_scale = max_scale + torch.log1p(torch.exp(min_scale - max_scale)) + softmax_lse_.copy_(new_scale) + + @jit_fuser def get_cu_seqlens_on_cp_rank( cu_seqlens: torch.Tensor, @@ -1529,47 +321,60 @@ def get_cu_seqlens_on_cp_rank( return cu_seqlens_on_cp_rank -@torch.compile -def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): +@jit_fuser +def get_seq_chunk_ids_for_reordering_before_attn(cp_size, device): """ Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. - To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks - before or after CP communications (e.g., all-gather, all-to-all). This function is to compute - sequence chunk ids for reordering. + To make sure tokens are ordered correctly for compute, we need to reorder sequence chunks to + be contigupus before attention compute. This function is to compute sequence chunk ids for + reordering. """ chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) - if to_contiguous: - for rank in range(cp_size): - chunk_ids[rank] = 2 * rank - chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 - else: - for rank in range(cp_size): - chunk_ids[2 * rank] = rank - chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + for rank in range(cp_size): + chunk_ids[rank] = 2 * rank + chunk_ids[rank + cp_size] = 2 * cp_size - 2 * rank - 1 return chunk_ids -@torch.compile -def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): - """Reorder sequence chunk for A2A communication.""" - if before_attn: - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] - x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) - # reorder the sequence chunks - x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) - else: - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] - x = x.movedim(seq_dim, 0).contiguous() - # reorder the sequence chunks - x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] - x = x.view(cp_size, 2, *x.shape[1:]) +@jit_fuser +def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): + """ + Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. + We need to reorder sequence chunks back to discontiguous after attention compute. This function + is to compute sequence chunk ids for reordering. + """ + chunk_ids = torch.empty(2 * cp_size, dtype=torch.int32, device=device) + for rank in range(cp_size): + chunk_ids[2 * rank] = rank + chunk_ids[2 * rank + 1] = 2 * cp_size - rank - 1 + return chunk_ids + + +@jit_fuser +def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication before attention compute.""" + # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + x = x.movedim(0, seq_dim).contiguous() + # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] + # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) + # reorder the sequence chunks + x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) + return x + + +@jit_fuser +def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): + """Reorder sequence chunk for A2A communication after attention compute.""" + # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + x = x.movedim(seq_dim, 0).contiguous() + # reorder the sequence chunks + x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) + # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] + # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -1597,8 +402,8 @@ def flash_attn_a2a_communicate( a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] # reorder the sequence chunks - x = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + x = reorder_seq_chunks_for_a2a_before_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size ) # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] @@ -1624,8 +429,8 @@ def flash_attn_a2a_communicate( # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks - a2a_inputs[i] = reorder_seq_chunks_for_a2a( - x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn + a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( + x, chunk_ids_for_a2a, seq_dim, cp_size ) if i > 1: with torch.cuda.stream(cp_stream): @@ -1641,6 +446,108 @@ def flash_attn_a2a_communicate( return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs +_cu_seqlens_info_with_cp_cache = {} + + +def _get_cu_seqlens_info_with_cp( + batch_size: int, + max_seqlen: int, + cp_size: int, + cu_seqlens: torch.Tensor, +): + """Cumulative sequence lengths with CP being considered.""" + global _cu_seqlens_info_with_cp_cache + if (batch_size, max_seqlen, cp_size) not in _cu_seqlens_info_with_cp_cache: + _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] = ( + cu_seqlens // cp_size, + cu_seqlens // (cp_size * 2), + ) + return _cu_seqlens_info_with_cp_cache[(batch_size, max_seqlen, cp_size)] + + +def get_fa_args( + forward: bool, + use_flash_attn_3: bool, + qkv_format: str, + cu_seqlens_q=None, + cu_seqlens_kv=None, + max_seqlen_q=None, + max_seqlen_kv=None, + dq=None, + dk=None, + dv=None, +): + """Get forward/backward arguments for flash-attn v2 and v3.""" + if use_flash_attn_3: + if forward: + if qkv_format == "thd": + return [ + *[None] * 4, # k_new, v_new, qv, out + cu_seqlens_q, + cu_seqlens_kv, + *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + return [ + *[None] + * 9, # k_new, v_new, qv, out, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_k_new, seqused_q, seqused_k + max_seqlen_q, + max_seqlen_kv, + *[None] + * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + ] + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + return [ + None, # cu_seqlens_q + None, # cu_seqlens_kv + None, # sequed_q + None, # sequed_k + max_seqlen_q, + max_seqlen_kv, + dq, + dk, + dv, + ] + if forward: + if qkv_format == "thd": + return [ + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [] + if qkv_format == "thd": + return [ + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + ] + return [ + dq, + dk, + dv, + ] + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -1679,8 +586,12 @@ def forward( cp_group, cp_global_ranks, cp_stream, + quantizers, + pad_between_seqs, + use_flash_attn_3, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -1710,72 +621,86 @@ def forward( causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type + batch_dim = None seq_dim = None + cu_seqlens_q_half, cu_seqlens_kv_half = None, None if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None + if use_fused_attention: + batch_dim = qkv_format.index("b") + cu_seqlens_q, cu_seqlens_q_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_q, cp_size, cu_seqlens_q + ) + cu_seqlens_kv, cu_seqlens_kv_half = _get_cu_seqlens_info_with_cp( + q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv + ) else: qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format + cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size + cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size - pad_between_seqs_q = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) - pad_between_seqs_kv = not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) max_seqlen_q = max_seqlen_q // cp_size max_seqlen_kv = max_seqlen_kv // cp_size - cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size - cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size cu_seqlens_q_per_step = [None for _ in range(cp_size)] cu_seqlens_kv_per_step = [None for _ in range(cp_size)] - fused_attn_qkv_dtype = None fused_attn_backend = None - amax_per_step = None qkv_dtype = q.dtype + amax_per_step = None + S_quantizer_per_step = [None for _ in range(cp_size)] + O_CP_quantizer_per_step = [None for _ in range(cp_size)] # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + is_output_fp8 = False + + ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + if fp8: if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] + assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q, k, v - q, k, v = q_fp8._data, k_fp8._data, v_fp8._data + QKV_quantizer = q._quantizer + q, k, v = q._data, k._data, v._data else: q_f16, k_f16, v_f16 = q, k, v if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + q = QKV_quantizer(q_f16)._data if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [k_f16, v_f16] - ] - fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O_CP + k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # partial result quantizer + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i] + O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() + O_CP_quantizer_per_step[i].amax = amax_per_step[1][i] else: assert False, "FP8 is only supported with Fused Attention!" else: q_f16 = q if use_fused_attention: - fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) + q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True ) @@ -1783,7 +708,7 @@ def forward( q_f16 = q elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16 = q - q = cast_to_fp8(q_f16, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) + q = QKV_quantizer(q_f16)._data assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 @@ -1795,9 +720,6 @@ def forward( elif qkv_format == "sbhd": # [s, b, np, hn] -> [2, s//2, b, np, hn] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] - total_tokens_kv = None if qkv_format != "thd" else k.shape[0] - # remove padded tokens at the end - k, v = [x if qkv_format != "thd" else x[: cu_seqlens_kv_padded[-1]] for x in [k, v]] if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -1820,25 +742,39 @@ def forward( ) assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" - softmax_lse_in_packed_format = not use_fused_attention and ( - _flash_attn_2_6_0_plus or _use_flash_attn_3 - ) + softmax_lse_in_packed_format = False + if qkv_format == "thd": + if use_fused_attention: + softmax_lse_in_packed_format = get_cudnn_version() >= (9, 6, 0) + else: + softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 + flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if use_flash_attn_3: + flash_attn_fwd = ( + _flash_attn_fwd_v3 # pylint: disable=possibly-used-before-assignment + ) fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_3_plus: + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = (-1, 0) if causal else (-1, -1) - if _flash_attn_2_4_plus: + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = 0 if causal else -1 + if fa_utils.v2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if fa_utils.v2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if fa_utils.v2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 # Flash Attn inputs q_inputs = [None, None] @@ -1856,13 +792,12 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if use_fused_attention and qkv_format in ["bshd", "sbhd"]: + if qkv_format in ["bshd", "sbhd"]: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) else: p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) send_recv_reqs = [[], []] - softmax_lse_ = None out = None for i in range(cp_size + 1): if i < cp_size: @@ -1887,48 +822,39 @@ def forward( kv_inputs[i % 2] = p2p_comm_buffers[i] else: # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = cast_to_fp8( - p2p_comm_buffers[i], - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - ) - if fp8 and use_fused_attention: - fp8_meta_kwargs["amax_s"] = amax_per_step - fp8_meta_kwargs["amax_s_offset"] = i - fp8_meta_kwargs["amax_o"] = amax_per_step - fp8_meta_kwargs["amax_o_offset"] = cp_size + i + kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data if causal: if i == 0: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True ) - else: + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + q_inputs[i % 2] = q if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - q_inputs[i % 2] = q if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -1938,25 +864,43 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] + out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, - fused_attn_backend, + q_part, + k_part, + v_part, + fake_dtype=qkv_dtype, + fused_attention_backend=fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, qkv_layout=qkv_layout, @@ -1973,33 +917,46 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=True, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] elif i <= rank: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, @@ -2008,46 +965,68 @@ def forward( True, False, ) - else: + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2][0] + elif qkv_format == "thd": + q_inputs[i % 2] = q + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_inputs[i % 2] = tex.thd_read_half_tensor( + kv_inputs[i % 2], cu_seqlens_kv_padded, 0 + ) if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...].contiguous() - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0].contiguous() - elif qkv_format == "thd": - q_inputs[i % 2] = q - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv // 2, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2069,43 +1048,53 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) - if qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, :, 0, ...].contiguous() - # [2, b, sk//2, np, hn] -> [2, b*sk//2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv // 2, + ) + if use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv // 2, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] else: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True ) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, @@ -2114,28 +1103,33 @@ def forward( True, True, ) - else: + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q_half + cu_seqlens_kv_per_step[i] = cu_seqlens_kv + if qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_inputs[i % 2] = q[:, 1, ...] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + k.shape[0], -1, 2, *k.shape[-2:] + ) + elif qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_inputs[i % 2] = q[1] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_inputs[i % 2] = kv_inputs[i % 2].view( + -1, k.shape[2], 2, *k.shape[-2:] + ) + elif qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_inputs[i % 2] = tex.thd_read_half_tensor( + q, cu_seqlens_q_padded, 1 + ) if use_fused_attention: - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i % 2] = q[:, 1, ...].contiguous() - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i % 2] = q[1].contiguous() - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) + q_inputs[i % 2] = q_inputs[i % 2].contiguous() if attn_bias is not None: idx = (rank - i) % cp_size attn_bias_inputs[i % 2] = torch.cat( @@ -2145,24 +1139,41 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q_inputs[i % 2] + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q // 2, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2184,43 +1195,53 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) - else: - # [b, 2, sq//2, np, hn]->[b, sq//2, np, hn]->[b*sq//2, np, hn] - q_inputs[i % 2] = ( - q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - ) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) - if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q // 2, + max_seqlen_kv=max_seqlen_kv, + ) + if use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 fa_outputs = flash_attn_fwd( q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q // 2, - max_seqlen_kv, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] else: - if pad_between_seqs_q: + if pad_between_seqs: cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True ) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - if pad_between_seqs_kv: cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, @@ -2229,8 +1250,12 @@ def forward( True, True, ) - else: + elif qkv_format == "thd": + cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step[i] = cu_seqlens_q + cu_seqlens_kv_per_step[i] = cu_seqlens_kv if use_fused_attention: if attn_bias is not None: idx = (rank - i) % cp_size @@ -2241,24 +1266,41 @@ def forward( ), dim=-1, ).contiguous() + + q_part = q + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) + fp8_meta_kwargs = {} + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=qkv_dtype, internal=True + ) + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] + fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] out_per_step[i], aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q_per_step[i], cu_seqlens_kv_per_step[i], - q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -2276,25 +1318,41 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: - # [b, sq, np, hn] -> [b*sq, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-2:]) - # [2, b, sk, np, hn] -> [2, b*sk, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view(2, -1, *k.shape[-2:]) + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[i], + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - kv_inputs[i % 2][0], - kv_inputs[i % 2][1], - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv, + q, + ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ), + ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ), + *fa_forward_args_thd, causal=False, **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] if i > 0: # wait until fwd restuls correction of last step is done @@ -2302,32 +1360,21 @@ def forward( flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) if use_fused_attention: - # [b, np, sq, 1] -> [b, np, sq] + # [b, np, sq, 1] -> [b, np, sq] or + # [t, np, 1] -> [t, np] softmax_lse_per_step[i - 1].squeeze_(-1) - if qkv_format != "thd" and softmax_lse_in_packed_format: - # [np, t] -> [np, b, sq] - softmax_lse_per_step[i - 1] = softmax_lse_per_step[i - 1].view( - q.shape[-2], q.shape[0], -1 - ) + if softmax_lse_in_packed_format: + softmax_lse_per_step[i - 1] = ( + softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() + ) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if fp8: - out_per_step[i - 1] = cast_from_fp8( - out_per_step[i - 1], - fp8_meta["scaling_fwd"], - META_O_CP, - fp8_dtype_forward, - TE_DType[torch.float32], - ) + out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) if i == 1: - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) softmax_lse = torch.clone(softmax_lse_per_step[0]).to(torch.double) - if causal and qkv_format != "thd": - # [b, np, sq] -> [b, np, 2, sq//2] lse not in packed format - # [np, b, sq] -> [np, b, 2, sq//2] lse in packed format - softmax_lse_ = softmax_lse.view( - *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 - ) + if qkv_format == "thd": + out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) elif (i - 1) <= rank or not causal: flash_attn_fwd_softmax_lse_correction( softmax_lse, softmax_lse_per_step[i - 1] @@ -2341,8 +1388,9 @@ def forward( softmax_lse_in_packed_format, ) else: - flash_attn_fwd_softmax_lse_correction( - softmax_lse_[..., 1, :], softmax_lse_per_step[i - 1] + flash_attn_fwd_second_half_softmax_lse_correction( + softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), + softmax_lse_per_step[i - 1], ) if i < cp_size: @@ -2350,28 +1398,30 @@ def forward( torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + second_half_lse_seqlen = None + if causal and rank < (cp_size - 1): + second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + softmax_lse = softmax_lse.to(torch.float) for i in range(cp_size): - out_ = None - if qkv_format == "bshd": - out_per_step[i] = out_per_step[i].view( - out.shape[0], -1, *out.shape[-2:] - ) # pylint: disable=used-before-assignment - out_ = out[:, 1, ...] - elif qkv_format == "sbhd": - out_per_step[i] = out_per_step[i].view(-1, *out.shape[-3:]) - out_ = out[1] - if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_out_correction( - out.view(*out_per_step[i].shape), - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - 0 if softmax_lse_in_packed_format else 2, - 2 if softmax_lse_in_packed_format else seq_dim, - ) + if i == 0: + out = flash_attn_fwd_out_correction_init( + out_per_step[0], + softmax_lse, + softmax_lse_per_step[0], + seq_dim, + ) + out = out.view(q.shape) + else: + flash_attn_fwd_out_correction( + out.view(*out_per_step[i].shape), + out_per_step[i], + softmax_lse, + softmax_lse_per_step[i], + seq_dim, + ) elif qkv_format == "thd": tex.thd_out_correction( out, @@ -2384,13 +1434,12 @@ def forward( ) else: if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_out_correction( - out_, + flash_attn_fwd_second_half_out_correction( + out, out_per_step[i], - softmax_lse_[..., 1, :], + softmax_lse, softmax_lse_per_step[i], - 0 if softmax_lse_in_packed_format else 2, - 2 if softmax_lse_in_packed_format else seq_dim, + seq_dim, ) elif qkv_format == "thd": tex.thd_out_correction( @@ -2403,9 +1452,6 @@ def forward( softmax_lse_in_packed_format, ) - if qkv_format != "thd" and softmax_lse_in_packed_format: - # [np, b, sq] -> [np, t] - softmax_lse = softmax_lse.view(softmax_lse.shape[0], -1) kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) @@ -2415,7 +1461,7 @@ def forward( ctx.batch_size = out.shape[1] if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, out.device) out = flash_attn_a2a_communicate( out, chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, False ) @@ -2431,68 +1477,50 @@ def forward( if fp8 and use_fused_attention: amax_cp_fwd = amax_per_step.amax(dim=1) - fp8_meta["scaling_fwd"].amax_history[0][META_S] = amax_cp_fwd[0] - fp8_meta["scaling_fwd"].amax_history[0][META_O_CP] = amax_cp_fwd[1] + S_quantizer.amax = amax_cp_fwd[0] + O_CP_quantizer.amax = amax_cp_fwd[1] out_fp8 = None out_f16 = out.to(qkv_dtype) + if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): - out_fp8 = cast_to_fp8(out_f16, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward) - - if fp8 and is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv_dtype, - ) - else: - out_ret = out_f16 + out_fp8 = O_quantizer(out_f16) # final result + + out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, kv_save, out_save = q, kv, out_fp8 - fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() - fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() + q_save, kv_save, out_save = q, kv, out_fp8._data elif fp8 and is_input_fp8: - q_fp8 = Float8Tensor( - data=q, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=q_fp8.dtype, - ) - kv_fp8 = Float8Tensor( - data=kv, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=k_fp8.dtype, - ) - q_save, kv_save, out_save = q_fp8, kv_fp8, out_f16 - fp8_fwd_scales, fp8_fwd_scale_invs = None, None + q_save, kv_save, out_save = q, kv, out_f16 else: q_f16 = q_f16.view(q.shape) q_save, kv_save, out_save = q_f16, kv, out_f16 - fp8_fwd_scales, fp8_fwd_scale_invs = None, None - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( q_save, kv_save, out_save, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_fwd_scales, - fp8_fwd_scale_invs, *cu_seqlens_q_per_step, *cu_seqlens_kv_per_step, *rng_states, *attn_biases, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.qkv_dtype = qkv_dtype + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.O_CP_quantizer = O_CP_quantizer + ctx.S_quantizer = S_quantizer + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dQKV_CP_quantizer = dQKV_CP_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.cp_group_a2a = cp_group_a2a ctx.cp_size_a2a = cp_size_a2a ctx.rank_a2a = rank_a2a @@ -2500,7 +1528,6 @@ def forward( ctx.cp_global_ranks = cp_global_ranks ctx.cp_stream = cp_stream ctx.dropout_p = dropout_p - ctx.total_tokens_kv = total_tokens_kv ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_kv = max_seqlen_kv ctx.softmax_scale = softmax_scale @@ -2510,15 +1537,21 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format + ctx.second_half_lse_seqlen = second_half_lse_seqlen ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.use_flash_attn_3 = use_flash_attn_3 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a @@ -2528,12 +1561,13 @@ def backward(ctx, dout): recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (cp_size == 2) - (q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded) = ctx.saved_tensors[:6] - (fp8_fwd_scales, fp8_fwd_scale_invs) = ctx.saved_tensors[6:8] - cu_seqlens_q_per_step = ctx.saved_tensors[8 : 8 + cp_size] - cu_seqlens_kv_per_step = ctx.saved_tensors[8 + cp_size : 8 + cp_size * 2] - rng_states = ctx.saved_tensors[8 + cp_size * 2 : 8 + cp_size * 3] - attn_biases = ctx.saved_tensors[8 + cp_size * 3 : 8 + cp_size * 4] + q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( + restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + ) + cu_seqlens_q_per_step = other_tensors[:cp_size] + cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] + rng_states = other_tensors[cp_size * 2 : cp_size * 3] + attn_biases = other_tensors[cp_size * 3 : cp_size * 4] causal = "causal" in ctx.attn_mask_type padding = "padding" in ctx.attn_mask_type @@ -2558,77 +1592,91 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None - softmax_lse_in_packed_format = not ctx.use_fused_attention and ( - _flash_attn_2_6_0_plus or _use_flash_attn_3 - ) - - if causal: - if ctx.qkv_format == "thd" or softmax_lse_in_packed_format: + softmax_lse_ = None + if causal and ctx.second_half_lse_seqlen is not None: + if ctx.qkv_format == "thd": softmax_lse_ = tex.thd_read_second_half_lse( - softmax_lse, cu_seqlens_q_padded, softmax_lse_in_packed_format + softmax_lse, + cu_seqlens_q_padded, + ctx.softmax_lse_in_packed_format, + ctx.second_half_lse_seqlen, ) else: # [b, np, sq] -> [b, np, 2, sq//2] - softmax_lse_ = softmax_lse.view( - *softmax_lse.shape[:-1], 2, softmax_lse.shape[-1] // 2 - ) + softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() - if ctx.use_fused_attention: - # [b, np, sq//2] -> [b, np, sq//2, 1] - softmax_lse_.unsqueeze_(-1) + if ctx.use_fused_attention: + if ctx.softmax_lse_in_packed_format: + softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() + # [b, np, sq//2] -> [b, np, sq//2, 1] or + # [t//2, np] -> [t//2, np, 1] + softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: - # [b, np, sq] -> [b, np, sq, 1] + if ctx.softmax_lse_in_packed_format: + softmax_lse = softmax_lse.transpose(0, 1).contiguous() + # [b, np, sq] -> [b, np, sq, 1] or + # [t, np] -> [t, np, 1] softmax_lse.unsqueeze_(-1) + dout = dout.contiguous() + dq = None dout_dtype = dout.dtype fused_attn_backend = None - fused_attn_qkv_dtype = None fused_attn_dqkv_dtype = None amax_per_step = None - dout_fp8_dtype = None + dP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] if ctx.fp8: if ctx.use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] - dq_fp8 = torch.empty((cp_size, *q.shape), dtype=q.dtype, device=q.device) - dkv_fp8 = torch.empty((cp_size, *kv.shape), dtype=kv.dtype, device=kv.device) + + dqkv_fp8_torch_dtype = get_fp8_torch_dtype( + ctx.fp8_meta["recipe"], fprop_tensor=False + ) + dq_fp8 = torch.empty( + (cp_size, *q.shape), dtype=dqkv_fp8_torch_dtype, device=q.device + ) + dkv_fp8 = torch.empty( + (cp_size, *kv.shape), dtype=dqkv_fp8_torch_dtype, device=kv.device + ) dkv_fp8_ = torch.empty_like(dkv_fp8) if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv - dout = dout._data + ctx.dO_quantizer = dout._quantizer else: - dout = cast_to_fp8( - dout, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] - fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] - fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] - fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] - fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] - fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] - fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV_CP] + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i] + dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() + dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i] else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.is_input_fp8: - q, kv = [x.from_float8(x.dtype) for x in [q, kv]] - if cp_size_a2a == 1: - dout = dout.from_float8(dout_dtype) - else: - dout_fp8_dtype = dout._fp8_dtype - dout_scale_inv = dout._scale_inv - dout = dout._data + if ctx.fp8_meta is not None: + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + kv = ctx.QKV_quantizer.create_tensor_from_data( + kv, fake_dtype=ctx.qkv_dtype, internal=True + ) + q = q.dequantize(dtype=ctx.qkv_dtype) + kv = kv.dequantize(dtype=ctx.qkv_dtype) + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + if cp_size_a2a == 1: + dout = dout.dequantize(dtype=dout_dtype) + else: + ctx.dO_quantizer = dout._quantizer + dout = dout._data dq = torch.empty_like(q) - if ctx.qkv_format == "thd" and causal: - dq[cu_seqlens_q_padded[-1] :].fill_(0) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), @@ -2636,7 +1684,6 @@ def backward(ctx, dout): p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] @@ -2644,7 +1691,9 @@ def backward(ctx, dout): if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, out.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn( + cp_size_a2a, out.device + ) out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, @@ -2655,14 +1704,10 @@ def backward(ctx, dout): True, ) if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = cast_from_fp8( - dout, - None, - None, - dout_fp8_dtype, - TE_DType[dout_dtype], - scale_inv=dout_scale_inv, # pylint: disable=used-before-assignment + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True ) + dout = dout.dequantize(dtype=dout_dtype) out = out.view(*q.shape) dout = dout.view(*q.shape) @@ -2671,16 +1716,23 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.use_flash_attn_3: + flash_attn_bwd = ( + _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + ) fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: + if fa_utils.v2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if fa_utils.v2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 for i in range(cp_size): # wait until KV is received @@ -2720,32 +1772,26 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] - dk_, dv_ = None, None - if ctx.fp8 and ctx.use_fused_attention: - fp8_meta_kwargs["amax_dp"] = amax_per_step[0][i] - fp8_meta_kwargs["amax_dqkv"] = amax_per_step[0][i] + q_, kv_, out_, dout_ = None, None, None, None + dq_, dk_, dv_ = None, None, None # In reversed order of fwd if causal: if i == (cp_size - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + q_, kv_, out_, dout_ = q, kv, out, dout if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_ = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - out_ = out.view(-1, *out.shape[-3:]) - dout_ = dout.view(-1, *dout.shape[-3:]) - elif ctx.qkv_format == "thd": - q_, kv_, out_, dout_ = q, kv, out, dout if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -2756,17 +1802,41 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -2780,59 +1850,72 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) - if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, 0) - if not _use_flash_attn_3: + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, - dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + *fa_backward_args_thd, causal=True, **fa_backward_kwargs, ) elif i >= (cp_size - rank - 1): + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] + q_, out_, dout_ = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] + ] + # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] + kv_ = kv[:, 0] + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq, b, np, hn] + q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] + # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] + kv_ = kv[0] + elif ctx.qkv_format == "thd": + q_, out_, dout_ = q, out, dout + # [2, t, np, hn] -> [2, t/2, np, hn] + kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_ = q.view(q.shape[0], -1, *q.shape[-2:]) - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0, ...].contiguous() - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - out_ = out.view(out.shape[0], -1, *out.shape[-2:]) - dout_ = dout.view(dout.shape[0], -1, *dout.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_ = q.view(-1, *q.shape[-3:]) - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0].contiguous() - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - out_ = out.view(-1, *out.shape[-3:]) - dout_ = dout.view(-1, *dout.shape[-3:]) - elif ctx.qkv_format == "thd": - q_, out_, dout_ = q, out, dout - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) + kv_ = kv_.contiguous() if ctx.fp8: aux_ctx_tensors = [ softmax_lse, @@ -2843,17 +1926,41 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -2869,65 +1976,73 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data else: - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - if ctx.qkv_format == "thd": - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - else: - # [2, b, 2, sk//2, np, hn]->[2, b, sk//2, np, hn]->[2, b*sk//2, np, hn] - kv_ = kv[:, :, 0, ...].contiguous().view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - # [b, 2, sq//2, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) - if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv // 2, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) - if not _use_flash_attn_3: + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse, - dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) else: + if ctx.qkv_format == "bshd": + # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] + q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] + # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] + kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) + elif ctx.qkv_format == "sbhd": + # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + q_, out_, dout_ = q[1], out[1], dout[1] + # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] + kv_ = kv.view(-1, *kv.shape[-4:]) + elif ctx.qkv_format == "thd": + # [t, np, hn] -> [t/2, np, hn] + q_, out_, dout_ = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q, out, dout] + ] + kv_ = kv if ctx.use_fused_attention: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_ = q[:, 1, ...].contiguous() - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - out_ = out[:, 1, ...].contiguous() - dout_ = dout[:, 1, ...].contiguous() - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_ = q[1].contiguous() - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - out_ = out[1].contiguous() - dout_ = dout[1].contiguous() - elif ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) - kv_ = kv + q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] if ctx.fp8: aux_ctx_tensors = [ softmax_lse_, @@ -2938,17 +2053,42 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + + q_part = q_ + k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + out_part = out_ + dout_part = dout_ + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], - out_, - dout_, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -2964,42 +2104,50 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data else: - if ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_ = tex.thd_read_half_tensor(q, cu_seqlens_q_padded, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - q_ = q[:, 1, ...].contiguous().view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, 2, sk//2, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) + dq_ = torch.empty_like(q_) dkv_ = torch.empty_like(kv_) - if ctx.qkv_format == "thd": - out_ = tex.thd_read_half_tensor(out, cu_seqlens_q_padded, 1) - dout_ = tex.thd_read_half_tensor(dout, cu_seqlens_q_padded, 1) - else: - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] -> [b*sq//2, np, hn] - out_ = out[:, 1, ...].contiguous().view(-1, *out.shape[-2:]) - dout_ = dout[:, 1, ...].contiguous().view(-1, *dout.shape[-2:]) - if _use_flash_attn_3 or _flash_attn_2_3_plus: + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q // 2, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ), + dv=( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ), + ) + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): fa_backward_kwargs["window_size"] = (-1, -1) - if not _use_flash_attn_3: + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( dout_, q_, - kv_[0], - kv_[1], + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], out_, softmax_lse_, - dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) @@ -3011,17 +2159,41 @@ def backward(ctx, dout): aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] if attn_dbias is not None: aux_ctx_tensors += [attn_biases[cp_size - i - 1]] + q_part = q + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] dq_, dk_, dv_, dbias_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], cu_seqlens_kv_per_step[cp_size - i - 1], - q, - kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], - kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], - out, - dout, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -3035,51 +2207,52 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + + if ctx.fp8: + dq_ = dq_._data + dk_ = dk_._data + dv_ = dv_._data + else: - # [b, sq, np, hn] -> [b*sq, np, hn] - q_ = q.view(-1, *q.shape[-2:]) - dq_ = torch.zeros_like(q_) - # [2, b, sk, np, hn] -> [2, b*sk, np, hn] - kv_ = kv.view(2, -1, *kv.shape[-2:]) - dkv_ = torch.empty_like(kv_) - # [b, sq, np, hn] -> [b*sq, np, hn] - out_ = out.view(-1, *out.shape[-2:]) - dout_ = dout.view(-1, *dout.shape[-2:]) - if _use_flash_attn_3 or _flash_attn_2_3_plus: + dq_ = torch.empty_like(q) + dkv_ = torch.empty_like(kv) + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq_, + dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], + dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + ) + if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) - if not _use_flash_attn_3: + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] flash_attn_bwd( - dout_, - q_, - kv_[0], - kv_[1], - out_, + dout, + q, + kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], + kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + out, softmax_lse, - dq_, - dkv_[0], - dkv_[1], - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + *fa_backward_args_thd, causal=False, **fa_backward_kwargs, ) if ctx.fp8: dq = dq_fp8[(rank + i + 1) % cp_size] - if i >= (cp_size - rank - 1) or not causal: - # [b*sq, np, hn] -> [b, 2, sq//2, np, hn] if causal - # [b*sq, np, hn] -> [b, sq, np, hn] if not causal + if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): + # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or + # [sq, b, np, hn] -> [2, sq//2, b, np, hn] dq_ = dq_.view(*dq.shape) - else: - if ctx.qkv_format == "bshd": - # [b*sq//2, np, hn] -> [b, sq//2, np, hn] - dq_ = dq_.view(dq.shape[0], *dq.shape[2:]) - elif ctx.qkv_format == "sbhd": - # [b*sq//2, np, hn] -> [sq//2, b, np, hn] - dq_ = dq_.view(-1, *dq.shape[-3:]) if ctx.fp8: if i >= (cp_size - rank - 1) or not causal: @@ -3154,24 +2327,21 @@ def backward(ctx, dout): else: dkv = p2p_comm_buffers[(i + 1) % 2][1] if ctx.use_fused_attention: - dkv_ = torch.cat( - (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 - ) # pylint: disable=used-before-assignment if ctx.qkv_format in ["bshd", "sbhd"]: - # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] - dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - # [2, b*sk//2, np, hn] -> [2, b, sk//2, np, hn] - dkv_ = dkv_.view(*dkv.shape[0:2], *dkv.shape[3:]) - elif ctx.qkv_format == "sbhd": - # [2, b*sk//2, np, hn] -> [2, sk//2, b, np, hn] - dkv_ = dkv_.view(dkv.shape[0], -1, *dkv.shape[-3:]) - else: - # [2, b*sk, np, hn] -> [2, b, 2, sk//2, np, hn] if causal - # [2, b*sk, np, hn] -> [2, b, sk, np, hn] if not causal - dkv_ = dkv_.view(*dkv.shape) + dkv_ = _combine_tensors([dk_, dv_], -2) + elif ctx.qkv_format == "thd": + dkv_ = torch.cat( + (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 + ) # pylint: disable=used-before-assignment + if ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) + dkv_ = dkv_.movedim(-3, 0) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or + # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] + dkv_ = dkv_.view(*dkv.shape) if ctx.fp8: if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): @@ -3223,22 +2393,19 @@ def backward(ctx, dout): if ctx.fp8 and ctx.use_fused_attention: amax_cp_bwd = amax_per_step.amax(dim=1) - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] = amax_cp_bwd[0] - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV_CP] = amax_cp_bwd[1] + ctx.dP_quantizer.amax = amax_cp_bwd[0] + ctx.dQKV_CP_quantizer.amax = amax_cp_bwd[1] if ctx.qkv_format in ["bshd", "sbhd"]: # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dq, dkv = [ - cast_from_fp8( - x, - ctx.fp8_meta["scaling_bwd"], - META_DQKV_CP, - fp8_dtype_backward, - TE_DType[torch.float32], - ) - for x in [dq_fp8, dkv_fp8] - ] + dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dq_fp8, fake_dtype=torch.float32, internal=True + ) + dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( + dkv_fp8, fake_dtype=torch.float32, internal=True + ) + dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] if causal: @@ -3253,23 +2420,17 @@ def backward(ctx, dout): # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) - if ctx.qkv_format == "thd": - dkv_ = torch.empty( - 2, ctx.total_tokens_kv, *dkv.shape[-2:], dtype=dkv.dtype, device=dkv.device - ) - dkv_[:, : cu_seqlens_kv_padded[-1]].copy_(dkv) - dkv_[:, cu_seqlens_kv_padded[-1] :].fill_(0) - dkv = dkv_ + if ctx.qkv_format == "thd" and not ctx.use_fused_attention: + dq[cu_seqlens_q_padded[-1] :].fill_(0) + dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - dq, dkv = [ - cast_to_fp8(x, ctx.fp8_meta["scaling_bwd"], META_DQKV, fp8_dtype_backward) - for x in [dq, dkv] - ] + assert torch.uint8 not in [dq.dtype, dkv.dtype] + dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] dk, dv = dkv[0], dkv[1] if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size_a2a, q.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, @@ -3284,22 +2445,15 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] - if ctx.fp8 and ctx.is_input_fp8: - dq, dk, dv = [ - Float8Tensor( - data=x, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=dout_dtype, - ) - for x in [dq, dk, dv] - ] - if attn_dbias is not None: # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) + # converting torch.uint8 to float8tensor + if ctx.fp8 and ctx.is_input_fp8: + dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) + dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) + dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") return ( None, @@ -3325,6 +2479,9 @@ def backward(ctx, dout): None, None, None, + None, + None, + None, ) @@ -3383,14 +2540,18 @@ def forward( window_size, cp_group, cp_stream, + use_flash_attn_3, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) + qkv_dtype = q.dtype + causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type assert not padding, f"{attn_mask_type} mask type is not supported!" @@ -3399,22 +2560,27 @@ def forward( assert attn_bias_type == "no_bias", f"{attn_bias_type} bias type is not supported!" assert q.shape[-1] % 8 == 0, "Hidden size per attention head should be multiple of 8!" assert ( - use_fused_attention or _flash_attn_2_3_plus + use_fused_attention or fa_utils.v2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if use_flash_attn_3: + flash_attn_fwd = _flash_attn_fwd_v3 else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if fa_utils.v2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if fa_utils.v2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 assert qkv_format != "thd", f"{qkv_format} format is not supported!" qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format @@ -3426,8 +2592,12 @@ def forward( max_seqlen_q = max_seqlen_q // (2 * cp_size) max_seqlen_kv = max_seqlen_kv // (2 * cp_size) - cu_seqlens_q = cu_seqlens_q // (2 * cp_size) - cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + if use_fused_attention or qkv_format == "thd": + cu_seqlens_q = cu_seqlens_q // (2 * cp_size) + if cu_seqlens_q_padded is not None and qkv_format == "thd": + cu_seqlens_q_padded = cu_seqlens_q_padded // (2 * cp_size) + else: + cu_seqlens_q_padded = None # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) @@ -3441,7 +2611,7 @@ def forward( # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] @@ -3482,9 +2652,10 @@ def forward( kv_seq_range_per_step[i][1], ) max_seqlen_kv_ = seq_end_idx - seq_start_idx - cu_seqlens_kv_per_step[i] = _get_full_cu_seqlens( - k.shape[1], max_seqlen_kv_, k.device - ) + if use_fused_attention or qkv_format == "thd": + cu_seqlens_kv_per_step[i] = dpa_utils.get_full_cu_seqlens( + k.shape[1], max_seqlen_kv_, k.device + ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] @@ -3498,7 +2669,7 @@ def forward( q_, k_, v_, - TE_DType[q.dtype], + qkv_dtype, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, attn_scale=softmax_scale, dropout=dropout_p, @@ -3511,30 +2682,45 @@ def forward( window_size=window_size_per_step[i], ) else: - q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv_, + ) + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_forward_kwargs["window_size"] = window_size_per_step[i] + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_forward_kwargs["window_size_right"] = window_size_per_step[i][1] fa_outputs = flash_attn_fwd( q_, k_, v_, - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - max_seqlen_q, - max_seqlen_kv_, + *fa_forward_args_thd, causal=causal, - window_size=window_size_per_step[i], **fa_forward_kwargs, ) - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not _use_flash_attn_3: - rng_states[i] = fa_outputs[7] + if not fa_utils.v2_7_0_plus: + out_per_step[i] = fa_outputs[4] + softmax_lse_per_step[i] = fa_outputs[5] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[7] + else: + out_per_step[i] = fa_outputs[0] + softmax_lse_per_step[i] = fa_outputs[1] + if not use_flash_attn_3: + rng_states[i] = fa_outputs[3] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": - out[:, i - 1].copy_(out_per_step[i - 1].view(out[:, i - 1].shape)) + out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": - out[i - 1].copy_(out_per_step[i - 1].view(out[i - 1].shape)) + out[i - 1].copy_(out_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) @@ -3557,6 +2743,8 @@ def forward( *softmax_lse_per_step, *rng_states, ) + + ctx.qkv_dtype = qkv_dtype ctx.kv_seq_range_per_step = kv_seq_range_per_step ctx.window_size_per_step = window_size_per_step ctx.cp_group = cp_group @@ -3569,19 +2757,23 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.use_flash_attn_3 = use_flash_attn_3 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") return out @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) - (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = ctx.saved_tensors[:5] - cu_seqlens_kv_per_step = ctx.saved_tensors[5:7] - out_per_step = ctx.saved_tensors[7:9] - softmax_lse_per_step = ctx.saved_tensors[9:11] - rng_states = ctx.saved_tensors[11:13] + (*saved_tensors,) = ctx.saved_tensors + (q, k, v, cu_seqlens_q, cu_seqlens_q_padded) = saved_tensors[:5] + cu_seqlens_kv_per_step = saved_tensors[5:7] + out_per_step = saved_tensors[7:9] + softmax_lse_per_step = saved_tensors[9:11] + rng_states = saved_tensors[11:13] kv_seq_range_per_step = ctx.kv_seq_range_per_step window_size_per_step = ctx.window_size_per_step @@ -3608,7 +2800,7 @@ def backward(ctx, dout): # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, k.device, True) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] @@ -3621,16 +2813,21 @@ def backward(ctx, dout): flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.use_flash_attn_3: + flash_attn_bwd = _flash_attn_bwd_v3 fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _flash_attn_2_4_plus: + if fa_utils.v2_4_plus: fa_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: + if fa_utils.v2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if fa_utils.v2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -3660,7 +2857,7 @@ def backward(ctx, dout): v_, out_, dout_, - TE_DType[q.dtype], + ctx.qkv_dtype, TE_DType[dout.dtype], aux_ctx_tensors, tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, @@ -3675,13 +2872,30 @@ def backward(ctx, dout): deterministic=ctx.deterministic, ) else: - batch_size = k_.shape[0] - q_, k_, v_ = [x.view(-1, *x.shape[-2:]) for x in [q_, k_, v_]] dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ torch.empty_like(x) for x in [q_, k_, v_] ] - if not _use_flash_attn_3: + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv_per_step[i], + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + dq=dq_per_step[i], + dk=dk_per_step[i], + dv=dv_per_step[i], + ) + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[i] + if ctx.use_flash_attn_3 or ( + fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus + ): + fa_backward_kwargs["window_size"] = window_size_per_step[i] + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = window_size_per_step[i][0] + fa_backward_kwargs["window_size_right"] = window_size_per_step[i][1] flash_attn_bwd( dout_, q_, @@ -3689,24 +2903,10 @@ def backward(ctx, dout): v_, out_, softmax_lse_per_step[i], - dq_per_step[i], - dk_per_step[i], - dv_per_step[i], - cu_seqlens_q, - cu_seqlens_kv_per_step[i], - ctx.max_seqlen_q, - max_seqlen_kv, + *fa_backward_args_thd, causal="causal" in ctx.attn_mask_type, - window_size=window_size_per_step[i], **fa_backward_kwargs, ) - # [b*sq//2, np, hn] -> [b, sq//2, np, hn] - dq_per_step[i] = dq_per_step[i].view(dq[:, i].shape) - # [b*s_range, np, hn] -> [b, s_range, np, hn] - dk_per_step[i], dv_per_step[i] = [ - x.view(batch_size, -1, *x.shape[-2:]) - for x in [dk_per_step[i], dv_per_step[i]] - ] if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): @@ -3736,7 +2936,7 @@ def backward(ctx, dout): # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) - chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering(cp_size, dk.device, False) + chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] @@ -3748,6 +2948,7 @@ def backward(ctx, dout): dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) dk = dk.movedim(0, seq_dim).contiguous() dv = dv.movedim(0, seq_dim).contiguous() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( None, @@ -3769,6 +2970,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3804,12 +3006,16 @@ def forward( fp8_meta, cp_group, cp_stream, + quantizers, + use_flash_attn_3, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) + qkv_dtype = q.dtype causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -3820,25 +3026,33 @@ def forward( window_size == (-1, 0) or window_size == (-1, -1) or use_fused_attention - or _flash_attn_2_3_plus + or fa_utils.v2_3_plus ), "Sliding window attention only can work with FusedAttention or FlashAttention >= 2.3!" flash_attn_fwd = None if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} - if _use_flash_attn_3: - flash_attn_fwd = flash_attn_varlen_fwd_v3 + if use_flash_attn_3: + flash_attn_fwd = _flash_attn_fwd_v3 fa_forward_kwargs["window_size"] = window_size else: - flash_attn_fwd = flash_attn_varlen_fwd + if qkv_format == "thd": + flash_attn_fwd = _flash_attn_varlen_fwd + else: + flash_attn_fwd = _flash_attn_fwd fa_forward_kwargs["dropout_p"] = dropout_p fa_forward_kwargs["return_softmax"] = False - if _flash_attn_2_3_plus: + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size"] = window_size - if _flash_attn_2_4_plus: + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = window_size[0] + fa_forward_kwargs["window_size_right"] = window_size[1] + if fa_utils.v2_4_plus: fa_forward_kwargs["alibi_slopes"] = None - if _flash_attn_2_5_7_plus: + if fa_utils.v2_5_7_plus and qkv_format == "thd": fa_forward_kwargs["block_table"] = None + if fa_utils.v2_6_0_plus: + fa_forward_kwargs["softcap"] = 0.0 assert ( q.shape[-2] % cp_size == 0 and k.shape[-2] % cp_size == 0 @@ -3853,76 +3067,71 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" - qkv_dtype = q.dtype fused_attn_backend = None - fused_attn_qkv_dtype = None # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha + is_output_fp8 = False + + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + ) if fp8: if use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - fused_attn_qkv_dtype = fp8_dtype_forward fused_attn_backend = FusedAttnBackend["FP8"] assert isinstance(k, q.__class__) and isinstance( v, q.__class__ ), "q, k, and v must have the same type." is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv + QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v - q, k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [q_f16, k_f16, v_f16] - ] + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_qkv_offset"] = META_QKV - fp8_meta_kwargs["d_scale_s"] = fp8_meta["scaling_fwd"].scale_inv - fp8_meta_kwargs["d_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_s"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_s_offset"] = META_S - fp8_meta_kwargs["q_scale_o"] = fp8_meta["scaling_fwd"].scale - fp8_meta_kwargs["q_scale_o_offset"] = META_O - fp8_meta_kwargs["amax_s"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta_kwargs["amax_s_offset"] = META_S - fp8_meta_kwargs["amax_o"] = fp8_meta["scaling_fwd"].amax_history - fp8_meta_kwargs["amax_o_offset"] = META_O + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: if use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, q.device) q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v - q, k, v = [ - cast_to_fp8(x, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward) - for x in [q_f16, k_f16, v_f16] - ] + q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] batch_size = q.shape[batch_dim] if use_fused_attention: + q_part, k_part, v_part = q, k, v + if fp8: + q_part = QKV_quantizer.create_tensor_from_data( + q, fake_dtype=qkv_dtype, internal=True + ) + k_part = QKV_quantizer.create_tensor_from_data( + k, fake_dtype=qkv_dtype, internal=True + ) + v_part = QKV_quantizer.create_tensor_from_data( + v, fake_dtype=qkv_dtype, internal=True + ) out, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, - k, - v, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + qkv_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -3935,27 +3144,35 @@ def forward( window_size=window_size, **fp8_meta_kwargs, ) + if fp8: + out = out._data else: - # [b, cp*s, np//cp, hn] -> [b*cp*s, np//cp, hn] - q, k, v = [x.view(-1, *x.shape[-2:]) for x in [q, k, v]] + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + ) fa_outputs = flash_attn_fwd( q, k, v, - cu_seqlens_q, - cu_seqlens_kv, - max_seqlen_q, - max_seqlen_kv, + *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, ) - out, softmax_lse = fa_outputs[4], fa_outputs[5] - rng_state = fa_outputs[7] if not _use_flash_attn_3 else None + if not fa_utils.v2_7_0_plus: + out, softmax_lse = fa_outputs[4], fa_outputs[5] + rng_state = fa_outputs[7] if not use_flash_attn_3 else None + else: + out, softmax_lse = fa_outputs[0], fa_outputs[1] + rng_state = fa_outputs[3] if not use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] - # [b*cp*s, np//cp, hn] -> [b, cp*s, np//cp, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) out = flash_attn_a2a_communicate( out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) @@ -3970,56 +3187,33 @@ def forward( if fp8: if is_output_fp8: - out_fp8 = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv_dtype, + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=False ) - out = out_fp8._data out_ret = out_fp8 + out = out_fp8._data else: - out_f16 = cast_from_fp8( - out, - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - TE_DType[q_f16.dtype], + out_fp8 = O_quantizer.create_tensor_from_data( + out, fake_dtype=qkv_dtype, internal=True ) + out_f16 = out_fp8.dequantize(dtype=qkv_dtype) out_ret = out_f16 else: out_ret = out - if fp8: - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, k_save, v_save, out_save = q, k, v, out - elif is_input_fp8: - q_fp8, k_fp8, v_fp8 = [ - Float8Tensor( - data=x, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_QKV, - fp8_dtype=fp8_dtype_forward, - dtype=out_fp8.dtype, - ) - for x in [q, k, v] - ] - q_save, k_save, v_save, out_save = q_fp8, k_fp8, v_fp8, out_fp8 - else: - q_save, k_save, v_save, out_save = q_f16, k_f16, v_f16, out_f16 - else: + if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_save, k_save, v_save, out_save = q, k, v, out - - if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - fp8_fwd_scales = fp8_meta["scaling_fwd"].scale.clone() - fp8_fwd_scale_invs = fp8_meta["scaling_fwd"].scale_inv.clone() else: - fp8_fwd_scales, fp8_fwd_scale_invs = None, None + if is_input_fp8: + q_save, k_save, v_save = q, k, v + else: + q_save, k_save, v_save = q_f16, k_f16, v_f16 + if is_output_fp8: + out_save = out + else: + out_save = out_f16 - ctx.save_for_backward( + tensors_to_save, tensor_objects = prepare_for_saving( q_save, k_save, v_save, @@ -4028,10 +3222,19 @@ def forward( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fp8_fwd_scales, - fp8_fwd_scale_invs, *aux_ctx_tensors, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.qkv_dtype = qkv_dtype + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer + ctx.S_quantizer = S_quantizer + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.batch_size = batch_size ctx.cp_group = cp_group ctx.cp_stream = cp_stream @@ -4049,108 +3252,154 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.use_flash_attn_3 = use_flash_attn_3 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") cp_size = get_distributed_world_size(ctx.cp_group) - q, k, v, out = ctx.saved_tensors[:4] - cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded = ctx.saved_tensors[ - 4:8 - ] - fp8_fwd_scales, fp8_fwd_scale_invs = ctx.saved_tensors[8:10] - aux_ctx_tensors = ctx.saved_tensors[10:] + ( + q, + k, + v, + out, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *aux_ctx_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") + dout_dtype = dout.dtype fused_attn_backend = None fused_attn_dqkv_dtype = None - fused_attn_qkv_dtype = None - dout_dtype = dout.dtype if ctx.fp8: if ctx.use_fused_attention: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - fused_attn_qkv_dtype = fp8_dtype_forward - fused_attn_dqkv_dtype = fp8_dtype_backward fused_attn_backend = FusedAttnBackend["FP8"] if ctx.is_output_fp8: assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = dout._scale_inv - dout_fp8 = dout - dout = dout_fp8._data + ctx.dO_quantizer = dout._quantizer else: - dout_f16 = dout - dout = cast_to_fp8( - dout_f16, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward - ) + dout = ctx.dO_quantizer(dout) + fused_attn_dqkv_dtype = dout._fp8_dtype + dout = dout._data fp8_meta_kwargs = {} - fp8_meta_kwargs["d_scale_qkv"] = fp8_fwd_scale_invs[META_QKV] - fp8_meta_kwargs["d_scale_s"] = fp8_fwd_scale_invs[META_S] - fp8_meta_kwargs["d_scale_o"] = fp8_fwd_scale_invs[META_O] - fp8_meta_kwargs["d_scale_do"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] - fp8_meta_kwargs["d_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP] - fp8_meta_kwargs["q_scale_s"] = fp8_fwd_scales[META_S] - fp8_meta_kwargs["q_scale_dp"] = ctx.fp8_meta["scaling_bwd"].scale[META_DP] - fp8_meta_kwargs["q_scale_dqkv"] = ctx.fp8_meta["scaling_bwd"].scale[META_DQKV] - fp8_meta_kwargs["amax_dp"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP] - fp8_meta_kwargs["amax_dqkv"] = ctx.fp8_meta["scaling_bwd"].amax_history[0][ - META_DQKV - ] + fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer + fp8_meta_kwargs["dp_quantizer"] = ctx.dP_quantizer + fp8_meta_kwargs["dqkv_quantizer"] = ctx.dQKV_quantizer + else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None and ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - q, k, v, out, dout = [x.from_float8(x.dtype) for x in [q, k, v, out, dout]] + if ctx.fp8_meta is not None: + if ctx.is_output_fp8: + assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" + ctx.dO_quantizer = dout._quantizer + dout = dout._data + if ctx.is_input_fp8: + q = ctx.QKV_quantizer.create_tensor_from_data( + q, fake_dtype=ctx.qkv_dtype, internal=True + ) + k = ctx.QKV_quantizer.create_tensor_from_data( + k, fake_dtype=ctx.qkv_dtype, internal=True + ) + v = ctx.QKV_quantizer.create_tensor_from_data( + v, fake_dtype=ctx.qkv_dtype, internal=True + ) + q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_qkv_dtype = TE_DType[q.dtype] - fused_attn_dqkv_dtype = TE_DType[dout.dtype] + fused_attn_dqkv_dtype = TE_DType[dout_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) dout = dout.view(*out.shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, out.device, True) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device) out, dout = flash_attn_a2a_communicate( [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) + if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: + out = ctx.O_quantizer.create_tensor_from_data( + out, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout = ctx.dO_quantizer.create_tensor_from_data( + dout, fake_dtype=dout_dtype, internal=True + ) + out = out.dequantize(dtype=ctx.qkv_dtype) + dout = dout.dequantize(dtype=dout_dtype) flash_attn_bwd = None if not ctx.use_fused_attention: fa_backward_kwargs = {"softmax_scale": ctx.softmax_scale} - if _use_flash_attn_3: - flash_attn_bwd = flash_attn_varlen_bwd_v3 + if ctx.use_flash_attn_3: + flash_attn_bwd = ( + _flash_attn_bwd_v3 # pylint: disable=possibly-used-before-assignment + ) fa_backward_kwargs["window_size"] = ctx.window_size fa_backward_kwargs["deterministic"] = ctx.deterministic else: - flash_attn_bwd = flash_attn_varlen_bwd + if ctx.qkv_format == "thd": + flash_attn_bwd = _flash_attn_varlen_bwd + else: + flash_attn_bwd = _flash_attn_bwd fa_backward_kwargs["dropout_p"] = ctx.dropout_p - if _flash_attn_2_3_plus: + if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = ctx.window_size - if _flash_attn_2_4_plus: + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = ctx.window_size[0] + fa_backward_kwargs["window_size_right"] = ctx.window_size[1] + if fa_utils.v2_4_plus: fa_backward_kwargs["alibi_slopes"] = None - if _flash_attn_2_4_1_plus: + if fa_utils.v2_4_1_plus: fa_backward_kwargs["deterministic"] = ctx.deterministic + if fa_utils.v2_6_0_plus: + fa_backward_kwargs["softcap"] = 0.0 if ctx.use_fused_attention: + q_part = q + k_part = k + v_part = v + out_part = out + dout_part = dout + + if ctx.fp8: + q_part = ctx.QKV_quantizer.create_tensor_from_data( + q_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + k_part = ctx.QKV_quantizer.create_tensor_from_data( + k_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + v_part = ctx.QKV_quantizer.create_tensor_from_data( + v_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + out_part = ctx.O_quantizer.create_tensor_from_data( + out_part, fake_dtype=ctx.qkv_dtype, internal=True + ) + dout_part = ctx.dO_quantizer.create_tensor_from_data( + dout_part, fake_dtype=dout_dtype, internal=True + ) + dq, dk, dv, _ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, - q, - k, - v, - out, - dout, - fused_attn_qkv_dtype, + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_dtype, fused_attn_dqkv_dtype, aux_ctx_tensors, fused_attn_backend, @@ -4165,11 +3414,26 @@ def backward(ctx, dout): deterministic=ctx.deterministic, **fp8_meta_kwargs, ) + if ctx.fp8: + dq = dq._data + dk = dk._data + dv = dv._data else: softmax_lse, rng_state = aux_ctx_tensors - out, dout = [x.view(-1, *x.shape[-2:]) for x in [out, dout]] dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] - if not _use_flash_attn_3: + fa_backward_args_thd = get_fa_args( + False, + ctx.use_flash_attn_3, + ctx.qkv_format, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=ctx.max_seqlen_q, + max_seqlen_kv=ctx.max_seqlen_kv, + dq=dq, + dk=dk, + dv=dv, + ) + if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state flash_attn_bwd( dout, @@ -4178,19 +3442,12 @@ def backward(ctx, dout): v, out, softmax_lse, - dq, - dk, - dv, - cu_seqlens_q, - cu_seqlens_kv, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, + *fa_backward_args_thd, causal=causal, **fa_backward_kwargs, ) - dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering(cp_size, q.device, False) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False ) @@ -4201,29 +3458,18 @@ def backward(ctx, dout): dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if ctx.fp8: - if ctx.is_input_fp8: - dq, dk, dv = [ - Float8Tensor( - data=x, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=dout_dtype, - ) - for x in [dq, dk, dv] - ] - else: - dq, dk, dv = [ - cast_from_fp8( - x, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - TE_DType[dout_dtype], - ) - for x in [dq, dk, dv] - ] + dq = ctx.dQKV_quantizer.create_tensor_from_data( + dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dk = ctx.dQKV_quantizer.create_tensor_from_data( + dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + dv = ctx.dQKV_quantizer.create_tensor_from_data( + dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 + ) + if not ctx.is_input_fp8: + dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( None, @@ -4250,6 +3496,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -4279,6 +3526,9 @@ def attn_forward_func_with_cp( window_size=None, fp8=False, fp8_meta=None, + quantizers=None, + pad_between_seqs=False, + use_flash_attn_3=False, ) -> torch.Tensor: """ Attention implementation with context parallelism. @@ -4308,30 +3558,21 @@ def attn_forward_func_with_cp( assert ( qkv_format != "sbhd" or use_fused_attention ), "FlashAttention does not support sbhd format!" - assert ( - qkv_format != "thd" - or not use_fused_attention - or attn_mask_type in ["padding", "padding_causal"] - ), ( - f"Context parallelism is not supported for {attn_mask_type} mask type and " - f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!" - ) assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( """Attention bias is only supported with FusedAttention and "causal" """ """or "no_mask" mask types!""" ) - assert ( + assert qkv_format != "thd" or ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_q_padded and cu_seqlens_kv_padded cannot be None with context parallelism!" + ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) ) - assert ( - not sliding_window_attn - or cp_comm_type == "a2a" - or (cp_comm_type == "all_gather" and not use_fused_attention) - ), "The context parallel running configs cannot support sliding window attetnion!" + assert not sliding_window_attn or cp_comm_type in [ + "a2a", + "all_gather", + ], "The context parallel running configs cannot support sliding window attetnion!" args = [ is_training, @@ -4355,15 +3596,24 @@ def attn_forward_func_with_cp( ] if cp_comm_type in ["p2p", "a2a+p2p"]: - args += [fp8, fp8_meta, cp_group, cp_global_ranks, cp_stream] + args += [ + fp8, + fp8_meta, + cp_group, + cp_global_ranks, + cp_stream, + quantizers, + pad_between_seqs, + use_flash_attn_3, + ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": args.pop(5) args.pop(8) - args += [window_size, cp_group, cp_stream] + args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream] + args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") @@ -4371,221 +3621,6 @@ def attn_forward_func_with_cp( return out -class RotaryPositionEmbedding(torch.nn.Module): - """ - Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. - """ - - def __init__( - self, - dim: int, - rotary_percent: float = 1.0, - seq_len_interpolation_factor: Optional[int] = None, - pretrained_max_position_embeddings: Optional[int] = None, - rotary_base: float = 10000.0, - ): - """ - Parameters - ---------- - dim: int - rotary embedding dimension - rotary_percent: float - Percent of rotary dimension to use for rotary position embeddings. - seq_len_interpolation_factor: int - if not None, discrete positions will be interpolated by this factor via the trick in - https://arxiv.org/abs/2306.15595 - pretrained_max_position_embeddings: int - pre-trained max_position_embeddings before position interpolation - """ - super().__init__() - if rotary_percent < 1.0: - dim = int(dim * rotary_percent) - self.seq_len_interpolation_factor = seq_len_interpolation_factor - self.rotary_base = rotary_base - inv_freq = 1.0 / ( - self.rotary_base - ** ( - torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) - / dim - ) - ) - self.register_buffer("inv_freq", inv_freq) - self.pretrained_max_position_embeddings = pretrained_max_position_embeddings - - def forward(self, max_seq_len: int, offset: int = 0): - """ - Create rotary position embedding frequencies - - Parameters - ---------- - max_seq_len: int - sequence length of a sample - offset: int, default = 0 - fixed offset for freqencies - """ - seq = ( - torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - + offset - ) - - if ( - self.pretrained_max_position_embeddings is not None - and self.seq_len_interpolation_factor is not None - ): - if ( - max_seq_len - > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor - ): - # dynamic linear scaling (length > position we have learned) - seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) - else: - # fixed linear scaling - seq *= 1 / self.seq_len_interpolation_factor - - freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) - # first part even vector components, second part odd vector components, - # 2 * dim in dimension size - emb = torch.cat((freqs, freqs), dim=-1) - # emb [seq_length, .., dim] - return emb.reshape(emb.size(0), 1, 1, emb.size(1)) - - -class FusedRoPEFunc(torch.autograd.Function): - """ - Function for FusedRoPE - - This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and - the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid - the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. - """ - - @staticmethod - def forward( - ctx, - t: torch.Tensor, - freqs: torch.Tensor, - tensor_format: str = "sbhd", - cu_seqlens: Union[torch.Tensor, None] = None, - cp_size: int = 1, - cp_rank: int = 0, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - if freqs.dtype != torch.float32: - freqs = freqs.float() - if tensor_format == "sbhd": - output = tex.fused_rope_forward(t, freqs, False) - elif tensor_format == "bshd": - output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) - elif tensor_format == "thd": - output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) - else: - raise ValueError(f"Unsupported tensor_format: {tensor_format}.") - ctx.save_for_backward(freqs, cu_seqlens) - ctx.tensor_format = tensor_format - ctx.cp_size = cp_size - ctx.cp_rank = cp_rank - - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: - # pylint: disable=missing-function-docstring - freqs, cu_seqlens = ctx.saved_tensors - if ctx.tensor_format == "sbhd": - grad_input = tex.fused_rope_backward(grad_output, freqs, False) - elif ctx.tensor_format == "bshd": - grad_input = tex.fused_rope_backward( - grad_output.transpose(0, 1), freqs, True - ).transpose(0, 1) - elif ctx.tensor_format == "thd": - grad_input = tex.fused_rope_thd_backward( - grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank - ) - else: - raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") - - return grad_input, None, None, None, None, None - - -def _rotate_half(x: torch.Tensor) -> torch.Tensor: - """ - change sign so the last dimension becomes [-odd, +even] - """ - x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) - x1, x2 = x.unbind(dim=-2) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, - tensor_format: str = "sbhd", - fused: bool = False, - cu_seqlens: Union[torch.Tensor, None] = None, - cp_size: int = 1, - cp_rank: int = 0, -) -> torch.Tensor: - """ - Apply rotary positional embedding tensor to the input tensor. - - Parameters - ---------- - t: torch.Tensor - Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which - rotary positional embedding will be applied. - freqs: torch.Tensor - Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', - with `s2 >= s` and `d2 <= d`. - fused: bool, default = False - Whether to use a fused applying RoPE implementation. - tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' - is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is - of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. - cu_seqlens: torch.Tensor, default = None. - Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and - dtype torch.int32. Only valid when `tensor_format` is 'thd'. - Should be `cu_seqlens_padded` when cp_size > 1. - cp_size: int, default = 1. - Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. - cp_rank: int, default = 0. - Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. - """ - if fused: - assert ( - tensor_format != "thd" or cu_seqlens is not None - ), "cu_seqlens must not be None when tensor_format is 'thd'." - return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) - - assert tensor_format in ("sbhd", "bshd"), ( - "Only formats `sbhd` or `bshd` are supported for input tensor `t` " - f"when fused is False, got {tensor_format}." - ) - - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # Only apply the rotary embeddings up to the sequence length of the running - # input. - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - if tensor_format == "bshd": - freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] - # cos/sin first then dtype conversion for better precision - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) - - rot_dim = freqs.shape[-1] - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t[..., :rot_dim], t[..., rot_dim:] - - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) - - class _SplitAlongDim(torch.autograd.Function): """""" @@ -4595,15 +3630,34 @@ def forward( mixed_x_layer: torch.Tensor, split_dim: int, split_size_or_sections: Union[int, List[int], Tuple[int]], + squeeze=False, ) -> Tuple[torch.Tensor, ...]: # pylint: disable=missing-function-docstring ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections + if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance( + mixed_x_layer, Float8Tensor + ): + return tuple( + Float8TensorBase( + fp8_scale_inv=mixed_x_layer._scale_inv, + fp8_dtype=mixed_x_layer._fp8_dtype, + data=x.squeeze(split_dim) if squeeze else x, + shape=x.squeeze(split_dim).shape if squeeze else x.shape, + quantizer=mixed_x_layer._quantizer, + ) + for x in torch.split( + mixed_x_layer._data, + split_size_or_sections=split_size_or_sections, + dim=split_dim, + ) + ) if isinstance(mixed_x_layer, Float8Tensor): return tuple( Float8Tensor.make_like( mixed_x_layer, - data=x, + data=x.squeeze(split_dim) if squeeze else x, + shape=x.squeeze(split_dim).shape if squeeze else x.shape, ) for x in torch.split( mixed_x_layer._data, @@ -4611,7 +3665,10 @@ def forward( dim=split_dim, ) ) - return torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim) + out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim) + if squeeze: + out_list = [x.squeeze(split_dim) for x in out_list] + return out_list @staticmethod def backward(ctx, *grad_outputs): @@ -4657,13 +3714,17 @@ def backward(ctx, *grad_outputs): new_shape, strides, ) - return Float8Tensor.make_like(grad_outputs[0], data=ret), None, None + return ( + Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape), + None, + None, + ) grad_outputs_data = [x._data for x in grad_outputs] + data = torch.cat(grad_outputs_data, dim=split_dim) return ( - Float8Tensor.make_like( - grad_outputs[0], data=torch.cat(grad_outputs_data, dim=split_dim) - ), + Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape), + None, None, None, ) @@ -4740,72 +3801,62 @@ def forward( cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + window_size: Optional[Tuple[int, int]] = None, core_attention_bias_type: str = "no_bias", core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, ) -> torch.Tensor: """Unfused attention fprop""" assert ( qkv_layout in QKVLayouts ), f"UnfusedDotProductAttention does not support qkv_layout = {qkv_layout}!" - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + + # get q_format and kv_format for training and inference + qkv_format, q_format, _ = dpa_utils.get_qkv_format(qkv_layout, inference_params) + if inference_params is not None and inference_params.is_paged: + key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) + if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + if qkv_format == "sbhd_2bshd": + key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]] + + total_tokens, batch_size = None, None + if qkv_format == "thd_2bshd": + total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0] + query_layer = tex.convert_thd_to_bshd( + query_layer, + cu_seqlens_q, + batch_size, + inference_params.max_ctx_len, + ) + query_layer, key_layer, value_layer = [ + x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] + ] batch_size, max_seqlen_q, max_seqlen_kv = ( query_layer.shape[1], query_layer.shape[0], key_layer.shape[0], ) - if "padding" in attn_mask_type: - if self.attention_type == "self": - assert attention_mask.shape == ( - batch_size, - 1, - 1, - max_seqlen_q, - ), "attention_mask should be a single tensor with [b, 1, 1, sq] shape!" - attention_mask = torch.logical_or( - attention_mask.squeeze(1).unsqueeze(3), attention_mask - ) - else: - assert ( - len(attention_mask) == 2 - and attention_mask[0].shape == (batch_size, 1, 1, max_seqlen_q) - and attention_mask[1].shape == (batch_size, 1, 1, max_seqlen_kv) - ), ( - "attention_mask should be a tuple of two tensors with shapes " - "[b, 1, 1, sq] and [b, 1, 1, skv]!" - ) - attention_mask = torch.logical_or( - attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] - ) - mask = attention_mask.squeeze(1).logical_not() - actual_seqlens_q = mask[:, :, 0].sum(dim=1) - actual_seqlens_kv = mask[:, 0, :].sum(dim=1) - mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( - 1, 1, max_seqlen_q, 1 - ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( - 1, 1, 1, max_seqlen_kv + + if "padding" in attn_mask_type and attention_mask is None: + attention_mask = dpa_utils.get_padding_mask( + batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv ) - if attn_mask_type == "padding_causal": - attention_mask = torch.logical_or( - torch.where(mask.view(1, 1, max_seqlen_q, max_seqlen_kv) < 0, 1, 0), - attention_mask, - ) - if attn_mask_type == "padding_causal_bottom_right": - attention_mask = torch.logical_or( - torch.where( - mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) - + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) - < 0, - 1, - 0, - ), - attention_mask, - ) + attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( + dpa_utils.get_full_mask( + max_seqlen_q, + max_seqlen_kv, + attn_mask_type=attn_mask_type, + attention_mask=attention_mask, + window_size=window_size, + attention_type=self.attention_type, + ) + ) batch_size, seqlen = query_layer.shape[1], query_layer.shape[0] apply_qk_layer_scaling = self.apply_qk_layer_scaling and key_layer.dtype == torch.float16 @@ -4835,19 +3886,14 @@ def forward( key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1) # preallocting result tensor: [b * np, sq, sk] - # WAR to set dtype to FP32 as ONNX lacks BF16 support for ConstantOfShape operator - is_bf16 = query_layer.dtype == torch.bfloat16 matmul_result = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], - dtype=torch.float32 if is_in_onnx_export_mode() and is_bf16 else query_layer.dtype, + dtype=query_layer.dtype, device=torch.cuda.current_device(), ) - if is_in_onnx_export_mode() and is_bf16: - matmul_result = matmul_result.bfloat16() - scale = self.softmax_scale if apply_qk_layer_scaling: scale /= self.layer_number @@ -4875,7 +3921,8 @@ def forward( if core_attention_bias_type == "post_scale_bias": assert core_attention_bias is not None, "core_attention_bias should not be None!" if core_attention_bias_type == "alibi": - _, core_attention_bias = get_alibi( + _, core_attention_bias = dpa_utils.get_alibi( + _alibi_cache, output_size[1], output_size[2], output_size[3], @@ -4932,20 +3979,34 @@ def forward( # change view [b, np, sq, hn] context_layer = context_layer.view(*output_size) - if qkv_format == "sbhd": + if q_format == "sbhd": # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] context_layer = context_layer.view(seqlen, batch_size, -1) - if qkv_format == "bshd": + if q_format == "bshd": # [b, np, sq, hn] --> [b, sq, np, hn] context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [b, sq, np, hn] --> [b, sq, hp] context_layer = context_layer.view(batch_size, seqlen, -1) + if q_format == "thd": + # [b, np, sq, hn] --> [b, sq, np, hn] + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + + # [b, sq, np, hn] --> [tq, np, hn] + context_layer = tex.convert_bshd_to_thd( + context_layer, + cu_seqlens_q, + total_tokens, + ) + + # [tq, np, hn] --> [tq, hp] + context_layer = context_layer.view(total_tokens, -1) + return context_layer @@ -4984,205 +4045,9 @@ def backward( return dq, dk, dv -def get_qkv_layout( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - qkv_format: str = "sbhd", -) -> str: - """Get qkv layout. - - Parameters - ---------- - q: torch.Tensor - Query tensor. - k: torch.Tensor - Key tensor. - v: torch.Tensor - Value tensor. - qkv_format: str, default = `sbhd` - Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for - the sequence length dimension, `b` batch size, `h` the number of attention heads, - `d` head size, and `t` the total number of tokens in a batch, i.e. - `t = sum(s_i) for i = 0...b-1`. - - Returns - ---------- - qkv_layout: str - Memory layout of `q`, `k` and `v`. Each `qkv_format` can be mapped to one of five - memory layouts. For example, `sb3hd` means `q`, `k`, `v` are created as one chunk - of memory and that they are interleaved in the `2`nd dimension. `sbhd_sbh2d` means - `q` and `kv` are created in two chunks and that `q` itself is contiguous and `k`, `v` - are interleaved with each other in the `3`rd dimension, `k = kv[:,:,:,0,:]` and - `v = kv[:,:,:,1,:]`. - Mapping: - `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`} - `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`} - `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} - q: torch.Tensor - Query tensor. It may be different from input `q` as we try to fit tensors to - a supported layout. - k: torch.Tensor - Key tensor. It may be different from input `k` as we try to fit tensors to - a supported layout. - v: torch.Tensor - Value tensor. It may be different from input `v` as we try to fit tensors to - a supported layout. - """ - - check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) - assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" - - def run_iteratively(q, k, v): - # check data pointers - data_ptr = q.untyped_storage().data_ptr() - check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) - check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) - data_ptr = k.untyped_storage().data_ptr() - check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) - - # check tensor shapes - shape = q.shape - check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) - shape = k.shape - check_shapes_kv = shape[:-1] == v.shape[:-1] - - # check tensor strides - stride = q.stride() - check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) - check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( - sv / v.shape[-1] for sv in v.stride()[:-1] - ) - - # check tensor offsets for h3d and 3hd layouts - prod_h_d = q.shape[-1] * q.shape[-2] - check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v])) - check_h3d_offsets = all( - x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v]) - ) - - # check tensor offsets for hd_h2d and hd_2hd layouts - prod_all_dims = [np.prod(x.shape) for x in [q, k]] - offset = prod_all_dims[0] if check_ptrs_qkv else 0 - prod_h_d = k.shape[-1] * k.shape[-2] - check_2hd_offsets = all( - x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v]) - ) - check_h2d_offsets = all( - x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v]) - ) - - # check tensor offsets for hd_hd_hd layouts - check_hd_offsets_qkv = ( - all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v])) - if check_ptrs_qkv - else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v])) - ) - check_hd_offsets_qk = ( - all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k])) - if not check_ptrs_qkv and check_ptrs_qk - else all(x.storage_offset() == 0 for i, x in enumerate([q, k])) - ) - check_hd_offsets_kv = ( - all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v])) - if not check_ptrs_qkv and check_ptrs_kv - else all(x.storage_offset() == 0 for i, x in enumerate([k, v])) - ) - - if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets: - # sb3hd, bs3hd, t3hd - # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv - qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:] - elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets: - # sbh3d, bsh3d, th3d - # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv - qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:] - elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets: - # sbhd_sb2hd, bshd_bs2hd, thd_t2hd - # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv - # q and kv may be disjoint or consecutive in memory, and when consecutive, they may - # have the same data pointer, i.e. check_ptrs_qkv=True - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] - elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets: - # sbhd_sbh2d, bshd_bsh2d, thd_th2d - # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv - # q and kv may be disjoint or consecutive in memory, and when consecutive, they may - # have the same data pointer, i.e. check_ptrs_qkv=True - qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:] - elif ( - check_strides_kv - and check_shapes_kv - and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) - ): - # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd - # three chunks of memory, q, k and v, which may be disjoint or consecutive, and - # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or - # check_ptrs_qk=True or check_ptrs_kv=True - qkv_layout = "_".join(list([qkv_format]) * 3) - else: - qkv_layout = "not_supported" - - return qkv_layout - - qkv_layout = run_iteratively(q, k, v) - if qkv_layout == "not_supported": - # force q,k,v to be contiguous and run get_layout again - q, k, v = [x.contiguous() for x in [q, k, v]] - qkv_layout = run_iteratively(q, k, v) - if qkv_layout == "not_supported": - raise RuntimeError("The provided qkv memory layout is not supported!") - - return qkv_layout, q, k, v - - -def check_set_window_size( - attn_mask_type: str, - window_size: Tuple[int, int] = None, -): - """Check if sliding window size is compliant with attention mask type. - If not, set it to the appropriate size. - - attn_mask_type | window_size - ------------------------------------------------------------------------- - no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0) - causal, padding_causal | (-1, 0) or (>=0, 0) - causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0) - """ - orig_window_size = window_size - if "causal" in attn_mask_type: - if orig_window_size is None: - window_size = (-1, 0) - elif orig_window_size == (-1, -1) or ( - orig_window_size[0] >= 0 and orig_window_size[1] != 0 - ): - window_size = (orig_window_size[0], 0) - warnings.warn( - "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type - ) - elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): - assert False, ( - "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type - ) - elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: - if orig_window_size is None: - window_size = (-1, -1) - elif orig_window_size == (-1, 0): - window_size = (-1, -1) - warnings.warn( - "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type - ) - elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): - assert False, ( - "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type - ) - else: - assert False, "Invalid attn_mask_type: " + attn_mask_type - return window_size - - -class FlashAttention(torch.nn.Module): - """Dot product attention, using HazyResearch flash-attn package: - https://github.com/Dao-AILab/flash-attention +class FlashAttention(torch.nn.Module): + """Dot product attention, using HazyResearch flash-attn package: + https://github.com/Dao-AILab/flash-attention """ def __init__( @@ -5196,13 +4061,13 @@ def __init__( ) -> None: super().__init__() - if _flash_attn_is_installed: + if fa_utils.is_installed: assert ( - _flash_attn_version >= _flash_attn_version_required - ), f"FlashAttention minimum version {_flash_attn_version_required} is required." + fa_utils.version >= fa_utils.version_required + ), f"FlashAttention minimum version {fa_utils.version_required} is required." assert ( - _flash_attn_version <= _flash_attn_max_version - ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." + fa_utils.version <= fa_utils.max_version + ), f"FlashAttention maximum version {fa_utils.max_version} is supported." self.softmax_scale = softmax_scale self.attention_dropout_ctx = attention_dropout_ctx @@ -5211,9 +4076,9 @@ def __init__( self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic self.logger = logging.getLogger("FlashAttention") - self.logger.setLevel(_log_level) + self.logger.setLevel(attn_log._log_level) if not self.logger.hasHandlers(): - self.logger.addHandler(_stream_handler) + self.logger.addHandler(attn_log._stream_handler) def forward( self, @@ -5235,6 +4100,9 @@ def forward( cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, + inference_params: Optional[InferenceParams] = None, + flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), ) -> torch.Tensor: """flash-attn fprop""" @@ -5257,8 +4125,10 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + # get q_format and kv_format for training and inference + qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params) + # convert q, k, v to bshd if they are in sbhd; qkv_format doesn't change if all(not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): if qkv_format == "sbhd": # For now just 128, will make it more general in the future @@ -5272,8 +4142,11 @@ def forward( ) else: query_layer, key_layer, value_layer = [ - x.transpose(0, 1) for x in (query_layer, key_layer, value_layer) + x.transpose(0, 1).contiguous() + for x in (query_layer, key_layer, value_layer) ] + elif q_format == "sbhd" and kv_format == "bshd": + query_layer = query_layer.transpose(0, 1).contiguous() if context_parallel: query_layer, key_layer, value_layer = [ x.contiguous() for x in (query_layer, key_layer, value_layer) @@ -5281,85 +4154,129 @@ def forward( else: if qkv_format == "sbhd": query_layer._data, key_layer._data, value_layer._data = [ - x.transpose(0, 1) + x.transpose(0, 1).contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] query_layer, key_layer, value_layer = [ - Float8Tensor.make_like(x, data=x._data) + Float8Tensor.make_like(x, data=x._data, shape=x._data.shape) for x in (query_layer, key_layer, value_layer) ] + elif q_format == "sbhd" and kv_format == "bshd": + query_layer._data = query_layer._data.transpose(0, 1).contiguous() + query_layer = Float8Tensor.make_like( + query_layer, data=query_layer._data, shape=query_layer._data.shape + ) if context_parallel: query_layer._data, key_layer._data, value_layer._data = [ x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] - batch_size = query_layer.shape[0] + # get batch_size, max_seqlen and cu_seqlens + batch_size, context_len = None, None + if inference_params is None: + if qkv_format in ["sbhd", "bshd"]: + batch_size = query_layer.shape[0] + max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size - if qkv_format in ["sbhd", "bshd"]: - max_seqlen_q, max_seqlen_kv = query_layer.shape[1], key_layer.shape[1] - max_seqlen_q *= cp_size - max_seqlen_kv *= cp_size + if "padding" in attn_mask_type: + assert ( + not context_parallel + ), "Padding mask not supported with context parallelism!" - if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" - # [b * s, h, d] - query_layer, key_layer, value_layer = [ - x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) - for x in [query_layer, key_layer, value_layer] - ] + # [b * s, h, d] + query_layer, key_layer, value_layer = [ + x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) + for x in [query_layer, key_layer, value_layer] + ] - if self.attention_type == "self": - assert ( - max_seqlen_q == max_seqlen_kv - ), "Maximum sequence length for Q and KV should be the same." - if cu_seqlens_q is None: + if self.attention_type == "self": assert ( - attention_mask is not None - ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask) + max_seqlen_q == max_seqlen_kv + ), "Maximum sequence length for Q and KV should be the same." + if cu_seqlens_q is None: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( + attention_mask + ) + else: + indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q) + cu_seqlens_kv = cu_seqlens_q + query_layer, key_layer, value_layer = dpa_utils.PackTensors.apply( + indices_q, query_layer, key_layer, value_layer + ) else: - indices_q = get_indices(max_seqlen_q, cu_seqlens_q) - cu_seqlens_kv = cu_seqlens_q - query_layer, key_layer, value_layer = PackTensors.apply( - indices_q, query_layer, key_layer, value_layer - ) + if cu_seqlens_q is None or cu_seqlens_kv is None: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + cu_seqlens_q, indices_q = dpa_utils.get_cu_seqlens_and_indices( + attention_mask[0] + ) + cu_seqlens_kv, indices_kv = dpa_utils.get_cu_seqlens_and_indices( + attention_mask[1] + ) + else: + indices_q = dpa_utils.get_indices(max_seqlen_q, cu_seqlens_q) + indices_kv = dpa_utils.get_indices(max_seqlen_kv, cu_seqlens_kv) + query_layer = dpa_utils.PackTensors.apply(indices_q, query_layer) + key_layer, value_layer = dpa_utils.PackTensors.apply( + indices_kv, key_layer, value_layer + ) else: - if cu_seqlens_q is None or cu_seqlens_kv is None: - assert ( - attention_mask is not None - ), "Please provide attention_mask for padding!" - cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(attention_mask[0]) - cu_seqlens_kv, indices_kv = get_cu_seqlens_and_indices(attention_mask[1]) - else: - indices_q = get_indices(max_seqlen_q, cu_seqlens_q) - indices_kv = get_indices(max_seqlen_kv, cu_seqlens_kv) - query_layer = PackTensors.apply(indices_q, query_layer) - key_layer, value_layer = PackTensors.apply(indices_kv, key_layer, value_layer) - else: - # Cumulative sequence lengths for unpadded data - if cu_seqlens_q is None: - cu_seqlens_q = _get_full_cu_seqlens( - batch_size, - max_seqlen_q, - query_layer.device, + # Cumulative sequence lengths for unpadded data + if cu_seqlens_q is None: + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, + ) + if cu_seqlens_kv is None: + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + elif qkv_format == "thd": + assert ( + cu_seqlens_q is not None and cu_seqlens_kv is not None + ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" + if max_seqlen_q is None: + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + max_seqlen_q = seqlens_q.max().item() + if max_seqlen_kv is None: + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + max_seqlen_kv = seqlens_kv.max().item() + else: + if qkv_format in ["sbhd_2bshd", "bshd"]: + # q is in bshd in both cases from conversion above or the original input + batch_size, context_len = query_layer.shape[:2] + cu_seqlens_q = cu_seqlens_q[: batch_size + 1] + cu_seqlens_kv = cu_seqlens_kv[: batch_size + 1] + # convert from bshd to thd_2bshd for flash_attn_varlen_func/_with_kvcache; + # kernel assumes tensor is contiguous + if isinstance(query_layer, Float8Tensor): + query_layer._data = tex.convert_bshd_to_thd( + query_layer._data, + cu_seqlens_q, + batch_size * context_len, ) - if cu_seqlens_kv is None: - cu_seqlens_kv = _get_full_cu_seqlens( - batch_size, - max_seqlen_kv, - key_layer.device, + query_layer = Float8Tensor.make_like( + query_layer, data=query_layer._data, shape=query_layer._data.shape + ) + else: + query_layer = tex.convert_bshd_to_thd( + query_layer, + cu_seqlens_q, + batch_size * context_len, ) - elif qkv_format == "thd": - assert ( - cu_seqlens_q is not None and cu_seqlens_kv is not None - ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" - if max_seqlen_q is None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = seqlens_q.max().item() - if max_seqlen_kv is None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = seqlens_kv.max().item() + use_flash_attn_3 = False + if flash_attention_backend is not None and flash_attention_backend > PkgVersion("3.0.0b"): + use_flash_attn_3 = True if context_parallel and all( not isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer] ): @@ -5376,8 +4293,8 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, + cu_seqlens_q if qkv_format == "thd" else None, + cu_seqlens_kv if qkv_format == "thd" else None, self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -5388,6 +4305,9 @@ def forward( attn_mask_type=attn_mask_type, deterministic=self.deterministic, window_size=window_size, + quantizers=quantizers, + pad_between_seqs=False, + use_flash_attn_3=use_flash_attn_3, ) else: @@ -5400,36 +4320,81 @@ def forward( tensor.activation_offloading = True with self.attention_dropout_ctx(): - fa_optional_forward_kwargs = {} - if _flash_attn_2_3_plus: - fa_optional_forward_kwargs["window_size"] = window_size - if _flash_attn_2_4_plus: - fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes - if _flash_attn_2_4_1_plus: - fa_optional_forward_kwargs["deterministic"] = self.deterministic + # | API | use cases + # ---------------------------------------------------------------------- + # FA v2 | flash_attn_func | bshd/sbhd + not padding + # | flash_attn_varlen_func | bshd/sbhd + padding + # | | thd + padding + # | | KV cache (not-paged/paged), i.e. + # | | bshd/sbhd/thd + padding + # FA v3 | flash_attn_func | bshd/sbhd + not padding + # | flash_attn_varlen_func | bshd/sbhd + padding + # | | thd + padding + # | flash_attn_with_kvcache | KV cache (not-paged/paged), i.e. + # | | bshd/sbhd/thd + padding fa_optional_forward_args_thd = [] if qkv_format in ["bshd", "sbhd"] and "padding" not in attn_mask_type: - func = flash_attn_func if not _use_flash_attn_3 else flash_attn_func_v3 - else: - if _flash_attn_2_5_7_plus: - fa_optional_forward_kwargs["block_table"] = None func = ( - flash_attn_varlen_func - if not _use_flash_attn_3 - else flash_attn_varlen_func_v3 + flash_attn_func if not use_flash_attn_3 else flash_attn_func_v3 + ) # pylint: disable=possibly-used-before-assignment + else: + if not use_flash_attn_3: + func = flash_attn_varlen_func + elif inference_params is None: + func = flash_attn_varlen_func_v3 # pylint: disable=possibly-used-before-assignment + else: + func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment + if not use_flash_attn_3 or inference_params is None: + fa_optional_forward_args_thd.append(cu_seqlens_q) + fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append(max_seqlen_q) + fa_optional_forward_args_thd.append(max_seqlen_kv) + if not use_flash_attn_3: + fa_optional_forward_kwargs = {} + if fa_utils.v2_3_plus: + fa_optional_forward_kwargs["window_size"] = window_size + if fa_utils.v2_4_plus: + fa_optional_forward_kwargs["alibi_slopes"] = alibi_slopes + if fa_utils.v2_4_1_plus: + fa_optional_forward_kwargs["deterministic"] = self.deterministic + if inference_params is not None: + # use block_table kwarg to support thd_2bshd for non-paged + fa_optional_forward_kwargs["block_table"] = ( + inference_params.cache_manager.page_table[:batch_size] + if inference_params.is_paged + else inference_params.cache_manager.batch_indices_post_step.unsqueeze( + 1 + )[:batch_size] + ) + output = func( + query_layer, + key_layer, + value_layer, + *fa_optional_forward_args_thd, + self.attention_dropout if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal="causal" in attn_mask_type, + **fa_optional_forward_kwargs, ) - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) - fa_optional_forward_args_thd.append(max_seqlen_q) - fa_optional_forward_args_thd.append(max_seqlen_kv) - if _use_flash_attn_3: + else: fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size - fa_3_optional_forward_kwargs["deterministic"] = self.deterministic - activation_dtype = query_layer.dtype + if inference_params is None: + fa_3_optional_forward_kwargs["deterministic"] = self.deterministic + else: + fa_3_optional_forward_kwargs["cu_seqlens_q"] = cu_seqlens_q + fa_3_optional_forward_kwargs["max_seqlen_q"] = max_seqlen_q + cache_seqlens = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_3_optional_forward_kwargs["cache_seqlens"] = cache_seqlens + # flash_attn_with_kvcache accepts thd_2bshd for non-paged + if inference_params.is_paged: + fa_3_optional_forward_kwargs["page_table"] = ( + inference_params.cache_manager.page_table[:batch_size] + ) if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] torch_dtype = get_fp8_torch_dtype(fp8_meta["recipe"], fprop_tensor=True) + torch_orig_dtype = query_layer.dtype def convert_to_torch_float8(tensor, dtype): out = torch.Tensor().to(device=tensor.device, dtype=dtype) @@ -5446,22 +4411,27 @@ def convert_to_torch_float8(tensor, dtype): assert isinstance(key_layer, query_layer.__class__) and isinstance( value_layer, query_layer.__class__ ), "q, k, and v must have the same type." - if isinstance(query_layer, Float8Tensor): - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = query_layer._scale_inv - else: + if not isinstance(query_layer, Float8Tensor): query_layer, key_layer, value_layer = ( - Float8Tensor.to_float8(x, fp8_dtype=fp8_dtype_forward) - for x in [query_layer, key_layer, value_layer] + QKV_quantizer(x) for x in [query_layer, key_layer, value_layer] ) - fa_3_optional_forward_kwargs["descale_q"] = query_layer._scale_inv - fa_3_optional_forward_kwargs["descale_k"] = key_layer._scale_inv - fa_3_optional_forward_kwargs["descale_v"] = value_layer._scale_inv + batch_size = cu_seqlens_q.shape[0] - 1 + num_heads_k = key_layer.shape[-2] + fa_3_optional_forward_kwargs["q_descale"] = ( + query_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) + ) + fa_3_optional_forward_kwargs["k_descale"] = key_layer._scale_inv.unsqueeze( + 0 + ).repeat(batch_size, num_heads_k) + fa_3_optional_forward_kwargs["v_descale"] = ( + value_layer._scale_inv.unsqueeze(0).repeat(batch_size, num_heads_k) + ) query_layer, key_layer, value_layer = ( convert_to_torch_float8(x, torch_dtype) for x in [query_layer, key_layer, value_layer] ) try: - output, _ = func( + output = func( query_layer, key_layer, value_layer, @@ -5470,61 +4440,65 @@ def convert_to_torch_float8(tensor, dtype): causal="causal" in attn_mask_type, **fa_3_optional_forward_kwargs, ) + if isinstance(output, (List, Tuple)): + output = output[0] except TypeError as e: - if _flash_attn_3_0_0_beta: + if fa_utils.v3_0_0_beta: e.args = ( e.args[0] + ". Please update your flash-attn v3 (beta) installation as it " + "may have added more supported arguments to its API. \n" - + _flash_attn_3_installation_steps, + + fa_utils.v3_installation_steps, ) + e.args[1:] raise + if fp8: + output = output.to(dtype=torch_orig_dtype) if fp8 and fp8_meta["recipe"].fp8_mha: - output = cast_to_fp8( - output, - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - ) - output = Float8Tensor( - data=output, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, - ) - else: - output = func( - query_layer, - key_layer, - value_layer, - *fa_optional_forward_args_thd, - self.attention_dropout if self.training else 0.0, - softmax_scale=self.softmax_scale, - causal="causal" in attn_mask_type, - **fa_optional_forward_kwargs, - ) - - if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: - output = UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + O_quantizer = quantizers["scaling_fwd"][META_O] + output = O_quantizer(output) + + if inference_params is None: + if qkv_format in ["sbhd", "bshd"] and "padding" in attn_mask_type: + output = dpa_utils.UnpackTensor.apply(indices_q, batch_size * max_seqlen_q, output) + elif qkv_format in ["bshd", "sbhd_2bshd"]: + # all KV caching cases use thd_2bshd for calculation + # convert results back to bshd from thd_2bshd + if isinstance(query_layer, Float8Tensor): + output._data = tex.convert_thd_to_bshd( + output._data, + cu_seqlens_q, + batch_size, + context_len, + ) + output = Float8Tensor.make_like(output, data=output._data, shape=output._data.shape) + else: + output = tex.convert_thd_to_bshd( + output, + cu_seqlens_q, + batch_size, + context_len, + ) - if qkv_format == "sbhd": + if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) if fp8 and fp8_meta["recipe"].fp8_mha: + output_data = ( + output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) + .transpose(0, 1) + .contiguous() + ) output = Float8Tensor.make_like( output, - data=output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) - .transpose(0, 1) - .contiguous(), + data=output_data, + shape=output_data.shape, ) else: output = output.view(batch_size, max_seqlen_q // cp_size, -1).transpose(0, 1) - elif qkv_format == "bshd": + elif q_format == "bshd": # (bs)hd -> bs(hd) output = output.reshape(batch_size, max_seqlen_q // cp_size, -1) - elif qkv_format == "thd": + elif q_format == "thd": # thd -> t(hd) output = output.reshape(output.shape[0], -1) @@ -5540,9 +4514,9 @@ def _combine_tensors( num_tensors = len(tensors) new_shape = list(tensors[0].shape) new_shape.insert(dim, num_tensors) - new_stride = list(tensors[0].stride()) - new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) if isinstance(tensors[0], Float8Tensor): + new_stride = list(tensors[0]._data.stride()) + new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype) combined_tensor.set_( tensors[0]._data.untyped_storage(), @@ -5550,8 +4524,10 @@ def _combine_tensors( new_shape, new_stride, ) - combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor) + combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape) else: + new_stride = list(tensors[0].stride()) + new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors)) combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype) combined_tensor.set_( tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride @@ -5560,18 +4536,24 @@ def _combine_tensors( return combined_tensor -class FusedAttnFunc_qkvpacked(torch.autograd.Function): - """Function for FusedAttention with packed QKV input""" +class FusedAttnFunc(torch.autograd.Function): + """Function for FusedAttention with separate Q, K, V tensors""" @staticmethod def forward( ctx, is_training, - max_seqlen, - cu_seqlens, - cu_seqlens_padded, - qkv, - qkv_dtype, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + q, + k, + v, attn_bias, attn_scale, dropout_p, @@ -5585,51 +4567,74 @@ def forward( use_FAv2_bwd, fp8, fp8_meta, + quantizers, deterministic, ): # pylint: disable=missing-function-docstring # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha + is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False + + # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn + fake_dtype = q.dtype + + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + ) if fp8: - is_input_fp8 = isinstance(qkv, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = qkv._scale_inv fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - assert ( - qkv_group == 1 - ), f"qkv layout should conform to 3hd or h3d, e.g. sb3hd, but found {qkv_layout}." + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, and v must have the same type." + + is_input_fp8 = isinstance(q, Float8Tensor) + q_fp8, k_fp8, v_fp8 = None, None, None if is_input_fp8: - qkv_fp8 = qkv._data + q_fp8, k_fp8, v_fp8 = q, k, v else: - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8( - qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(qkv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_fp8 = QKV_quantizer(qkv) + q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) + case 2: + q_fp8 = QKV_quantizer(q) + dim = qkv_layout.split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_fp8 = QKV_quantizer(kv_c) + k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) + case 3: + q_fp8 = QKV_quantizer(q) + k_fp8 = QKV_quantizer(k) + v_fp8 = QKV_quantizer(v) + case _: + raise "Invalid qkv_layout " + qkv_layout + # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn + out_fp8, aux_ctx_tensors = fused_attn_fwd( is_training, - max_seqlen, - cu_seqlens, - qkv_fp8, - fp8_dtype_forward, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q_fp8, + k_fp8, + v_fp8, + fake_dtype, fused_attention_backend, attn_bias, - cu_seqlens_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + None, + None, + S_quantizer, + O_quantizer, attn_scale, dropout_p, fast_zero_fill, @@ -5640,69 +4645,58 @@ def forward( rng_gen, ) if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=qkv.dtype, - ) + out_ret = out_fp8 else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) + out_ret = out_fp8.dequantize().view(out_fp8.shape) + # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 + # is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn out_save = out_ret + if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): + # 1: qkv packed, 2: kv packed, 3: qkv separate if is_input_fp8: - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) + qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) + if qkv_group == 1: + dim = qkv_layout.find("3") + qkv = _combine_tensors([q, k, v], dim) + qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) + qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape) + q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) + if qkv_group == 2: + q = q.dequantize() + dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") + kv = _combine_tensors([k, v], dim) + kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) + kv_no_fp8 = kv.dequantize() + k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True) + if qkv_group == 3: + q = q.dequantize() + k = k.dequantize() + v = v.dequantize() if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - fp8_tensors = ( - qkv_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) + out_save = out_fp8.dequantize() + + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: - out_ret, aux_ctx_tensors = fused_attn_fwd_qkvpacked( + # q, k, v, out_ret: torch.float16 or torch.bfloat16 + out_ret, aux_ctx_tensors = fused_attn_fwd( is_training, - max_seqlen, - cu_seqlens, - qkv, - qkv_dtype, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q, + cu_seqlens_kv, + q, + k, + v, + fake_dtype, fused_attention_backend, attn_bias, - cu_seqlens_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + None, # s_quantizer + None, # o_quantizer attn_scale, dropout_p, fast_zero_fill, @@ -5712,19 +4706,49 @@ def forward( window_size, rng_gen, ) - fp8_tensors = (None, None, None, None) out_save = out_ret + fp8_tensors = (None, None, None, None) ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + + from .cpu_offload import CPUOffloadEnabled + + if CPUOffloadEnabled: + if ctx.fp8: + tensor_list = fp8_tensors + else: + tensor_list = [q, k, v, out_save] + + tensor_list.extend(aux_ctx_tensors) + + qkv_layout = "sbhd_sbhd_sbhd" + for tensor in tensor_list: + if tensor is not None: + tensor.activation_offloading = True + ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) - ctx.save_for_backward( - *qkvo_tensors, cu_seqlens, cu_seqlens_padded, *fp8_tensors, *aux_ctx_tensors + qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) + tensors_to_save, tensor_objects = prepare_for_saving( + *fp8_tensors, + *qkvo_tensors, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *aux_ctx_tensors, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.fp8_meta = fp8_meta - ctx.max_seqlen = max_seqlen - ctx.qkv_dtype = qkv_dtype + + ctx.dQKV_quantizer = dQKV_quantizer + ctx.dO_quantizer = dO_quantizer + ctx.dP_quantizer = dP_quantizer + ctx.S_quantizer = S_quantizer + + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_kv = max_seqlen_kv ctx.attn_scale = attn_scale ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill @@ -5747,969 +4771,18 @@ def backward(ctx, d_out): assert isinstance( d_out, Float8Tensor ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data + + # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 + # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2 + fake_dtype = d_out.dtype d_out = d_out.contiguous() ( - qkv, - out, - cu_seqlens, - cu_seqlens_padded, - qkv_fp8, + q_fp8, + k_fp8, + v_fp8, out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors - rest = [None] - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() - if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors - dqkv = torch.empty_like(qkv) - d_out, q, k, v, out = [ - maybe_contiguous(x) for x in (d_out, qkv[:, 0], qkv[:, 1], qkv[:, 2], out) - ] - flash_attn_cuda_bwd( - d_out, - q, - k, - v, - out, - softmax_lse, - dqkv[:, 0], - dqkv[:, 1], - dqkv[:, 2], - cu_seqlens, - cu_seqlens, - ctx.max_seqlen, - ctx.max_seqlen, - ctx.dropout_p, - ctx.attn_scale, - False, - "causal" in ctx.attn_mask_type, - None, - rng_state, - ) - dqkv = dqkv[..., : d_out.shape[-1]] - else: - with torch.cuda.nvtx.range("_FusedAttn_qkvpacked"): - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) - if ctx.is_output_fp8: - d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) - dqkv_fp8, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, - cu_seqlens, - qkv_fp8, - out_fp8, - d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - if ctx.is_input_fp8: - dqkv = Float8Tensor( - data=dqkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dqkv_c_fp8 = dqkv_fp8.view( - -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] - ) - dqkv = cast_from_fp8( - dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dqkv_fp8.shape) - else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(qkv.dtype) - dqkv, *rest = fused_attn_bwd_qkvpacked( - ctx.max_seqlen, - cu_seqlens, - qkv, - out, - d_out, - ctx.qkv_dtype, - ctx.qkv_dtype, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_padded, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - dqkv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) - return ( - None, - None, - None, - None, - dqkv, - None, - rest[0], - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class FusedAttnFunc_kvpacked(torch.autograd.Function): - """Function for FusedAttention with packed KV input""" - - @staticmethod - def forward( - ctx, - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - q, - kv, - qkv_dtype, - attn_bias, - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - fused_attention_backend, - use_FAv2_bwd, - fp8, - fp8_meta, - deterministic, - ): - # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha - if fp8: - assert isinstance(kv, q.__class__), "q and kv must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if is_input_fp8: - q_fp8, kv_fp8 = q._data, kv._data - else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - assert qkv_group == 2, ( - "qkv layout should conform to hd_2hd or hd_h2d, e.g. sbhd_sb2hd, " - f"but found {qkv_layout}." - ) - q_fp8 = cast_to_fp8(q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward).view( - q.shape - ) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8( - kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(kv.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q_fp8, - kv_fp8, - fp8_dtype_forward, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - out_save = out_ret - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - if is_input_fp8: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - fp8_tensors = ( - q_fp8, - kv_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) - else: - out_ret, aux_ctx_tensors = fused_attn_fwd_kvpacked( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - qkv_dtype, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - out_save = out_ret - fp8_tensors = (None, None, None, None, None) - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - ctx.is_input_fp8 = is_input_fp8 - ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) - ctx.save_for_backward( - *qkvo_tensors, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - *fp8_tensors, - *aux_ctx_tensors, - ) - ctx.fp8_meta = fp8_meta - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.window_size = window_size - ctx.fused_attention_backend = ( - fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] - ) - ctx.use_FAv2_bwd = use_FAv2_bwd - ctx.deterministic = deterministic - - return out_ret - - @staticmethod - def backward(ctx, d_out): - # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data - - d_out = d_out.contiguous() - ( - q, - kv, - out, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - q_fp8, - kv_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors - rest = [None] - if not aux_ctx_tensors[0].is_contiguous(): - aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() - if ctx.use_FAv2_bwd: - softmax_lse, rng_state = aux_ctx_tensors - dq = torch.empty_like(q) - dkv = torch.empty_like(kv) - d_out, q, k, v, out = [maybe_contiguous(x) for x in (d_out, q, kv[:, 0], kv[:, 1], out)] - flash_attn_cuda_bwd( - d_out, - q, - k, - v, - out, - softmax_lse, - dq, - dkv[:, 0], - dkv[:, 1], - cu_seqlens_q, - cu_seqlens_kv, - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - ctx.dropout_p, - ctx.attn_scale, - False, - "causal" in ctx.attn_mask_type, - None, - rng_state, - ) - dq = dq[..., : d_out.shape[-1]] - dkv = dkv[..., : d_out.shape[-1]] - else: - with torch.cuda.nvtx.range("_FusedAttn_kvpacked"): - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) - if ctx.is_output_fp8: - d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv - else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) - dq_fp8, dkv_fp8, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q_fp8, - kv_fp8, - out_fp8, - d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - if ctx.is_input_fp8: - dq = Float8Tensor( - data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dkv = Float8Tensor( - data=dkv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) - dkv_c_fp8 = dkv_fp8.view( - -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] - ) - dkv = cast_from_fp8( - dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dkv_fp8.shape) - else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) - dq, dkv, *rest = fused_attn_bwd_kvpacked( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - out, - d_out, - ctx.qkv_dtype, - ctx.qkv_dtype, - aux_ctx_tensors, - ctx.fused_attention_backend, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ctx.attn_scale, - ctx.dropout_p, - ctx.fast_zero_fill, - ctx.qkv_layout, - ctx.attn_bias_type, - ctx.attn_mask_type, - ctx.window_size, - ctx.deterministic, - ) - - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - None, - None, - None, - dq, - dkv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) - return ( - None, - None, - None, - None, - None, - None, - None, - dq, - dkv, - None, - rest[0], - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - - -class FusedAttnFunc(torch.autograd.Function): - """Function for FusedAttention with separate Q, K, V tensors""" - - @staticmethod - def forward( - ctx, - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - q, - k, - v, - qkv_dtype, - attn_bias, - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - fused_attention_backend, - use_FAv2_bwd, - fp8, - fp8_meta, - deterministic, - ): - # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha - if fp8: - fused_attention_backend = FusedAttnBackend["FP8"] - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - if is_input_fp8: - fp8_meta["scaling_fwd"].scale_inv[META_QKV] = q._scale_inv - q_fp8, k_fp8, v_fp8 = q._data, k._data, v._data - else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = cast_to_fp8( - qkv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(qkv.shape) - q_fp8, k_fp8, v_fp8 = _SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1]) - q_fp8, k_fp8, v_fp8 = [x.squeeze(dim) for x in [q_fp8, k_fp8, v_fp8]] - if qkv_group == 2: - q_fp8 = cast_to_fp8( - q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = cast_to_fp8( - kv_c, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(kv.shape) - k_fp8, v_fp8 = _SplitAlongDim.apply(kv_fp8, dim, [1, 1]) - k_fp8, v_fp8 = [x.squeeze(dim) for x in [k_fp8, v_fp8]] - if qkv_group == 3: - q_fp8 = cast_to_fp8( - q, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(q.shape) - k_fp8 = cast_to_fp8( - k, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(k.shape) - v_fp8 = cast_to_fp8( - v, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward - ).view(v.shape) - out_fp8, aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q_fp8, - k_fp8, - v_fp8, - fp8_dtype_forward, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - fp8_meta["scaling_fwd"].scale_inv, # d_scale_qkv - META_QKV, # d_scale_qkv_offset - fp8_meta["scaling_fwd"].scale_inv, # d_scale_s - META_S, # d_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_s - META_S, # q_scale_s_offset - fp8_meta["scaling_fwd"].scale, # q_scale_o - META_O, # q_scale_o_offset - fp8_meta["scaling_fwd"].amax_history, # amax_s - META_S, # amax_s_offset - fp8_meta["scaling_fwd"].amax_history, # amax_o - META_O, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - if is_output_fp8: - out_ret = Float8Tensor( - data=out_fp8, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=META_O, - fp8_dtype=fp8_dtype_forward, - dtype=q.dtype, - ) - else: - out_ret = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - out_save = out_ret - - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - # 1: qkv packed, 2: kv packed, 3: qkv separate - if is_input_fp8: - qkv_group = len(qkv_layout.split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = _combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = cast_from_fp8( - qkv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[qkv.dtype], - ).view(qkv.shape) - q, k, v = _SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1]) - q, k, v = [x.squeeze(dim) for x in [q, k, v]] - if qkv_group == 2: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - dim = qkv_layout.split("_")[1].find("2") - kv = _combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = cast_from_fp8( - kv_c._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[kv.dtype], - ).view(kv.shape) - k, v = _SplitAlongDim.apply(kv_no_fp8, dim, [1, 1]) - k, v = [x.squeeze(dim) for x in [k, v]] - if qkv_group == 3: - q = cast_from_fp8( - q._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[q.dtype], - ).view(q.shape) - k = cast_from_fp8( - k._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[k.dtype], - ).view(k.shape) - v = cast_from_fp8( - v._data, - fp8_meta["scaling_fwd"], - META_QKV, - fp8_dtype_forward, - TE_DType[v.dtype], - ).view(v.shape) - if is_output_fp8: - out_save = cast_from_fp8( - out_fp8.view(-1, out_fp8.shape[-2] * out_fp8.shape[-1]), - fp8_meta["scaling_fwd"], - META_O, - fp8_dtype_forward, - qkv_dtype, - ).view(out_fp8.shape) - - fp8_tensors = ( - q_fp8, - k_fp8, - v_fp8, - out_fp8, - fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), - ) - else: - out_ret, aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q, - cu_seqlens_kv, - q, - k, - v, - qkv_dtype, - fused_attention_backend, - attn_bias, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - None, # d_scale_qkv - 0, # d_scale_qkv_offset - None, # d_scale_s - 0, # d_scale_s_offset - None, # q_scale_s - 0, # q_scale_s_offset - None, # q_scale_o - 0, # q_scale_o_offset - None, # amax_s - 0, # amax_s_offset - None, # amax_o - 0, # amax_o_offset - attn_scale, - dropout_p, - fast_zero_fill, - qkv_layout, - attn_bias_type, - attn_mask_type, - window_size, - rng_gen, - ) - out_save = out_ret - fp8_tensors = (None, None, None, None, None, None) - - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - - from .cpu_offload import CPUOffloadEnabled - - if CPUOffloadEnabled: - if ctx.fp8: - tensor_list = fp8_tensors - else: - tensor_list = [q, k, v, out_save] - - tensor_list.extend(aux_ctx_tensors) - - qkv_layout = "sbhd_sbhd_sbhd" - for tensor in tensor_list: - if tensor is not None: - tensor.activation_offloading = True - - ctx.is_input_fp8 = is_input_fp8 - ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) - ctx.save_for_backward( - *qkvo_tensors, - cu_seqlens_q, - cu_seqlens_kv, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - *fp8_tensors, - *aux_ctx_tensors, - ) - ctx.fp8_meta = fp8_meta - ctx.max_seqlen_q = max_seqlen_q - ctx.max_seqlen_kv = max_seqlen_kv - ctx.qkv_dtype = qkv_dtype - ctx.attn_scale = attn_scale - ctx.dropout_p = dropout_p - ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout - ctx.attn_bias_type = attn_bias_type - ctx.attn_mask_type = attn_mask_type - ctx.window_size = window_size - ctx.fused_attention_backend = ( - fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] - ) - ctx.use_FAv2_bwd = use_FAv2_bwd - ctx.deterministic = deterministic - - return out_ret - - @staticmethod - def backward(ctx, d_out): - # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - d_out_f8tensor = d_out - d_out = d_out._data - - d_out = d_out.contiguous() - ( q, k, v, @@ -6718,14 +4791,11 @@ def backward(ctx, d_out): cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - q_fp8, - k_fp8, - v_fp8, - out_fp8, - fwd_scales, - fwd_scale_invs, - *aux_ctx_tensors, - ) = ctx.saved_tensors + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + + aux_ctx_tensors = other_tensors + if not aux_ctx_tensors[0].is_contiguous(): aux_ctx_tensors[0] = aux_ctx_tensors[0].contiguous() rest = [None] @@ -6762,20 +4832,13 @@ def backward(ctx, d_out): else: with torch.cuda.nvtx.range("_FusedAttn"): if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype( - ctx.fp8_meta["recipe"], fprop_tensor=False - ) if ctx.is_output_fp8: d_out_fp8 = d_out - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO] = d_out_f8tensor._scale_inv else: - d_out_fp8 = cast_to_fp8( - d_out.view(-1, d_out.shape[-2] * d_out.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DO, - fp8_dtype_backward, - ).view(d_out.shape) + d_out_fp8 = ctx.dO_quantizer(d_out) + dqkv_dtype = TE_DType[d_out_fp8._data.dtype] + # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn + # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6786,22 +4849,15 @@ def backward(ctx, d_out): v_fp8, out_fp8, d_out_fp8, - fp8_dtype_forward, - fp8_dtype_backward, + fake_dtype, + dqkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, cu_seqlens_kv_padded, - fwd_scale_invs[META_QKV], # d_scale_qkv, - fwd_scale_invs[META_S], # d_scale_s, - fwd_scale_invs[META_O], # d_scale_o, - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do - ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp - fwd_scales[META_S], # q_scale_s - ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp - ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp - ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv + ctx.S_quantizer, + ctx.dP_quantizer, + ctx.dQKV_quantizer, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, @@ -6812,95 +4868,40 @@ def backward(ctx, d_out): ctx.deterministic, ) - if ctx.is_input_fp8: - dq = Float8Tensor( - data=dq_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dk = Float8Tensor( - data=dk_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - dv = Float8Tensor( - data=dv_fp8, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=META_DQKV, - fp8_dtype=fp8_dtype_backward, - dtype=d_out_f8tensor.dtype, - ) - else: - qkv_group = len(ctx.qkv_layout.split("_")) + # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 + # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 + if not ctx.is_input_fp8: + qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) if qkv_group == 1: dim = ctx.qkv_layout.find("3") - dqkv_fp8 = _combine_tensors([dq_fp8, dk_fp8, dv_fp8], dim) - dqkv_c_fp8 = dqkv_fp8.view( - -1, dqkv_fp8.shape[-3] * dqkv_fp8.shape[-2] * dqkv_fp8.shape[-1] + dqkv_fp8_data = _combine_tensors( + [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim + ) + dqkv_fp8 = dq_fp8.make_like( + tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape ) - dqkv = cast_from_fp8( - dqkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dqkv_fp8.shape) - dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1]) - dq, dk, dv = [x.squeeze(dim) for x in [dq, dk, dv]] + dqkv = dqkv_fp8.dequantize() + dq, dk, dv = _SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) if qkv_group == 2: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) + dq = dq_fp8.dequantize() dim = ctx.qkv_layout.split("_")[1].find("2") dkv_fp8 = _combine_tensors([dk_fp8, dv_fp8], dim) dkv_c_fp8 = dkv_fp8.view( -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] ) - dkv = cast_from_fp8( - dkv_c_fp8, - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dkv_fp8.shape) - dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1]) - dk, dv = [x.squeeze(dim) for x in [dk, dv]] + dkv = dkv_c_fp8.dequantize() + dk, dv = _SplitAlongDim.apply(dkv, dim, [1, 1], True) if qkv_group == 3: - dq = cast_from_fp8( - dq_fp8.view(-1, dq_fp8.shape[-2] * dq_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dq_fp8.shape) - dk = cast_from_fp8( - dk_fp8.view(-1, dk_fp8.shape[-2] * dk_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dk_fp8.shape) - dv = cast_from_fp8( - dv_fp8.view(-1, dv_fp8.shape[-2] * dv_fp8.shape[-1]), - ctx.fp8_meta["scaling_bwd"], - META_DQKV, - fp8_dtype_backward, - ctx.qkv_dtype, - ).view(dv_fp8.shape) + dq = dq_fp8.dequantize() + dk = dk_fp8.dequantize() + dv = dv_fp8.dequantize() + else: + dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 else: - if d_out.dtype == torch.uint8: - d_out = d_out_f8tensor.from_float8(q.dtype) + if isinstance(d_out, QuantizedTensor): + d_out = d_out.dequantize() + dqkv_dtype = TE_DType[d_out.dtype] + # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -6911,8 +4912,8 @@ def backward(ctx, d_out): v, out, d_out, - ctx.qkv_dtype, - ctx.qkv_dtype, + fake_dtype, + dqkv_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -6920,13 +4921,6 @@ def backward(ctx, d_out): None, None, None, - None, - None, - None, - None, - None, - None, - None, ctx.attn_scale, ctx.dropout_p, ctx.fast_zero_fill, @@ -6947,6 +4941,8 @@ def backward(ctx, d_out): None, None, None, + None, + None, dq, dk, dv, @@ -6977,10 +4973,11 @@ def backward(ctx, d_out): None, None, None, + None, + None, dq, dk, dv, - None, rest[0], None, None, @@ -6997,6 +4994,7 @@ def backward(ctx, d_out): None, None, None, + None, ) @@ -7096,6 +5094,9 @@ def forward( cp_comm_type: str = "p2p", fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, + pad_between_seqs: bool = False, + inference_params: Optional[InferenceParams] = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -7120,64 +5121,65 @@ def forward( cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) + # get q_format and kv_format for training and inference + qkv_format, q_format, kv_format = dpa_utils.get_qkv_format(qkv_layout, inference_params) - if qkv_format in ["sbhd", "bshd"]: - if qkv_format == "sbhd": - batch_size, max_seqlen_q, max_seqlen_kv = ( - query_layer.shape[1], - query_layer.shape[0], - key_layer.shape[0], - ) - if qkv_format == "bshd": - batch_size, max_seqlen_q, max_seqlen_kv = ( - query_layer.shape[0], - query_layer.shape[1], - key_layer.shape[1], - ) - max_seqlen_q *= cp_size - max_seqlen_kv *= cp_size - if "padding" in attn_mask_type: - assert not context_parallel, "Padding mask not supported with context parallelism!" - - if cu_seqlens_q is None or cu_seqlens_kv is None: - if attention_mask is None: - raise RuntimeError( - "Please provide attention_mask or cu_seqlens for padding!" + page_table = None + if inference_params is None: + if qkv_format in ["sbhd", "bshd"]: + if qkv_format == "sbhd": + batch_size = query_layer.shape[1] + max_seqlen_q = query_layer.shape[0] + max_seqlen_kv = key_layer.shape[0] + if qkv_format == "bshd": + batch_size = query_layer.shape[0] + max_seqlen_q = query_layer.shape[1] + max_seqlen_kv = key_layer.shape[1] + max_seqlen_q *= cp_size + max_seqlen_kv *= cp_size + if "padding" in attn_mask_type: + assert ( + not context_parallel + ), "Padding mask not supported with context parallelism!" + if cu_seqlens_q is None or cu_seqlens_kv is None: + if attention_mask is None: + raise RuntimeError( + "Please provide attention_mask or cu_seqlens for padding!" + ) + if self.attention_type == "self": + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) + cu_seqlens_kv = cu_seqlens_q + else: + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) + else: + if cu_seqlens_q is None: + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( + batch_size, + max_seqlen_q, + query_layer.device, ) - if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) - cu_seqlens_kv = cu_seqlens_q - else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) - else: - if cu_seqlens_q is None: - cu_seqlens_q = _get_full_cu_seqlens( - batch_size, - max_seqlen_q, - query_layer.device, - ) - if cu_seqlens_kv is None: - cu_seqlens_kv = _get_full_cu_seqlens( - batch_size, - max_seqlen_kv, - key_layer.device, - ) - if qkv_format == "thd": - assert ( - max_seqlen_q is not None - and max_seqlen_kv is not None - and cu_seqlens_q is not None - and cu_seqlens_kv is not None - ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" - - if cu_seqlens_q_padded is None or cu_seqlens_kv_padded is None: + if cu_seqlens_kv is None: + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( + batch_size, + max_seqlen_kv, + key_layer.device, + ) + if qkv_format == "thd": + assert ( + max_seqlen_q is not None + and max_seqlen_kv is not None + and cu_seqlens_q is not None + and cu_seqlens_kv is not None + ), "max_seqlen_q/kv and cu_seqlens_q/kv can not be None when qkv_format is thd!" + elif inference_params.is_paged: + page_table = inference_params.cache_manager.page_table + + if (q_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_q_padded is None: cu_seqlens_q_padded = cu_seqlens_q + if (kv_format == "thd" or "padding" in attn_mask_type) and cu_seqlens_kv_padded is None: cu_seqlens_kv_padded = cu_seqlens_kv - qkv_dtype = TE_DType[query_layer.dtype] - use_FAv2_bwd = ( self.use_FAv2_bwd and (core_attention_bias_type == "no_bias") @@ -7233,6 +5235,8 @@ def forward( window_size=window_size, fp8=fp8, fp8_meta=fp8_meta, + quantizers=quantizers, + pad_between_seqs=pad_between_seqs, ) else: with self.attention_dropout_ctx(): @@ -7244,10 +5248,11 @@ def forward( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, + page_table, + page_table, query_layer, key_layer, value_layer, - qkv_dtype, core_attention_bias, self.softmax_scale, self.attention_dropout if self.training else 0.0, @@ -7261,6 +5266,7 @@ def forward( use_FAv2_bwd, fp8, fp8_meta, + quantizers, self.deterministic, ) @@ -7420,15 +5426,15 @@ def __init__( super().__init__() self.logger = logging.getLogger("DotProductAttention") - self.logger.setLevel(_log_level) + self.logger.setLevel(attn_log._log_level) if not self.logger.hasHandlers(): - self.logger.addHandler(_stream_handler) + self.logger.addHandler(attn_log._stream_handler) self.qkv_format = qkv_format attn_mask_type = attn_mask_type.replace(",", "_") if attn_mask_type == "causal_padding": attn_mask_type = "padding_causal" self.attn_mask_type = attn_mask_type - self.window_size = check_set_window_size(attn_mask_type, window_size) + self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -7632,14 +5638,14 @@ def forward( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, - attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, - qkv_format: Optional[str] = None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_kv: Optional[torch.Tensor] = None, - cu_seqlens_q_padded: Optional[torch.Tensor] = None, - cu_seqlens_kv_padded: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - max_seqlen_kv: Optional[int] = None, + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + qkv_format: str = None, + cu_seqlens_q: torch.Tensor = None, + cu_seqlens_kv: torch.Tensor = None, + cu_seqlens_q_padded: torch.Tensor = None, + cu_seqlens_kv_padded: torch.Tensor = None, + max_seqlen_q: int = None, + max_seqlen_kv: int = None, attn_mask_type: Optional[str] = None, window_size: Optional[Tuple[int, int]] = None, checkpoint_core_attention: bool = False, @@ -7648,7 +5654,7 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, - is_first_microbatch: Optional[bool] = None, + pad_between_seqs: Optional[bool] = None, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -7818,27 +5824,26 @@ def forward( Adjustments of the sequence_len_offset should be done after a complete forward pass. If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand. Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient. - is_first_microbatch : {True, False, None}, default = None - During training using either gradient accumulation or - pipeline parallelism a minibatch of data is further split - into microbatches. Between the microbatches of the same minibatch - the model weights are not updated. Setting this parameter indicates - whether the current microbatch is the first in a minibatch or not. - When set, this parameter enables additional optimizations: - - * during FP8 training, it allows caching of the FP8 versions of - the weights - * it also allows skipping gradient accumulation during the - first microbatch (since it is the first gradient being - produced) + pad_between_seqs: Optional[bool], default = `None` + If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If true, there are padding tokens between individual sequences in a packed batch. """ + with self.prepare_forward( query_layer, - is_first_microbatch, num_gemms=3, allow_non_contiguous=True, ) as query_layer: + # checks for RNG + if self.rng_states_tracker is not None and is_graph_capturing(): + assert isinstance( + self.rng_states_tracker, CudaRNGStatesTracker + ), "Unsupported RNG states tracker." + assert ( + graph_safe_rng_available() + ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." + # checks for FP8 if self.fp8: if self.fp8_meta["recipe"].fp8_mha: if not self.fp8_meta["recipe"].fp8_dpa: @@ -7847,7 +5852,6 @@ def forward( """Forcing fp8_meta["recipe"].fp8_dpa=True due to """ """fp8_meta["recipe"].fp8_mha=True""" ) - if self.fp8 and self.fp8_meta["recipe"].fp8_dpa: forward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=True) backward_dtype = get_fp8_te_dtype(self.fp8_meta["recipe"], fprop_tensor=False) @@ -7859,6 +5863,7 @@ def forward( tex.DType.kFloat8E5M2, ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" + # checks for q/k/v shapes assert ( query_layer.is_cuda and key_layer.is_cuda and value_layer.is_cuda ), "DotProductAttention only supports CUDA tensors." @@ -7868,15 +5873,26 @@ def forward( assert ( key_layer.shape[:-1] == value_layer.shape[:-1] ), "Keys and values must have the same batch size, sequence length and number of heads!" + num_attention_heads = query_layer.shape[-2] + num_gqa_groups = key_layer.shape[-2] + assert ( + query_layer.shape[-1] == key_layer.shape[-1] + ), "Queries and keys must have the same head dimension!" + head_dim_qk, head_dim_v = query_layer.shape[-1], value_layer.shape[-1] assert ( - key_layer.shape[-1] == self.hidden_size_per_attention_head_k - ), f"Keys have head_dim = {key_layer.shape[-1]}, " + head_dim_qk == self.hidden_size_per_attention_head_k + ), f"Keys have head_dim = {head_dim_qk}, " "but expected head_dim = {self.hidden_size_per_attention_head_k}!" assert ( - value_layer.shape[-1] == self.hidden_size_per_attention_head_v - ), f"Values have head_dim = {value_layer.shape[-1]}, " + head_dim_v == self.hidden_size_per_attention_head_v + ), f"Values have head_dim = {head_dim_v}, " "but expected head_dim = {self.hidden_size_per_attention_head_v}!" + assert num_gqa_groups == self.num_gqa_groups_per_partition, ( + "Keys and values must have num_gqa_group =" + f" {self.num_gqa_groups_per_partition} heads! Found {num_gqa_groups}." + ) + # checks for attention mask if attn_mask_type is None: attn_mask_type = self.attn_mask_type else: @@ -7886,82 +5902,40 @@ def forward( assert ( attn_mask_type in AttnMaskTypes ), f"Attention mask type {attn_mask_type} is not supported!" - if qkv_format == "thd": - assert ( - "padding" in attn_mask_type - ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" + # checks for sliding window if window_size is None: window_size = self.window_size - window_size = check_set_window_size(attn_mask_type, window_size) - - if self.rng_states_tracker is not None and is_graph_capturing(): - assert isinstance( - self.rng_states_tracker, CudaRNGStatesTracker - ), "Unsupported RNG states tracker." - assert ( - graph_safe_rng_available() - ), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture." + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + # checks for qkv_format if qkv_format is None: qkv_format = self.qkv_format - - if inference_params is not None: - assert self.layer_number is not None, "Layer number must be set!" - - # convert causal to causal_bottom_right in inference when KV-caching is in use - # so users can run with the same attn_mask_type for training and inference - if attn_mask_type in ["causal", "padding_causal"]: - attn_mask_type = attn_mask_type + "_bottom_right" - - if qkv_format == "bshd": - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] - - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - - # Copy keys and values into KV-cache - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - key_layer - ) - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = ( - value_layer - ) - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - - if qkv_format == "bshd": - key_layer = key_layer.transpose(0, 1) - value_layer = value_layer.transpose(0, 1) - - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - assert ( - key_layer.shape[-2] == self.num_gqa_groups_per_partition - and value_layer.shape[-2] == self.num_gqa_groups_per_partition - ), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!" assert qkv_format in [ "sbhd", "bshd", "thd", ], "DotProductAttention only supports qkv_format = {'sbhd', 'bshd', 'thd'}!" - + batch_size = None + if qkv_format in ["sbhd", "bshd"]: + assert all( + len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) + ), f"Queries, keys and values must be 4D tensors when {qkv_format=}!" + if qkv_format == "sbhd": + batch_size = query_layer.shape[1] + max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv + else: + batch_size = query_layer.shape[0] + max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q + max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv if qkv_format == "thd": assert all( len(x.shape) == 3 for x in (query_layer, key_layer, value_layer) ), "Queries, keys and values must be 3D tensors when qkv_format = thd!" + assert ( + "padding" in attn_mask_type + ), "Attention mask type must be padding or padding_causal for qkv_format=thd!" assert ( cu_seqlens_q is not None and cu_seqlens_kv is not None ), "cu_seqlens_q and cu_seqlens_kv can not be None when qkv_format = thd!" @@ -7987,6 +5961,76 @@ def forward( seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) + # update KV cache and retrieve saved tokens from cache for inference + if inference_params is not None: + assert self.layer_number is not None, "Layer number must be set!" + + # convert top-left causal to bottom-right causal due to KV caching + # users can still use the same attention mask for inference as for training + assert "padding" in attn_mask_type, "KV caching requires padding mask!" + if attn_mask_type == "padding_causal": + attn_mask_type = attn_mask_type + "_bottom_right" + + self.attention_type = "cross" + self.flash_attention.attention_type = self.attention_type + self.fused_attention.attention_type = self.attention_type + self.unfused_attention.attention_type = self.attention_type + + query_layer, key_layer, value_layer = [ + x.contiguous() if not x.is_contiguous() else x + for x in [query_layer, key_layer, value_layer] + ] + + # get full K/V tensors from cache and adjust cu_seqlens, qkv_format based on the cache + ( + key_layer, + value_layer, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_kv, + qkv_format, + ) = inference_params.step( + self.layer_number, + key_layer, + value_layer, + qkv_format, + ) + cu_seqlens_q_padded = None + cu_seqlens_kv_padded = None + + # get qkv's memory layout + if all(isinstance(x, Float8Tensor) for x in [query_layer, key_layer, value_layer]): + ( + qkv_layout, + query_layer._data, + key_layer._data, + value_layer._data, + q_format, + kv_format, + ) = dpa_utils.get_qkv_layout( + query_layer._data, + key_layer._data, + value_layer._data, + qkv_format=qkv_format, + inference_params=inference_params, + ) + else: + ( + qkv_layout, + query_layer, + key_layer, + value_layer, + q_format, + kv_format, + ) = dpa_utils.get_qkv_layout( + query_layer, + key_layer, + value_layer, + qkv_format=qkv_format, + inference_params=inference_params, + ) + + # adjust max_seqlen and cu_seqlens for CP cp_size = 1 if isinstance(self.cp_group, dist_group_type): cp_size = get_distributed_world_size(self.cp_group) @@ -7994,69 +6038,42 @@ def forward( for group in self.cp_group: cp_size *= get_distributed_world_size(group) context_parallel = cp_size > 1 - - if qkv_format in ["sbhd", "bshd"]: - assert all( - len(x.shape) == 4 for x in (query_layer, key_layer, value_layer) - ), f"Queries, keys and values must be 4D tensors when qkv_format = {qkv_format}!" - if qkv_format == "sbhd": - max_seqlen_q = query_layer.shape[0] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[0] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[1] - else: - max_seqlen_q = query_layer.shape[1] if max_seqlen_q is None else max_seqlen_q - max_seqlen_kv = key_layer.shape[1] if max_seqlen_kv is None else max_seqlen_kv - batch_size = query_layer.shape[0] + if q_format in ["sbhd", "bshd"]: max_seqlen_q *= cp_size - max_seqlen_kv *= cp_size - if cu_seqlens_q is not None: - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - assert all( - seqlens_q <= max_seqlen_q - ), """Sequence lengths indicated by cu_seqlens_q must be no greater than - the sequence dimension in 'query_layer'!""" - if cu_seqlens_kv is not None: - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - assert all( - seqlens_kv <= max_seqlen_kv - ), """Sequence lengths indicated by cu_seqlens_kv must be no greater than - the sequence dimension in 'key_layer' and 'value_layer'!""" - if cu_seqlens_q is None or cu_seqlens_kv is None: + if cu_seqlens_q is None: if "padding" in attn_mask_type: assert ( attention_mask is not None ), "Please provide attention_mask for padding!" if self.attention_type == "self": - cu_seqlens_q = get_cu_seqlens(attention_mask) - cu_seqlens_kv = cu_seqlens_q + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask) else: - cu_seqlens_q = get_cu_seqlens(attention_mask[0]) - cu_seqlens_kv = get_cu_seqlens(attention_mask[1]) + cu_seqlens_q = dpa_utils.get_cu_seqlens(attention_mask[0]) else: - cu_seqlens_q = _get_full_cu_seqlens( + cu_seqlens_q = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_q, query_layer.device, ) - cu_seqlens_kv = _get_full_cu_seqlens( + if kv_format in ["sbhd", "bshd"]: + max_seqlen_kv *= cp_size + if cu_seqlens_kv is None: + if "padding" in attn_mask_type: + assert ( + attention_mask is not None + ), "Please provide attention_mask for padding!" + if self.attention_type == "self": + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask) + else: + cu_seqlens_kv = dpa_utils.get_cu_seqlens(attention_mask[1]) + else: + cu_seqlens_kv = dpa_utils.get_full_cu_seqlens( batch_size, max_seqlen_kv, key_layer.device, ) - if ( - isinstance(query_layer, Float8Tensor) - and isinstance(key_layer, Float8Tensor) - and isinstance(value_layer, Float8Tensor) - ): - qkv_layout, query_layer._data, key_layer._data, value_layer._data = get_qkv_layout( - query_layer._data, key_layer._data, value_layer._data, qkv_format=qkv_format - ) - else: - qkv_layout, query_layer, key_layer, value_layer = get_qkv_layout( - query_layer, key_layer, value_layer, qkv_format=qkv_format - ) - + # set ALiBi attributes global _alibi_cache if alibi_slopes is not None: assert ( @@ -8080,6 +6097,7 @@ def forward( _alibi_cache["_alibi_slopes_require_update"] = True _alibi_cache["_alibi_bias_require_update"] = True + # detect bias shape core_attention_bias_shape = None if core_attention_bias is not None: if ( @@ -8103,25 +6121,30 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" - pad_between_seqs = ( - cu_seqlens_q_padded is not None - and not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) - ) or ( - cu_seqlens_kv_padded is not None - and not torch.equal(cu_seqlens_kv_padded, cu_seqlens_kv) - ) + if pad_between_seqs is None: + if qkv_format == "thd": + pad_between_seqs = ( + cu_seqlens_q_padded is not None + and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1]) + ) or ( + cu_seqlens_kv_padded is not None + and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1]) + ) + else: + pad_between_seqs = False - attention_params = AttentionParams( + # gather attention params for get_attention_backend + attention_params = dpa_utils.AttentionParams( qkv_type=type(query_layer), qkv_dtype=query_layer.dtype, qkv_layout=qkv_layout, batch_size=batch_size, - num_heads=query_layer.shape[-2], - num_gqa_groups=key_layer.shape[-2], + num_heads=num_attention_heads, + num_gqa_groups=num_gqa_groups, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, - head_dim_qk=query_layer.shape[-1], - head_dim_v=value_layer.shape[-1], + head_dim_qk=head_dim_qk, + head_dim_v=head_dim_v, attn_mask_type=attn_mask_type, window_size=window_size, alibi_slopes_shape=alibi_slopes.shape if alibi_slopes is not None else None, @@ -8137,8 +6160,9 @@ def forward( is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, + inference_params=inference_params, ) - global _attention_backends, _use_flash_attn_3 + global _attention_backends if ( _attention_backends["attention_params"] is None or attention_params != _attention_backends["attention_params"] @@ -8146,18 +6170,26 @@ def forward( _attention_backends["attention_params"] = attention_params _attention_backends["backend_selection_requires_update"] = True if _attention_backends["backend_selection_requires_update"]: - _use_flash_attn_3 = _flash_attn_3_is_installed ( use_flash_attention, + flash_attention_backend, use_fused_attention, fused_attention_backend, use_unfused_attention, _, - ) = get_attention_backend(attention_params) + ) = dpa_utils.get_attention_backend(attention_params) + # Set global _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False if use_flash_attention: self.logger.info( "Running with FlashAttention backend (version %s)", - _flash_attn_version if not _use_flash_attn_3 else _flash_attn_3_version, + flash_attention_backend, ) elif use_fused_attention: self.logger.info( @@ -8168,13 +6200,20 @@ def forward( self.logger.info("Running with UnfusedDotProductAttention backend") else: use_flash_attention = _attention_backends["use_flash_attention"] + flash_attention_backend = _attention_backends["flash_attention_backend"] use_fused_attention = _attention_backends["use_fused_attention"] fused_attention_backend = _attention_backends["fused_attention_backend"] use_unfused_attention = _attention_backends["use_unfused_attention"] + # raise exception if no backend is available + if sum([use_flash_attention, use_fused_attention, use_unfused_attention]) == 0: + raise ValueError("No dot product attention support for the provided inputs!") + + # run attention if use_flash_attention: if core_attention_bias_type == "alibi": - alibi_slopes, _ = get_alibi( + alibi_slopes, _ = dpa_utils.get_alibi( + _alibi_cache, query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, @@ -8199,6 +6238,9 @@ def forward( max_seqlen_kv=max_seqlen_kv, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, ) if use_fused_attention: @@ -8208,7 +6250,8 @@ def forward( alibi_slopes is not None or max_seqlen_q != max_seqlen_kv ): fu_core_attention_bias_type = "post_scale_bias" - _, fu_core_attention_bias = get_alibi( + _, fu_core_attention_bias = dpa_utils.get_alibi( + _alibi_cache, query_layer.shape[-2], max_seqlen_q, max_seqlen_kv, @@ -8216,6 +6259,7 @@ def forward( bias_dtype=query_layer.dtype, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) + # checkpoint_core_attention=False if checkpoint_core_attention: return self._checkpointed_attention_forward( self.fused_attention, @@ -8242,6 +6286,9 @@ def forward( cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, + inference_params=inference_params, ) return self.fused_attention( query_layer, @@ -8267,6 +6314,9 @@ def forward( cp_comm_type=self.cp_comm_type, fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, + inference_params=inference_params, ) from .cpu_offload import CPUOffloadEnabled @@ -8278,12 +6328,6 @@ def forward( ) if use_unfused_attention: - if window_size is not None and ( - window_size[0] != -1 or window_size[1] not in [-1, 0] - ): - attn_mask_type, attention_mask = get_swa_mask( - window_size, max_seqlen_q, max_seqlen_kv, attn_mask_type, attention_mask - ) if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -8295,9 +6339,11 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + inference_params=inference_params, ) return self.unfused_attention( query_layer, @@ -8308,12 +6354,13 @@ def forward( cu_seqlens_kv=cu_seqlens_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, + window_size=window_size, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + inference_params=inference_params, ) - - raise ValueError("No dot product attention support for the provided inputs!") + return None class MultiheadAttention(torch.nn.Module): @@ -8482,11 +6529,11 @@ def __init__( fuse_qkv_params: bool = False, zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", @@ -8496,8 +6543,8 @@ def __init__( self.qkv_format = qkv_format self.attn_mask_type = attn_mask_type - self.window_size = check_set_window_size(attn_mask_type, window_size) - self.layer_number = layer_number + self.window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) + self.layer_number = 1 if layer_number is None else layer_number self.input_layernorm = input_layernorm self.attention_type = attention_type self.get_rng_state_tracker = get_rng_state_tracker @@ -8666,19 +6713,6 @@ def __init__( **common_gemm_kwargs, ) - def _allocate_memory( - self, inference_max_sequence_len: int, batch_size: int, dtype: torch.dtype - ) -> torch.Tensor: - """Allocates memory for KV cache.""" - return torch.empty( - inference_max_sequence_len, - batch_size, - self.num_gqa_groups_per_partition, - self.hidden_size_per_attention_head, - dtype=dtype, - device=torch.cuda.current_device(), - ) - def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ Set the tensor parallel group for the given @@ -8767,6 +6801,7 @@ def forward( max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, + pad_between_seqs: Optional[bool] = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """ Forward propagation for MultiheadAttention layer. @@ -8845,6 +6880,9 @@ def forward( Calculated from `cu_seqlens_kv` if not provided. fast_zero_fill: bool, default = `True` Whether to set output tensors to 0 or not before use. + pad_between_seqs: Optional[bool], default = `None` + If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If true, there are padding tokens between individual sequences in a packed batch. """ # hidden_states: [sq, b, h] @@ -8852,7 +6890,7 @@ def forward( attn_mask_type = self.attn_mask_type if window_size is None: window_size = self.window_size - window_size = check_set_window_size(attn_mask_type, window_size) + window_size = dpa_utils.check_set_window_size(attn_mask_type, window_size) if "padding" in attn_mask_type and attention_mask is not None: for mask in attention_mask: @@ -8863,31 +6901,14 @@ def forward( ), f"core_attention_bias_type {core_attention_bias_type} is not supported!" # ================================================= - # Pre-allocate memory for key-values for inference + # Pre-allocate memory for key-value cache for inference # ================================================= - if inference_params and self.layer_number is not None: - assert ( - self.qkv_format != "thd" - ), "qkv_format == thd is not supported for an inference with KV-cache!" - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, hidden_states.dtype - ) - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - else: - ( - inference_key_memory, - inference_value_memory, - ) = inference_params.key_value_memory_dict[self.layer_number] + if ( + inference_params is not None + and self.layer_number not in inference_params.cache_manager.cache + ): + inference_params.allocate_memory(self.layer_number) # ====================== # Query, Key, and Value @@ -8948,16 +6969,9 @@ def forward( # not qkv_weight_interleaved: # [sq, b, (np/ng + 2), ng, hn] # --> [sq, b, np/ng, np, hn], [sq, b, 1, ng, hn], [sq, b, 1, ng, hn] - if not is_in_onnx_export_mode(): - query_layer, key_layer, value_layer = _SplitAlongDim.apply( - mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) - ) - else: - query_layer, key_layer, value_layer = torch.split( - mixed_x_layer, - (num_queries_per_key_value, 1, 1), - dim=split_dim, - ) + query_layer, key_layer, value_layer = _SplitAlongDim.apply( + mixed_x_layer, split_dim, (num_queries_per_key_value, 1, 1) + ) if self.qkv_format == "thd": query_layer, key_layer, value_layer = ( @@ -8999,18 +7013,11 @@ def forward( mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # mixed_kv_layer --> 2 [sk, b, ng, hn] - if not is_in_onnx_export_mode(): - key_layer, value_layer = _SplitAlongDim.apply( - mixed_kv_layer, - split_dim, - mixed_kv_layer.shape[split_dim] // 2, - ) - else: - key_layer, value_layer = torch.split( - mixed_kv_layer, - mixed_kv_layer.shape[split_dim] // 2, - dim=split_dim, - ) + key_layer, value_layer = _SplitAlongDim.apply( + mixed_kv_layer, + split_dim, + mixed_kv_layer.shape[split_dim] // 2, + ) key_layer, value_layer = ( x.reshape( x.size(0), @@ -9067,9 +7074,12 @@ def forward( elif self.qkv_format == "bshd": sequence_length = key_layer.size(1) else: - raise ValueError(f"QKV format {self.qkv_format} not supported for KV caching.") + raise ValueError( + f"qkv_format={self.qkv_format} not supported for KV caching and RoPE." + ) - sequence_start = inference_params.sequence_len_offset + sequence_start = inference_params.get_seqlens_pre_step() + # sequence_start = inference_params.seqlens[0] sequence_end = sequence_start + sequence_length q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] @@ -9116,15 +7126,16 @@ def forward( alibi_slopes=alibi_slopes, fast_zero_fill=fast_zero_fill, inference_params=inference_params, + pad_between_seqs=pad_between_seqs, ) # =================== # Output. [sq, b, h] # =================== - projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch, + fp8_grad=isinstance(context_layer, QuantizedTensor), ) if self.return_bias: diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index bf5ca4d98e..3d807960ca 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -16,12 +16,24 @@ """ TE_DType = { torch.uint8: tex.DType.kByte, + torch.float8_e4m3fn: tex.DType.kFloat8E4M3, + torch.float8_e5m2: tex.DType.kFloat8E5M2, torch.int32: tex.DType.kInt32, torch.float32: tex.DType.kFloat32, torch.half: tex.DType.kFloat16, torch.bfloat16: tex.DType.kBFloat16, } +TE_DType_To_Torch = { + tex.DType.kByte: torch.uint8, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, +} + AttnMaskTypes = ( "no_mask", "padding", @@ -52,6 +64,16 @@ "thd_t2hd", "thd_th2d", "thd_thd_thd", + "sbhd_bshd_bshd", + "bshd_sbhd_sbhd", + "thd_bshd_bshd", + "thd_sbhd_sbhd", + "paged_kv_bshd_bshd_bshd", + "paged_kv_bshd_sbhd_sbhd", + "paged_kv_sbhd_bshd_bshd", + "paged_kv_sbhd_sbhd_sbhd", + "paged_kv_thd_bshd_bshd", + "paged_kv_thd_sbhd_sbhd", ) LayerTypes = ("encoder", "decoder") @@ -59,3 +81,5 @@ GemmParallelModes = ("row", "column", None) dist_group_type = torch.distributed.ProcessGroup + +MXFP8_BLOCK_SCALING_SIZE = 32 diff --git a/transformer_engine/pytorch/cpp_extensions/__init__.py b/transformer_engine/pytorch/cpp_extensions/__init__.py index 9f3c1b2424..944d1849bf 100644 --- a/transformer_engine/pytorch/cpp_extensions/__init__.py +++ b/transformer_engine/pytorch/cpp_extensions/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -7,8 +7,3 @@ from .fused_attn import * from .gemm import * -from .transpose import * -from .activation import * -from .normalization import * -from .cast import * -from .padding import * diff --git a/transformer_engine/pytorch/cpp_extensions/_common.py b/transformer_engine/pytorch/cpp_extensions/_common.py deleted file mode 100644 index 8f7e72e268..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/_common.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Helper functions for C++ extensions""" -import functools -from typing import Dict, Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex - - -@functools.lru_cache(maxsize=None) -def empty_tensor() -> torch.Tensor: - """Get tensor with no entries and no data""" - return torch.Tensor() - - -def canonicalize_fp8_scales( - *, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - fp8_meta: Optional[tex.FP8TensorMeta] = None, - fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None, - allow_multiple_offsets: bool = True, -) -> Tuple[Dict[str, torch.Tensor], Dict[str, int]]: - """Canonicalize FP8 scaling factors (scale, amax, scale-inverse) - - If a scaling factor is not provided, try to access it within the - FP8 meta tensors. Returns dict with tensors and dict with tensor - offsets. - - """ - - # Default: use provided scales with no offsets - scale_offset = 0 - amax_offset = 0 - scale_inv_offset = 0 - - # Get scales from FP8 meta tensors if needed - if (fp8_meta is not None) and any(arg is None for arg in (scale, amax, scale_inv)): - if fp8_meta_index is None: - raise ValueError("Provided `fp8_meta` without corresponding `fp8_meta_index`") - fp8_meta_index = int(fp8_meta_index) - if scale is None: - scale = fp8_meta.scale - scale_offset = fp8_meta_index - if amax is None: - amax = fp8_meta.amax_history - amax_offset = fp8_meta_index - if scale_inv is None: - scale_inv = fp8_meta.scale_inv - scale_inv_offset = fp8_meta_index - - # Construct empty tensors if needed - if scale is None: - scale = empty_tensor() - scale_offset = 0 - if amax is None: - amax = empty_tensor() - amax_offset = 0 - if scale_inv is None: - scale_inv = empty_tensor() - scale_inv_offset = 0 - - # Force offsets to be the same if needed - if not allow_multiple_offsets and not scale_offset == amax_offset == scale_inv_offset: - if scale_offset != 0: - scale = scale[scale_offset:] - scale_offset = 0 - if amax_offset != 0: - amax = amax[:, amax_offset:] - amax_offset = 0 - if scale_inv_offset != 0: - scale_inv = scale_inv[scale_inv_offset:] - scale_inv_offset = 0 - - # Pack tensors and offsets into dicts - tensors = {"scale": scale, "amax": amax, "scale_inv": scale_inv} - offsets = { - "scale_offset": scale_offset, - "amax_offset": amax_offset, - "scale_inv_offset": scale_inv_offset, - } - return tensors, offsets diff --git a/transformer_engine/pytorch/cpp_extensions/activation.py b/transformer_engine/pytorch/cpp_extensions/activation.py deleted file mode 100644 index f204982aa0..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/activation.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for activation extensions""" -from typing import Optional, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - -__all__ = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] - - -def gelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """GeLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.gelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def relu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.relu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def geglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """GeGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.geglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def reglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.reglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def swiglu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """SwiGLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.swiglu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def qgelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """QuickGELU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.qgelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - - -def srelu( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ReLU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - return torch.ops.tex_ts.srelu_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py deleted file mode 100644 index cd3c01c785..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for cast extensions""" -from typing import Optional, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - -__all__ = ["cast_to_fp8", "cast_from_fp8"] - - -def cast_to_fp8( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Cast input to FP8""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch FP8 cast kernel - if inp.nelement() == 0: - if out is None: - out = torch.empty_like(inp, dtype=torch.uint8) - elif out is None: - out = torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - else: - torch.ops.tex_ts.cast_to_fp8_noalloc_ts( - inp, - fp8_scales["scale"], - out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - ) - return out - - -def cast_from_fp8( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - itype: tex.DType, - otype: tex.DType, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """Cast input from FP8""" - - # Get scaling factors from FP8 meta tensors if needed - scale_inv_offset = 0 - if (fp8_meta_tensor is not None) and (scale_inv is None): - if fp8_tensor is None: - raise ValueError("Provided `fp8_meta_tensor` without corresponding `fp8_tensor`") - scale_inv = fp8_meta_tensor.scale_inv - scale_inv_offset = int(fp8_tensor) - - # Construct empty tensors if needed - if scale_inv is None: - raise ValueError("Did not provide either `scale_inv` or `fp8_meta_tensor`") - - # Launch FP8 cast kernel - return torch.ops.tex_ts.cast_from_fp8_ts( - inp, - scale_inv, - scale_inv_offset, - itype, - otype, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 1932e9feb2..b9810bf861 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -1,25 +1,23 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Python interface for fused attention extensions""" import math -from typing import Tuple, List, Union +from typing import Tuple, List, Union, Optional import torch import transformer_engine_torch as tex from transformer_engine_torch import ( NVTE_QKV_Layout, + NVTE_QKV_Format, NVTE_Bias_Type, NVTE_Mask_Type, NVTE_Fused_Attn_Backend, ) +from ..tensor.quantized_tensor import Quantizer __all__ = [ - "fused_attn_fwd_qkvpacked", - "fused_attn_bwd_qkvpacked", - "fused_attn_fwd_kvpacked", - "fused_attn_bwd_kvpacked", "fused_attn_fwd", "fused_attn_bwd", ] @@ -34,6 +32,16 @@ tex.DType.kInt32: torch.int32, } +QKVFormat = { + "bshd": NVTE_QKV_Format.NVTE_BSHD, + "sbhd": NVTE_QKV_Format.NVTE_SBHD, + "thd": NVTE_QKV_Format.NVTE_THD, + "sbhd_2bshd": NVTE_QKV_Format.NVTE_SBHD_2BSHD, + "bshd_2sbhd": NVTE_QKV_Format.NVTE_BSHD_2SBHD, + "thd_2bshd": NVTE_QKV_Format.NVTE_THD_2BSHD, + "thd_2sbhd": NVTE_QKV_Format.NVTE_THD_2SBHD, +} + QKVLayout = { "sb3hd": NVTE_QKV_Layout.NVTE_SB3HD, "sbh3d": NVTE_QKV_Layout.NVTE_SBH3D, @@ -50,6 +58,16 @@ "thd_t2hd": NVTE_QKV_Layout.NVTE_THD_T2HD, "thd_th2d": NVTE_QKV_Layout.NVTE_THD_TH2D, "thd_thd_thd": NVTE_QKV_Layout.NVTE_THD_THD_THD, + "sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_SBHD_BSHD_BSHD, + "bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_BSHD_SBHD_SBHD, + "thd_bshd_bshd": NVTE_QKV_Layout.NVTE_THD_BSHD_BSHD, + "thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_THD_SBHD_SBHD, + "paged_kv_bshd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_BSHD_BSHD, + "paged_kv_bshd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_BSHD_SBHD_SBHD, + "paged_kv_sbhd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_BSHD_BSHD, + "paged_kv_sbhd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_SBHD_SBHD_SBHD, + "paged_kv_thd_bshd_bshd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_BSHD_BSHD, + "paged_kv_thd_sbhd_sbhd": NVTE_QKV_Layout.NVTE_Paged_KV_THD_SBHD_SBHD, } AttnBiasType = { @@ -89,803 +107,6 @@ META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 -def fused_attn_fwd_qkvpacked( - is_training: bool, - max_seqlen: int, - cu_seqlens: torch.Tensor, - qkv: torch.Tensor, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - attn_bias: torch.Tensor = None, - cu_seqlens_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbh3d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - rng_gen: torch.Generator = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention FWD for packed QKV input. - - Parameters - ---------- - is_training: bool - if True, runs training and produces auxiliary tensors aux_ctx_tensors - for the backward; if False, runs inference and doesn't produce aux_ctx_tensors - max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(seqlens), - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - cu_seqlens: torch.Tensor - cumulative sequence lengths for QKV; shape [batch_size + 1] - qkv: torch.Tensor - input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of QKV; in tex.DType, not torch.dtype - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen, max_seqlen], same data type as qkv - cu_seqlens_padded: torch.Tensor, default = None - cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbh3d" - layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - rng_gen: torch.Generator, default = None - random number generator; - if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen - - Returns - ---------- - o: torch.Tensor - output tensor O, of the attention calculation; same data type as QKV; - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] - if is_training is False, aux_ctx_tensors = None - - softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen, max_seqlen], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - softmaxStats: torch.Tensor - log(sum(e^(x - max(x)))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen - state of the random number generator; - [seed, offset], dtype uint64 - """ - - if attn_scale is None: - d = qkv.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == qkv.dtype, "attn_bias tensor must be in the same dtype as qkv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = ( - max_seqlen * max_seqlen + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." - else: - raise ValueError(f"Unsupported backend {fused_attention_backend}") - - # execute kernel - output_tensors = tex.fused_attn_fwd_qkvpacked( - max_seqlen, - is_training, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - cu_seqlens, - qkv, - qkv_dtype, - cu_seqlens_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, - attn_bias, - rng_gen, - rng_elts_per_thread, - ) - - # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] - - -def fused_attn_bwd_qkvpacked( - max_seqlen: int, - cu_seqlens: torch.Tensor, - qkv: torch.Tensor, - o: torch.Tensor, - d_o: torch.Tensor, - qkv_dtype: tex.DType, - dqkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor], - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - cu_seqlens_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbh3d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - deterministic: bool = False, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention BWD for packed QKV input. - - Parameters - ---------- - max_seqlen: int - max sequence length for QKV, used for padding; may be larger than max(seqlens) - seqlens = cu_seqlens[1:] - cu_seqlens[:-1] - cu_seqlens: torch.Tensor - cumulative sequence lengths for QKV; shape [batch_size + 1] - qkv: torch.Tensor - input tensor QKV; shape 3hd or h3d (see `qkv_layout` for details) - o: torch.Tensor - input tensor O (output of forward); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - d_o: torch.Tensor - input tensor dO (gradient of O); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of QKV; in tex.DType, not torch.dtype - dqkv_dtype: tex.DType - data type of dQKV; in tex.DType, not torch.dtype - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - cu_seqlens_padded: torch.Tensor, default = None - cumulative sequence offsets for QKV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQKV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQKV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbh3d" - layout of QKV; {"sbh3d", "sb3hd", "bsh3d", "bs3hd", "th3d", "t3hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - deterministic: bool, default = False - whether to execute the backward pass with deterministic behaviours. - - Returns - ---------- - d_qkv: torch.Tensor - gradient tensor of QKV; same data type and shape as QKV - d_bias: torch.Tensor, optional - gradient tensor of Bias when attn_bias_type is "pre_scale_bias" - or "post_scale_bias"; same data type and shape as Bias - """ - - if attn_scale is None: - d = qkv.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - - if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - - # execute kernel - output_tensors = tex.fused_attn_bwd_qkvpacked( - max_seqlen, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - deterministic, - cu_seqlens, - qkv, - o, - d_o, - qkv_dtype, - dqkv_dtype, - aux_ctx_tensors, - cu_seqlens_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, - ) - - return output_tensors - - -def fused_attn_fwd_kvpacked( - is_training: bool, - max_seqlen_q: int, - max_seqlen_kv: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - q: torch.Tensor, - kv: torch.Tensor, - qkv_dtype: tex.DType, - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - attn_bias: torch.Tensor = None, - cu_seqlens_q_padded: torch.Tensor = None, - cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbhd_sbh2d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - rng_gen: torch.Generator = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention FWD for packed KV input. - - Parameters - ---------- - is_training: bool - if True, runs training and produces auxiliary tensors aux_ctx_tensors - for the backward; if False, runs inference and doesn't produce aux_ctx_tensors - max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(seqlens_q), - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(seqlens_kv), - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - cu_seqlens_q: torch.Tensor - cumulative sequence lengths for Q; shape [batch_size + 1] - cu_seqlens_kv: torch.Tensor - cumulative sequence lengths for KV; shape [batch_size + 1] - q: torch.Tensor - input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details) - kv: torch.Tensor - packed input tensor KV; shape 2hd or h2d (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q and KV; in tex.DType, not torch.dtype - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - attn_bias: torch.Tensor, default = None - input tensor Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; - shape [1, num_heads, max_seqlen_q, max_seqlen_kv], same data type as q and kv - cu_seqlens_q_padded: torch.Tensor, default = None - cumulative sequence offsets for Q; shape [batch_size + 1] - cu_seqlens_kv_padded: torch.Tensor, default = None - cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbhd_sbh2d" - layout of QKV; - {"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - rng_gen: torch.Generator, default = None - random number generator; - if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen - - Returns - ---------- - o: torch.Tensor - output tensor O, of the attention calculation; same data type as QKV; - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors used for the backward; - if is_training is True, aux_ctx_tensors = [softmax-related tensors, rng_state] - if is_training is False, aux_ctx_tensors = None - - softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] - softmaxStats: torch.Tensor - log(sum(e^(x - max(x)))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor - 1/sum(e^(x - max(x))), where x=Q*K.T - shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen - state of the random number generator; - [seed, offset], dtype uint64 - """ - - if attn_scale is None: - d = q.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - if attn_bias_type not in ["no_bias", "alibi"]: - assert ( - attn_bias is not None - ), "attn_bias tensor cannot be None when attn_bias_type is not no_bias or alibi." - assert attn_bias.dtype == q.dtype, "attn_bias tensor must be in the same dtype as q and kv." - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: - rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS - # FP8 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["FP8"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - - assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." - else: - raise ValueError(f"Unsupported backend {fused_attention_backend}") - - # execute kernel - output_tensors = tex.fused_attn_fwd_kvpacked( - max_seqlen_q, - max_seqlen_kv, - is_training, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - qkv_dtype, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, - attn_bias, - rng_gen, - rng_elts_per_thread, - ) - - # out, aux_ctx_tensors - return output_tensors[0], output_tensors[1:] - - -def fused_attn_bwd_kvpacked( - max_seqlen_q: int, - max_seqlen_kv: int, - cu_seqlens_q: torch.Tensor, - cu_seqlens_kv: torch.Tensor, - q: torch.Tensor, - kv: torch.Tensor, - o: torch.Tensor, - d_o: torch.Tensor, - qkv_dtype: tex.DType, - dqkv_dtype: tex.DType, - aux_ctx_tensors: List[torch.Tensor], - fused_attention_backend: tex.NVTE_Fused_Attn_Backend, - cu_seqlens_q_padded: torch.Tensor = None, - cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, - dropout: float = 0.0, - fast_zero_fill: bool = True, - qkv_layout: str = "sbhd_sbh2d", - attn_bias_type: str = "no_bias", - attn_mask_type: str = "padding", - window_size: Tuple[int, int] = (-1, -1), - deterministic: bool = False, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Fused Attention BWD for packed KV input. - - Parameters - ---------- - max_seqlen_q: int - max sequence length for Q, used for padding; may be larger than max(seqlens_q), - seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_kv: int - max sequence length for KV, used for padding; may be larger than max(seqlens_kv), - seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - cu_seqlens_q: torch.Tensor - cumulative sequence lengths for Q; shape [batch_size + 1] - cu_seqlens_kv: torch.Tensor - cumulative sequence lengths for KV; shape [batch_size + 1] - q: torch.Tensor - input tensor Q; shape thd, sbhd or bshd (see `qkv_layout` for details) - kv: torch.Tensor - packed input tensor KV; shape h2d or 2hd (see `qkv_layout` for details) - o: torch.Tensor - input tensor O (output of forward); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - d_o: torch.Tensor - input tensor dO (gradient of O); - same shape as Q, i.e. thd, sbhd or bshd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q and KV; in tex.DType, not torch.dtype - dqkv_dtype: tex.DType - data type of dQ and dKV; in tex.DType, not torch.dtype - aux_ctx_tensors: List[torch.Tensor] - auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] - fused_attention_backend: tex.NVTE_Fused_Attn_Backend - please see FusedAttention module for details on supported backends. - cu_seqlens_q_padded: torch.Tensor, default = None - cumulative sequence offsets for Q; shape [batch_size + 1] - cu_seqlens_kv_padded: torch.Tensor, default = None - cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQKV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations, - P = Q * K.T - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQKV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default - dropout: float, default = 0.0 - dropout probability, 0.0 means no dropout, 1.0 means no output; - dropout must be 0.0 if is_training is False - fast_zero_fill: bool, default = True - if True, initializes the output tensor O to zero using the fast filling method; - if False, uses PyTorch's .fill_() method - qkv_layout: str, default = "sbhd_sbh2d" - layout of QKV; - {"sbhd_sbh2d", "sbhd_sb2hd", "bshd_bsh2d", "bshd_bs2hd", "thd_th2d", "thd_t2hd"} - attn_bias_type: str, default = "no_bias" - type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} - attn_mask_type: str, default = "padding" - type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} - window_size: Tuple[int, int], default = (-1, -1) - sliding window size for local attention, where query at position i attends to keys - in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q - + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding - window and causal mask specifically. - deterministic: bool, default = False - whether to execute the backward pass with deterministic behaviours. - - Returns - ---------- - d_q: torch.Tensor - gradient tensor of Q; same data type and shape as Q - d_kv: torch.Tensor - gradient tensor of KV; same data type and shape as KV - d_bias: torch.Tensor, optional - gradient tensor of Bias when attn_bias_type is "pre_scale_bias" - or "post_scale_bias"; same data type and shape as Bias - """ - - if attn_scale is None: - d = q.size(-1) - attn_scale = 1.0 / math.sqrt(d) - - assert ( - fused_attention_backend != FusedAttnBackend["No_Backend"] - ), "Fused attention does not support this input combination." - - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - assert ( - len(aux_ctx_tensors) >= 1 - ), "aux_ctx_tensors must contain rng_state as its last element." - - if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." - assert ( - len(aux_ctx_tensors) == 3 - ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - - # execute kernel - output_tensors = tex.fused_attn_bwd_kvpacked( - max_seqlen_q, - max_seqlen_kv, - attn_scale, - dropout, - fast_zero_fill, - QKVLayout[qkv_layout], - AttnBiasType[attn_bias_type], - AttnMaskType[attn_mask_type], - window_size, - deterministic, - cu_seqlens_q, - cu_seqlens_kv, - q, - kv, - o, - d_o, - qkv_dtype, - dqkv_dtype, - aux_ctx_tensors, - cu_seqlens_q_padded, - cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, - ) - - return output_tensors - - def fused_attn_fwd( is_training: bool, max_seqlen_q: int, @@ -895,23 +116,15 @@ def fused_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - qkv_dtype: tex.DType, + fake_dtype: torch.dtype, fused_attention_backend: tex.NVTE_Fused_Attn_Backend, attn_bias: torch.Tensor = None, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_qkv_offset: int = META_QKV, - d_scale_s: torch.Tensor = None, - d_scale_s_offset: int = META_S, - q_scale_s: torch.Tensor = None, - q_scale_s_offset: int = META_S, - q_scale_o: torch.Tensor = None, - q_scale_o_offset: int = META_O, - amax_s: torch.Tensor = None, - amax_s_offset: int = META_S, - amax_o: torch.Tensor = None, - amax_o_offset: int = META_O, + page_table_k: torch.Tensor = None, + page_table_v: torch.Tensor = None, + s_quantizer: Quantizer = None, + o_quantizer: Quantizer = None, attn_scale: float = None, dropout: float = 0.0, fast_zero_fill: bool = True, @@ -946,8 +159,9 @@ def fused_attn_fwd( input tensor K; shape sbhd, bshd or thd (see `qkv_layout` for details) v: torch.Tensor input tensor V; shape sbhd, bshd or thd (see `qkv_layout` for details) - qkv_dtype: tex.DType - data type of Q, K and V; in tex.DType, not torch.dtype + fake_dtype: tex.DType + data type of Q, K and V - in case of high precision, fake dtype in case of FP8; + in torch.dtype fused_attention_backend: tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. attn_bias: torch.Tensor, default = None @@ -957,30 +171,14 @@ def fused_attn_fwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of QKV in FP8 computations - d_scale_qkv_offset: int, default = META_QKV - offset in d_scale_qkv for QKV - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_s_offset: int, default = META_S - offset in d_scale_s for S - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations, S = Softmax(Q * K.T) - q_scale_s_offset: int, default = META_S - offset in q_scale_s for S - q_scale_o: torch.Tensor, default = None - input tensor for the quantization of O in FP8 computations - q_scale_o_offset: int, default = META_O - offset in q_scale_o for O - amax_s: torch.Tensor, default = None - output tensor, amax of S, used by the next iteration in FP8 computations - amax_s_offset: int, default = META_S - offset in amax_s for S - amax_o: torch.Tensor, default = None - output tensor, amax of O, used by the next iteration in FP8 computations - amax_o_offset: int, default = META_O - offset in amax_o for O + page_table_k: torch.Tensor, default = None + page table for K cache; shape [batch_size, max_pages_per_seq_k] + page_table_v: torch.Tensor, default = None + page table for V cache; shape [batch_size, max_pages_per_seq_v] + s_quantizer: Quantizer, default = None + Quantizer object for the intermediate value S. + o_quantizer: Quantizer, default = None + Quantizer object for the output of the attention. attn_scale: float, default = None if not None, use attn_scale as the attention scale for Q*K.T BMM; if None, use 1.0/sqrt(head_dim_qk) as the default @@ -1068,17 +266,16 @@ def fused_attn_fwd( ) // BACKEND_F16m512_FP8_THREADS_PER_CTA assert ( - d_scale_qkv is not None - ), "d_scale_qkv is required as an input for FP8 fused attention." - assert d_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required as an input for FP8 fused attention." - assert q_scale_o is not None, "q_scale_o is required as an input for FP8 fused attention." - assert amax_s is not None, "amax_s is required as an input for FP8 fused attention." - assert amax_o is not None, "amax_o is required as an input for FP8 fused attention." + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention." + assert ( + o_quantizer is not None + ), "o_quantizer is required as an input for FP8 fused attention." else: raise ValueError(f"Unsupported backend {fused_attention_backend}") # execute kernel + output_tensors = tex.fused_attn_fwd( max_seqlen_q, max_seqlen_kv, @@ -1095,21 +292,13 @@ def fused_attn_fwd( q, k, v, - qkv_dtype, + fake_dtype, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_qkv_offset, - d_scale_s, - d_scale_s_offset, - q_scale_s, - q_scale_s_offset, - q_scale_o, - q_scale_o_offset, - amax_s, - amax_s_offset, - amax_o, - amax_o_offset, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, attn_bias, rng_gen, rng_elts_per_thread, @@ -1129,23 +318,16 @@ def fused_attn_bwd( v: torch.Tensor, o: torch.Tensor, d_o: torch.Tensor, - qkv_dtype: tex.DType, + fake_dtype: torch.dtype, dqkv_dtype: tex.DType, aux_ctx_tensors: List[torch.Tensor], fused_attention_backend: tex.NVTE_Fused_Attn_Backend, cu_seqlens_q_padded: torch.Tensor = None, cu_seqlens_kv_padded: torch.Tensor = None, - d_scale_qkv: torch.Tensor = None, - d_scale_s: torch.Tensor = None, - d_scale_o: torch.Tensor = None, - d_scale_do: torch.Tensor = None, - d_scale_dp: torch.Tensor = None, - q_scale_s: torch.Tensor = None, - q_scale_dp: torch.Tensor = None, - q_scale_dqkv: torch.Tensor = None, - amax_dp: torch.Tensor = None, - amax_dqkv: torch.Tensor = None, - attn_scale: float = None, + s_quantizer: Quantizer = None, + dp_quantizer: Quantizer = None, + dqkv_quantizer: Quantizer = None, + attn_scale: Optional[float] = None, dropout: float = 0.0, fast_zero_fill: bool = True, qkv_layout: str = "sbh3d", @@ -1181,8 +363,9 @@ def fused_attn_bwd( d_o: torch.Tensor input tensor dO (gradient of O); same data type as Q, K and V; same shape as Q - qkv_dtype: tex.DType - data type of Q, K and V; in tex.DType, not torch.dtype + fake_dtype: tex.DType + data type of Q, K and V - in case of high precision, fake dtype in case of FP8; + in torch.dtype dqkv_dtype: tex.DType data type of dQ, dK and dV; in tex.DType, not torch.dtype aux_ctx_tensors: List[torch.Tensor] @@ -1194,30 +377,12 @@ def fused_attn_bwd( cumulative sequence offsets for Q; shape [batch_size + 1] cu_seqlens_kv_padded: torch.Tensor, default = None cumulative sequence offsets for KV; shape [batch_size + 1] - d_scale_qkv: torch.Tensor, default = None - input tensor for the dequantization of Q, K and V in FP8 computations - d_scale_s: torch.Tensor, default = None - input tensor for the dequantization of S in FP8 computations, S = Softmax(Q * K.T) - d_scale_o: torch.Tensor, default = None - input tensor for the dequantization of O in FP8 computations - d_scale_do: torch.Tensor, default = None - input tensor for the dequantization of dO in FP8 computations - d_scale_dp: torch.Tensor, default = None - input tensor for the dequantization of dP in FP8 computations - q_scale_s: torch.Tensor, default = None - input tensor for the quantization of S in FP8 computations - q_scale_dp: torch.Tensor, default = None - input tensor for the quantization of dP in FP8 computations, P = Q * K.T - q_scale_dqkv: torch.Tensor, default = None - input tensor for the quantization of dQ, dK and dV in FP8 computations - amax_dp: torch.Tensor, default = None - output tensor, amax of dP, used by the next iteration in FP8 computations, - P = Q * K.T - amax_dqkv: torch.Tensor, default = None - output tensor, amax of dQ, dK and dV, used by the next iteration in FP8 computations - attn_scale: float, default = None - if not None, use attn_scale as the attention scale for Q*K.T BMM; - if None, use 1.0/sqrt(head_dim_qk) as the default + s_quantizer: Quantizer, default = None + Quantizer object for the intermediate value S. + dp_quantizer: Quantizer, default = None + Quantizer object for the intermediate value dP. + dqkv_quantizer: Quantizer, default = None + Quantizer object for the output values of the fused_attn_bwd. dropout: float, default = 0.0 dropout probability, 0.0 means no dropout, 1.0 means no output; dropout must be 0.0 if is_training is False @@ -1253,7 +418,6 @@ def fused_attn_bwd( gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias """ - if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -1268,21 +432,19 @@ def fused_attn_bwd( ), "aux_ctx_tensors must contain rng_state as its last element." if fused_attention_backend == FusedAttnBackend["FP8"]: - assert d_scale_qkv is not None, "d_scale_qkv is required for FP8 fused attention." - assert d_scale_s is not None, "d_scale_s is required for FP8 fused attention." - assert d_scale_o is not None, "d_scale_o is required for FP8 fused attention." - assert d_scale_do is not None, "d_scale_do is required for FP8 fused attention." - assert d_scale_dp is not None, "d_scale_dp is required for FP8 fused attention." - assert q_scale_s is not None, "q_scale_s is required for FP8 fused attention." - assert q_scale_dp is not None, "q_scale_dp is required for FP8 fused attention." - assert q_scale_dqkv is not None, "q_scale_dqkv is required for FP8 fused attention." - assert amax_dp is not None, "amax_dp is required for FP8 fused attention." - assert amax_dqkv is not None, "amax_dqkv is required for FP8 fused attention." + assert ( + s_quantizer is not None + ), "s_quantizer is required as an input for FP8 fused attention backward." + assert ( + dp_quantizer is not None + ), "dp_quantizer is required as an input for FP8 fused attention backward." + assert ( + dqkv_dtype is not None + ), "dqkv_dtype is required as an input for FP8 fused attention backward." assert ( len(aux_ctx_tensors) == 3 ), "aux_ctx_tensors is required to be [M, ZInv, rng_state] for FP8 fused attention." - # execute kernel output_tensors = tex.fused_attn_bwd( max_seqlen_q, max_seqlen_kv, @@ -1301,21 +463,14 @@ def fused_attn_bwd( v, o, d_o, - qkv_dtype, + fake_dtype, dqkv_dtype, aux_ctx_tensors, cu_seqlens_q_padded, cu_seqlens_kv_padded, - d_scale_qkv, - d_scale_s, - d_scale_o, - d_scale_do, - d_scale_dp, - q_scale_s, - q_scale_dp, - q_scale_dqkv, - amax_dp, - amax_dqkv, + s_quantizer, + dp_quantizer, + dqkv_quantizer, ) return output_tensors diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 932bb3cafa..948a13a03e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -1,502 +1,226 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Python interface for GEMM extensions""" import functools -from typing import Optional, Tuple, Union, List +from typing import Iterable, Optional, Tuple, Union, List +import os import torch import transformer_engine_torch as tex from ..constants import TE_DType -from ..utils import assert_dim_for_fp8_exec +from ..utils import assert_dim_for_fp8_exec, get_sm_count +from ..tensor.quantized_tensor import Quantizer +from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = [ - "gemm", - "fp8_gemm", - "grouped_gemm", - "fp8_grouped_gemm", + "general_gemm", + "general_grouped_gemm", ] @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: """Get tensor with no entries and no data""" - return torch.Tensor() + return torch.Tensor().cuda() -def fp8_gemm( - A: torch.Tensor, - A_scale_inv: torch.Tensor, - A_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - A_dtype: tex.DType, - B: torch.Tensor, - B_scale_inv: torch.Tensor, - B_fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - B_dtype: tex.DType, - out_dtype: torch.dtype, - workspace: torch.Tensor, - gelu: bool = False, - accumulate: bool = False, - out: Optional[torch.Tensor] = None, - out_index=None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - bias: Optional[torch.Tensor] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, - ub_algo: tex.CommOverlapAlgo = None, - ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, - extra_output_tensor: torch.Tensor = None, -) -> torch.Tensor: - """TN layout GEMM with fp8 inputs.""" +def swizzle_inputs(A: torch.Tensor, B: torch.Tensor, layout: str): + """Swizzle gemm inputs and return original scaling factor inverses.""" + if not isinstance(A, MXFP8TensorBase) or not isinstance(B, MXFP8TensorBase): + return None - empty_tensor = _empty_tensor() - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_index is not None - assert_dim_for_fp8_exec(A) - assert_dim_for_fp8_exec(B) - assert A.dtype == torch.uint8 - assert B.dtype == torch.uint8 - - if out is None: - out = torch.empty( - B.shape[0], - A.shape[0], - dtype=out_dtype, - device="cuda", - ) + original_scale_inverses = ( + A._rowwise_scale_inv, + A._columnwise_scale_inv, + B._rowwise_scale_inv, + B._columnwise_scale_inv, + ) + + if layout[0] == "T": + A._rowwise_scale_inv = tex.rowwise_swizzle(A._rowwise_data, A._rowwise_scale_inv) else: - if not out.is_contiguous(): - raise ValueError("Output tensor is not contiguous.") + A._columnwise_scale_inv = tex.columnwise_swizzle( + A._columnwise_data, A._columnwise_scale_inv + ) - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias.dtype - if gelu: - gelu_input = torch.empty_like(out, dtype=bias_dtype) + if layout[1] == "N": + B._rowwise_scale_inv = tex.rowwise_swizzle(B._rowwise_data, B._rowwise_scale_inv) else: - gelu_input = empty_tensor - bias_dtype = TE_DType[bias_dtype] + B._columnwise_scale_inv = tex.columnwise_swizzle( + B._columnwise_data, B._columnwise_scale_inv + ) - out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype + return original_scale_inverses - args = ( - A, - A_scale_inv, - A_fp8_tensor, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor, - B_dtype, - False, # transb - out, - empty_tensor if out_index is None else fp8_meta_tensor.scale[out_index], - out_dtype, - empty_tensor if out_index is None else fp8_meta_tensor.amax_history[0][out_index], - bias if use_bias else empty_tensor, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - ) - fn = torch.ops.tex_ts.te_gemm_ts - if ub_algo is not None: - assert ub is not None, "ub object is None!" - if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.AG, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: - fn = ub.bulk_overlap - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple( - args - + ( - tex.CommOverlapType.RS, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: - fn = ub.split_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P: - fn = ub.atomic_gemm_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: - fn = ub.split_overlap_rs - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: - fn = ub.split_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS: - fn = ub.atomic_gemm_overlap_rs - assert extra_output_tensor is not None, "ATOMIC_GEMM_RS requires extra output tensor" - args = tuple( - args - + ( - True, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P: - fn = ub.atomic_gemm_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "ATOMIC_GEMM_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - _ = fn(*args) - - return out, gelu_input - - -def gemm( + +def reset_swizzled_inputs(A, B, scale_inverses): + """Reset the swizzled scale inverses after GEMM.""" + if scale_inverses is not None: + ( + A._rowwise_scale_inv, + A._columnwise_scale_inv, + B._rowwise_scale_inv, + B._columnwise_scale_inv, + ) = scale_inverses + + +def general_gemm( A: torch.Tensor, B: torch.Tensor, - dtype: torch.dtype, workspace: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, + quantization_params: Optional[Quantizer] = None, gelu: bool = False, - gelu_input: Optional[torch.Tensor] = None, - grad: bool = False, + gelu_in: torch.Tensor = None, accumulate: bool = False, layout: str = "TN", out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, - use_bias: bool = False, - ub_algo: tex.CommOverlapAlgo = None, + use_split_accumulator: bool = False, + grad: bool = False, ub: Union[tex.CommOverlap, tex.CommOverlapP2P] = None, - extra_output_tensor: torch.Tensor = None, -) -> Tuple[Union[torch.Tensor, None], ...]: - """Non FP8 GEMM.""" + ub_type: tex.CommOverlapType = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, +) -> Iterable[Optional[torch.Tensor]]: + """GEMM supporting fp8 inputs.""" assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - empty_tensor = _empty_tensor() - fp8_index = -1 # dummy index - - if out is None: - out = torch.empty( - B.shape[1] if transb else B.shape[0], - A.shape[0] if transa else A.shape[1], - dtype=dtype, - device="cuda", + # assert quantization_params is None, "FP8 output not supported yet" + + if ub_type is not None: + assert ub is not None, ( + f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" + + "a valid `ub` communicator object." ) - else: - if not out.is_contiguous(): - raise ValueError("Output tensor is not contiguous.") - if gelu and not grad: - gelu_input = torch.empty_like(out, dtype=dtype) - elif not gelu: - gelu_input = empty_tensor + if ub is not None: + assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument." + if ub_type == tex.CommOverlapType.RS: + if not (bulk_overlap and not ub.is_fp8_ubuf()): + assert extra_output is not None, "GEMM+RS overlap requires extra output tensor." - if grad and use_bias: - grad_bias = torch.empty(B.shape[1], dtype=out.dtype, device="cuda") - else: - grad_bias = empty_tensor - - bias = bias if use_bias else empty_tensor + if out is not None: + if not out.is_contiguous(): + raise ValueError("Output tensor is not contiguous.") - assert ( - A.dtype == dtype and B.dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out.dtype] - if use_bias: - bias_dtype = TE_DType[grad_bias.dtype] if grad else TE_DType[bias.dtype] - else: - bias_dtype = output_dtype + # Use bfloat16 as default bias_dtype + bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] args = ( A, - empty_tensor, - fp8_index, - input_dtype, - transa, + transa, # transa B, - empty_tensor, - fp8_index, - input_dtype, - transb, + transb, # transb out, - empty_tensor, # out_scale - output_dtype, - empty_tensor, # out_amax - grad_bias if grad else bias, + quantization_params, + TE_DType[out_dtype] if out_dtype is not None else None, + bias, bias_dtype, - gelu_input, - grad, + gelu, + gelu_in, + grad, # grad workspace, workspace.shape[0], accumulate, - False, # use_split_accumulator + use_split_accumulator, ) - fn = torch.ops.tex_ts.te_gemm_ts - if ub_algo is not None: - assert ub is not None, "ub object is None!" - if ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_AG: - fn = ub.bulk_overlap - args = tuple(args + (tex.CommOverlapType.AG, empty_tensor)) - elif ub_algo == tex.CommOverlapAlgo.BULK_OVERLAP_RS: - fn = ub.bulk_overlap - args = tuple(args + (tex.CommOverlapType.RS, empty_tensor)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P: - fn = ub.split_overlap_ag_p2p - extra_output_tensor = ( - empty_tensor if extra_output_tensor is None else extra_output_tensor - ) - args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS: - fn = ub.split_overlap_rs - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS requires extra output tensor" - args = tuple( - args - + ( - False, - extra_output_tensor, - ) - ) - elif ub_algo == tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P: - fn = ub.split_overlap_rs_p2p - assert ( - extra_output_tensor is not None - ), "SPLIT_PIPELINED_RS_P2P requires extra output tensor" - args = tuple(args + (extra_output_tensor,)) - _ = fn(*args) - - return out, grad_bias, gelu_input - - -def grouped_gemm( + kwargs = { + "comm_overlap": ub, + "comm_type": ub_type, + "extra_output": extra_output, + "bulk_overlap": bulk_overlap, + } + + original_scale_inverses = swizzle_inputs(A, B, layout) + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + reset_swizzled_inputs(A, B, original_scale_inverses) + + return out, bias_grad, gelu_input, extra_output + + +def general_grouped_gemm( A: List[torch.Tensor], B: List[torch.Tensor], out: List[torch.Tensor], - dtype: torch.dtype, + out_dtype: torch.dtype, workspaces: List[torch.Tensor], + layout: str = "TN", + m_splits: Optional[List[int]] = None, gelu: bool = False, - gelu_input: Optional[List[torch.Tensor]] = None, - grad: bool = False, + grad=False, accumulate: bool = False, - layout: str = "TN", bias: Optional[List[torch.Tensor]] = None, use_bias: bool = False, + use_split_accumulator: bool = False, + D_dtype: Optional[tex.DType] = None, + single_output=False, ) -> Tuple[List[torch.Tensor], ...]: - """Non FP8 Grouped GEMM.""" + """ + TN layout Grouped GEMM with fp8 inputs. + """ + num_gemms = len(A) - assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." transa = layout[0] == "T" transb = layout[1] == "T" - num_gemms = len(A) + + # assert [a.is_contiguous() for a in A] + # assert [b.is_contiguous() for b in B] + + if isinstance(A[0], Float8TensorBase): + for a, b in zip(A, B): + assert_dim_for_fp8_exec(a._data) + assert_dim_for_fp8_exec(b._data) + empty_tensor = _empty_tensor() empty_tensors = [empty_tensor] * num_gemms - if gelu and not grad: - gelu_input = [ - torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out - ] - elif not gelu: - gelu_input = empty_tensors + # Use bfloat16 as default bias_dtype + gelu_input = empty_tensors + out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype + sm_count = get_sm_count() if grad and use_bias: grad_bias = [ torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms) ] else: grad_bias = empty_tensors - bias = bias if use_bias else empty_tensors - - assert ( - A[0].dtype == dtype and B[0].dtype == dtype - ), f"Expected dtype={dtype}, but found A.dtype={A[0].dtype} and B.dtype={B[0].dtype}" - input_dtype = TE_DType[dtype] - output_dtype = TE_DType[out[0].dtype] if use_bias: bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype] else: - bias_dtype = output_dtype + bias_dtype = TE_DType[torch.bfloat16] - torch.ops.tex_ts.te_grouped_gemm_ts( + if gelu: + gelu_input = [ + torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) + for o in out + ] # this should differ with respect to single output + + bias = tex.te_general_grouped_gemm( A, - empty_tensor, - 0, # A_offset - input_dtype, transa, B, - empty_tensor, - 0, # B_offset - input_dtype, transb, out, - 0, # out_offset - empty_tensor, # out_scale - output_dtype, - empty_tensor, # out_amax + out_dtype, + m_splits, grad_bias if grad else bias, bias_dtype, - gelu_input, # gelu_input - grad, + single_output, + gelu_input, # this is pre_gelu_out + grad, # grad workspaces, workspaces[0].shape[0], accumulate, - False, # use_split_accumulator + use_split_accumulator, + sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))), ) - return out, grad_bias, gelu_input - - -def fp8_grouped_gemm( - A: List[torch.Tensor], - A_scale_inv: List[torch.Tensor], - A_fp8_tensor_offset: int, - A_dtype: tex.DType, - B: List[torch.Tensor], - B_scale_inv: torch.Tensor, - B_fp8_tensor_offset: int, - B_dtype: tex.DType, - out: List[torch.Tensor], - out_dtype: torch.dtype, - workspaces: List[torch.Tensor], - m_splits: Optional[List[int]] = None, - out_offset: Optional[int] = None, - fp8_meta_tensor: tex.FP8TensorMeta = None, - gelu: bool = False, - accumulate: bool = False, - bias: Optional[List[torch.Tensor]] = None, - use_bias: bool = False, - use_split_accumulator: bool = False, - D_dtype: Optional[tex.DType] = None, -) -> Tuple[List[torch.Tensor], ...]: - """ - TN layout Grouped GEMM with fp8 inputs. - Input requirements: - 1. If len(A_scale_inv) == num_gemms, len(out) must be 1, and m_splits is not None. - This is used for the calculation of output (fwd) and dgrad (bwd). - 2. if len(A_scale_inv) == 1, len(out) must be num_gemms. This is used for the - calculation of wgrad. - """ - num_gemms = len(A) - if num_gemms > 1 and len(A_scale_inv) == num_gemms: - assert len(out) == 1 and m_splits is not None - elif num_gemms > 1 and len(A_scale_inv) == 1: - assert len(out) == num_gemms - elif num_gemms == 1: - assert len(A_scale_inv) == 1 and len(out) == 1 - else: - raise ValueError("Invalid input combinations of A_scale_inv and out.") - - empty_tensor = _empty_tensor() - empty_tensors = [empty_tensor] * num_gemms - if D_dtype is not None and D_dtype in [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]: - assert fp8_meta_tensor is not None and out_offset is not None - for a, b in zip(A, B): - assert_dim_for_fp8_exec(a) - assert_dim_for_fp8_exec(b) - assert A[0].dtype == torch.uint8 - assert B[0].dtype == torch.uint8 - - # Use bfloat16 as default bias_dtype - bias_dtype = torch.bfloat16 if bias is None else bias[0].dtype - bias_dtype = TE_DType[bias_dtype] - gelu_input = empty_tensors - out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype - - if len(A_scale_inv) == 1: - if gelu: - gelu_input = [ - torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format) - for o in out - ] - - torch.ops.tex_ts.te_grouped_gemm_ts( - A, - A_scale_inv[0], - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - out, - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - else: - if gelu: - gelu_input = [torch.empty((m, A[0].size(0)), dtype=bias_dtype) for m in m_splits] - - torch.ops.tex_ts.te_grouped_gemm_single_output_ts( - A, - A_scale_inv, - A_fp8_tensor_offset, - A_dtype, - True, # transa - B, - B_scale_inv, - B_fp8_tensor_offset, - B_dtype, - False, # transb - m_splits, - out[0], - 0 if out_offset is None else out_offset, - empty_tensor if out_offset is None else fp8_meta_tensor.scale, - out_dtype, - empty_tensor if out_offset is None else fp8_meta_tensor.amax_history, - bias if use_bias else empty_tensors, - bias_dtype, - gelu_input, # this is pre_gelu_out - False, # grad - workspaces, - workspaces[0].shape[0], - accumulate, - use_split_accumulator, - ) - - return out, gelu_input + return out, bias, gelu_input diff --git a/transformer_engine/pytorch/cpp_extensions/normalization.py b/transformer_engine/pytorch/cpp_extensions/normalization.py deleted file mode 100644 index 50fd6b7709..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/normalization.py +++ /dev/null @@ -1,260 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for normalization extensions""" -from typing import Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex -from ._common import canonicalize_fp8_scales - - -__all__ = [ - "layernorm_fwd_fp8", - "layernorm_fwd_fp8_inf", - "layernorm_fwd_inf", - "rmsnorm_fwd_fp8", - "rmsnorm_fwd_fp8_inf", - "rmsnorm_fwd_inf", -] - - -def layernorm_fwd_fp8( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - ln_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """LayerNorm with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - if ln_out is not None: - return tex.layernorm_fwd_fp8_noalloc( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - ln_out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - return tex.layernorm_fwd_fp8( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - - -def layernorm_fwd_fp8_inf( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """LayerNorm with FP8 output. - - This version of layernorm_fwd_fp8 is specialized for inference, and returns - only the normalized output. - """ - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - ret = torch.ops.tex_ts.layernorm_fwd_fp8_inf_ts( - inp, - weight, - bias, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - sm_margin, - zero_centered_gamma, - ) - return ret - - -def layernorm_fwd_inf( - inp: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - eps: float, - sm_margin: int, - zero_centered_gamma: bool, -) -> torch.Tensor: - """LayerNorm with FP8 output""" - return torch.ops.tex_ts.layernorm_fwd_inf_ts( - inp, - weight, - bias, - eps, - sm_margin, - zero_centered_gamma, - ) - - -def rmsnorm_fwd_fp8( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma: bool, - rmsnorm_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """RMSNorm with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - if rmsnorm_out is not None: - return tex.rmsnorm_fwd_fp8_noalloc( - inp, - weight, - eps, - fp8_scales["scale"], - rmsnorm_out, - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - return tex.rmsnorm_fwd_fp8( - inp, - weight, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - sm_margin, - zero_centered_gamma, - **fp8_scales_offsets, - ) - - -def rmsnorm_fwd_fp8_inf( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - sm_margin: int, - zero_centered_gamma, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """RMSNorm with FP8 output. - - This version of rmsnorm_fwd_fp8 is specialized for inference, and returns - only the normalized output. - """ - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - allow_multiple_offsets=False, - ) - - # Launch kernel - ret = torch.ops.tex_ts.rmsnorm_fwd_fp8_inf_ts( - inp, - weight, - eps, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - fp8_scales_offsets["scale_offset"], - otype, - sm_margin, - zero_centered_gamma, - ) - return ret - - -def rmsnorm_fwd_inf( - inp: torch.Tensor, - weight: torch.Tensor, - eps: float, - sm_margin: int, - zero_centered_gamma: bool, -) -> torch.Tensor: - """RMSNorm with FP8 output""" - return torch.ops.tex_ts.rmsnorm_fwd_inf_ts( - inp, - weight, - eps, - sm_margin, - zero_centered_gamma, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/padding.py b/transformer_engine/pytorch/cpp_extensions/padding.py deleted file mode 100644 index 41dfbe2466..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/padding.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for transpose extensions""" -from typing import List, Tuple, Union -import torch -import transformer_engine_torch as tex - - -__all__ = [ - "multi_padding_fused", -] - - -def multi_padding_fused( - inp: torch.Tensor, - row_list: List[int], - padded_row_list: List[int], - out: torch.Tensor, -) -> Union[Tuple[List[torch.Tensor], List[torch.Tensor]], None]: - """Padding""" - - tex.fused_multi_row_padding( - inp, - out, - row_list, - padded_row_list, - ) diff --git a/transformer_engine/pytorch/cpp_extensions/transpose.py b/transformer_engine/pytorch/cpp_extensions/transpose.py deleted file mode 100644 index ddc3b67e9e..0000000000 --- a/transformer_engine/pytorch/cpp_extensions/transpose.py +++ /dev/null @@ -1,191 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Python interface for transpose extensions""" -from typing import List, Optional, Tuple, Union - -import torch - -import transformer_engine_torch as tex -from ..constants import TE_DType -from ._common import canonicalize_fp8_scales, empty_tensor - - -__all__ = [ - "fp8_cast_transpose_fused", - "fp8_cast_transpose_bgrad_fused", - "fp8_cast_transpose_bgrad_dgelu_fused", - "fp8_multi_cast_transpose_fused", - "fp8_transpose_bgrad_fused", -] - - -def fp8_cast_transpose_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - cast_out: Optional[torch.Tensor] = None, - transpose_out: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - noop_flag: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Cast + Transpose with FP8 output""" - - # Allocate outputs if needed - if transpose_out is None: - transpose_out = torch.empty(inp.shape[1], inp.shape[0], device="cuda", dtype=torch.uint8) - if cast_out is None: - cast_out = torch.empty_like(inp, dtype=torch.uint8) - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Construct no-op flag if needed - if noop_flag is None: - noop_flag = empty_tensor() - - # Launch kernel if needed - if inp.nelement() > 0: - tex.fused_cast_transpose_noop( - inp, - noop_flag, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - cast_out, - transpose_out, - otype, - **fp8_scales_offsets, - ) - - return cast_out, transpose_out - - -def fp8_cast_transpose_bgrad_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Cast + Transpose + BGRAD with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_cast_transpose_bgrad( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_transpose_bgrad_fused( - inp: torch.Tensor, - fp8_meta_tensor: Optional[tex.FP8TensorMeta], - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None], - otype: tex.DType, - grad_bias_type: torch.dtype, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transpose + BGRAD with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_fp8_transpose_bgrad( - inp, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - TE_DType[grad_bias_type], - **fp8_scales_offsets, - ) - - -def fp8_cast_transpose_bgrad_dgelu_fused( - grad_output: torch.Tensor, - gelu_input: torch.Tensor, - fp8_meta_tensor: tex.FP8TensorMeta, - fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], - otype: tex.DType, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Cast + Transpose + BGRAD + DGELU with FP8 output""" - - # Get FP8 scaling factors - fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales( - scale=scale, - amax=amax, - scale_inv=scale_inv, - fp8_meta=fp8_meta_tensor, - fp8_meta_index=fp8_tensor, - ) - - # Launch kernel - return tex.fused_cast_transpose_bgrad_dgelu( - grad_output, - gelu_input, - fp8_scales["scale"], - fp8_scales["amax"], - fp8_scales["scale_inv"], - otype, - **fp8_scales_offsets, - ) - - -def fp8_multi_cast_transpose_fused( - input_list: List[torch.Tensor], - fp8_meta_tensor: tex.FP8TensorMeta, - scale_indices: List[int], - amax_indices: List[int], - scale_inv_indices: List[int], - otype: tex.DType, - scale_inv: Optional[torch.Tensor] = None, -) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: - """Cast + Transpose with FP8 output""" - - return tex.fused_multi_cast_transpose_alloc( - input_list, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - scale_inv if scale_inv is not None else fp8_meta_tensor.scale_inv, - scale_indices, - amax_indices, - scale_inv_indices, - otype, - ) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 123758b0da..93df512ac6 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,13 +9,27 @@ import torch -from .float8_tensor import Float8Tensor +from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] CPUOffloadEnabled = False +def set_offloading_param(tensor, param_name, value): + """Set the type of the offloading needed for a tensor.""" + assert param_name in ["weight_offloading", "activation_offloading"] + if tensor is None: + return + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + setattr(tensor, param_name, value) + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + setattr(tensor, param_name, value) + + def is_cpu_offload_enabled() -> bool: """Check if CPU offloading is currently enabled.""" return CPUOffloadEnabled @@ -219,19 +233,15 @@ def on_group_commit_backward(self): @staticmethod def offload(src_tensor, pin_memory=True): """Offload.""" - fp8_offload = isinstance(src_tensor, Float8Tensor) cpu_backup = torch.empty( src_tensor.size(), - dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + dtype=src_tensor.dtype, layout=src_tensor.layout, device="cpu", pin_memory=pin_memory, ) - if fp8_offload: - cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) state = (src_tensor.device, cpu_backup) return state @@ -258,6 +268,7 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs): else: # will be offloaded together after group commit self.tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag def tensor_pop(self, tensor_tag, **kwargs): @@ -294,6 +305,9 @@ def __init__( self.num_layers = num_model_group # Data Structure to maintain reference to activation tensors self.tensor_tag_to_buf = {} + # Data structure to hold the FP8/MXFP8 tensor objects + self.fp8_tensor_object_map = {} + self.float8_transpose_cache_valid = {} # Tracking the number of layers offloaded self.offloaded_group_count = 0 # Core data structure that decides the window for offloading @@ -324,18 +338,46 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: ), ) + is_quantized_tensor = callable(getattr(tensor, "prepare_for_saving", None)) + if not torch_stray_tensor: + # obtain a unique tensor tag tensor_tag = (self.current_group, self.tensor_count_current_group) self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state - self.tensor_tag_to_state[tensor_tag] = tensor + if is_quantized_tensor: + tensor_list, _ = tensor.prepare_for_saving() - if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( - tensor - ): - self.tensor_tag_to_buf[tensor_tag] = tensor + self.tensor_tag_to_state[tensor_tag] = [] + self.tensor_tag_to_buf[tensor_tag] = [] + + self.fp8_tensor_object_map[tensor_tag] = tensor + if isinstance(tensor, Float8Tensor): + self.float8_transpose_cache_valid[tensor_tag] = getattr( + tensor, "_transpose_invalid" + ) + else: + tensor_list = [tensor] + + for t in tensor_list: + if is_quantized_tensor: + self.tensor_tag_to_state[tensor_tag].append(t) + else: + self.tensor_tag_to_state[tensor_tag] = t + + if ( + self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(t) + ): + if is_quantized_tensor: + self.tensor_tag_to_buf[tensor_tag].append(t) + # Need to clear the internal data reference for the quantized tensors + tensor.clear() + else: + self.tensor_tag_to_buf[tensor_tag] = t else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 @@ -347,7 +389,14 @@ def tensor_pop(self, tensor_tag, **kwargs): """Tensor pop.""" assert tensor_tag in self.tensor_tag_to_state tensor = self.tensor_tag_to_state.pop(tensor_tag) + + # Handling the quantized tensor case specially here + if isinstance(tensor, list): + self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) + tensor = self.fp8_tensor_object_map.pop(tensor_tag) + self.tensor_tag_to_buf.pop(tensor_tag, None) + # the tensor should have been copied back in on_group_commit_backward() # which invokes bulk_reload_group. assert not isinstance(tensor, tuple) @@ -360,12 +409,37 @@ def bulk_offload_group(self, group_to_offload): group_id, _ = tensor_tag if group_id == group_to_offload: assert not isinstance(state, tuple) - tensor_on_device = state - # if offload, return the reference to cpu copy - if self.tensor_need_offloading_checker(tensor_on_device): - state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) - self.tensor_tag_to_state[tensor_tag] = state + is_quantized_tensor = isinstance(state, list) + + if is_quantized_tensor: + tensor_list = state + self.tensor_tag_to_state[tensor_tag] = [] + else: + tensor_list = [state] + + for tensor_on_device in tensor_list: + # `tensor_offloaded` is a hacky way of dealing with columnwise-only + # quantized tensors for CPU offloading. The complication is due to + # the `rowwise_data` being `None`. The offloading checker incorrectly + # returns `False` and the entire `state` ([None, columnwise_tensor]) + # is added to the tensor tag state dict. A better design would change + # how quantized tensors are kept track of in the offload handler. + # Currently at every stage it is ensured that a quantized tensor is a + # list whereas a non-quantized tensor is standalone object, which is + # not good! TODO(@sanandaraj5597) + tensor_offloaded = False + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + tensor_offloaded = True + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + if is_quantized_tensor: + if tensor_offloaded: + self.tensor_tag_to_state[tensor_tag].append(state) + else: + self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) + else: + self.tensor_tag_to_state[tensor_tag] = state def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" @@ -415,6 +489,23 @@ def bulk_reload_group(self, group_to_reload): if isinstance(state, tuple): recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) self.tensor_tag_to_state[tensor_label] = recovered_tensor + elif isinstance(state, list): + tensor_list = [] + for state_tuple in state: + if isinstance(state_tuple, tuple): + tensor_list.append( + SynchronizedGroupOffloadHandler.reload(state_tuple) + ) + else: + tensor_list.append(state_tuple) + _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list) + if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): + self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( + self.float8_transpose_cache_valid.pop(tensor_label) + ) + self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( + tensor_label + ) def on_group_commit_backward(self): # first decrement the current group. diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py new file mode 100644 index 0000000000..e5da32164d --- /dev/null +++ b/transformer_engine/pytorch/cross_entropy.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Cross Entropy Loss API""" + +import torch + +import transformer_engine.pytorch.triton.cross_entropy as triton_cross_entropy + +__all__ = [ + "parallel_cross_entropy", +] + + +class CrossEntropyFunction(torch.autograd.Function): + """ + This class implements a custom autograd function for the Cross Entropy loss. The input tensor can be in BF16/FP32, the + loss and gradient calculation happens in FP32 only. The returned loss is always in FP32, the input gradients are upcasted + to the dataype of the input. + """ + + @staticmethod + def forward( + ctx, _input, target, label_smoothing=0.0, reduce_loss=False, dist_process_group=None + ): + """ + The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each + distributed rank should be (*,V/world_size). Note that each of the ranks should get equal shards along the V dimension. + + Parameters: + ctx : The context object. + _input (tensor): The input tensor of shape (B, SQ, V) or (SQ, B, V) where B is batch size, SQ is sequence length, V is vocab size. + target (tensor): The target tensor of shape (B,SQ) or (SQ, B) where each value is in [0, V-1]. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduce_loss (bool): If true, returns the averaged loss across the B*SQ dimension. + dist_process_group (torch.dist.ProcessGroup): The distributed process group the loss computation is split across, None if on 1 device. + + Returns: + tensor: The computed loss. + """ + loss, _input = triton_cross_entropy.cross_entropy_forward( + _input, target, label_smoothing, reduce_loss, dist_process_group + ) + + ctx.save_for_backward(_input.detach()) + return loss + + @staticmethod + def backward(ctx, grad_output): + """ + The backward pass of the Cross Entropy loss. + + Parameters: + ctx : The context object with saved tensors. + grad_output (tensor): The tensor containing the gradient of the loss with respect to the output. + + Returns: + tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. + """ + (_input,) = ctx.saved_tensors + _input = triton_cross_entropy.cross_entropy_backward(_input, grad_output) + return ( + _input, + None, + None, + None, + None, + ) + + +parallel_cross_entropy = CrossEntropyFunction.apply diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 2ac190863c..23137a1003 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -1,12 +1,38 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "common.h" +#include "c10/util/ArrayRef.h" +#include "pybind.h" #include "transformer_engine/transformer_engine.h" +namespace transformer_engine::pytorch { + +std::vector getTensorShape(at::Tensor t) { + std::vector shape; + for (auto s : t.sizes()) { + shape.push_back(s); + } + return shape; +} + +std::unique_ptr convert_quantizer(py::handle quantizer) { + init_extension(); + if (quantizer.is_none()) { + return std::make_unique(quantizer); + } + for (auto [_check_type, check_quantizer_type, _create_tensor, create_quantizer] : + detail::custom_types_converters) { + if (check_quantizer_type(quantizer.ptr())) { + return create_quantizer(quantizer); + } + } + + NVTE_ERROR("Unexpected type for quantizer"); +} transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe) { @@ -17,6 +43,41 @@ transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, return transformer_engine::DType::kFloat8E5M2; } +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer) { + NVTE_CHECK(!tensor.is_none(), "Tensor is not allocated!"); + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + // check for both quantizer & tensor type: + // mxfp8 tensor -> mxfp8 quantizer + // float8 tensor -> delayed scaling quantizer OR current scaling quantizer + // also during dequantize, the quantizer param is unknown -> so quantizer is NoneQuantizer + for (auto [check_type, check_quantizer_type, create_tensor, _] : + detail::custom_types_converters) { + if (check_type(tensor.ptr())) { + if (!(quantizer.is_none() || check_quantizer_type(quantizer.ptr()))) { + continue; + } + auto x = create_tensor(tensor, my_quantizer.get()); + return x; + } + } + NVTE_CHECK(dynamic_cast(my_quantizer.get()) != nullptr, + "Unexpected quantization params type."); + + // Regular pyTorch tensor + at::Tensor torch_tensor = tensor.cast(); + + // #TODO (pgadzinski) - needed in attention for non-contiguous tensors. + //if (!torch_tensor.is_contiguous()) { + // torch_tensor = torch_tensor.contiguous(); + //} + auto ret = TensorWrapper(my_quantizer->get_scaling_mode()); + ret.set_rowwise_data(torch_tensor.data_ptr(), + GetTransformerEngineDType(torch_tensor.scalar_type()), + getTensorShape(torch_tensor)); + my_quantizer->set_quantization_params(&ret); + return ret; +} + transformer_engine::TensorWrapper makeTransformerEngineTensor( void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type) { return transformer_engine::TensorWrapper(data_ptr, shape, type); @@ -30,48 +91,95 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); std::vector shape; - for (auto s : tensor.sizes()) { shape.push_back(s); } return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype); } -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, - void* scale_inv_ptr) { - return transformer_engine::TensorWrapper( - data_ptr, shape, type, reinterpret_cast(amax_ptr), - reinterpret_cast(scale_ptr), reinterpret_cast(scale_inv_ptr)); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape, + NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + return ret; +} + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape, + const std::vector& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) { + TensorWrapper ret(scaling_mode); + ret.set_rowwise_data(data_ptr, type, shape); + ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape); + const std::vector meta_shape{1}; + ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); + ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); + auto scale_inv_dtype = + (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); + ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, + columnwise_scale_inv_shape); + return ret; } transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, const at::Tensor scale, - at::Tensor scale_inv) { + at::Tensor scale_inv, + NVTEScalingMode scaling_mode) { transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type()); - std::vector shape; - for (auto s : tensor.sizes()) { - shape.push_back(s); - } + auto tensor_shape = getTensorShape(tensor); + auto scale_inv_shape = getTensorShape(scale_inv); + NVTE_CHECK(amax.scalar_type() == at::kFloat); NVTE_CHECK(scale.scalar_type() == at::kFloat); NVTE_CHECK(scale_inv.scalar_type() == at::kFloat); - return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); + return makeTransformerEngineTensor(tensor.data_ptr(), tensor_shape, dtype, amax.data_ptr(), + scale.data_ptr(), scale_inv.data_ptr(), scale_inv_shape, + scaling_mode); } -size_t product(const std::vector& shape) { - size_t ret = 1; +template +T product(const std::vector& shape) { + T ret = 1; for (auto s : shape) { ret *= s; } return ret; } +template size_t product(const std::vector& shape); +template int64_t product(const std::vector& shape); + +size_t product(const NVTEShape& shape, size_t begin, size_t end) { + NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end, + " in a shape with ", shape.ndim, " entries"); + size_t ret = 1; + for (size_t i = begin; i < end; ++i) { + ret *= shape.data[i]; + } + return ret; +} + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape) { + std::vector shape; + for (size_t i = 0; i < nvte_shape.ndim; i++) { + shape.push_back(nvte_shape.data[i]); + } + return shape; +} + at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros) { std::vector shape_int64(shape.begin(), shape.end()); @@ -121,3 +229,14 @@ void* getDataPtr(at::Tensor tensor, int offset) { } return dptr; } + +std::vector convertShape(const NVTEShape& shape) { + return std::vector(shape.data, shape.data + shape.ndim); +} + +int roundup(const int value, const int multiple) { + assert(multiple > 0); + return ((value + multiple - 1) / multiple) * multiple; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 175a7b0e90..88a983d6f3 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -14,6 +14,8 @@ #include #include #include +#include +#include #include #include #include @@ -28,29 +30,30 @@ #include #include #include -#include +#include #include #include #include -#include #include +#include #include #include #include #include #include -#include #include #include -#include -#include #include #include +#include "c10/util/ArrayRef.h" #include "common/util/logging.h" -namespace transformer_engine { +namespace transformer_engine::pytorch { + +// in python we have: dist_group_type = torch.distributed.ProcessGroup +using dist_group_type = c10d::ProcessGroup; // Each tensor here is shape (N, ) holding all scaling // data for a single FP8 block, e.g. LayerNormLinear @@ -86,7 +89,99 @@ enum FP8BwdTensors { GRAD_INPUT3 = 5 }; -} // namespace transformer_engine +class Quantizer { + public: + virtual NVTEScalingMode get_scaling_mode() const = 0; + + virtual void set_quantization_params(TensorWrapper* tensor) const = 0; + + virtual std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const = 0; + + virtual ~Quantizer() = default; + + bool rowwise_usage = true; + bool columnwise_usage = true; + bool internal = false; + py::handle quantizer; + + protected: + explicit Quantizer(const py::handle& quantizer); +}; + +class NoneQuantizer : public Quantizer { + public: + explicit NoneQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {} + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override {} + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class Float8Quantizer : public Quantizer { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax; + DType dtype; + + explicit Float8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class Float8CurrentScalingQuantizer : public Quantizer { + public: + at::Tensor scale; + at::Tensor scale_inv; + at::Tensor amax; + DType dtype; + bool with_amax_reduction; + c10::intrusive_ptr amax_reduction_group; + int amax_reduction_size; + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + + explicit Float8CurrentScalingQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_DELAYED_TENSOR_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +class MXFP8Quantizer : public Quantizer { + public: + DType dtype; + + explicit MXFP8Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_MXFP8_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + +std::unique_ptr convert_quantizer(py::handle quantizer); + +std::vector getTensorShape(at::Tensor t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); @@ -104,9 +199,11 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { case transformer_engine::DType::kBFloat16: return at::kBFloat16; case transformer_engine::DType::kByte: + return at::kByte; case transformer_engine::DType::kFloat8E4M3: + return at::kFloat8_e4m3fn; case transformer_engine::DType::kFloat8E5M2: - return at::kByte; + return at::kFloat8_e5m2; default: NVTE_ERROR("Invalid type"); } @@ -114,6 +211,10 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) { inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { switch (t) { + case at::kFloat8_e4m3fn: + return transformer_engine::DType::kFloat8E4M3; + case at::kFloat8_e5m2: + return transformer_engine::DType::kFloat8E5M2; case at::kHalf: return transformer_engine::DType::kFloat16; case at::kFloat: @@ -129,6 +230,7 @@ inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) { case torch::kInt64: return transformer_engine::DType::kInt64; default: + std::cout << "Type: " << static_cast(t) << std::endl; NVTE_ERROR("Invalid type"); } } @@ -141,11 +243,18 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const std::vector& shape, const transformer_engine::DType type); -transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, - const std::vector& shape, - const transformer_engine::DType type, - void* amax_ptr, void* scale_ptr, - void* scale_inv_ptr); +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, const std::vector& shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + void* data_ptr, void* columnwise_data_ptr, const std::vector& shape, + const std::vector& columnwise_shape, const transformer_engine::DType type, + void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr, + const std::vector& scale_inv_shape = {1}, + const std::vector& columnwise_scale_inv_shape = {1}, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, const NVTEShape& shape, @@ -153,11 +262,18 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr, transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor); -transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax, - const at::Tensor scale, - at::Tensor scale_inv); +TensorWrapper makeTransformerEngineTensor(py::handle tensor, py::handle quantizer); + +transformer_engine::TensorWrapper makeTransformerEngineTensor( + at::Tensor tensor, at::Tensor amax, const at::Tensor scale, at::Tensor scale_inv, + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING); + +template +T product(const std::vector& shape); -size_t product(const std::vector& shape); +size_t product(const NVTEShape& shape, size_t begin, size_t end); + +std::vector nvte_shape_to_vector(const NVTEShape& nvte_shape); at::Tensor allocateSpace(const std::vector& shape, const transformer_engine::DType type, bool init_to_zeros); @@ -171,4 +287,54 @@ at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype); void* getDataPtr(at::Tensor tensor, int offset = 0); +std::vector convertShape(const NVTEShape& shape); + +int roundup(const int value, const int multiple); + +} // namespace transformer_engine::pytorch + +namespace std { +template +string to_string(const vector& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +// Torch shape -> string +template +string to_string(const c10::ArrayRef& vec) { + string ret = "["; + for (const auto& val : vec) { + ret += to_string(val) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} + +inline string to_string(const NVTEShape& s) { + string ret = "["; + for (size_t i = 0; i < s.ndim; ++i) { + ret += to_string(s.data[i]) + ","; + } + if (ret.size() > 1) { + ret[ret.size() - 1] = ']'; + } else { + ret += "]"; + } + return ret; +} +} // namespace std + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index b039bf2d1b..e430be0782 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,7 +10,6 @@ #include #include "common.h" -#include "common/common.h" /*************************************************************************************************** * Permutation @@ -45,326 +44,183 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); -std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); - -std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, - const c10::optional descale_S, const c10::optional descale_O, - const c10::optional descale_dO, const c10::optional descale_dP, - const c10::optional scale_S, const c10::optional scale_dP, - const c10::optional scale_dQKV, c10::optional amax_dP, - c10::optional amax_dQKV); - -std::vector fused_attn_fwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread); - -std::vector fused_attn_bwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV); - -std::vector fused_attn_fwd( +std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, + const c10::optional page_table_k, const c10::optional page_table_v, + py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread); -std::vector fused_attn_bwd( +std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV); + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len); +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t); +void copy_to_kv_cache(torch::Tensor new_k, torch::Tensor new_v, torch::Tensor k_cache, + torch::Tensor v_cache, torch::Tensor page_table, torch::Tensor cu_new_lens, + torch::Tensor cu_cached_lens, NVTE_QKV_Format kv_format, int b, + int max_ctx_len, int max_seq_len, int max_pages_per_seq, bool is_non_paged); + /*************************************************************************************************** * GEMM **************************************************************************************************/ -void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count); +using MaybeTensor = std::optional; void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter); -void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, - bool transb, std::vector D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - std::vector bias, transformer_engine::DType bias_type, - std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count); - -void te_grouped_gemm_single_output( - std::vector A, std::vector A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, - std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, - transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); /*************************************************************************************************** * Transpose **************************************************************************************************/ -void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - at::Tensor input_cast, at::Tensor input_transpose, - transformer_engine::DType otype); - -void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, at::Tensor input_cast, - at::Tensor input_transpose, transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0); - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - transformer_engine::DType grad_bias_type, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset = 0, int amax_offset = 0, - int scale_inv_offset = 0); - -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_output_list, - std::vector scale_inv_output_list, - transformer_engine::DType otype); - -std::tuple, std::vector> fused_multi_cast_transpose_alloc( - std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - std::vector scale_indices, std::vector amax_indices, - std::vector scale_inv_indices, transformer_engine::DType otype); - -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype); - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype); - -void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, - transformer_engine::DType otype); +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype); + +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output = std::nullopt); + +namespace transformer_engine::pytorch { /*************************************************************************************************** * Activations **************************************************************************************************/ -at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object gelu(const at::Tensor &input, py::handle quantizer); + +py::object relu(const at::Tensor &input, py::handle quantizer); -at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object geglu(const at::Tensor &input, py::handle quantizer); -at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object qgeglu(const at::Tensor &input, py::handle quantizer); -at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object reglu(const at::Tensor &input, py::handle quantizer); -at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object swiglu(const at::Tensor &input, py::handle quantizer); -at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object qgelu(const at::Tensor &input, py::handle quantizer); -at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype); +py::object srelu(const at::Tensor &input, py::handle quantizer); -at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype); +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +} // namespace transformer_engine::pytorch /*************************************************************************************************** * LayerNorm **************************************************************************************************/ -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); -std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); - -std::vector layernorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, - at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, const int scale_inv_offset = 0); - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object ln_out, py::handle quantizer, + transformer_engine::DType out_dtype, const int sm_margin, const bool zero_centered_gamma); -std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, at::Tensor ln_out, float eps, - const int sm_margin, const bool zero_centered_gamma); - -at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma); - /*************************************************************************************************** * RMSNorm **************************************************************************************************/ -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); - -std::vector rmsnorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, float eps, at::Tensor scale, - at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset = 0, - const int amax_offset = 0, const int scale_inv_offset = 0); - -std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma); +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object ln_out, py::handle quantizer, + transformer_engine::DType otype, const int sm_margin, + const bool zero_centered_gamma); + +/*************************************************************************************************** + * Cast + **************************************************************************************************/ + +namespace transformer_engine::pytorch { + +py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, + std::optional noop); -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma); +py::object dequantize(const py::handle &input, transformer_engine::DType otype); -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma); +std::vector bgrad_quantize(const at::Tensor &input, py::handle py_quantizer); + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, + std::optional comm_type = std::nullopt, + MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); /*************************************************************************************************** - * Cast + * Cast fusions **************************************************************************************************/ -at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); -void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset = 0, const int amax_offset = 0, - const int scale_inv_offset = 0); +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); -at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype, - const int scale_inv_offset = 0); +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer); + +} // namespace transformer_engine::pytorch /*************************************************************************************************** * Softmax @@ -399,7 +255,6 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, std::vector amax_histories, std::vector scales, - std::vector scale_invs, const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin); @@ -438,7 +293,7 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st const at::Tensor &cu_seqlens, bool lse_packed); at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - bool lse_packed); + bool lse_packed, int second_half_lse_seqlen); void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, @@ -473,6 +328,12 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, const int step, const int mode, const int bias_correction, const float weight_decay); +void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay); + void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, @@ -506,6 +367,16 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list); +/*************************************************************************************************** + * swizzle + **************************************************************************************************/ + +void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans); + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv); + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv); + /*************************************************************************************************** * Comm+GEMM Overlap Wrappers **************************************************************************************************/ @@ -527,8 +398,7 @@ class CommOverlapHelper : torch::CustomClassHolder { CommOverlapHelper(); CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_node_group, - std::optional inter_node_group); + std::optional intra_node_group); ~CommOverlapHelper(); @@ -539,151 +409,44 @@ class CommOverlapHelper : torch::CustomClassHolder { }; class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits = 3, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false); - - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, int comm_type); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Bulk GEMM + COMM - ** This function assumes the communication input is pre-copied to _ubuf - */ - std::vector bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - transformer_engine::CommOverlapType comm_type, at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output); - - /* - ** Split FPROP GEMM + ReduceScatter - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - bool gemm_overlap, at::Tensor rs_output); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16, + bool set_sm_margin = true, bool atomic_gemm = false, + bool rs_overlap_first_gemm = false); + + ~CommOverlap() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + }; // CommOverlap class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { - private: - torch::Tensor _ubuf_torch; - torch::Tensor _ubuf_counter; - public: CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, transformer_engine::CommOverlapType comm_type, int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2, - int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false, - bool use_ce = true, bool aggregate = false); - - void set_ubuf_scale_inv(torch::Tensor scale_inv) { - assert(scale_inv.numel()); - assert(scale_inv.scalar_type() == torch::kFloat32); - transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv( - reinterpret_cast(scale_inv.data_ptr())); - } - - void copy_input_to_ubuf(torch::Tensor input, bool chunk); - - torch::Tensor get_ubuf_output(int comm_type); - - /* - ** Split AllGather + AtomicGEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor B_copy); - - /* - ** Split AllGather + GEMM using P2P communication - ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is - *needed to have AG outputs - ** in each rank to be in the contiguous memory space after all ring exchange - *phases. - */ - void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor B_copy); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, - at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, - bool grad, at::Tensor workspace, size_t workspaceSize, - bool accumulate, bool use_split_accumulator, at::Tensor rs_output); - - /* - ** Split ReduceScatter + GEMM using P2P communication - */ - void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - transformer_engine::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - transformer_engine::DType B_type, bool transb, at::Tensor D, - at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, - at::Tensor bias, transformer_engine::DType bias_type, - at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - at::Tensor rs_output); + int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3, + bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true, + bool aggregate = false); + + ~CommOverlapP2P() {} + + void set_buffer_params(py::handle quantizer); + + void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false); + + py::object get_buffer(py::handle quantizer, bool local_chunk = false, + std::optional> shape = std::nullopt); + }; // CommOverlapP2P #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7f8cff5584..1ef6f5258d 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,276 +1,147 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#include "common.h" #include "extensions.h" - -at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_gelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +#include "pybind.h" + +namespace transformer_engine::pytorch { + +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); + + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + input_shape[input_shape.size() - 1] /= shape_divisor; + auto fake_tensor_type = input.scalar_type(); + + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); + + // for current scaling, we need to compute amax first and then quantize + // because cache cannot fit in the entire tensor to compute amax and quantize + // the quantizer should not need amax reduction, no process group needed here + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // activation function might change the input data range, we need to first call the activation function + // and then find the amax and scale of that and then do the quantization + // get a NoneQuantizer to calculate amax of activation output + auto my_quantizer_none = std::make_unique(py::none()); + auto [te_output_act, out_act] = + my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); + act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); + // use te_output_act as input to the compute amax and find the amax of activated tensor + nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + if (my_quantizer_cs->with_amax_reduction) { + NVTE_ERROR( + "per-tensor current scaling amax reduction is not supported in activation functions."); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + nvte_quantize(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + } else { + act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + } + + return out; } -at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; +template +py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = input.contiguous(); + auto grad_tensor = grad.contiguous(); - auto output = allocateTorchTensor(M, N, otype); + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = input.scalar_type(); - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); + auto [te_output, out] = + my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - nvte_dgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - return output; + return out; } -at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = static_cast(input.numel()) / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_relu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object gelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_drelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_geglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object relu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dgeglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_reglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object geglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dreglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object qgeglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N / 2, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = - makeTransformerEngineTensor(output.data_ptr(), {M, N / 2}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_swiglu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N / 2}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_qgelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object reglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); - - nvte_dqgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = static_cast(input.numel()) / N; - - auto output = allocateTorchTensor(M, N, otype); - - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_srelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; +py::object swiglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t N = static_cast(input.size(-1)); - size_t M = input.numel() / N; +py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} - auto output = allocateTorchTensor(M, N, otype); +py::object qgelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} - auto itype = GetTransformerEngineDType(input.scalar_type()); - auto gtype = GetTransformerEngineDType(grad.scalar_type()); - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype); - auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype); +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} - nvte_dsrelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +py::object srelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} - return output; +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index c0cd2e9920..c323e7b6c1 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -8,7 +8,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const bool transpose_output_memory) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(input.size(0) <= freqs.size(0), @@ -66,7 +66,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const bool transpose_output_memory) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(output_grads.size(0) <= freqs.size(0), @@ -122,7 +122,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); @@ -174,7 +174,7 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, const at::Tensor &freqs, const int cp_size, const int cp_rank) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cu b/transformer_engine/pytorch/csrc/extensions/attention.cu index 8088a2b8f1..da82120f4a 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cu +++ b/transformer_engine/pytorch/csrc/extensions/attention.cu @@ -1,10 +1,11 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ - #include "extensions.h" +#include "kv_cache.cuh" +#include "thd_utils.cuh" constexpr int block_size = 512; constexpr int ctas_per_sm = 4; @@ -37,22 +38,27 @@ __global__ void __launch_bounds__(block_size) } // fast zero-fills of tensors -void mha_fill(const at::Tensor &self, const at::Tensor &start_index) { - auto max_tokens = self.size(0); - auto self_2d = self.view({max_tokens, -1}); - auto fcd_size = self_2d.size(1); - TORCH_CHECK(self.is_contiguous(), "input not contiguous"); +void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &start_index) { + std::vector shape = transformer_engine::pytorch::convertShape(self.shape()); + + auto max_tokens = shape[0]; + auto fcd_size = 1; + for (int i = 1; i <= shape.size(); i++) { + fcd_size *= shape[i]; + } TORCH_CHECK(fcd_size % block_size == 0, "input size not aligned to block size"); const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; uint64_t num_blk_y = (uint64_t)(fcd_size / block_size); uint64_t num_blk_x = (uint64_t)((num_mp * ctas_per_sm + num_blk_y - 1) / num_blk_y); dim3 dim_grid(num_blk_x, num_blk_y); dim3 dim_block(block_size); + // trzeba jakos przekonwertowac DType na scalar_type + at::ScalarType scalar_type = transformer_engine::pytorch::GetATenDType(self.dtype()); AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( - at::ScalarType::Half, at::ScalarType::BFloat16, self_2d.scalar_type(), "mha_fill", [&]() { + at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "mha_fill", [&]() { mha_fill_kernel<<>>( - self_2d.data_ptr(), static_cast(start_index.data_ptr()), - max_tokens); + static_cast(self.get_rowwise_data().data_ptr), + static_cast(start_index.data_ptr()), max_tokens); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); } @@ -77,426 +83,82 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_pe return philox_args; } -// fused attention FWD with packed QKV -std::vector fused_attn_fwd_qkvpacked( - size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, const at::Tensor cu_seqlens, const at::Tensor QKV, - const transformer_engine::DType qkv_type, const c10::optional cu_seqlens_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { - using namespace transformer_engine; - - auto qkv_sizes = QKV.sizes().vec(); - std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; - std::vector q_shape; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - int loc_3 = 0; - switch (layout_group) { - case NVTE_3HD: - loc_3 = qkv_sizes.size() - 3; - break; - case NVTE_H3D: - loc_3 = qkv_sizes.size() - 2; - break; - default: - NVTE_ERROR("Invalid QKV layout group."); - } - for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { - if (it - qkv_shape.begin() != loc_3) { - q_shape.push_back(*it); - } - } - std::vector o_shape{q_shape.begin(), q_shape.end()}; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty(o_shape, options); - - // construct NVTE tensors - TensorWrapper te_QKV, te_S, te_O, te_Bias, te_cu_seqlens, te_cu_seqlens_padded; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); - } - // BF16 or FP16 - te_QKV = - makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); - } - auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); - std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; - te_cu_seqlens = makeTransformerEngineTensor(cu_seqlens.data_ptr(), cu_seqlens_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if (cu_seqlens_padded.has_value()) { - auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec(); - std::vector cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(), - cu_seqlens_padded_sizes.end()}; - te_cu_seqlens_padded = - makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape, - DType::kInt32, nullptr, nullptr, nullptr); - } - - // extract random number generator seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; - } - } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd_qkvpacked( - te_QKV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens.data(), te_cu_seqlens_padded.data(), te_rng_state.data(), max_seqlen, - is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - -// fused attention BWD with packed QKV -std::vector fused_attn_bwd_qkvpacked( - size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_padded, const c10::optional descale_QKV, - const c10::optional descale_S, const c10::optional descale_O, - const c10::optional descale_dO, const c10::optional descale_dP, - const c10::optional scale_S, const c10::optional scale_dP, - const c10::optional scale_dQKV, c10::optional amax_dP, - c10::optional amax_dQKV) { - using namespace transformer_engine; - - auto qkv_sizes = QKV.sizes().vec(); - std::vector qkv_shape{qkv_sizes.begin(), qkv_sizes.end()}; - std::vector q_shape; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - int loc_3 = 0; - switch (layout_group) { - case NVTE_3HD: - loc_3 = qkv_sizes.size() - 3; - break; - case NVTE_H3D: - loc_3 = qkv_sizes.size() - 2; - break; - default: - NVTE_ERROR("Invalid QKV layout group."); - } - for (auto it = qkv_shape.begin(); it != qkv_shape.end(); ++it) { - if (it - qkv_shape.begin() != loc_3) { - q_shape.push_back(*it); - } - } - auto h = q_shape[q_shape.size() - 2]; - - // create output tensor dQKV - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - at::Tensor dQKV = torch::empty_like(QKV, options); - - // construct NVTE tensors - TensorWrapper te_QKV, te_O, te_dO, te_S, te_dP, te_dQKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQKV, cu_seqlens.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQKV.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_QKV = makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dQKV.fill_(0); - } - // BF16 or FP16 - te_QKV = - makeTransformerEngineTensor(QKV.data_ptr(), qkv_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQKV = makeTransformerEngineTensor(dQKV.data_ptr(), qkv_shape, dqkv_type, nullptr, nullptr, - nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // convert auxiliary tensors from forward into NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } - - // create dBias the same shape as Bias - at::Tensor dBias; - TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - if (nvte_aux_tensor_pack.size >= 2) { - std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); - dBias = torch::empty(bias_shape, options); - te_dBias = makeTransformerEngineTensor(dBias); - } else { - dBias = torch::empty({1, static_cast(h), static_cast(max_seqlen), - static_cast(max_seqlen)}, - options); - te_dBias = makeTransformerEngineTensor(dBias); - } - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dBias.fill_(0); - } - } - - // create cu_seqlens tensorwrappers - auto cu_seqlens_sizes = cu_seqlens.sizes().vec(); - std::vector cu_seqlens_shape{cu_seqlens_sizes.begin(), cu_seqlens_sizes.end()}; - TensorWrapper te_cu_seqlens = makeTransformerEngineTensor( - cu_seqlens.data_ptr(), cu_seqlens_shape, DType::kInt32, nullptr, nullptr, nullptr); - - TensorWrapper te_cu_seqlens_padded; - if (cu_seqlens_padded.has_value()) { - auto cu_seqlens_padded_sizes = cu_seqlens_padded.value().sizes().vec(); - std::vector cu_seqlens_padded_shape{cu_seqlens_padded_sizes.begin(), - cu_seqlens_padded_sizes.end()}; - te_cu_seqlens_padded = - makeTransformerEngineTensor(cu_seqlens_padded.value().data_ptr(), cu_seqlens_padded_shape, - DType::kInt32, nullptr, nullptr, nullptr); - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_qkvpacked( - te_QKV.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, - te_dQKV.data(), te_dBias.data(), te_cu_seqlens.data(), te_cu_seqlens_padded.data(), - max_seqlen, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQKV, dBias}; -} - -// fused attention FWD with packed KV -std::vector fused_attn_fwd_kvpacked( +// fused attention FWD with separate Q, K and V tensors +std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor KV, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const at::ScalarType fake_dtype, const c10::optional cu_seqlens_q_padded, const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, + const c10::optional page_table_k, const c10::optional page_table_v, + py::handle s_quantizer, py::handle o_quantizer, const c10::optional Bias, const c10::optional rng_gen, size_t rng_elts_per_thread) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; + TensorWrapper te_Q, te_K, te_V, te_O, te_S; - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto kv_sizes = KV.sizes().vec(); - std::vector kv_shape{kv_sizes.begin(), kv_sizes.end()}; - std::vector o_shape{q_shape.begin(), q_shape.end()}; + auto none = py::none(); + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); - // create output tensor O + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); + + // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto O = torch::empty(o_shape, options); + // create output tensor O + + auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; + o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; + py::object o_python, s_python; + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_S, te_O, te_Bias, te_cu_seqlens_q, te_cu_seqlens_kv; + TensorWrapper te_Bias; + TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; + TensorWrapper te_page_table_k, te_page_table_v; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; if (set_zero && ((h * d) % block_size == 0) && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { - O.fill_(0); + te_O.zero_(at::cuda::getCurrentCUDAStream()); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); + te_O.zero_(at::cuda::getCurrentCUDAStream()); } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_KV = - makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { auto bias_sizes = Bias.value().sizes().vec(); std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); + te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32); } auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); + te_cu_seqlens_q = + makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, DType::kInt32); + te_cu_seqlens_kv = + makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, DType::kInt32); if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); @@ -506,11 +168,22 @@ std::vector fused_attn_fwd_kvpacked( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); + } + + if ((page_table_k.has_value()) && (page_table_v.has_value())) { + auto page_table_k_sizes = page_table_k.value().sizes().vec(); + std::vector page_table_k_shape{page_table_k_sizes.begin(), page_table_k_sizes.end()}; + auto page_table_v_sizes = page_table_v.value().sizes().vec(); + std::vector page_table_v_shape{page_table_v_sizes.begin(), page_table_v_sizes.end()}; + te_page_table_k = + makeTransformerEngineTensor(page_table_k.value().data_ptr(), page_table_k_shape, + DType::kInt32, nullptr, nullptr, nullptr); + te_page_table_v = + makeTransformerEngineTensor(page_table_v.value().data_ptr(), page_table_v_shape, + DType::kInt32, nullptr, nullptr, nullptr); } // extract rng seed and offset @@ -530,10 +203,11 @@ std::vector fused_attn_fwd_kvpacked( TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, + nvte_fused_attn_fwd( + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), + &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), + te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -543,38 +217,51 @@ std::vector fused_attn_fwd_kvpacked( makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); + std::vector output_tensors; + output_tensors.push_back(o_python); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); // allocate memory for nvte_aux_tensor_pack.tensors at::Tensor output_tensor; if (nvte_aux_tensor_pack.size >= 2) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); + output_tensor = allocateSpace( + nvte_shape_to_vector(temp_shape), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); } else if (i == nvte_aux_tensor_pack.size - 2) { output_tensor = rng_state; } else if (i == nvte_aux_tensor_pack.size - 1) { output_tensor = Bias.value(); } } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; + NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); + output_tensor = + (i < nvte_aux_tensor_pack.size - 1) + ? allocateSpace( + nvte_shape_to_vector(temp_shape), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false) + : rng_state; } } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); + NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); + output_tensor = allocateSpace( + nvte_shape_to_vector(temp_shape), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); } - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); + output_tensors.push_back(py::cast(output_tensor)); + NVTEBasicTensor temp_data = {output_tensor.data_ptr(), + nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), + nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; + nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } // execute the kernel - nvte_fused_attn_fwd_kvpacked( - te_Q.data(), te_KV.data(), te_Bias.data(), te_S.data(), te_O.data(), &nvte_aux_tensor_pack, - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, + nvte_fused_attn_fwd( + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), + &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), + te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); @@ -585,414 +272,56 @@ std::vector fused_attn_fwd_kvpacked( return output_tensors; } -// fused attention BWD with packed KV -std::vector fused_attn_bwd_kvpacked( - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O, - const at::Tensor dO, const transformer_engine::DType qkv_type, - const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV) { - using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto kv_sizes = KV.sizes().vec(); - std::vector kv_shape{kv_sizes.begin(), kv_sizes.end()}; - std::vector k_shape; - for (auto i : kv_shape) { - if (i != 2) { - k_shape.push_back(i); - } - } - auto h_q = q_shape[q_shape.size() - 2]; - auto h_kv = k_shape[k_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - - // create output tensors dQ and dKV - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - at::Tensor dQ = torch::empty_like(Q, options); - at::Tensor dKV = torch::empty_like(KV, options); - - // construct NVTE tensors - TensorWrapper te_Q, te_KV, te_O, te_dO, te_S, te_dP, te_dQ, te_dKV; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - if (set_zero && ((h_q * d) % block_size == 0) && ((h_kv * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dKV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dKV.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_KV = makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dKV = makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, - amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dQ.fill_(0); - dKV.fill_(0); - } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_KV = - makeTransformerEngineTensor(KV.data_ptr(), kv_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dKV = - makeTransformerEngineTensor(dKV.data_ptr(), kv_shape, dqkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - - // create cu_seqlens tensorwrappers - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - } - - // convert auxiliary tensors from forward to NVTETensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); - std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); - } - - // create dBias the same shape as Bias - at::Tensor dBias; - TensorWrapper te_dBias; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - if (nvte_aux_tensor_pack.size >= 2) { - std::vector bias_shape(Aux_CTX_Tensors[nvte_aux_tensor_pack.size - 1].sizes().vec()); - dBias = torch::empty(bias_shape, options); - te_dBias = makeTransformerEngineTensor(dBias); - } else { - dBias = torch::empty({1, static_cast(h_q), static_cast(max_seqlen_q), - static_cast(max_seqlen_kv)}, - options); - te_dBias = makeTransformerEngineTensor(dBias); - } - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - dBias.fill_(0); - } - } - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // execute kernel - nvte_fused_attn_bwd_kvpacked(te_Q.data(), te_KV.data(), te_O.data(), te_dO.data(), te_S.data(), - te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dKV.data(), - te_dBias.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), - max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, qkv_layout, - bias_type, attn_mask_type, window_size[0], window_size[1], - deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - return {dQ, dKV, dBias}; -} - -// fused attention FWD with separate Q, K and V tensors -std::vector fused_attn_fwd( - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, - bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q, - const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type, - const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const int descale_QKV_offset, - const c10::optional descale_S, const int descale_S_offset, - const c10::optional scale_S, const int scale_S_offset, - const c10::optional scale_O, const int scale_O_offset, - c10::optional amax_S, const int amax_S_offset, c10::optional amax_O, - const int amax_O_offset, const c10::optional Bias, - const c10::optional rng_gen, size_t rng_elts_per_thread) { - using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto k_sizes = K.sizes().vec(); - std::vector k_shape{k_sizes.begin(), k_sizes.end()}; - auto v_sizes = V.sizes().vec(); - std::vector v_shape{v_sizes.begin(), v_sizes.end()}; - - // create output tensor O - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - auto o_shape = std::vector{q_sizes.begin(), q_sizes.end()}; - o_shape[o_shape.size() - 1] = v_sizes[v_sizes.size() - 1]; - std::vector o_shape_tmp{o_shape.begin(), o_shape.end()}; - auto O = torch::empty(c10::IntArrayRef(o_shape_tmp), options); - - // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_S, te_O, te_Bias; - TensorWrapper te_cu_seqlens_q, te_cu_seqlens_kv; - TensorWrapper te_cu_seqlens_q_padded, te_cu_seqlens_kv_padded; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // FP8 - auto h = q_shape[q_shape.size() - 2]; - auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - O.fill_(0); - } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!scale_S.has_value()) || - (!scale_O.has_value()) || (!amax_S.has_value()) || (!amax_O.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, scale_S, scale_O, amax_S and amax_O "; - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - getDataPtr(descale_QKV.value(), descale_QKV_offset)); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, - getDataPtr(amax_S.value(), amax_S_offset), - getDataPtr(scale_S.value(), scale_S_offset), - getDataPtr(descale_S.value(), descale_S_offset)); - te_O = makeTransformerEngineTensor(O.data_ptr(), q_shape, qkv_type, - getDataPtr(amax_O.value(), amax_O_offset), - getDataPtr(scale_O.value(), scale_O_offset), nullptr); - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { - if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { - O.fill_(0); - } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); - } else { - NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); - } - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - auto bias_sizes = Bias.value().sizes().vec(); - std::vector bias_shape{bias_sizes.begin(), bias_sizes.end()}; - te_Bias = makeTransformerEngineTensor(Bias.value().data_ptr(), bias_shape, DType::kFloat32, - nullptr, nullptr, nullptr); - } - auto cu_seqlens_q_sizes = cu_seqlens_q.sizes().vec(); - std::vector cu_seqlens_q_shape{cu_seqlens_q_sizes.begin(), cu_seqlens_q_sizes.end()}; - auto cu_seqlens_kv_sizes = cu_seqlens_kv.sizes().vec(); - std::vector cu_seqlens_kv_shape{cu_seqlens_kv_sizes.begin(), cu_seqlens_kv_sizes.end()}; - te_cu_seqlens_q = makeTransformerEngineTensor(cu_seqlens_q.data_ptr(), cu_seqlens_q_shape, - DType::kInt32, nullptr, nullptr, nullptr); - te_cu_seqlens_kv = makeTransformerEngineTensor(cu_seqlens_kv.data_ptr(), cu_seqlens_kv_shape, - DType::kInt32, nullptr, nullptr, nullptr); - - if ((cu_seqlens_q_padded.has_value()) && (cu_seqlens_kv_padded.has_value())) { - auto cu_seqlens_q_padded_sizes = cu_seqlens_q_padded.value().sizes().vec(); - std::vector cu_seqlens_q_padded_shape{cu_seqlens_q_padded_sizes.begin(), - cu_seqlens_q_padded_sizes.end()}; - auto cu_seqlens_kv_padded_sizes = cu_seqlens_kv_padded.value().sizes().vec(); - std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), - cu_seqlens_kv_padded_sizes.end()}; - te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - } - - // extract rng seed and offset - auto gen = at::get_generator_or_default( - rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); - at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( - philox_args, static_cast(rng_state.data_ptr())); - auto te_rng_state = makeTransformerEngineTensor(rng_state); - - // create auxiliary output tensors - NVTETensorPack nvte_aux_tensor_pack; - nvte_tensor_pack_create(&nvte_aux_tensor_pack); - - // create workspace - TensorWrapper workspace; - - // populate tensors with appropriate shapes and dtypes - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, - max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // allocate memory for workspace and auxiliary output tensors - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // output_tensors = [O, nvte_aux_tensor_pack.tensors] - std::vector output_tensors; - output_tensors.push_back(O); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - output_tensor = (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace(tensor->data.shape, tensor->data.dtype, false) - : rng_state; - } - } else { - output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false); - } - output_tensors.push_back(output_tensor); - tensor->data.dptr = output_tensor.data_ptr(); - } - - // execute the kernel - nvte_fused_attn_fwd(te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), - te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), - te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), te_rng_state.data(), max_seqlen_q, - max_seqlen_kv, is_training, attn_scale, p_dropout, qkv_layout, bias_type, - attn_mask_type, window_size[0], window_size[1], workspace.data(), - at::cuda::getCurrentCUDAStream()); - - // destroy tensor wrappers, but not allocated memory - nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - - // if training, [O, softmax-related tensors, rng_state]; if inference, [O] - return output_tensors; -} - // fused attention BWD with separate Q, K and V -std::vector fused_attn_bwd( +std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V, - const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const transformer_engine::DType dqkv_type, const std::vector Aux_CTX_Tensors, const c10::optional cu_seqlens_q_padded, - const c10::optional cu_seqlens_kv_padded, - const c10::optional descale_QKV, const c10::optional descale_S, - const c10::optional descale_O, const c10::optional descale_dO, - const c10::optional descale_dP, const c10::optional scale_S, - const c10::optional scale_dP, const c10::optional scale_dQKV, - c10::optional amax_dP, c10::optional amax_dQKV) { + const c10::optional cu_seqlens_kv_padded, py::handle s_quantizer, + py::handle dp_quantizer, py::handle dqkv_quantizer) { using namespace transformer_engine; - - auto q_sizes = Q.sizes().vec(); - std::vector q_shape{q_sizes.begin(), q_sizes.end()}; - auto k_sizes = K.sizes().vec(); - std::vector k_shape{k_sizes.begin(), k_sizes.end()}; - auto v_sizes = V.sizes().vec(); - std::vector v_shape{v_sizes.begin(), v_sizes.end()}; + using namespace transformer_engine::pytorch; + auto none = py::none(); + TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + te_Q = makeTransformerEngineTensor(Q, none); + te_K = makeTransformerEngineTensor(K, none); + te_V = makeTransformerEngineTensor(V, none); + te_O = makeTransformerEngineTensor(O, none); + te_dO = makeTransformerEngineTensor(dO, none); + // qkv type from the te_Q + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); + const transformer_engine::DType qkv_type = te_Q.dtype(); + const transformer_engine::DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + + py::object s_python, dp_python; + std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); + std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + std::vector q_shape = convertShape(te_Q.shape()); + std::vector k_shape = convertShape(te_K.shape()); + std::vector v_shape = convertShape(te_V.shape()); auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d_qk = q_shape[q_shape.size() - 1]; auto d_v = v_shape[v_shape.size() - 1]; auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - std::vector o_shape{q_sizes.begin(), q_sizes.end()}; + std::vector o_shape{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = d_v; - at::Tensor dQ; - at::Tensor dK; - at::Tensor dV; - at::Tensor dQKV, dKV; + at::Tensor dQ, dK, dV, dQKV, dKV; + py::object py_dQ, py_dK, py_dV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1009,7 +338,7 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_H3D: - tmp_shape = std::vector{q_sizes.begin(), q_sizes.end()}; + tmp_shape = std::vector{q_shape.begin(), q_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(3)); dQKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dQ = dQKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1023,8 +352,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_2HD: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 2, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1037,8 +367,9 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 3); break; case NVTE_QKV_Layout_Group::NVTE_HD_H2D: - dQ = torch::empty_like(Q, options); - tmp_shape = std::vector{k_sizes.begin(), k_sizes.end()}; + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector{k_shape.begin(), k_shape.end()}; tmp_shape.insert(tmp_shape.begin() + tmp_shape.size() - 1, int64_t(2)); dKV = torch::empty(c10::IntArrayRef(tmp_shape), options); dK = dKV.index({"...", torch::indexing::Slice(0, 1, 1), @@ -1049,82 +380,41 @@ std::vector fused_attn_bwd( .squeeze(tmp_shape.size() - 2); break; case NVTE_QKV_Layout_Group::NVTE_HD_HD_HD: - dQ = torch::empty_like(Q, options); - dK = torch::empty_like(K, options); - dV = torch::empty_like(V, options); + tmp_shape = std::vector(q_shape.begin(), q_shape.end()); + dQ = torch::empty(tmp_shape, options); + tmp_shape = std::vector(k_shape.begin(), k_shape.end()); + dK = torch::empty(tmp_shape, options); + tmp_shape = std::vector(v_shape.begin(), v_shape.end()); + dV = torch::empty(tmp_shape, options); break; default: NVTE_ERROR("QKV layout not supported!"); } + std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); // construct NVTE tensors - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { // FP8 if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); } else { dQ.fill_(0); dK.fill_(0); dV.fill_(0); } - if ((!descale_QKV.has_value()) || (!descale_S.has_value()) || (!descale_O.has_value()) || - (!descale_dO.has_value()) || (!descale_dP.has_value()) || (!scale_S.has_value()) || - (!scale_dP.has_value()) || (!scale_dQKV.has_value()) || (!amax_dP.has_value()) || - (!amax_dQKV.has_value())) { - std::string err_tensors = "descale_QKV, descale_S, descale_O, descale_dO, descale_dP, "; - err_tensors = err_tensors + std::string("scale_S, scale_dP, scale_dQKV, "); - err_tensors = err_tensors + std::string("amax_dP and amax_dQKV "); - NVTE_ERROR(err_tensors + std::string("are required for FP8 operation. \n")); - } - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, - descale_QKV.value().data_ptr()); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, - descale_O.value().data_ptr()); - te_dO = makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, - descale_dO.value().data_ptr()); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, - scale_S.value().data_ptr(), descale_S.value().data_ptr()); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, amax_dP.value().data_ptr(), - scale_dP.value().data_ptr(), descale_dP.value().data_ptr()); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dK = - makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); - te_dV = - makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, amax_dQKV.value().data_ptr(), - scale_dQKV.value().data_ptr(), nullptr); + } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); dV.fill_(0); } - // BF16 or FP16 - te_Q = makeTransformerEngineTensor(Q.data_ptr(), q_shape, qkv_type, nullptr, nullptr, nullptr); - te_K = makeTransformerEngineTensor(K.data_ptr(), k_shape, qkv_type, nullptr, nullptr, nullptr); - te_V = makeTransformerEngineTensor(V.data_ptr(), v_shape, qkv_type, nullptr, nullptr, nullptr); - te_O = makeTransformerEngineTensor(O.data_ptr(), o_shape, qkv_type, nullptr, nullptr, nullptr); - te_dO = - makeTransformerEngineTensor(dO.data_ptr(), o_shape, dqkv_type, nullptr, nullptr, nullptr); - te_S = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dP = makeTransformerEngineTensor(nullptr, {0}, DType::kFloat32, nullptr, nullptr, nullptr); - te_dQ = - makeTransformerEngineTensor(dQ.data_ptr(), q_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dK = - makeTransformerEngineTensor(dK.data_ptr(), k_shape, dqkv_type, nullptr, nullptr, nullptr); - te_dV = - makeTransformerEngineTensor(dV.data_ptr(), v_shape, dqkv_type, nullptr, nullptr, nullptr); } else { NVTE_ERROR("Fused attention only supports FP8 and BF16/FP16 data types. \n"); } @@ -1149,11 +439,9 @@ std::vector fused_attn_bwd( std::vector cu_seqlens_kv_padded_shape{cu_seqlens_kv_padded_sizes.begin(), cu_seqlens_kv_padded_sizes.end()}; te_cu_seqlens_q_padded = makeTransformerEngineTensor(cu_seqlens_q_padded.value().data_ptr(), - cu_seqlens_q_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); - te_cu_seqlens_kv_padded = makeTransformerEngineTensor(cu_seqlens_kv_padded.value().data_ptr(), - cu_seqlens_kv_padded_shape, DType::kInt32, - nullptr, nullptr, nullptr); + cu_seqlens_q_padded_shape, DType::kInt32); + te_cu_seqlens_kv_padded = makeTransformerEngineTensor( + cu_seqlens_kv_padded.value().data_ptr(), cu_seqlens_kv_padded_shape, DType::kInt32); } // convert auxiliary tensors from forward to NVTETensors @@ -1161,11 +449,14 @@ std::vector fused_attn_bwd( nvte_tensor_pack_create(&nvte_aux_tensor_pack); nvte_aux_tensor_pack.size = Aux_CTX_Tensors.size(); for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - auto tensor = reinterpret_cast(nvte_aux_tensor_pack.tensors[i]); - tensor->data.dptr = Aux_CTX_Tensors[i].data_ptr(); std::vector tmp(Aux_CTX_Tensors[i].sizes().vec()); - tensor->data.shape = std::vector(tmp.begin(), tmp.end()); - tensor->data.dtype = GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type()); + auto temp_vec = std::vector(tmp.begin(), tmp.end()); + const NVTEShape temp_shape = {temp_vec.data(), temp_vec.size()}; + NVTEBasicTensor temp_data = { + Aux_CTX_Tensors[i].data_ptr(), + static_cast(GetTransformerEngineDType(Aux_CTX_Tensors[i].scalar_type())), + temp_shape}; + nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); } // create dBias the same shape as Bias @@ -1216,7 +507,7 @@ std::vector fused_attn_bwd( // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - return {dQ, dK, dV, dBias}; + return {py_dQ, py_dK, py_dV, py::cast(dBias)}; } namespace flash_attention { @@ -1359,64 +650,10 @@ at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) { return qkv; } -/*************************************************************************************************** - * Support THD format for Context Parallel: Binary search - **************************************************************************************************/ - -__forceinline__ __device__ int binary_search(int target, int *array, int len) { - int left = 1, right = len - 1; - while (left < right) { - int mid = (left + right) / 2; - if (array[mid] <= target) { - left = mid + 1; - } else { - right = mid; - } - } - return left - 1; -} - /*************************************************************************************************** * Support THD format for Context Parallel: Read the half of a THD tensor **************************************************************************************************/ -__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, - int hidden_size_in_bytes, int half_idx, - int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; - int laneid = threadIdx.x % 32; - int num_warps = (blockDim.x * gridDim.x) / 32; - int num_total_tokens = cu_seqlens_s[batch]; - int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; - half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); - tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); - - for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { - int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); - - size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; - float4 *cur_half_token = - reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); - - offset_in_bytes = - (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; - float4 *cur_token = - reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); - - for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { - cur_half_token[idx] = cur_token[idx]; - } - } -} - at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens, int half_idx) { NVTE_CHECK(tensor.dim() == 3 || tensor.dim() == 4); @@ -1452,8 +689,8 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s grid_y *= tensor.size(i); } dim3 grid = {grid_x, grid_y}; - thd_read_half_tensor_kernel<<>>( + transformer_engine::fused_attn::thd_read_half_tensor_kernel<<< + grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( half.data_ptr(), tensor.data_ptr(), cu_seqlens.data_ptr(), batch, hidden_size_in_bytes, half_idx, tensor.size(seq_dim)); @@ -1464,51 +701,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s * Support THD format for Context Parallel: softmax_lse related operations **************************************************************************************************/ -template -__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, - int num_heads, int total_tokens) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - int num_total_tokens = cu_seqlens_s[batch]; - - for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, half_idx; - if constexpr (lse_packed) { - idx = head_id * total_tokens + token_id + cu_seqlens_s[seq_id + 1]; - half_idx = head_id * total_tokens / 2 + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - - idx = row * total_tokens + col + seq_len; - half_idx = row * total_tokens / 2 + col; - } - - Functor::run(lse, half_lse, idx, half_idx); - } - } -} - -struct LseCorrectionFunctor { - __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, - size_t half_idx) { - double val = lse[idx]; - float val_per_step = half_lse[half_idx]; - double max_scale = max(val, val_per_step); - double min_scale = min(val, val_per_step); - lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); - } -}; - void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens, bool lse_packed) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Double); @@ -1516,7 +708,7 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch, num_heads, total_tokens; + int batch, num_heads, lse_seqlen, second_half_lse_seqlen; if (lse_packed) { NVTE_CHECK(lse.dim() == 2); @@ -1524,55 +716,51 @@ void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_st batch = cu_seqlens.size(0) - 1; num_heads = lse.size(0); - total_tokens = lse.size(1); + lse_seqlen = lse.size(1); + second_half_lse_seqlen = lse_per_step.size(1); NVTE_CHECK(lse_per_step.size(0) == num_heads); - NVTE_CHECK(lse_per_step.size(1) == total_tokens / 2); + NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); } else { NVTE_CHECK(lse.dim() == 3); NVTE_CHECK(lse_per_step.dim() == 3); batch = lse.size(0); num_heads = lse.size(1); - total_tokens = lse.size(2); + lse_seqlen = lse.size(2); + second_half_lse_seqlen = lse_per_step.size(2); NVTE_CHECK(lse_per_step.size(0) == batch); NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == total_tokens / 2); + NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); } constexpr unsigned int block = 256; - unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; + if (lse_packed) { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), - batch, num_heads, total_tokens); + batch, num_heads, lse_seqlen, second_half_lse_seqlen); } else { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), - batch, num_heads, total_tokens); + batch, num_heads, lse_seqlen, second_half_lse_seqlen); } } -struct ReadLseFunctor { - __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, - size_t half_idx) { - half_lse[half_idx] = lse[idx]; - } -}; - at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens, - bool lse_packed) { + bool lse_packed, int second_half_lse_seqlen) { NVTE_CHECK(lse.scalar_type() == at::ScalarType::Float); NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); NVTE_CHECK(cu_seqlens.dim() == 1); - int batch, num_heads, total_tokens; + int batch, num_heads, lse_seqlen; std::vector shape; if (lse_packed) { @@ -1580,37 +768,41 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ batch = cu_seqlens.size(0) - 1; num_heads = lse.size(0); - total_tokens = lse.size(1); + lse_seqlen = lse.size(1); + + NVTE_CHECK(second_half_lse_seqlen >= lse_seqlen / 2); - shape = {num_heads, total_tokens / 2}; + shape = {num_heads, second_half_lse_seqlen}; } else { NVTE_CHECK(lse.dim() == 3); batch = lse.size(0); num_heads = lse.size(1); - total_tokens = lse.size(2); + lse_seqlen = lse.size(2); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); + NVTE_CHECK(second_half_lse_seqlen == lse_seqlen / 2); - shape = {batch, num_heads, total_tokens / 2}; + shape = {batch, num_heads, second_half_lse_seqlen}; } at::Tensor half_lse = at::zeros(shape, at::CUDA(lse.scalar_type())); constexpr unsigned int block = 256; - unsigned int grid_x = (total_tokens / 2 + block - 1) / block; + unsigned int grid_x = (lse_seqlen / 2 + block - 1) / block; unsigned int grid_y = num_heads; dim3 grid = {grid_x, grid_y}; + if (lse_packed) { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, total_tokens); + num_heads, lse_seqlen, second_half_lse_seqlen); } else { - thd_lse_kernel + transformer_engine::fused_attn::thd_lse_kernel <<>>( lse.data_ptr(), half_lse.data_ptr(), cu_seqlens.data_ptr(), batch, - num_heads, total_tokens); + num_heads, lse_seqlen, second_half_lse_seqlen); } return half_lse; @@ -1620,59 +812,6 @@ at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_ * Support THD format for Context Parallel: Out correction in forward **************************************************************************************************/ -template -__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, - float *lse_per_step, int *cu_seqlens, int batch, - int num_heads, int dim_per_head, int lse_seqlen) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); - } - __syncthreads(); - - int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; - int lane_id = threadIdx.x % tile_size; - int num_tiles = (blockDim.x * gridDim.x) / tile_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); - - for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { - size_t idx, idx_per_step; - - if constexpr (lse_packed) { - idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx_per_step = head_id * lse_seqlen / (only_second_half + 1) + token_id; - } else { - size_t row = static_cast(seq_id) * num_heads + head_id; - int col = token_id - cu_seqlens_s[seq_id]; - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - idx = row * lse_seqlen + col + seq_len * only_second_half; - idx_per_step = row * lse_seqlen / (only_second_half + 1) + col; - } - float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); - - idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; - idx = (idx * num_heads + head_id) * dim_per_head; - idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; - dtype *cur_out = out + idx; - dtype *cur_out_per_step = out_per_step + idx_per_step; - - for (int j = lane_id; j < num_loops_per_head; j += tile_size) { - float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; - float4 data = reinterpret_cast(cur_out)[j]; - dtype *p_per_step = reinterpret_cast(&data_per_step); - dtype *p = reinterpret_cast(&data); - for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { - p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); - } - reinterpret_cast(cur_out)[j] = data; - } - } - } -} - template static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse, const at::Tensor &lse_per_step, @@ -1690,23 +829,25 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ NVTE_CHECK(out_per_step.size(1) == num_heads); NVTE_CHECK(out_per_step.size(2) == dim_per_head); - int batch, lse_seqlen; + int batch, lse_seqlen, lse_per_step_seqlen; if (lse_packed) { batch = cu_seqlens.size(0) - 1; - lse_seqlen = total_tokens; + lse_seqlen = lse.size(1); + lse_per_step_seqlen = lse_per_step.size(1); NVTE_CHECK(lse.size(0) == num_heads); - NVTE_CHECK(lse.size(1) == lse_seqlen); + NVTE_CHECK(lse_seqlen >= total_tokens); NVTE_CHECK(lse_per_step.size(0) == num_heads); - NVTE_CHECK(lse_per_step.size(1) == lse_seqlen / (only_second_half + 1)); + NVTE_CHECK(lse_per_step_seqlen >= lse_seqlen / (only_second_half + 1)); } else { batch = lse.size(0); lse_seqlen = lse.size(2); + lse_per_step_seqlen = lse_per_step.size(2); NVTE_CHECK(lse.size(1) == num_heads); NVTE_CHECK(lse_per_step.size(0) == batch); NVTE_CHECK(lse_per_step.size(1) == num_heads); - NVTE_CHECK(lse_per_step.size(2) == lse_seqlen / (only_second_half + 1)); + NVTE_CHECK(lse_per_step_seqlen == lse_seqlen / (only_second_half + 1)); NVTE_CHECK(cu_seqlens.size(0) == batch + 1); } @@ -1717,17 +858,17 @@ static void thd_out_correction_helper(at::Tensor out, const at::Tensor &out_per_ dim3 grid = {grid_x, (unsigned int)num_heads}; if (lse_packed) { - thd_out_correction_kernel + transformer_engine::fused_attn::thd_out_correction_kernel <<>>( out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, - dim_per_head, lse_seqlen); + dim_per_head, lse_seqlen, lse_per_step_seqlen); } else { - thd_out_correction_kernel + transformer_engine::fused_attn::thd_out_correction_kernel <<>>( out.data_ptr(), out_per_step.data_ptr(), lse.data_ptr(), lse_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, num_heads, - dim_per_head, lse_seqlen); + dim_per_head, lse_seqlen, lse_per_step_seqlen); } } @@ -1773,87 +914,6 @@ void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at * Support THD format for Context Parallel: Gradients correction in backward **************************************************************************************************/ -template -__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, - int batch, int hidden_size, int dim_size_of_token) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - if constexpr (functor_idx < 2) { - cu_seqlens_s[i] = cu_seqlens[i] / 2; - } else { - cu_seqlens_s[i] = cu_seqlens[i]; - } - } - __syncthreads(); - - int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; - int lane_id = threadIdx.x % group_size; - int num_groups = (blockDim.x * gridDim.x) / group_size; - int num_total_tokens = cu_seqlens_s[batch]; - int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); - - size_t offset = static_cast(dim_size_of_token) * hidden_size; - if constexpr (functor_idx < 2) { - grad_per_step = grad_per_step + offset / 2 * blockIdx.y; - } else { - grad_per_step = grad_per_step + offset * blockIdx.y; - } - grad = grad + offset * blockIdx.y; - - for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - - int token_offset; - bool is_first_half; - if constexpr (functor_idx < 2) { - token_offset = cu_seqlens_s[seq_id + functor_idx]; - is_first_half = (functor_idx == 0); - } else { - token_offset = 0; - int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); - } - - dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; - dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; - for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { - if (is_first_half) { - Functor_0::run(token, token_per_step, idx); - } else { - Functor_1::run(token, token_per_step, idx); - } - } - } -} - -struct EmptyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} -}; - -struct CopyFunctor { - __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { - reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; - } -}; - -template -struct AddFunctor { - __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { - float4 d_ = reinterpret_cast(token)[idx]; - dtype *p_ = reinterpret_cast(&d_); - - float4 d = reinterpret_cast(token_per_step)[idx]; - dtype *p = reinterpret_cast(&d); - -#pragma unroll - for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { - p_[i] += p[i]; - } - - reinterpret_cast(token)[idx] = d_; - } -}; - template static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_per_step, const at::Tensor &cu_seqlens) { @@ -1894,7 +954,8 @@ static void thd_grad_correction_helper(at::Tensor grad, const at::Tensor &grad_p } dim3 grid = {grid_x, grid_y}; - thd_grad_correction_kernel + transformer_engine::fused_attn::thd_grad_correction_kernel <<>>( grad.data_ptr(), grad_per_step.data_ptr(), cu_seqlens.data_ptr(), batch, hidden_size, total_tokens); @@ -1945,31 +1006,6 @@ void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step, * Support THD format for Context Parallel: Generate partitioned indices for input tokens **************************************************************************************************/ -__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, - int total_tokens, int world_size, int rank) { - extern __shared__ int cu_seqlens_s[]; - for (int i = threadIdx.x; i <= batch; i += blockDim.x) { - int seqlen = cu_seqlens[i]; - // Currently we assume that each sequence length is divisible by (world_size*2) since we have - // to distribute each sequence evenly to different GPUs. - assert(seqlen % (world_size * 2) == 0); - cu_seqlens_s[i] = seqlen / world_size; - } - __syncthreads(); - - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int num_threads = blockDim.x * gridDim.x; - - for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { - int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); - int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; - int index = token_id - cu_seqlens_s[seq_id]; - int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; - index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; - output[token_id] = index; - } -} - at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens, int world_size, int rank) { NVTE_CHECK(cu_seqlens.scalar_type() == at::ScalarType::Int); @@ -1986,9 +1022,180 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t constexpr unsigned int block = 256; unsigned int grid = (output.size(0) + block - 1) / block; - thd_partition_indices_kernel<<>>( + transformer_engine::fused_attn::thd_partition_indices_kernel<<< + grid, block, sizeof(int) * (batch + 1), at::cuda::getCurrentCUDAStream()>>>( output.data_ptr(), cu_seqlens.data_ptr(), batch, total_tokens, world_size, rank); return output; } + +/*************************************************************************************************** + * KV Cache: Convert a tensor from qkv_format = thd to qkv_format = bshd + **************************************************************************************************/ + +template +void convert_thd_to_bshd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, + int b, int max_seq_len, int h, int d) { + transformer_engine::fused_attn:: + convert_thd_to_bshd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(tensor.data_ptr()), + reinterpret_cast(new_tensor.data_ptr()), cu_seqlens.data_ptr(), + b, max_seq_len, h, d); +} + +at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len) { + int h = tensor.size(1); + int d = tensor.size(2); + std::vector shape = {b, max_seq_len, h, d}; + at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); + if (new_tensor.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::Float) { + using dtype = float; + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) { + using dtype = at::Float8_e4m3fn; + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (new_tensor.scalar_type() == at::ScalarType::Float8_e5m2) { + using dtype = at::Float8_e5m2; + convert_thd_to_bshd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); + } + return new_tensor; +} + +/*************************************************************************************************** + * KV Cache: Convert a tensor from qkv_format = bshd to qkv_format = thd + **************************************************************************************************/ + +template +void convert_bshd_to_thd_launcher(at::Tensor tensor, at::Tensor new_tensor, at::Tensor cu_seqlens, + int b, int max_seq_len, int h, int d) { + transformer_engine::fused_attn:: + convert_bshd_to_thd_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(tensor.data_ptr()), + reinterpret_cast(new_tensor.data_ptr()), cu_seqlens.data_ptr(), + b, max_seq_len, h, d); +} + +at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { + int b = tensor.size(0); + int max_seq_len = tensor.size(1); + int h = tensor.size(2); + int d = tensor.size(3); + std::vector shape = {t, h, d}; + at::Tensor new_tensor = at::zeros(shape, at::CUDA(tensor.scalar_type())); + if (tensor.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::Float) { + using dtype = float; + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::Float8_e4m3fn) { + using dtype = at::Float8_e4m3fn; + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else if (tensor.scalar_type() == at::ScalarType::Float8_e5m2) { + using dtype = at::Float8_e5m2; + convert_bshd_to_thd_launcher(tensor, new_tensor, cu_seqlens, b, max_seq_len, h, d); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); + } + return new_tensor; +} + +/*************************************************************************************************** + * KV Cache: Copy new KV tokens to the KV cache + * 1. new_k and new_v are in qkv_format; k_cache and v_cache are in 'bshd' format + * 2. cu_new_lens and cu_cached_lens are in shape [b + 1]; cu_cached_lens include the added lens + * in current step + * 3. Non-paged KV cache is a special case of paged KV cache, with page_table = [b, 1] and + * max_pages_per_seq = 1. We use the same underlying kernel for both non-paged and paged. + * Set is_non_paged = True/False to indicate as such. + * 4. is_non_paged = True also re-indexes the KV cache, e.g. the initial batch indices [0, 3, 1, 2] + * becomes [0, 1, 1, 2]. The page_table = batch_indices.unsqueeze(1) is however unchanged. + * batch_indices_post can be used for monotonical indexing, i.e. [0, 1, 2, 3]. batch_indices is + * preserved for the next layer in the same iteration. + * 5. Only supports same page_table for k_cache and v_cache + * 6. Only pad_between_seqs = False when qkv_format = thd, i.e. there should be no pad tokens + * between sequences in new_k and new_v such as [a a a 0..0 b b 0..0 c 0..0]. + **************************************************************************************************/ + +template +void copy_to_kv_cache_launcher(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, + at::Tensor v_cache, at::Tensor page_table, at::Tensor cu_new_lens, + at::Tensor cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, + int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, + int max_pages_per_seq, bool is_non_paged) { + if (new_k.data_ptr() != nullptr && new_v.data_ptr() != nullptr && k_cache.data_ptr() != nullptr && + v_cache.data_ptr() != nullptr) { + if (is_non_paged) { + transformer_engine::fused_attn:: + reindex_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + page_table.data_ptr(), cu_new_lens.data_ptr(), + cu_cached_lens.data_ptr(), h_kv, d_k, d_v, b, max_seq_len); + } + transformer_engine::fused_attn:: + copy_to_kv_cache_kernel<<<16, 256, 0, at::cuda::getCurrentCUDAStream()>>>( + reinterpret_cast(new_k.data_ptr()), + reinterpret_cast(new_v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), page_table.data_ptr(), + cu_new_lens.data_ptr(), cu_cached_lens.data_ptr(), qkv_format, h_kv, d_k, d_v, + b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + } +} + +void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache, + at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens, + NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, + int max_pages_per_seq, bool is_non_paged) { + int h_kv = new_k.size(-2); + int d_k = new_k.size(-1); + int d_v = new_v.size(-1); + NVTE_CHECK(k_cache.scalar_type() == v_cache.scalar_type() && + new_k.scalar_type() == new_v.scalar_type() && + new_k.scalar_type() == k_cache.scalar_type(), + "new_k, new_v, k_cache and v_cache must be of the same data type."); + NVTE_CHECK(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD || + qkv_format == NVTE_QKV_Format::NVTE_THD, + "qkv_format must be {BSHD, SBHD, THD}."); + if (k_cache.scalar_type() == at::ScalarType::Half) { + using dtype = at::Half; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + + } else if (k_cache.scalar_type() == at::ScalarType::BFloat16) { + using dtype = at::BFloat16; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else if (k_cache.scalar_type() == at::ScalarType::Float) { + using dtype = float; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else if (k_cache.scalar_type() == at::ScalarType::Float8_e4m3fn) { + using dtype = at::Float8_e4m3fn; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else if (k_cache.scalar_type() == at::ScalarType::Float8_e5m2) { + using dtype = at::Float8_e5m2; + copy_to_kv_cache_launcher(new_k, new_v, k_cache, v_cache, page_table, cu_new_lens, + cu_cached_lens, qkv_format, h_kv, d_k, d_v, b, max_ctx_len, + max_seq_len, max_pages_per_seq, is_non_paged); + } else { + NVTE_ERROR("Unsupported dtype for KV cache.\n"); + } +} diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp new file mode 100644 index 0000000000..5ff10f6efb --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -0,0 +1,74 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common.h" +#include "pybind.h" +#include "transformer_engine/cast.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +std::vector bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) { + auto quantizer = convert_quantizer(py_quantizer); + + auto input_tensor = makeTransformerEngineTensor(input); + + auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype()); + + std::vector output_shape; + for (auto s : input.sizes()) { + output_shape.emplace_back(static_cast(s)); + } + auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype()); + + // Return immediately if tensors are empty + if (product(output_shape) == 0) { + return {py::cast(dbias.zero_()), out}; + } + + auto dbias_tensor = makeTransformerEngineTensor(dbias); + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + void* workspace_data_ptr = nullptr; + if (workspace.shape().ndim > 0) { + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace_data_ptr = workspace_data.data_ptr(); + } + workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype()); + + // Launch kernel + if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(quantizer.get()); + nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(out_tensor.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape); + } + nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), workspace.data(), + at::cuda::getCurrentCUDAStream()); + + return {py::cast(dbias), out}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 47f5825866..2c3ccff154 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1,72 +1,152 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#include "transformer_engine/cast.h" + +#include "common.h" #include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, + std::optional noop) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + auto input_tensor = tensor.contiguous(); + + const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); + const auto& te_input_shape = te_input.shape(); + std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); + auto fake_tensor_type = tensor.scalar_type(); + if (!detail::IsFloatingPointType(fake_tensor_type)) { + fake_tensor_type = at::kFloat; + } + + TensorWrapper te_output; + py::object out; + if (output.is_none()) { + DType fake_te_type = GetTransformerEngineDType(fake_tensor_type); + std::tie(te_output, out) = my_quantizer->create_tensor(input_shape, fake_te_type); + } else { + out = output; + te_output = makeTransformerEngineTensor(output, quantizer); + } + + TensorWrapper te_noop; + if (noop.has_value()) { + te_noop = makeTransformerEngineTensor(*noop); + } else { + te_noop = TensorWrapper(); + } + + if (te_output.numel() == 0) return out; + + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // my_quantizer here has to be a Float8CurrentScalingQuantizer + auto my_quantizer_cs = static_cast(my_quantizer.get()); + nvte_compute_amax(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + // check if we need to do amax reudction (depending on model parallel configs) + if (my_quantizer_cs->with_amax_reduction) { + c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; + // construct torch tesnor from NVTEBasicTensor without reallocating memory + at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; + std::vector tensors = {amax_tensor_torch}; + // allreduce amax tensor + c10d::AllreduceOptions allreduce_opts; + allreduce_opts.reduceOp = c10d::ReduceOp::MAX; + process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); + } + QuantizationConfigWrapper quant_config; + quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); + quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); + nvte_compute_scale_from_amax(te_output.data(), quant_config, at::cuda::getCurrentCUDAStream()); + // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel + te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); + } + nvte_quantize_noop(te_input.data(), te_output.data(), te_noop.data(), + at::cuda::getCurrentCUDAStream()); + + return out; +} + +py::object dequantize(const py::handle& input, transformer_engine::DType otype) { + init_extension(); -at::Tensor cast_to_fp8(const at::Tensor& input, const at::Tensor& scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset, const int amax_offset, const int scale_inv_offset) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; + const auto none = py::none(); - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + const auto& input_tensor = makeTransformerEngineTensor(input, none); - if (input.numel() == 0) return output; + NoneQuantizer q(none); - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + const auto& shape = convertShape(input_tensor.shape()); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax_dptr, - scale_dptr, scale_inv_dptr); + auto [out_tensor, out] = q.create_tensor(shape, otype); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + nvte_dequantize(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); - return output; + return out; } -void cast_to_fp8_noalloc(const at::Tensor& input, const at::Tensor& scale, at::Tensor output, - at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { - using namespace transformer_engine; - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); +template +std::vector dbias_dact(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + init_extension(); + auto my_quantizer = convert_quantizer(quantizer); + + auto grad_tensor = makeTransformerEngineTensor(grad_output); + + auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); + auto act_input_tensor = makeTransformerEngineTensor(act_input); + + const auto& shape = convertShape(grad_tensor.shape()); + auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); + auto dbias_tensor = makeTransformerEngineTensor(grad_bias); - auto input_cu = makeTransformerEngineTensor(input); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); + // Query workspace size and allocate workspace + transformer_engine::TensorWrapper workspace; + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - nvte_fp8_quantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + // Launch kernel + func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), + workspace.data(), at::cuda::getCurrentCUDAStream()); - return; + return {py::cast(grad_bias), dact}; } -at::Tensor cast_from_fp8(const at::Tensor& input, const at::Tensor& scale_inv, - transformer_engine::DType itype, transformer_engine::DType otype, - const int scale_inv_offset) { - using namespace transformer_engine; - auto input_shape = input.sizes().vec(); - std::vector shape{input_shape.begin(), input_shape.end()}; +std::vector dbias_dgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); +std::vector dbias_dsilu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), shape, itype, nullptr, nullptr, - getDataPtr(scale_inv, scale_inv_offset)); - auto output_cu = makeTransformerEngineTensor(output); +std::vector dbias_drelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - nvte_fp8_dequantize(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); +std::vector dbias_dqgelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); +} - return output; +std::vector dbias_dsrelu(const at::Tensor& grad_output, const at::Tensor& act_input, + py::handle quantizer) { + return dbias_dact(grad_output, act_input, quantizer); } + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index d212d13516..6d05869c36 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -1,10 +1,11 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "../extensions.h" +#include "transformer_engine/transformer_engine.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 @@ -14,50 +15,6 @@ using namespace std::placeholders; namespace te = transformer_engine; -#define MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inv, A_fp8_index, A_type, B, B_scale_inv, \ - B_fp8_index, B_type, D, D_amax, D_scale, D_type, bias, \ - bias_type, pre_gelu_out, workspace) \ - A = A.contiguous(); \ - void *A_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(A_type)) { \ - assert(A_scale_inv.numel()); \ - A_scale_inv_ptr = A_scale_inv[A_fp8_index].data_ptr(); \ - } \ - auto A_ = makeTransformerEngineTensor( \ - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, \ - nullptr, nullptr, A_scale_inv_ptr); \ - B = B.contiguous(); \ - void *B_scale_inv_ptr = nullptr; \ - if (te::is_fp8_dtype(B_type)) { \ - assert(B_scale_inv.numel()); \ - B_scale_inv_ptr = B_scale_inv[B_fp8_index].data_ptr(); \ - } \ - auto B_ = makeTransformerEngineTensor( \ - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, \ - nullptr, nullptr, B_scale_inv_ptr); \ - void *D_amax_ptr = nullptr; \ - void *D_scale_ptr = nullptr; \ - if (te::is_fp8_dtype(D_type)) { \ - assert(D_amax.numel()); \ - D_amax_ptr = D_amax.data_ptr(); \ - assert(D_scale.numel()); \ - D_scale_ptr = D_scale.data_ptr(); \ - } \ - auto D_ = makeTransformerEngineTensor( \ - D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, \ - D_amax_ptr, D_scale_ptr, nullptr); \ - auto bias_ = makeTransformerEngineTensor( \ - bias.data_ptr(), std::vector{static_cast(bias.size(0))}, bias_type); \ - const auto gelu_shape = (pre_gelu_out.data_ptr() == nullptr) \ - ? std::vector{static_cast(pre_gelu_out.size(0))} \ - : std::vector{static_cast(pre_gelu_out.size(0)), \ - static_cast(pre_gelu_out.size(1))}; \ - auto pre_gelu_out_ = makeTransformerEngineTensor( \ - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); \ - auto workspace_ = makeTransformerEngineTensor( \ - workspace.data_ptr(), std::vector{static_cast(workspace.size(0))}, \ - te::DType::kByte); - /*************************************************************************************************** * CommOverlapHelper **************************************************************************************************/ @@ -69,8 +26,7 @@ CommOverlapHelper::CommOverlapHelper() { } // empty constructor for NVTE_UB_WITH_MPI=1 CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, - std::optional intra_domain_group, - std::optional inter_domain_group) { + std::optional intra_domain_group) { #ifndef NVTE_UB_WITH_MPI pgs.insert({"world", world_group}); myrank = pgs["world"]->getRank(); @@ -96,20 +52,9 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, mynode = 0; numnodes = 1; } else { - // Intra-node group is different than the world group so there must be multiple nodes - NVTE_CHECK( - inter_domain_group.has_value(), - "Internal TE error: Inter-node group cannot be `None` when intra-node group is not ", - "identical to the world_group!"); - // Get node ID and number of nodes - NVTE_CHECK( - inter_domain_group.value()->getBackendType() == backend, - "Internal TE error: Inter-node group must be on the same backend (%s) as the world ", - "group!", pgs["world"]->getBackendName()); - pgs.insert({"inter", inter_domain_group.value()}); - mynode = pgs["inter"]->getRank(); - numnodes = pgs["inter"]->getSize(); + mynode = myrank / numlocal; + numnodes = numranks / numlocal; } } else { // Intra-node group is not set so we assume there is only 1 node @@ -185,145 +130,92 @@ void CommOverlapHelper::ub_barrier(ExtComm group) { CommOverlap::CommOverlap(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, int num_splits, - int num_max_streams, int comm_cga_size, int num_comm_sm, - bool set_sm_margin, bool atomic_gemm) - : te::CommOverlapBase(buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, - helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, - helper->numnodes, tp_size, + int num_max_streams, int comm_cga_size, int gemm_priority, + int comm_priority, int num_comm_sm, bool set_sm_margin, bool atomic_gemm, + bool rs_overlap_first_gemm) + : te::CommOverlapBase(buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), + helper->myrank, helper->numranks, helper->mylocal, helper->numlocal, + helper->mynode, helper->numnodes, tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits, - num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Bulk GEMM + COMM -** This function assumes the communication input is pre-copied to _ubuf -*/ -std::vector CommOverlap::bulk_overlap( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, - te::CommOverlapType comm_type, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::bulk_overlap(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, workspace_, - grad, accumulate, use_split_accumulator, comm_type, rs_out_, - stream_main); - - // Get the current userbuf offset - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - if (comm_type == te::CommOverlapType::RS) { - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - } - - // Generate output tensor from userbuf data pointer - int output_c_dim0 = - (comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - auto output_tensor = - torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); + num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, + set_sm_margin, atomic_gemm, rs_overlap_first_gemm) {} - return {D, output_tensor}; -} // CommOverlap::bulk_overlap - -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs - -/* -** Split FPROP GEMM + ReduceScatter -*/ -void CommOverlap::split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - te::DType A_type, bool transa, at::Tensor B, - at::Tensor B_scale_inverse, int64_t B_fp8_tensor, - te::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, bool gemm_overlap, - at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - gemm_overlap, rs_out_, stream_main); -} // CommOverlap::split_overlap_rs +void CommOverlap::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + _ubuf_scale_inv_initialized = true; +} /* ** Helper function to copy input to _ubuf */ -void CommOverlap::copy_input_to_ubuf(torch::Tensor input, int comm_type) { +void CommOverlap::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + char *ubuf_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type == te::CommOverlapType::AG) { - if ((input.numel() * _tp_size) != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + if (local_chunk) { + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + ubuf_ptr += (_ubuf.numel() / _tp_size) * _tp_id * _ubuf.element_size(); } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); } + // Copy either row or columnwise data into the communication buffer's columnwise data + // NOTE: _ubuf.columnwise_dptr() is not a valid copy target because it is not registered with + // the Userbuffers communicator. at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); NVTE_CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0)); - NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(), + NVTE_CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input_tensor.dptr(), + input_tensor.numel() * input_tensor.element_size(), cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm)); } -torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { +py::object CommOverlap::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, - torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype()))); + if (local_chunk) ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size(); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubuf.numel() / _tp_size : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; } /*************************************************************************************************** @@ -333,148 +225,85 @@ torch::Tensor CommOverlap::get_ubuf_output(int comm_type) { CommOverlapP2P::CommOverlapP2P(const std::vector &buffer_shape, at::ScalarType buffer_dtype, CommOverlapHelper *helper, int tp_size, te::CommOverlapType comm_type, int num_max_streams, - int comm_cga_size, int num_comm_sm, bool set_sm_margin, - bool atomic_gemm, bool use_ce, bool aggregate) + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool atomic_gemm, bool use_ce, + bool aggregate) : te::CommOverlapP2PBase( - buffer_shape, GetTransformerEngineDType(buffer_dtype), helper->myrank, helper->numranks, - helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size, - std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), + buffer_shape, te::pytorch::GetTransformerEngineDType(buffer_dtype), helper->myrank, + helper->numranks, helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, + tp_size, std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5), std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams, - comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) { - // Even though we never use these PyTorch tensor wrappers directly, they're still necessary to - // for PyTorch to factor externally allocated memory into its memory pool and garbage collection - // threshold calculation. - _ubuf_torch = torch::from_blob( - _ubuf.dptr(), {static_cast(_ubuf.size(0)), static_cast(_ubuf.size(1))}, - at::device(torch::kCUDA).dtype(buffer_dtype)); - if (_atomic_gemm) { - _ubuf_counter = torch::from_blob(_counter.dptr(), {static_cast(_num_splits * 2)}, - at::device(torch::kCUDA).dtype(torch::kInt32)); - } -} - -/* -** Split AllGather + AtomicGEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::atomic_gemm_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, B_copy_, stream_main); -} // atomic_gemm_overlap_ag - -/* -** Split AllGather + GEMM using P2P communication -** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is -*needed to have AG outputs -** in each rank to be in the contiguous memory space after all ring exchange -*phases. -*/ -void CommOverlapP2P::split_overlap_ag( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor B_copy) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto B_copy_ = makeTransformerEngineTensor(B_copy); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_ag(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - B_copy_, stream_main); -} // split_overlap_ag + comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, + atomic_gemm, aggregate) {} -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::atomic_gemm_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::atomic_gemm_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, - use_split_accumulator, rs_out_, stream_main); -} - -/* -** Split ReduceScatter + GEMM using P2P communication -*/ -void CommOverlapP2P::split_overlap_rs( - at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, te::DType A_type, bool transa, - at::Tensor B, at::Tensor B_scale_inverse, int64_t B_fp8_tensor, te::DType B_type, bool transb, - at::Tensor D, at::Tensor D_scale, te::DType D_type, at::Tensor D_amax, at::Tensor bias, - te::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, - size_t workspaceSize, bool accumulate, bool use_split_accumulator, at::Tensor rs_output) { - MAKE_TRANSFORMER_ENGINE_TENSORS(A, A_scale_inverse, A_fp8_tensor, A_type, B, B_scale_inverse, - B_fp8_tensor, B_type, D, D_amax, D_scale, D_type, bias, bias_type, - pre_gelu_out, workspace) - - auto rs_out_ = makeTransformerEngineTensor(rs_output); - cudaStream_t stream_main = static_cast(at::cuda::getCurrentCUDAStream()); - te::CommOverlapP2PBase::split_overlap_rs(A_, transa, B_, transb, D_, bias_, pre_gelu_out_, - workspace_, grad, accumulate, use_split_accumulator, - rs_out_, stream_main); +void CommOverlapP2P::set_buffer_params(py::handle quantizer) { + std::unique_ptr my_quantizer = te::pytorch::convert_quantizer(quantizer); + my_quantizer->set_quantization_params(&_ubuf); + for (size_t i = 0; i < _ubufs.size(); i++) my_quantizer->set_quantization_params(&_ubufs[i]); } /* ** Copy input to _ubufs[0] */ -void CommOverlapP2P::copy_input_to_ubuf(torch::Tensor input, bool chunk) { +void CommOverlapP2P::copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk) { + auto input_tensor = te::pytorch::makeTransformerEngineTensor(input, quantizer); + auto input_ptr = input_tensor.dptr() ? input_tensor.dptr() : input_tensor.columnwise_dptr(); + NVTE_CHECK(input_ptr, "Input tensor does not have rowwise or columnwise data!"); + at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream(); - if (chunk) { + if (local_chunk) { // Copy input to the target ubuf chunk by rank offset - if (input.numel() != (int64_t)_ubufs[0].numel() || - input.element_size() != (int64_t)_ubufs[0].element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() * _tp_size > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the local communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); + } else { - if (input.numel() != (int64_t)_ubuf.numel() || - input.element_size() != (int64_t)_ubuf.element_size()) { - NVTE_ERROR("input and ubuf size do not match!"); - } - NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input.data_ptr(), - input.numel() * input.element_size(), cudaMemcpyDeviceToDevice, - (cudaStream_t)stream_main)); + if (input_tensor.numel() > (int64_t)_ubuf.numel()) + NVTE_ERROR("input is larger than the global communication buffer!"); + if (input_tensor.element_size() != (int64_t)_ubuf.element_size()) + NVTE_ERROR("input data type does not match communication buffer!"); + NVTE_CHECK_CUDA(cudaMemcpyAsync(_ubuf.dptr(), input_ptr, + input_tensor.numel() * input_tensor.element_size(), + cudaMemcpyDeviceToDevice, (cudaStream_t)stream_main)); } } -torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) { - char *ubuf_wt_ptr = reinterpret_cast(_ubuf.dptr()); - te::CommOverlapType _comm_type = static_cast(comm_type); - if (_comm_type != te::CommOverlapType::AG && _comm_type != te::CommOverlapType::RS) - NVTE_ERROR("Invalid comm_type"); - if (_comm_type == te::CommOverlapType::RS) - ubuf_wt_ptr += _ubuf.numel() / _tp_size * _self_chunk_id * _ubuf.element_size(); - int output_c_dim0 = - (_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size; - int output_c_dim1 = _ubuf.size(1); - return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options()); +py::object CommOverlapP2P::get_buffer(py::handle quantizer, bool local_chunk, + std::optional> shape) { + using namespace te::pytorch; + char *ubuf_wt_ptr = reinterpret_cast(local_chunk ? _ubufs[_tp_id].dptr() : _ubuf.dptr()); + + std::vector torch_shape; + if (shape.has_value()) { + torch_shape = shape.value(); + auto requested = product(torch_shape); + auto expected = local_chunk ? _ubufs[_tp_id].numel() : _ubuf.numel(); + NVTE_CHECK(requested == expected, "Number of elements in the requested shape (", requested, + ") does not match allocated buffer size (", expected, ")!"); + } else { + int64_t output_c_dim0 = (local_chunk) ? _ubuf.size(0) / _tp_size : _ubuf.size(0); + int64_t output_c_dim1 = _ubuf.size(1); + torch_shape = {output_c_dim0, output_c_dim1}; + } + auto ubuf_tensor = torch::from_blob(reinterpret_cast(ubuf_wt_ptr), torch_shape, + at::dtype(GetATenDType(_ubuf.dtype())).device(torch::kCUDA)); + + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + std::vector te_shape; + for (auto s : torch_shape) te_shape.emplace_back(static_cast(s)); + + // Always output a rowwise-only QuantizedTensor + // TODO (Alp): This needs to produce an un-interleaved transpose when required. + auto is_internal = my_quantizer->internal; + auto uses_columnwise = my_quantizer->columnwise_usage; + my_quantizer->internal = false; + my_quantizer->columnwise_usage = false; + auto [te_tensor, py_tensor] = my_quantizer->create_tensor(te_shape, _ubuf.dtype(), ubuf_tensor); + my_quantizer->internal = is_internal; + my_quantizer->columnwise_usage = uses_columnwise; + return py_tensor; } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 40b96a057f..53fed04735 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -1,77 +1,275 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#include +#include + +#include +#include + +#include "../common.h" +#include "common.h" #include "common/util/cuda_runtime.h" +#include "common/util/system.h" #include "extensions.h" +#include "pybind.h" +#include "transformer_engine/transformer_engine.h" -void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine; - if (A.numel() == 0 || B.numel() == 0) { - if (D.numel() != 0 && !accumulate) D.zero_(); - if (bias.numel() != 0 && grad) { - if (B.numel() == 0) { - bias.zero_(); - } else { - bias.copy_(B.sum(0)); +namespace { + +void* get_data_ptr(MaybeTensor tensor) { + if (tensor.has_value()) return tensor->data_ptr(); + return nullptr; +} + +size_t get_size(MaybeTensor tensor, int dim) { + if (tensor.has_value()) return static_cast(tensor->size(dim)); + return 0; +} + +} // namespace + +namespace transformer_engine::pytorch { + +namespace detail { + +std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, + const NVTEShape& B_shape, const bool transb) { + // Flatten outer dims to get 2D matrices + const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); + const size_t A1 = A_shape.data[A_shape.ndim - 1]; + const size_t B0 = product(B_shape, 0, B_shape.ndim - 1); + const size_t B1 = B_shape.data[B_shape.ndim - 1]; + + // Check matrix dims + NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", + A0, ",", A1, "), transa=", transa, ", B=(", B0, ",", B1, "), transb=", transb, ")"); + + // Construct output dims + std::vector ret; + if (transb) { + ret.emplace_back(B1); + } else { + // Unflatten B0 + for (size_t i = 0; i < B_shape.ndim - 1; ++i) { + ret.emplace_back(B_shape.data[i]); + } + } + if (transa) { + ret.emplace_back(A0); + } else { + ret.emplace_back(A1); + } + return ret; +} + +bool checkGemmShape(const std::vector& expected, const NVTEShape& actual) { + if (expected.size() != actual.ndim) return false; + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != actual.data[i]) return false; + } + return true; +} + +} // namespace detail + +std::pair createOutputTensor(const std::vector& shape, + DType dtype, py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} + +std::vector gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, + py::handle quantizer, std::optional out_dtype, MaybeTensor bias, + DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad, + at::Tensor workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, CommOverlapCore* comm_overlap, + std::optional comm_type, MaybeTensor extra_output, + bool bulk_overlap) { + // Input tensors + NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); + NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); + auto none = py::none(); + TensorWrapper A_tensor = makeTransformerEngineTensor(A, none); + TensorWrapper B_tensor = makeTransformerEngineTensor(B, none); + + // Check tensor dimensions + const auto& A_shape = A_tensor.shape(); + const auto& B_shape = B_tensor.shape(); + const auto& D_shape = detail::getGemmOutputShape(A_shape, transa, B_shape, transb); + NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); + NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); + + // Output tensor + TensorWrapper D_tensor; + if (D.is_none()) { + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); + std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); + } else { + D_tensor = makeTransformerEngineTensor(D, quantizer); + NVTE_CHECK(detail::checkGemmShape(D_shape, D_tensor.shape()), + "GEMM output has invalid dims (expected ", std::to_string(D_shape), ", got ", + std::to_string(D_tensor.shape()), ")"); + if (out_dtype) { + NVTE_CHECK(*out_dtype == D_tensor.dtype(), "GEMM output has invalid dtype (expected ", + static_cast(*out_dtype), ", found ", static_cast(D_tensor.dtype()), ")"); + } + } + + // Bias tensor + TensorWrapper bias_tensor; + MaybeTensor bias_grad = std::nullopt; + if (bias.has_value()) { + if (grad) { + auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); + bias_tensor = makeTransformerEngineTensor(*bias_grad); + } else { + if (!bias->is_contiguous()) { + bias = bias->contiguous(); } + bias_tensor = makeTransformerEngineTensor(*bias); } - if (pre_gelu_out.numel() != 0) pre_gelu_out.zero_(); - return; } - A = A.contiguous(); - B = B.contiguous(); + // Activation input tensor + MaybeTensor pre_gelu_out = std::nullopt; + DType gelu_type = bias_type; + if (gelu) { + if (!grad) { + auto dtype = GetATenDType(gelu_type); + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + std::vector torch_shape; + for (auto v : D_shape) { + torch_shape.push_back(v); + } + pre_gelu_out = at::empty(torch_shape, opts); + } else { + if (gelu_in.has_value()) { + pre_gelu_out = *gelu_in; + } + } + } + const auto gelu_shape = gelu ? D_shape : std::vector{0}; - auto te_A = makeTransformerEngineTensor( - A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr()); - auto te_B = makeTransformerEngineTensor( - B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr()); - auto te_D = makeTransformerEngineTensor( - D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, - D_amax.data_ptr(), D_scale.data_ptr(), nullptr); - auto te_bias = - makeTransformerEngineTensor(bias.data_ptr(), {static_cast(bias.size(0))}, bias_type); + auto te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out), gelu_shape, gelu_type); - const auto gelu_shape = pre_gelu_out.data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out.size(0))} - : std::vector{static_cast(pre_gelu_out.size(0)), - static_cast(pre_gelu_out.size(1))}; - auto te_pre_gelu_out = makeTransformerEngineTensor( - pre_gelu_out.data_ptr(), gelu_shape, GetTransformerEngineDType(pre_gelu_out.scalar_type())); + // Workspace auto te_workspace = makeTransformerEngineTensor(workspace.data_ptr(), {workspaceSize}, DType::kByte); - nvte_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), te_pre_gelu_out.data(), - transa, transb, grad, te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + const int device_id = at::cuda::current_device(); + const int sm_count = transformer_engine::cuda::sm_count(device_id); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + + auto main_stream = at::cuda::getCurrentCUDAStream(); + if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { + if (comm_overlap) { + // Prepare extra output tensor + TensorWrapper extra_output_tensor; + if (extra_output.has_value()) { + extra_output_tensor = makeTransformerEngineTensor(*extra_output); + } else { + extra_output_tensor = + makeTransformerEngineTensor(nullptr, std::vector{0}, DType::kByte); + } + + // Direct GEMM call to the correct overlap + if (bulk_overlap) { + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, comm_type.value(), extra_output_tensor, + main_stream); + } else if (comm_type.value() == CommOverlapType::AG) { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } else { + if (comm_overlap->is_atomic_gemm()) { + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + bias_tensor, te_pre_gelu_out, te_workspace, grad, + accumulate, use_split_accumulator, + extra_output_tensor, main_stream); + } else { + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + te_pre_gelu_out, te_workspace, grad, accumulate, + use_split_accumulator, extra_output_tensor, main_stream); + } + } + } else { + // Launch GEMM + nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), + te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), + accumulate, use_split_accumulator, num_math_sms, main_stream); + } + } else { + if (D_tensor.numel() != 0 && !accumulate) { + D_tensor.zero_(main_stream); + } + if (bias.has_value()) { + if (bias->numel() != 0 && grad) { + bias_grad->zero_(); + } + } + } + + // Pack outputs + std::vector out; + out.emplace_back(std::move(D)); + out.emplace_back(py::cast(bias_grad)); + if (gelu && !grad) { + out.emplace_back(py::cast(*pre_gelu_out)); + } else { + out.emplace_back(py::none()); + } + if (extra_output.has_value()) { + out.emplace_back(py::cast(extra_output)); + } else { + out.emplace_back(py::none()); + } + return out; } +} // namespace transformer_engine::pytorch + void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type, - bool transa, at::Tensor B, at::Tensor B_scale_inverse, - transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias, - transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, - at::Tensor workspace, size_t workspaceSize, bool accumulate, + std::vector A_scaling_mode, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, transformer_engine::DType B_type, + std::vector B_scaling_mode, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out, + bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, at::Tensor counter) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; + + // TODO: Handle scaling modes + NVTEScalingMode nvte_scaling_modeA = NVTE_DELAYED_TENSOR_SCALING; + NVTEScalingMode nvte_scaling_modeB = NVTE_DELAYED_TENSOR_SCALING; + auto te_A = makeTransformerEngineTensor( A.data_ptr(), {static_cast(A.size(0)), static_cast(A.size(1))}, A_type, - nullptr, nullptr, A_scale_inverse.data_ptr()); + nullptr, nullptr, A_scale_inverse.data_ptr(), getTensorShape(A_scale_inverse), + nvte_scaling_modeA); auto te_B = makeTransformerEngineTensor( B.data_ptr(), {static_cast(B.size(0)), static_cast(B.size(1))}, B_type, - nullptr, nullptr, B_scale_inverse.data_ptr()); + nullptr, nullptr, B_scale_inverse.data_ptr(), getTensorShape(B_scale_inverse), + nvte_scaling_modeB); + // TODO: D_scale_inv cannot be nullptr when D_type is FP8. auto te_D = makeTransformerEngineTensor( D.data_ptr(), {static_cast(D.size(0)), static_cast(D.size(1))}, D_type, D_amax.data_ptr(), D_scale.data_ptr(), nullptr); @@ -95,134 +293,121 @@ void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine gemm_producer, te_counter.data(), at::cuda::getCurrentCUDAStream()); } -void te_grouped_gemm(std::vector A, at::Tensor A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, - bool transb, std::vector D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, - std::vector bias, transformer_engine::DType bias_type, - std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { +std::optional> te_general_grouped_gemm( + std::vector A, bool transa, std::vector B, bool transb, + std::optional> D, transformer_engine::DType D_type, + std::vector m_splits, std::vector bias, + transformer_engine::DType bias_type, bool single_output, std::vector pre_gelu_out, + bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, + bool use_split_accumulator, int math_sm_count) { using namespace transformer_engine; - std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; + using namespace transformer_engine::pytorch; + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector, te_workspace_vector; + std::vector wrappers; + std::vector D_vectors; + + auto none = py::none(); + + std::vector single_output_begins; + std::vector single_output_ends; + int slicing_dim; + if (single_output && D == std::nullopt) { + NVTE_ERROR("not implemented, D should be allocated for single output case."); + } + + void* output_data_ptr; + if (single_output) { + output_data_ptr = (*D)[0].data_ptr(); + } + for (size_t i = 0; i < A.size(); i++) { - if (A[i].numel() == 0 || B[i].numel() == 0) { - if (D[i].numel() != 0 && !accumulate) D[i].zero_(); - if (bias[i].numel() != 0 && grad) { - if (B[i].numel() == 0) { - bias[i].zero_(); - } else { - bias[i].copy_(B[i].sum(0)); + auto te_A = makeTransformerEngineTensor(A[i], none); + auto te_B = makeTransformerEngineTensor(B[i], none); + + // if there is single output + at::Tensor out_tensor; + auto size_t_shape = + pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb); + bool D_numel_is_zero = false; + std::vector D_shape; + for (size_t t : size_t_shape) { + D_shape.push_back(t); + if (t == 0) { + D_numel_is_zero = true; + } + } + auto dtype = GetATenDType(D_type); + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + if (single_output) { + if (output_data_ptr == nullptr) { + out_tensor = at::empty(D_shape, opts); + } else { + // We need to check !D_numel_is_zero because if the final input portion has zero elements, + // output_data_ptr would point beyond the allocated memory of D. This would cause + // at::from_blob to fail as it would reference memory not allocated by CUDA. + if (!D_numel_is_zero) { + out_tensor = at::from_blob(output_data_ptr, D_shape, opts); } } + char* char_ptr = reinterpret_cast(output_data_ptr); + char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size(); + output_data_ptr = reinterpret_cast(char_ptr); + D_vectors.emplace_back(out_tensor); + } else { + if (D == std::nullopt) { + auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA); + out_tensor = at::empty(D_shape, opts); + D_vectors.emplace_back(out_tensor); + } else { + out_tensor = (*D)[i]; + } + } + + if (te_A.numel() == 0 || te_B.numel() == 0) { + if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_(); + if (bias[i].numel() != 0 && grad) { + bias[i].zero_(); + } if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_(); continue; } - NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); - NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); - NVTE_CHECK(D[i].is_contiguous(), "D[", i, "] must be contiguous."); - - te_A.emplace_back(make_tensor( - A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, - A_type, nullptr, nullptr, getDataPtr(A_scale_inverse, A_offset + i))); - te_B.emplace_back(make_tensor( - B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, - B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); - te_D.emplace_back(make_tensor( - D[i].data_ptr(), {static_cast(D[i].size(0)), static_cast(D[i].size(1))}, - D_type, getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); - te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, - bias_type, nullptr, nullptr, nullptr)); + auto te_D = makeTransformerEngineTensor(out_tensor); + auto te_bias = makeTransformerEngineTensor(bias[i]); + auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out[i].size(0))} - : std::vector{static_cast(pre_gelu_out[i].size(0)), - static_cast(pre_gelu_out[i].size(1))}; - te_pre_gelu_out.emplace_back(make_tensor( - pre_gelu_out[i].data_ptr(), gelu_shape, - GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); - } - for (size_t i = 0; i < workspace.size(); i++) { - te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, - nullptr, nullptr, nullptr)); - } + ? std::vector{static_cast(te_pre_gelu_out.size(0))} + : std::vector{static_cast(te_pre_gelu_out.size(0)), + static_cast(te_pre_gelu_out.size(1))}; - // For now, we only have multi-stream cublas backend. - nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), - te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, - te_workspace.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); -} + DType gelu_type = bias_type; + te_pre_gelu_out = + makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); -void te_grouped_gemm_single_output( - std::vector A, std::vector A_scale_inverse, int A_offset, - transformer_engine::DType A_type, bool transa, std::vector B, - at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb, - std::vector m_splits, at::Tensor D, int D_offset, at::Tensor D_scale, - transformer_engine::DType D_type, at::Tensor D_amax, std::vector bias, - transformer_engine::DType bias_type, std::vector pre_gelu_out, bool grad, - std::vector workspace, size_t workspaceSize, bool accumulate, - bool use_split_accumulator, int math_sm_count) { - using namespace transformer_engine; - std::vector te_A, te_B, te_D, te_bias, te_pre_gelu_out, te_workspace; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; - NVTE_CHECK(D.is_contiguous(), "D must be contiguous."); - void* d_i_ptr = reinterpret_cast(D.data_ptr()); - for (size_t i = 0; i < A.size(); i++) { - if (m_splits[i] == 0) continue; - NVTE_CHECK(A[i].data_ptr() != nullptr, "A[", i, "] must not be nullptr."); - NVTE_CHECK(B[i].data_ptr() != nullptr, "B[", i, "] must not be nullptr."); - NVTE_CHECK(A[i].is_contiguous(), "A[", i, "] must be contiguous."); - NVTE_CHECK(B[i].is_contiguous(), "B[", i, "] must be contiguous."); - te_A.emplace_back(make_tensor( - A[i].data_ptr(), {static_cast(A[i].size(0)), static_cast(A[i].size(1))}, - A_type, nullptr, nullptr, getDataPtr(A_scale_inverse[i], A_offset))); - te_B.emplace_back(make_tensor( - B[i].data_ptr(), {static_cast(B[i].size(0)), static_cast(B[i].size(1))}, - B_type, nullptr, nullptr, getDataPtr(B_scale_inverse, B_offset + i))); - te_D.emplace_back(make_tensor( - d_i_ptr, {static_cast(m_splits[i]), static_cast(A[i].size(0))}, D_type, - getDataPtr(D_amax, D_offset + i), getDataPtr(D_scale, D_offset + i), nullptr)); - te_bias.emplace_back(make_tensor(bias[i].data_ptr(), {static_cast(bias[i].size(0))}, - bias_type, nullptr, nullptr, nullptr)); + te_A_vector.emplace_back(te_A.data()); + te_B_vector.emplace_back(te_B.data()); + te_D_vector.emplace_back(te_D.data()); + te_bias_vector.emplace_back(te_bias.data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr - ? std::vector{static_cast(pre_gelu_out[i].size(0))} - : std::vector{static_cast(pre_gelu_out[i].size(0)), - static_cast(pre_gelu_out[i].size(1))}; - te_pre_gelu_out.emplace_back(make_tensor( - pre_gelu_out[i].data_ptr(), gelu_shape, - GetTransformerEngineDType(pre_gelu_out[i].scalar_type()), nullptr, nullptr, nullptr)); - // Move the D pointer to the next split. - char* char_ptr = reinterpret_cast(d_i_ptr); - char_ptr += m_splits[i] * A[i].size(0) * D.element_size(); - d_i_ptr = reinterpret_cast(char_ptr); + wrappers.emplace_back(std::move(te_A)); + wrappers.emplace_back(std::move(te_B)); + wrappers.emplace_back(std::move(te_D)); + wrappers.emplace_back(std::move(te_bias)); + wrappers.emplace_back(std::move(te_pre_gelu_out)); } for (size_t i = 0; i < workspace.size(); i++) { - te_workspace.emplace_back(make_tensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte, - nullptr, nullptr, nullptr)); + auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte); + te_workspace_vector.emplace_back(wsp.data()); + wrappers.emplace_back(std::move(wsp)); } - // For now, we only have multi-stream cublas backend. - nvte_multi_stream_cublas_gemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(), - te_pre_gelu_out.data(), te_A.size(), transa, transb, grad, - te_workspace.data(), accumulate, use_split_accumulator, + nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), + te_A_vector.size(), transa, transb, grad, + te_workspace_vector.data(), accumulate, use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); + return bias; } diff --git a/transformer_engine/pytorch/csrc/extensions/misc.cpp b/transformer_engine/pytorch/csrc/extensions/misc.cpp index 8200942643..9785602998 100644 --- a/transformer_engine/pytorch/csrc/extensions/misc.cpp +++ b/transformer_engine/pytorch/csrc/extensions/misc.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu index 7d49a0848b..548dd5a267 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_adam.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -179,6 +179,122 @@ struct AdamFunctorMaster { } }; +template +struct AdamFunctorMasterParamRemainder { + __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, + TensorListMetadata<5> &tl, // NOLINT(*) + const float beta1, const float beta2, + const float beta1_correction, + const float beta2_correction, const float epsilon, + const float lr, adamMode_t mode, const float decay) { + index_t tensor_loc = tl.block_to_tensor[blockIdx.x]; + + index_t chunk_idx = tl.block_to_chunk[blockIdx.x]; + index_t n = tl.sizes[tensor_loc]; + + GRAD_T *g = reinterpret_cast(tl.addresses[0][tensor_loc]); + g += chunk_idx * chunk_size; + + int16_t *p = reinterpret_cast(tl.addresses[1][tensor_loc]); + p += chunk_idx * chunk_size; + + FULL_T *m = reinterpret_cast(tl.addresses[2][tensor_loc]); + m += chunk_idx * chunk_size; + + FULL_T *v = reinterpret_cast(tl.addresses[3][tensor_loc]); + v += chunk_idx * chunk_size; + + int16_t *p_remainder = reinterpret_cast(tl.addresses[4][tensor_loc]); + p_remainder += chunk_idx * chunk_size; + + n -= chunk_idx * chunk_size; + + // see note in multi_tensor_scale_kernel.cu + for (index_t i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) { + union fp32_or_int162 { + float fp32; + int16_t int16[2]; + }; + fp32_or_int162 local_master_param[ILP]; + int16_t local_p[ILP]; + int16_t local_p_rem[ILP]; + MATH_T r_g[ILP]; + MATH_T r_m[ILP]; + MATH_T r_v[ILP]; +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + r_g[ii] = static_cast(g[i]); + r_m[ii] = static_cast(m[i]); + r_v[ii] = static_cast(v[i]); + + local_p[ii] = static_cast(p[i]); + local_p_rem[ii] = static_cast(p_remainder[i]); + } else { + r_g[ii] = MATH_T(0); + r_m[ii] = MATH_T(0); + r_v[ii] = MATH_T(0); + + local_p[ii] = int16_t(0); + local_p_rem[ii] = int16_t(0); + } + } +// Reconstruct FP32 params +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (local_p_rem[ii] < 0) local_p[ii]--; // Undo rounding + local_master_param[ii].int16[1] = local_p[ii]; + local_master_param[ii].int16[0] = local_p_rem[ii]; + } + + MATH_T *r_p = reinterpret_cast(local_master_param); + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + if (mode == ADAM_MODE_0) { // L2 + r_g[ii] = r_g[ii] + (decay * r_p[ii]); + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = next_m_unbiased / denom; + r_p[ii] = r_p[ii] - (lr * update); + } else { // weight decay + r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii]; + r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii]; + MATH_T next_m_unbiased = r_m[ii] / beta1_correction; + MATH_T next_v_unbiased = r_v[ii] / beta2_correction; + MATH_T denom = sqrtf(next_v_unbiased) + epsilon; + MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]); + r_p[ii] = r_p[ii] - (lr * update); + } + } + +// Split into BF16 params (rounded-to-nearest) and remainders +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + local_p[ii] = local_master_param[ii].int16[1]; + local_p_rem[ii] = local_master_param[ii].int16[0]; + if (local_p_rem[ii] < 0) local_p[ii]++; // Round up + } + +#pragma unroll + for (int ii = 0; ii < ILP; ii++) { + int i = i_start + threadIdx.x + ii * blockDim.x; + if (i < n && i < chunk_size) { + p_remainder[i] = static_cast(local_p_rem[ii]); + p[i] = static_cast(local_p[ii]); + + m[i] = static_cast(r_m[ii]); + v[i] = static_cast(r_v[ii]); + } + } + } + } +}; + template struct AdamFunctor { __device__ __forceinline__ void operator()(index_t chunk_size, volatile int *noop_gmem, @@ -548,6 +664,42 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, AT_CUDA_CHECK(cudaGetLastError()); } +void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, + std::vector> tensor_lists, + const float lr, const float beta1, const float beta2, + const float epsilon, const int step, const int mode, + const int bias_correction, const float weight_decay) { + using namespace at; + + // Handle bias correction mode + float bias_correction1 = 1.0f, bias_correction2 = 1.0f; + if (bias_correction == 1) { + bias_correction1 = 1 - std::pow(beta1, step); + bias_correction2 = 1 - std::pow(beta2, step); + } + + const auto g_in_type = tensor_lists[0][0].scalar_type(); + const auto p_in_type = tensor_lists[1][0].scalar_type(); + auto tl_size = tensor_lists.size(); + + // case 5: g, p, m, v, p_master + TORCH_CHECK(tl_size == 5, "tensor list must contain 5"); + TORCH_CHECK(p_in_type == at::ScalarType::BFloat16, + "Adam with BF16 param remainders requires BF16 params"); + + // g, p, m, v, p_master + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + p_in_type, 0, "adam", + DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT( + g_in_type, 1, "adam", + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, + AdamFunctorMasterParamRemainder(), + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); + + AT_CUDA_CHECK(cudaGetLastError()); +} + void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu index 3626bce9c2..8bc03ae375 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_l2norm_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu index bd673b7d6e..d5d55c2872 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_scale_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu index 3009a82768..5ea5c1d3d1 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_sgd_kernel.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 04274ae2ef..bb011faf98 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -1,15 +1,35 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#include "common/util/system.h" #include "extensions.h" -std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, +namespace transformer_engine::pytorch { +std::pair createOutputTensor(const NVTEShape &shape, DType dtype, + py::handle quantizer) { + std::vector shape_vec; + for (int i = 0; i < shape.ndim; i++) { + size_t t = shape.data[i]; + shape_vec.push_back(t); + } + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape_vec, dtype); +} +std::pair createOutputTensor(std::vector &shape, DType dtype, + py::handle quantizer) { + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + return my_quantizer->create_tensor(shape, dtype); +} +} // namespace transformer_engine::pytorch + +std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &mu, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; const auto &dz_ = dz.contiguous(); const auto &x_ = x.contiguous(); const auto &mu_ = mu.contiguous(); @@ -19,7 +39,7 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dx = at::empty_like(x_); auto dgamma = at::empty_like(gamma_); auto dbeta = at::empty_like(gamma_); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part, dbeta_part; + transformer_engine::TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); auto x_cu = makeTransformerEngineTensor(x_); @@ -31,162 +51,111 @@ std::vector layernorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dbeta_cu = makeTransformerEngineTensor(dbeta); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_layernorm1p_bwd : nvte_layernorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Alloc space for Tensors. auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); - auto dbeta_part_data = allocateSpace(dbeta_part.shape(), dbeta_part.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), - dgamma_part.dtype()); - dbeta_part = makeTransformerEngineTensor(dbeta_part_data.data_ptr(), dbeta_part.shape(), - dbeta_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dbeta_cu.data(), dgamma_part.data(), dbeta_part.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); - - return {dx, dgamma, dbeta}; -} - -std::vector layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset) { - using namespace transformer_engine; + nvte_layernorm_bwd(dz_cu.data(), x_cu.data(), mu_cu.data(), rsigma_cu.data(), gamma_cu.data(), + dx_cu.data(), dgamma_cu.data(), dbeta_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - const auto &input_ = input.contiguous(); - - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); - return layernorm_fwd_fp8_noalloc(input_, weight, bias, eps, scale, ln_out, amax, scale_inv, otype, - sm_margin, zero_centered_gamma, scale_offset, amax_offset, - scale_inv_offset); + return {py::cast(dx), py::cast(dgamma), py::cast(dbeta)}; } -std::vector layernorm_fwd_fp8_noalloc( - const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps, - at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, const int scale_inv_offset) { +std::vector layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias, + float eps, py::object out, py::handle quantizer, + DType out_dtype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; using namespace transformer_engine; - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - const auto &bias_ = bias.contiguous(); - - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd; + // Input and param tensors + auto none = py::none(); + const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); + TensorWrapper bias_cu; + if (bias.has_value()) { + bias_cu = makeTransformerEngineTensor(*bias); + } // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void *scale_dptr = getDataPtr(scale, scale_offset); - void *amax_dptr = getDataPtr(amax, amax_offset); - void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input_); - auto gamma_cu = makeTransformerEngineTensor(weight_); - auto beta_cu = makeTransformerEngineTensor(bias_); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, - scale_inv_dptr); - auto mu_cu = makeTransformerEngineTensor(mu); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); - - // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); - - // Allocate workspaces + const size_t N = static_cast(input_cu.size(0)); + const size_t H = static_cast(input_cu.size(1)); + const std::vector size = {N, H}; + + // Tensors to save for backward pass + at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); + TensorWrapper mu_cu = makeTransformerEngineTensor(mu); + TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); + + // Output tensor + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + TensorWrapper out_cu; + if (out.is_none()) { + std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + } else { + out_cu = makeTransformerEngineTensor(out, quantizer); + } + + // Determine whether to avoid fused kernel + bool force_unfused_kernel = false; + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + // TE only supports MXFP8 norm with cuDNN backend + force_unfused_kernel = true; + } else if (N % 128 != 0 || H % 128 != 0) { + // cuDNN norm requires full tile for MXFP8 + force_unfused_kernel = true; + } + } + TensorWrapper unquantized_out_cu; + if (force_unfused_kernel) { + NoneQuantizer q{none}; + py::object unquantized_out; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } + TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; + + // Query workspace size + transformer_engine::TensorWrapper workspace; + nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + + // Allocate workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); // Launch kernel - func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(), mu_cu.data(), - rsigma_cu.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); - - return {ln_out, mu, rsigma}; -} - -at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset - -) { - // This is a specialized version of layernorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = - layernorm_fwd_fp8(input, weight, bias, eps, scale, amax, scale_inv, otype, sm_margin, - zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); - return out[0]; -} - -std::vector layernorm_fwd(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - const auto &input_ = input.contiguous(); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - - return layernorm_fwd_noalloc(input_, weight, bias, ln_out, eps, sm_margin, zero_centered_gamma); -} - -std::vector layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, at::Tensor ln_out, float eps, - const int sm_margin, const bool zero_centered_gamma) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - return layernorm_fwd_fp8_noalloc(input, weight, bias, eps, at::Tensor(), ln_out, at::Tensor(), - at::Tensor(), itype, sm_margin, zero_centered_gamma); -} - -at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, float eps, const int sm_margin, - const bool zero_centered_gamma) { - // This is a specialized version of layernorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = - layernorm_fwd(input, weight, bias, eps, sm_margin, zero_centered_gamma); - return out[0]; + nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), + mu_cu.data(), rsigma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + + // Quantize output if using unfused kernel + if (force_unfused_kernel) { + nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, + at::cuda::getCurrentCUDAStream()); + } + + return {out, py::cast(mu), py::cast(rsigma)}; } -std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, +std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; const auto &dz_ = dz.contiguous(); const auto &x_ = x.contiguous(); const auto &rsigma_ = rsigma.contiguous(); @@ -194,7 +163,7 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dx = at::empty_like(x_); auto dgamma = at::empty_like(gamma_); - transformer_engine::TensorWrapper workspace, barrier, dgamma_part; + transformer_engine::TensorWrapper workspace; auto dz_cu = makeTransformerEngineTensor(dz_); auto x_cu = makeTransformerEngineTensor(x_); @@ -204,142 +173,97 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, auto dgamma_cu = makeTransformerEngineTensor(dgamma); // This call populates tensors with the required config. - const auto bwd_fun = zero_centered_gamma ? nvte_rmsnorm1p_bwd : nvte_rmsnorm_bwd; - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); // Alloc space for Tensors. auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); - auto dgamma_part_data = allocateSpace(dgamma_part.shape(), dgamma_part.dtype()); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); - dgamma_part = makeTransformerEngineTensor(dgamma_part_data.data_ptr(), dgamma_part.shape(), - dgamma_part.dtype()); // Actual call to bwd kernel. - bwd_fun(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), - dgamma_cu.data(), dgamma_part.data(), at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); - - return {dx, dgamma}; -} - -std::vector rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { - using namespace transformer_engine; - - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); + nvte_rmsnorm_bwd(dz_cu.data(), x_cu.data(), rsigma_cu.data(), gamma_cu.data(), dx_cu.data(), + dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(otype))); - return rmsnorm_fwd_fp8_noalloc(input_, weight_, eps, scale, ln_out, amax, scale_inv, otype, - sm_margin, zero_centered_gamma, scale_offset, amax_offset, - scale_inv_offset); + return {py::cast(dx), py::cast(dgamma)}; } -std::vector rmsnorm_fwd_fp8_noalloc(const at::Tensor &input, const at::Tensor &weight, - float eps, at::Tensor scale, at::Tensor ln_out, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - const int sm_margin, const bool zero_centered_gamma, - const int scale_offset, const int amax_offset, - const int scale_inv_offset) { +std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, + py::object out, py::handle quantizer, + transformer_engine::DType out_dtype, const int sm_margin, + const bool zero_centered_gamma) { + using namespace transformer_engine::pytorch; using namespace transformer_engine; - // Choose kernel implementation - const auto func = zero_centered_gamma ? nvte_rmsnorm1p_fwd : nvte_rmsnorm_fwd; + // Input and param tensors + auto none = py::none(); + const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); // Tensor dimensions - size_t N = static_cast(input.size(0)); - size_t H = static_cast(input.size(1)); + const size_t N = static_cast(input_cu.shape().data[0]); + const size_t H = static_cast(input_cu.shape().data[1]); + const std::vector size = {N, H}; - // Get pointers for FP8 scale, amax, scale-inverse - void *scale_dptr = getDataPtr(scale, scale_offset); - void *amax_dptr = getDataPtr(amax, amax_offset); - void *scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - DType itype = GetTransformerEngineDType(input.scalar_type()); + // Tensors to save for backward pass auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto input_cu = makeTransformerEngineTensor(input); - auto gamma_cu = makeTransformerEngineTensor(weight); - auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype, amax_dptr, scale_dptr, - scale_inv_dptr); auto rsigma_cu = makeTransformerEngineTensor(rsigma); - // Query workspace sizes - transformer_engine::TensorWrapper workspace, barrier; - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); - - // Allocate workspaces + // Output tensor + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + TensorWrapper out_cu; + if (out.is_none()) { + std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + } else { + out_cu = makeTransformerEngineTensor(out, quantizer); + } + + // Determine whether to avoid fused kernel + bool force_unfused_kernel = false; + if (my_quantizer->get_scaling_mode() == NVTE_MXFP8_1D_SCALING) { + if (!transformer_engine::getenv("NVTE_CUDNN_MXFP8_NORM", false)) { + // TE only supports MXFP8 norm with cuDNN backend + force_unfused_kernel = true; + } else if (N % 128 != 0 || H % 128 != 0) { + // cuDNN norm requires full tile for MXFP8 + force_unfused_kernel = true; + } + } + TensorWrapper unquantized_out_cu; + if (force_unfused_kernel) { + NoneQuantizer q{none}; + py::object unquantized_out; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } + TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; + + // Query workspace size + transformer_engine::TensorWrapper workspace; + nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), + workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + + // Allocate workspace auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - auto barrier_data = allocateSpace(barrier.shape(), barrier.dtype(), true); workspace = makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - barrier = makeTransformerEngineTensor(barrier_data.data_ptr(), barrier.shape(), barrier.dtype()); // Launch kernel - func(input_cu.data(), gamma_cu.data(), eps, z_cu.data(), rsigma_cu.data(), - at::cuda::getCurrentCUDAStream(), - at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, workspace.data(), - barrier.data()); - - return {ln_out, rsigma}; -} - -at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, const int sm_margin, - const bool zero_centered_gamma, const int scale_offset, - const int amax_offset, const int scale_inv_offset) { - // This is a specialized version of rmsnorm_fwd_fp8, optimized for inference, - // which only returns the normalized output. - std::vector out = - rmsnorm_fwd_fp8(input, weight, eps, scale, amax, scale_inv, otype, sm_margin, - zero_centered_gamma, scale_offset, amax_offset, scale_inv_offset); - return out[0]; -} - -std::vector rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma) { - using namespace transformer_engine; - - const auto &input_ = input.contiguous(); - const auto &weight_ = weight.contiguous(); - - DType itype = GetTransformerEngineDType(input.scalar_type()); - auto ln_out = at::empty_like(input_, at::CUDA(GetATenDType(itype))); - - return rmsnorm_fwd_noalloc(input_, weight_, ln_out, eps, sm_margin, zero_centered_gamma); -} - -std::vector rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight, - at::Tensor ln_out, float eps, const int sm_margin, - const bool zero_centered_gamma) { - using namespace transformer_engine; - - DType itype = GetTransformerEngineDType(input.scalar_type()); - - return rmsnorm_fwd_fp8_noalloc(input, weight, eps, at::Tensor(), ln_out, at::Tensor(), - at::Tensor(), itype, sm_margin, zero_centered_gamma); -} - -at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps, - const int sm_margin, const bool zero_centered_gamma) { - // This is a specialized version of rmsnorm_fwd, optimized for inference, - // which only returns the normalized output. - std::vector out = rmsnorm_fwd(input, weight, eps, sm_margin, zero_centered_gamma); - return out[0]; + nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), + workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + + // Quantize output if using unfused kernel + if (force_unfused_kernel) { + nvte_quantize_noop(unquantized_out_cu.data(), out_cu.data(), nullptr, + at::cuda::getCurrentCUDAStream()); + } + + return {out, py::none(), py::cast(rsigma)}; } diff --git a/transformer_engine/pytorch/csrc/extensions/padding.cpp b/transformer_engine/pytorch/csrc/extensions/padding.cpp index d975ebeeef..b9972af7cb 100644 --- a/transformer_engine/pytorch/csrc/extensions/padding.cpp +++ b/transformer_engine/pytorch/csrc/extensions/padding.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -10,6 +10,7 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output, std::vector input_row_list, std::vector padded_input_row_list) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; NVTE_CHECK(input_row_list.size() == padded_input_row_list.size(), "Number of input row list and padded row list must match."); diff --git a/transformer_engine/pytorch/csrc/extensions/permutation.cu b/transformer_engine/pytorch/csrc/extensions/permutation.cu index 0c9bed45e0..47282da504 100644 --- a/transformer_engine/pytorch/csrc/extensions/permutation.cu +++ b/transformer_engine/pytorch/csrc/extensions/permutation.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -11,6 +11,7 @@ std::tuple> moe_permute_fwd( at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices, int64_t num_out_tokens, std::vector workspace, int64_t max_expanded_token_num) { + using namespace transformer_engine::pytorch; const int num_tokens = input.size(0); int num_cols = input.size(1); const int topK = indices.size(1); @@ -96,6 +97,7 @@ at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dty at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens, int64_t topK) { + using namespace transformer_engine::pytorch; int num_cols = input.size(1); // Activations type @@ -129,6 +131,7 @@ at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType d std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd, const transformer_engine::DType dtype, at::Tensor row_id_map, at::Tensor prob) { + using namespace transformer_engine::pytorch; const int topK = (prob.numel() > 0) ? prob.size(1) : 1; const int num_tokens = (prob.numel() > 0) ? prob.size(0) : row_id_map.size(0); int num_cols = input_bwd.size(1); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 39679ed669..a58fd3a6a4 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -1,17 +1,137 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#include "pybind.h" + +#include +#include +#include +#include #include #include +#include + +#include "../common.h" #include "../extensions.h" +#include "common.h" + +namespace transformer_engine::pytorch { + +PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *Float8TensorBasePythonClass = nullptr; +PyTypeObject *Float8QuantizerClass = nullptr; +PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; +PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove +PyTypeObject *MXFP8TensorBasePythonClass = nullptr; +PyTypeObject *MXFP8QuantizerClass = nullptr; + +void init_float8_extension() { + if (Float8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); + Float8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); + Float8CurrentScalingQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8CurrentScalingQuantizer")); + Float8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); + Float8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); + NVTE_CHECK(Float8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch Float8 extension."); +} + +void init_mxfp8_extension() { + if (MXFP8TensorPythonClass) return; + auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); + MXFP8QuantizerClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); + MXFP8TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); + auto fp8_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); + MXFP8TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); + NVTE_CHECK(MXFP8TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch MXFP8 extension."); +} + +void init_extension() { + init_float8_extension(); + init_mxfp8_extension(); +} + +} // namespace transformer_engine::pytorch + #include "common/util/pybind_helper.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { NVTE_DECLARE_COMMON_PYBIND11_HANDLES(m) + m.def("quantize", transformer_engine::pytorch::quantize, py::arg("tensor"), py::arg("quantizer"), + py::arg("output") = py::none(), py::arg("noop") = py::none()); + m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), + py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, + "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); + m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", + py::arg("A"), py::arg("transA"), py::arg("B"), py::arg("transB"), py::arg("D"), + py::arg("quantizer"), py::arg("output_dtype"), py::arg("bias"), py::arg("bias_type"), + py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), + py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), + py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, + py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); + m.def("rowwise_swizzle", &rowwise_swizzle, "Swizzle rowwise scale inverses.", + py::call_guard()); + m.def("columnwise_swizzle", &columnwise_swizzle, "Swizzle columnwise scale inverses.", + py::call_guard()); + m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + py::arg("quantizer")); + m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_drelu", transformer_engine::pytorch::dbias_drelu, "DReLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dqgelu", transformer_engine::pytorch::dbias_dqgelu, "DQGeLU + DBias + Quantize", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dbias_dsrelu", transformer_engine::pytorch::dbias_dsrelu, + "DSquaredReLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer")); // Permutation functions m.def("moe_permute_fwd", moe_permute_fwd); @@ -42,114 +162,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Other granular functions - m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("bias"), py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("sm_margin"), - py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, - py::arg("scale_inv_offset") = 0); - m.def("layernorm_bwd", &layernorm_bwd, "LN BWD", py::call_guard()); - m.def("layernorm_fwd", &layernorm_fwd, "LN FWD", py::call_guard()); - m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD", - py::call_guard()); - m.def("rmsnorm_fwd_fp8", &rmsnorm_fwd_fp8, "RMSNorm FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("eps"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("sm_margin"), py::arg("zero_centered_gamma"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("rmsnorm_fwd_fp8_noalloc", &rmsnorm_fwd_fp8_noalloc, "RMSNorm FWD FP8", - py::call_guard(), py::arg("input"), py::arg("weight"), - py::arg("eps"), py::arg("scale"), py::arg("ln_out"), py::arg("amax"), py::arg("scale_inv"), - py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("rmsnorm_bwd", &rmsnorm_bwd, "RMSNorm BWD", py::call_guard()); - m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm FWD", py::call_guard()); - m.def("rmsnorm_fwd_noalloc", &rmsnorm_fwd_noalloc, "RMSNorm FWD", - py::call_guard()); - m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose", - py::call_guard()); - m.def("fused_cast_transpose_noop", &fused_cast_transpose_noop, - "Cast + Transpose with noop option", py::call_guard(), - py::arg("input"), py::arg("noop"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), - py::arg("input_cast"), py::arg("input_transpose"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, "Fused Cast + Transpose + BGRAD", - py::call_guard(), py::arg("grad_output"), py::arg("scale"), - py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_fp8_transpose_bgrad", &fused_fp8_transpose_bgrad, "Fused FP8 Transpose + BGRAD", - py::call_guard(), py::arg("grad_output"), py::arg("scale"), - py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), py::arg("grad_bias_type"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_cast_transpose_bgrad_dgelu", &fused_cast_transpose_bgrad_dgelu, - "Fused Cast + Transpose + BGRAD + DGELU", py::call_guard(), - py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"), - py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0, - py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, - "Fused Multi-tensor Cast + Transpose", py::call_guard()); - m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc, - "Fused Multi-tensor Cast + Transpose with allocating output tensors", - py::call_guard()); - m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8", py::call_guard(), - py::arg("input"), py::arg("scale"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8", - py::call_guard(), py::arg("input"), py::arg("scale"), - py::arg("output"), py::arg("amax"), py::arg("scale_inv"), py::arg("otype"), - py::arg("scale_offset") = 0, py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0); - m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8", py::call_guard(), - py::arg("input"), py::arg("scale_inv"), py::arg("itype"), py::arg("otype"), - py::arg("scale_inv_offset") = 0); - m.def("te_gemm", &te_gemm, "CublasLt GEMM"); /// TODO Think - m.def("te_grouped_gemm", &te_grouped_gemm, "Grouped GEMM"); - m.def("fused_attn_fwd_qkvpacked", &fused_attn_fwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed QKV", - py::call_guard()); - m.def("fused_attn_bwd_qkvpacked", &fused_attn_bwd_qkvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed QKV", - py::call_guard()); - m.def("fused_attn_fwd_kvpacked", &fused_attn_fwd_kvpacked, - "Fused Attention FP8/BF16/FP16 FWD with packed KV", - py::call_guard()); - m.def("fused_attn_bwd_kvpacked", &fused_attn_bwd_kvpacked, - "Fused Attention FP8/BF16/FP16 BWD with packed KV", - py::call_guard()); - m.def("fused_attn_fwd", &fused_attn_fwd, - "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V", - py::call_guard()); - m.def("fused_attn_bwd", &fused_attn_bwd, - "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V", - py::call_guard()); - m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", - py::call_guard()); - m.def("fp8_transpose_noalloc", &fp8_transpose_noalloc, "Transpose with FP8 I/O", - py::call_guard()); - m.def("fp8_transpose_noalloc_noop", &fp8_transpose_noalloc_noop, - "Transpose with FP8 I/O with noop option.", py::call_guard()); - m.def("gelu", &gelu, "GeLU with FP8 output", py::call_guard()); - m.def("relu", &relu, "ReLU with FP8 output", py::call_guard()); - m.def("geglu", &geglu, "GeGLU with FP8 output", py::call_guard()); - m.def("reglu", ®lu, "ReGLU with FP8 output", py::call_guard()); - m.def("swiglu", &swiglu, "SwiGLU with FP8 output", py::call_guard()); - m.def("qgelu", &qgelu, "QuickGELU with FP8 output", py::call_guard()); - m.def("srelu", &srelu, "Squared ReLU with FP8 output", py::call_guard()); - m.def("dgelu", &dgelu, "Backward of GeLU", py::call_guard()); - m.def("drelu", &drelu, "Backward of ReLU", py::call_guard()); - m.def("dgeglu", &dgeglu, "Backward of GeGLU", py::call_guard()); - m.def("dreglu", &dreglu, "Backward of ReGLU", py::call_guard()); - m.def("dswiglu", &dswiglu, "Backward of SwiGLU", py::call_guard()); - m.def("dqgelu", &dqgelu, "Backward of QuickGELU", py::call_guard()); - m.def("dsrelu", &dsrelu, "Backward of Squared ReLU", py::call_guard()); - m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", - py::call_guard()); - m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", - py::call_guard()); + m.def("layernorm_fwd", &layernorm_fwd, "LayerNorm", py::arg("input"), py::arg("weight"), + py::arg("bias"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), + py::arg("sm_margin"), py::arg("zero_centered_gamma")); + m.def("layernorm_bwd", &layernorm_bwd, "Backward of LayerNorm"); + m.def("rmsnorm_fwd", &rmsnorm_fwd, "RMSNorm", py::arg("input"), py::arg("weight"), py::arg("eps"), + py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), + py::arg("zero_centered_gamma")); + m.def("rmsnorm_bwd", &rmsnorm_bwd, "Backward of RMSNorm"); + m.def("fused_multi_quantize", &fused_multi_quantize, "Fused Multi-tensor Cast + Transpose", + py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype")); + + m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM"); + m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), + py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction, @@ -157,6 +183,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); m.def("fused_multi_row_padding", &fused_multi_row_padding, "Fused Multi-tensor padding", py::call_guard()); + + // attention kernels + m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention", + py::call_guard()); + m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention", + py::call_guard()); + m.def("fused_attn_fwd", &fused_attn_fwd, + "Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V"); + m.def("fused_attn_bwd", &fused_attn_bwd, + "Fused Attention FP8/BF16/FP16 BWD with separate Q, K and V"); + m.def("copy_to_kv_cache", ©_to_kv_cache, "Copy new KV tokens to KV cache"); + m.def("convert_thd_to_bshd", &convert_thd_to_bshd, "Convert a tensor from THD to BSHD"); + m.def("convert_bshd_to_thd", &convert_bshd_to_thd, "Convert a tesnor from BSHD to THD"); + // fused apply rope m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD", py::call_guard()); @@ -207,6 +247,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_adam", &multi_tensor_adam_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); + m.def("multi_tensor_adam_param_remainder", &multi_tensor_adam_param_remainder_cuda, + "Compute and apply gradient update to parameters for Adam optimizer" + "where the master parameters only store the remainder bits", + py::call_guard()); m.def("multi_tensor_adam_fp8", &multi_tensor_adam_fp8_cuda, "Compute and apply gradient update to parameters for Adam optimizer", py::call_guard()); @@ -223,86 +267,68 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::call_guard()); // Data structures - py::class_(m, "FP8TensorMeta") + py::class_(m, "FP8TensorMeta") .def(py::init<>()) - .def_readwrite("scale", &transformer_engine::FP8TensorMeta::scale) - .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) - .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); + .def_readwrite("scale", &transformer_engine::pytorch::FP8TensorMeta::scale) + .def_readwrite("scale_inv", &transformer_engine::pytorch::FP8TensorMeta::scale_inv) + .def_readwrite("amax_history", &transformer_engine::pytorch::FP8TensorMeta::amax_history); - py::enum_(m, "FP8FwdTensors") - .value("GEMM1_INPUT", transformer_engine::FP8FwdTensors::GEMM1_INPUT) - .value("GEMM1_WEIGHT", transformer_engine::FP8FwdTensors::GEMM1_WEIGHT) - .value("GEMM1_OUTPUT", transformer_engine::FP8FwdTensors::GEMM1_OUTPUT) - .value("GEMM2_INPUT", transformer_engine::FP8FwdTensors::GEMM2_INPUT) - .value("GEMM2_WEIGHT", transformer_engine::FP8FwdTensors::GEMM2_WEIGHT) - .value("GEMM2_OUTPUT", transformer_engine::FP8FwdTensors::GEMM2_OUTPUT) - .value("GEMM3_INPUT", transformer_engine::FP8FwdTensors::GEMM3_INPUT) - .value("GEMM3_WEIGHT", transformer_engine::FP8FwdTensors::GEMM3_WEIGHT) - .value("GEMM3_OUTPUT", transformer_engine::FP8FwdTensors::GEMM3_OUTPUT); + py::enum_(m, "FP8FwdTensors") + .value("GEMM1_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_INPUT) + .value("GEMM1_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_WEIGHT) + .value("GEMM1_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM1_OUTPUT) + .value("GEMM2_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_INPUT) + .value("GEMM2_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_WEIGHT) + .value("GEMM2_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM2_OUTPUT) + .value("GEMM3_INPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_INPUT) + .value("GEMM3_WEIGHT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_WEIGHT) + .value("GEMM3_OUTPUT", transformer_engine::pytorch::FP8FwdTensors::GEMM3_OUTPUT); - py::enum_(m, "FP8BwdTensors") - .value("GRAD_OUTPUT1", transformer_engine::FP8BwdTensors::GRAD_OUTPUT1) - .value("GRAD_INPUT1", transformer_engine::FP8BwdTensors::GRAD_INPUT1) - .value("GRAD_OUTPUT2", transformer_engine::FP8BwdTensors::GRAD_OUTPUT2) - .value("GRAD_INPUT2", transformer_engine::FP8BwdTensors::GRAD_INPUT2) - .value("GRAD_OUTPUT3", transformer_engine::FP8BwdTensors::GRAD_OUTPUT3) - .value("GRAD_INPUT3", transformer_engine::FP8BwdTensors::GRAD_INPUT3); + py::enum_(m, "FP8BwdTensors") + .value("GRAD_OUTPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT1) + .value("GRAD_INPUT1", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT1) + .value("GRAD_OUTPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT2) + .value("GRAD_INPUT2", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT2) + .value("GRAD_OUTPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_OUTPUT3) + .value("GRAD_INPUT3", transformer_engine::pytorch::FP8BwdTensors::GRAD_INPUT3); py::class_(m, "CommOverlapHelper") .def(py::init<>(), py::call_guard()) - .def(py::init, - std::optional>(), + .def(py::init>(), py::call_guard(), py::arg("world_group"), - py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); + py::arg("intra_node_group") = py::none()); - py::class_(m, "CommOverlap") + py::class_, transformer_engine::CommOverlapBase, + transformer_engine::CommOverlapCore>(m, "CommOverlap") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, int, int, - int, int, bool, bool>(), + int, int, int, int, bool, bool, bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("num_splits") = 3, py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, - py::arg("comm_cga_size") = 2, py::arg("num_comm_sm") = 16, - py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false) - .def("bulk_overlap", &CommOverlap::bulk_overlap, py::call_guard()) - .def("split_overlap_rs", &CommOverlap::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_rs", &CommOverlap::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlap::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlap::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlap::set_ubuf_scale_inv, - py::call_guard()) - .def("is_atomic_gemm", &CommOverlap::is_atomic_gemm, py::call_guard()) - .def("is_p2p_overlap", &CommOverlap::is_p2p_overlap, py::call_guard()) - .def("is_fp8_ubuf", &CommOverlap::is_fp8_ubuf, py::call_guard()); + py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, + py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, + py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) + .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlap::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlap::set_buffer_params); - py::class_(m, "CommOverlapP2P") + py::class_, + transformer_engine::CommOverlapP2PBase, transformer_engine::CommOverlapCore>( + m, "CommOverlapP2P") .def(py::init &, at::ScalarType, CommOverlapHelper *, int, - transformer_engine::CommOverlapType, int, int, int, bool, bool, bool, bool>(), + transformer_engine::CommOverlapType, int, int, int, int, int, bool, bool, bool, + bool>(), py::call_guard(), py::arg("buffer_shape"), py::arg("buffer_dtype"), py::arg("helper"), py::arg("tp_size"), py::arg("comm_type"), py::arg("num_max_streams") = NVTE_COMM_OVERLAP_MAX_STREAMS, py::arg("comm_cga_size") = 1, - py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, - py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("split_overlap_ag_p2p", &CommOverlapP2P::split_overlap_ag, - py::call_guard()) - .def("split_overlap_rs_p2p", &CommOverlapP2P::split_overlap_rs, - py::call_guard()) - .def("atomic_gemm_overlap_ag_p2p", &CommOverlapP2P::atomic_gemm_overlap_ag, - py::call_guard()) - .def("atomic_gemm_overlap_rs_p2p", &CommOverlapP2P::atomic_gemm_overlap_rs, - py::call_guard()) - .def("copy_input_to_ubuf", &CommOverlapP2P::copy_input_to_ubuf, - py::call_guard()) - .def("get_ubuf_output", &CommOverlapP2P::get_ubuf_output, - py::call_guard()) - .def("set_ubuf_scale_inv", &CommOverlapP2P::set_ubuf_scale_inv, - py::call_guard()) - .def("is_fp8_ubuf", &CommOverlapP2P::is_fp8_ubuf, py::call_guard()) - .def("is_atomic_gemm", &CommOverlapP2P::is_atomic_gemm, - py::call_guard()) - .def("is_p2p_overlap", &CommOverlapP2P::is_p2p_overlap, - py::call_guard()); + py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, + py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, + py::arg("use_ce") = true, py::arg("aggregate") = false) + .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), + py::arg("quantizer"), py::arg("local_chunk") = false) + .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("quantizer"), + py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) + .def("set_buffer_params", &CommOverlapP2P::set_buffer_params); } diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp new file mode 100644 index 0000000000..3d55fc15d4 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -0,0 +1,343 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include "common.h" +#include "pybind.h" +#include "torch/torch.h" +#include "util.h" + +namespace transformer_engine::pytorch { + +constexpr size_t MXFP8_BLOCK_SIZE = 32; + +Quantizer::Quantizer(const py::handle& quantizer) { + if (quantizer.is_none()) { + this->rowwise_usage = true; + this->columnwise_usage = true; + this->internal = false; + } else { + this->rowwise_usage = quantizer.attr("rowwise_usage").cast(); + this->columnwise_usage = quantizer.attr("columnwise_usage").cast(); + this->internal = quantizer.attr("internal").cast(); + this->quantizer = quantizer; + } +} + +Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + const at::Tensor& scale = quantizer.attr("scale").cast(); + const at::Tensor& amax = quantizer.attr("amax").cast(); + const DType type = quantizer.attr("dtype").cast(); + + this->amax = amax; + this->scale = scale; + this->dtype = type; +} + +std::pair NoneQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + at::TensorOptions opts; + opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); + std::vector torch_shape; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + } + at::Tensor ret; + if (rowwise_data.has_value()) { + ret = std::move(*rowwise_data); + } else { + ret = at::empty(torch_shape, opts); + } + + TensorWrapper tensor; + tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); + return {std::move(tensor), py::cast(ret)}; +} + +void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector rowwise_torch_shape; + std::vector columnwise_torch_shape; + + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size(); ++i) { + if (i < shape.size() - 1) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + rowwise_torch_shape.emplace_back(static_cast(shape[i])); + } + at::TensorOptions opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_torch_shape, opts); + } + } + const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); + at::Tensor columnwise_data; + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + if (create_transpose) { + columnwise_data = at::empty(columnwise_torch_shape, opts); + } + const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + opts = opts.dtype(torch::kFloat32); + at::Tensor scale_inv = at::reciprocal(scale); + py::object ret; + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } + TensorWrapper tensor(this->get_scaling_mode()); + if (rowwise_usage) { + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (create_transpose) { + std::vector transposed_shape; + for (auto s : columnwise_torch_shape) { + transposed_shape.emplace_back(static_cast(s)); + } + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + this->set_quantization_params(&tensor); + return {std::move(tensor), std::move(ret)}; +} + +Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) + : Quantizer(quantizer) { + const at::Tensor& scale = quantizer.attr("scale").cast(); + const at::Tensor& amax = quantizer.attr("amax").cast(); + const DType type = quantizer.attr("dtype").cast(); + // For current scaling, need several other components: + // 1. with_amax_reduction: bool + // 2. amax_reduction_group: torch.distributed.ProcessGroup or None + // 3. amax_reduction_size: int + const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); + const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group"); + const c10::intrusive_ptr amax_reduction_group = + amax_reduction_group_obj.is_none() + ? nullptr + : amax_reduction_group_obj.cast>(); + const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast(); + + this->amax = amax; + this->scale = scale; + this->dtype = type; + this->with_amax_reduction = with_amax_reduction; + this->amax_reduction_group = amax_reduction_group; + this->amax_reduction_size = amax_reduction_size; + + // fp8 current scaling specific quantization params + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); +} + +void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // transfer amax and scale pointer from quantizer to output tensor (only as gpu buffer, no meaningful data in them) + tensor->set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); + tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + // quantize output and its transpose + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair Float8CurrentScalingQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector rowwise_torch_shape; + std::vector columnwise_torch_shape; + std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv + + if (!shape.empty()) { + columnwise_torch_shape.emplace_back(static_cast(shape.back())); + } + for (size_t i = 0; i < shape.size(); ++i) { + if (i < shape.size() - 1) { + columnwise_torch_shape.emplace_back(static_cast(shape[i])); + } + rowwise_torch_shape.emplace_back(static_cast(shape[i])); + } + at::TensorOptions opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(rowwise_torch_shape, opts); + } + } + const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); + at::Tensor columnwise_data; + bool create_transpose = columnwise_usage && !non_tn_fp8_gemm_supported(); + if (create_transpose) { + columnwise_data = at::empty(columnwise_torch_shape, opts); + } + const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); + + // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. + at::Tensor scale_inv = at::reciprocal(scale); + + py::object ret; + if (internal) { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } else { + py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); + ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), + "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, + "quantizer"_a = this->quantizer); + } + TensorWrapper tensor(this->get_scaling_mode()); + if (rowwise_usage) { + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (create_transpose) { + std::vector transposed_shape; + for (auto s : columnwise_torch_shape) { + transposed_shape.emplace_back(static_cast(s)); + } + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); + tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } + this->set_quantization_params(&tensor); + return {std::move(tensor), std::move(ret)}; +} + +MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); +} + +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair MXFP8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); + at::TensorOptions opts; + at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, + columnwise_scale_inv; // TODO(pgadzinski) - change + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + auto last_dim = static_cast(torch_shape.back()); + + NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", torch_shape, ")"); + + at::Tensor data; + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data = std::move(*rowwise_data); + } else { + data = at::empty(torch_shape, opts); + } + auto sinv0 = roundup(numel / last_dim, 128); + auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4); + rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); + } + + if (columnwise_usage) { + auto sinv0 = roundup(numel / (last_dim * MXFP8_BLOCK_SIZE), 4); + auto sinv1 = roundup(last_dim, 128); + columnwise_data = at::empty(torch_shape, opts); + columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts); + + tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); + tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, + std::vector{sinv0, sinv1}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); + ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, + "rowwise_scale_inv"_a = rowwise_scale_inv, + "columnwise_scale_inv"_a = columnwise_scale_inv, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index a130169fe7..e8a31da99a 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -9,20 +9,22 @@ #include +#include "common/common.h" #include "extensions.h" -void fused_amax_and_scale_update_after_reduction( - const at::Tensor &amax_reduction_buffer, std::vector amax_histories, - std::vector scales, std::vector scale_invs, - const std::string &amax_compute_algo, transformer_engine::DType fp8_dtype, float margin) { +void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer, + std::vector amax_histories, + std::vector scales, + const std::string &amax_compute_algo, + transformer_engine::DType fp8_dtype, + float margin) { using namespace transformer_engine; + using namespace transformer_engine::pytorch; size_t num_tensors = amax_histories.size(); std::vector t_amax_histories(num_tensors); std::vector t_scales(num_tensors); - std::vector t_scale_invs(num_tensors); std::vector te_amax_histories(num_tensors); std::vector te_scales(num_tensors); - std::vector te_scale_invs(num_tensors); for (size_t i = 0; i < num_tensors; i++) { t_amax_histories[i].data.dptr = amax_histories[i].data_ptr(); auto amax_sizes = amax_histories[i].sizes().vec(); @@ -36,18 +38,11 @@ void fused_amax_and_scale_update_after_reduction( t_scales[i].data.shape = scale_shape; t_scales[i].data.dtype = DType::kFloat32; - t_scale_invs[i].data.dptr = scale_invs[i].data_ptr(); - auto scale_inv_sizes = scale_invs[i].sizes().vec(); - std::vector scale_inv_shape{scale_inv_sizes.begin(), scale_inv_sizes.end()}; - t_scale_invs[i].data.shape = scale_inv_shape; - t_scale_invs[i].data.dtype = DType::kFloat32; - te_amax_histories[i] = reinterpret_cast(&t_amax_histories[i]); te_scales[i] = reinterpret_cast(&t_scales[i]); - te_scale_invs[i] = reinterpret_cast(&t_scale_invs[i]); } nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( makeTransformerEngineTensor(amax_reduction_buffer).data(), te_amax_histories, te_scales, - te_scale_invs, amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, + amax_compute_algo.c_str(), static_cast(fp8_dtype), margin, at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/extensions/softmax.cpp b/transformer_engine/pytorch/csrc/extensions/softmax.cpp index acb68543d8..02f8fcbdf6 100644 --- a/transformer_engine/pytorch/csrc/extensions/softmax.cpp +++ b/transformer_engine/pytorch/csrc/extensions/softmax.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ @@ -7,7 +7,7 @@ #include "extensions.h" at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), @@ -38,7 +38,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor) { at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -65,7 +65,7 @@ at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_r } at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || @@ -105,7 +105,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, floa at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -132,7 +132,7 @@ at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor so } at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 3, "expected 3D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || @@ -159,7 +159,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float sc at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grads_.contiguous(); auto softmax_results = softmax_results_.contiguous(); @@ -188,7 +188,7 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, } at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; AT_ASSERTM(input.dim() == 4, "expected 4D tensor"); AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) || (input.scalar_type() == at::ScalarType::BFloat16), @@ -220,7 +220,7 @@ at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_, float scale_factor) { - using namespace transformer_engine; + using namespace transformer_engine::pytorch; auto output_grads = output_grad_.contiguous(); auto softmax_results = softmax_results_.contiguous(); diff --git a/transformer_engine/pytorch/csrc/extensions/swizzle.cpp b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp new file mode 100644 index 0000000000..b127b5d75b --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/swizzle.cpp @@ -0,0 +1,120 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "extensions.h" +#include "transformer_engine/transformer_engine.h" + +void swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (input.scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { + return; + } + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = input.get_rowwise_scale_inv(); + } else { + scale_inv = input.get_columnwise_scale_inv(); + } + + auto input_shape = nvte_shape_to_vector(input.shape()); + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + + // Allocate memory for swizzled output. + auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + std::vector scale_inv_shape_int; + for (size_t i = 0; i < scale_inv_shape.size(); ++i) { + scale_inv_shape_int.push_back(static_cast(scale_inv_shape[i])); + } + auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options); + void* scale_inv_dptr = scale_inv.data_ptr; + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + output_cu.set_columnwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, + scale_inv_shape); + } + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + if (rowwise) { + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } else { + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); + } +} + +at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input), + DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr, + getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return swizzled_scale_inv; +} + +at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) { + using namespace transformer_engine::pytorch; + + NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + + auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA); + auto swizzled_scale_inv = at::empty_like(scale_inv, options); + + // Return immediately if tensor is empty + if (scale_inv.numel() == 0) { + return swizzled_scale_inv; + } + + void* scale_inv_dptr = getDataPtr(scale_inv, 0); + void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + + auto input_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING); + auto output_cu = makeTransformerEngineTensor( + nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr, + nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv), + NVTE_MXFP8_1D_SCALING); + + // Launch kernel + nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); + + return swizzled_scale_inv; +} diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index 56f6b56769..37fbddcc18 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -1,368 +1,107 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ -#include "extensions.h" - -void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - at::Tensor input_cast, at::Tensor input_transpose, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); +#include - auto input_cu = makeTransformerEngineTensor(input); - auto output_cast_cu = - makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - auto output_transpose_cu = - makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, amax.data_ptr(), - scale.data_ptr(), scale_inv.data_ptr()); - - nvte_cast_transpose(input_cu.data(), output_cast_cu.data(), output_transpose_cu.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax, - at::Tensor scale_inv, at::Tensor input_cast, - at::Tensor input_transpose, transformer_engine::DType otype, - int scale_offset, int amax_offset, int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(input); - auto noop_cu = makeTransformerEngineTensor(noop); - auto output_cast_cu = makeTransformerEngineTensor(input_cast.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto output_transpose_cu = makeTransformerEngineTensor(input_transpose.data_ptr(), {N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - - // Launch kernel - nvte_cast_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cast_cu.data(), - output_transpose_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -std::vector fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); +#include "ATen/core/TensorBody.h" +#include "extensions.h" - // Allocate output tensors - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto grad_output_cast = - allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); +std::vector fused_multi_quantize(std::vector input_list, + std::optional> output_list, + std::vector quantizer_list, + transformer_engine::DType otype) { + using namespace transformer_engine::pytorch; + std::vector nvte_tensor_input_list; + std::vector nvte_tensor_output_list; + std::vector py_output_objects_list; + std::vector tensor_wrappers; + auto none = py::none(); + + // create TE tensors from input + for (int i = 0; i < input_list.size(); i++) { + auto input_tensor = makeTransformerEngineTensor(input_list[i], none); + const NVTEShape input_shape = input_tensor.shape(); + + transformer_engine::TensorWrapper output_tensor; + + if (output_list == std::nullopt) { + std::unique_ptr quantizer = convert_quantizer(quantizer_list[i]); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + py::object o; + std::tie(output_tensor, o) = quantizer->create_tensor(output_shape, otype); + py_output_objects_list.push_back(o); + } else { + output_tensor = makeTransformerEngineTensor((*output_list)[i], quantizer_list[i]); + } + if (input_tensor.numel() == 0) continue; - // Return immediately if tensors are empty - if (M == 0 || N == 0) { - return {grad_bias.zero_(), grad_output_cast, grad_output_transpose}; + nvte_tensor_output_list.emplace_back(output_tensor.data()); + nvte_tensor_input_list.emplace_back(input_tensor.data()); + tensor_wrappers.emplace_back(std::move(input_tensor)); + tensor_wrappers.emplace_back(std::move(output_tensor)); } - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(grad_output_cast.data_ptr(), {M, N}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor( - grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_cast_transpose_dbias(input_cu.data(), cast_output_cu.data(), transposed_output_cu.data(), - dbias_cu.data(), workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_cast, grad_output_transpose}; -} - -std::vector fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - transformer_engine::DType grad_bias_type, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_bias_type); - auto grad_output_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); - auto input_cu = makeTransformerEngineTensor(grad_output.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor( - grad_output_transpose.data_ptr(), {N, M}, otype, amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_fp8_transpose_dbias(input_cu.data(), transposed_output_cu.data(), dbias_cu.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - - return {grad_bias, grad_output_transpose}; -} - -std::vector fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output, - at::Tensor gelu_input, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, - transformer_engine::DType otype, - int scale_offset, int amax_offset, - int scale_inv_offset) { - using namespace transformer_engine; - - // Tensor dimensions - size_t M = static_cast(grad_output.size(0)); - size_t N = static_cast(grad_output.size(1)); - - // Get pointers for FP8 scale, amax, scale-inverse - void* scale_dptr = getDataPtr(scale, scale_offset); - void* amax_dptr = getDataPtr(amax, amax_offset); - void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset); - - // Construct Transformer Engine tensors - DType grad_output_type = GetTransformerEngineDType(grad_output.scalar_type()); - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_output_type); - auto dgelu = allocateTorchTensor(grad_output.size(0), grad_output.size(1), DType::kByte); - auto dgelu_transpose = - allocateTorchTensor(grad_output.size(1), grad_output.size(0), DType::kByte); - auto gelu_input_cu = makeTransformerEngineTensor(gelu_input); - auto input_cu = makeTransformerEngineTensor(grad_output); - auto cast_output_cu = makeTransformerEngineTensor(dgelu.data_ptr(), {M, N}, otype, amax_dptr, - scale_dptr, scale_inv_dptr); - auto transposed_output_cu = makeTransformerEngineTensor(dgelu_transpose.data_ptr(), {N, M}, otype, - amax_dptr, scale_dptr, scale_inv_dptr); - auto dbias_cu = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - nvte_cast_transpose_dbias_dgelu(input_cu.data(), gelu_input_cu.data(), cast_output_cu.data(), - transposed_output_cu.data(), dbias_cu.data(), workspace.data(), - at::cuda::getCurrentCUDAStream()); - - return {grad_bias, dgelu, dgelu_transpose}; -} - -void fused_multi_cast_transpose_base(std::vector input_list, - std::vector scale_dptr_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_dptr_list, - std::vector scale_inv_dptr_list, - transformer_engine::DType otype) { - using namespace transformer_engine; - - // Extract properties from PyTorch tensors - std::vector input_dptr_list, cast_output_dptr_list, transposed_output_dptr_list; - std::vector> input_shape_list, cast_output_shape_list, - transposed_output_shape_list; - std::vector input_type_list, cast_output_type_list, - transposed_output_type_list; - auto extract_tensor_props_skip_dtype = [](at::Tensor& tensor, std::vector& dptr_list, - std::vector>& shape_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); + // Check tensor lists + NVTE_CHECK(nvte_tensor_output_list.size() == nvte_tensor_input_list.size(), + "Number of input and output tensors must match"); + + // Choose implementation + // Note: Currently only have fused kernel for FP8 cast-transpose + bool with_fused_kernel = true; + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + const auto& tensor = nvte_tensor_output_list[i]; + if (nvte_tensor_scaling_mode(tensor) != NVTE_DELAYED_TENSOR_SCALING) { + with_fused_kernel = false; + break; } - }; - auto extract_tensor_props = [](at::Tensor& tensor, std::vector& dptr_list, - std::vector>& shape_list, - std::vector& type_list) { - dptr_list.push_back(tensor.data_ptr()); - shape_list.push_back({}); - for (int d = 0; d < tensor.dim(); ++d) { - shape_list.back().push_back(tensor.size(d)); + if (nvte_tensor_columnwise_data(tensor) == nullptr) { + with_fused_kernel = false; + break; } - type_list.push_back(GetTransformerEngineDType(tensor.scalar_type())); - }; - for (size_t tensor_id = 0; tensor_id < input_list.size(); ++tensor_id) { - extract_tensor_props(input_list[tensor_id], input_dptr_list, input_shape_list, input_type_list); - extract_tensor_props_skip_dtype(cast_output_list[tensor_id], cast_output_dptr_list, - cast_output_shape_list); - cast_output_type_list.push_back(otype); - extract_tensor_props_skip_dtype(transposed_output_list[tensor_id], transposed_output_dptr_list, - transposed_output_shape_list); - transposed_output_type_list.push_back(otype); } - // Construct TE tensors - std::vector nvte_input_list, nvte_cast_output_list, nvte_transposed_output_list; - std::vector tensor_wrappers; - auto make_tensor = [&tensor_wrappers](void* dptr, const std::vector& shape, - transformer_engine::DType dtype, void* amax_dptr, - void* scale_dptr, void* scale_inv_dptr) -> NVTETensor { - tensor_wrappers.emplace_back( - makeTransformerEngineTensor(dptr, shape, dtype, amax_dptr, scale_dptr, scale_inv_dptr)); - return tensor_wrappers.back().data(); - }; - for (size_t i = 0; i < input_dptr_list.size(); ++i) { - if (input_dptr_list[i] == nullptr) continue; - nvte_input_list.emplace_back(make_tensor(input_dptr_list[i], input_shape_list[i], - input_type_list[i], nullptr, nullptr, nullptr)); - nvte_cast_output_list.emplace_back( - make_tensor(cast_output_dptr_list[i], cast_output_shape_list[i], cast_output_type_list[i], - amax_dptr_list[i], scale_dptr_list[i], scale_inv_dptr_list[i])); - nvte_transposed_output_list.emplace_back( - make_tensor(transposed_output_dptr_list[i], transposed_output_shape_list[i], - transposed_output_type_list[i], amax_dptr_list[i], scale_dptr_list[i], - scale_inv_dptr_list[i])); - } - - // Check tensor lists - NVTE_CHECK(nvte_cast_output_list.size() == nvte_input_list.size(), - "Number of input and C output tensors must match"); - NVTE_CHECK(nvte_transposed_output_list.size() == nvte_input_list.size(), - "Number of input and T output tensors must match"); - // Launch TE kernel - nvte_multi_cast_transpose(nvte_input_list.size(), nvte_input_list.data(), - nvte_cast_output_list.data(), nvte_transposed_output_list.data(), - at::cuda::getCurrentCUDAStream()); -} - -void fused_multi_cast_transpose(std::vector input_list, - std::vector scale_list, - std::vector cast_output_list, - std::vector transposed_output_list, - std::vector amax_list, - std::vector scale_inv_list, - transformer_engine::DType otype) { - std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; - for (size_t i = 0; i < scale_list.size(); ++i) { - scale_dptr_list.push_back(scale_list[i].data_ptr()); - amax_dptr_list.push_back(amax_list[i].data_ptr()); - scale_inv_dptr_list.push_back(scale_inv_list[i].data_ptr()); + if (with_fused_kernel) { + nvte_multi_cast_transpose(nvte_tensor_input_list.size(), nvte_tensor_input_list.data(), + nvte_tensor_output_list.data(), at::cuda::getCurrentCUDAStream()); + } else { + for (size_t i = 0; i < nvte_tensor_output_list.size(); i++) { + nvte_quantize(nvte_tensor_input_list[i], nvte_tensor_output_list[i], + at::cuda::getCurrentCUDAStream()); + } } - - fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, - transposed_output_list, amax_dptr_list, scale_inv_dptr_list, - otype); + return py_output_objects_list; } -std::tuple, std::vector> fused_multi_cast_transpose_alloc( - std::vector input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - std::vector scale_indices, std::vector amax_indices, - std::vector scale_inv_indices, transformer_engine::DType otype) { - using namespace transformer_engine; +at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype, + std::optional output) { + using namespace transformer_engine::pytorch; - std::vector cast_output_list; - std::vector transposed_output_list; - std::vector scale_dptr_list, amax_dptr_list, scale_inv_dptr_list; - for (size_t i = 0; i < input_list.size(); ++i) { - auto input_i = input_list[i]; - // construct cast output tensors - auto cast_output_i = allocateTorchTensor(input_i.size(0), input_i.size(1), DType::kByte); - cast_output_list.push_back(cast_output_i); - // construct transposed output tensors - auto transposed_output_i = allocateTorchTensor(input_i.size(1), input_i.size(0), DType::kByte); - transposed_output_list.push_back(transposed_output_i); - // construct amax/scale/scale_inv dptr lists - amax_dptr_list.push_back(getDataPtr(amax, amax_indices[i])); - scale_dptr_list.push_back(getDataPtr(scale, scale_indices[i])); - scale_inv_dptr_list.push_back(getDataPtr(scale_inv, scale_inv_indices[i])); - } - - fused_multi_cast_transpose_base(input_list, scale_dptr_list, cast_output_list, - transposed_output_list, amax_dptr_list, scale_inv_dptr_list, - otype); - - return std::make_tuple(std::move(cast_output_list), std::move(transposed_output_list)); -} + const auto dim = input.dim(); + NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); -at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype) { - using namespace transformer_engine; + if (input.dim() > 2) { + input = input.view({-1, input.size(dim - 1)}); + } size_t M = static_cast(input.size(0)); size_t N = static_cast(input.size(1)); - auto output = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); - if (M == 0 || N == 0) return output; - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); - - return output; -} - -void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); + at::Tensor out; + if (output.has_value()) { + out = *output; + } else { + out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + } + if (M == 0 || N == 0) return out; auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); + auto output_cu = makeTransformerEngineTensor(out.data_ptr(), {N, M}, otype); nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); -} - -void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop, - transformer_engine::DType otype) { - using namespace transformer_engine; - - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - - auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, otype); - auto noop_cu = makeTransformerEngineTensor(noop); - auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, M}, otype); - nvte_transpose_with_noop(input_cu.data(), noop_cu.data(), output_cu.data(), - at::cuda::getCurrentCUDAStream()); + return out; } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp new file mode 100644 index 0000000000..d5654fb43a --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -0,0 +1,89 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include "common.h" +#include "pybind.h" + +namespace transformer_engine::pytorch { +namespace detail { + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer) { + auto ret = TensorWrapper(quantizer->get_scaling_mode()); + + bool data_exists = !tensor.attr("_data").is_none(); + bool transpose_exists = + !tensor.attr("_transpose_invalid").cast() && !tensor.attr("_transpose").is_none(); + + NVTE_CHECK(data_exists || transpose_exists, "No data found for FP8 Tensor."); + + // FP8 data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (data_exists) { + const auto &data = tensor.attr("_data").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + } + + // FP8 data transpose + if (transpose_exists) { + const auto &data_transpose = tensor.attr("_transpose").cast(); + ret.set_columnwise_data(data_transpose.data_ptr(), fp8_dtype, getTensorShape(data_transpose)); + } + + // Scale-inverse + { + const auto &scale_inv = tensor.attr("_scale_inv").cast(); + float *dptr = reinterpret_cast(scale_inv.data_ptr()); + const auto &dtype = GetTransformerEngineDType(scale_inv.scalar_type()); + const auto &shape = getTensorShape(scale_inv); + ret.set_rowwise_scale_inv(dptr, dtype, shape); + ret.set_columnwise_scale_inv(dptr, dtype, shape); + } + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) { + auto ret = TensorWrapper(NVTE_MXFP8_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for MXFP8 Tensor."); + + // Row-scaled data + const DType fp8_dtype = tensor.attr("_fp8_dtype").cast(); + if (rowwise_usage) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + ret.set_rowwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, getTensorShape(scale_inv)); + } + + // Column-scaled data + if (columnwise_usage) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + ret.set_columnwise_data(data.data_ptr(), fp8_dtype, getTensorShape(data)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E8M0, + getTensorShape(scale_inv)); + } + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + +} // namespace detail + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/util.cpp b/transformer_engine/pytorch/csrc/extensions/util.cpp new file mode 100644 index 0000000000..5f49383d11 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/util.cpp @@ -0,0 +1,14 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "util.h" + +#include "ATen/cuda/CUDAContextLight.h" + +bool non_tn_fp8_gemm_supported() { + int major = at::cuda::getCurrentDeviceProperties()->major; + return major >= 10; +} diff --git a/transformer_engine/pytorch/csrc/kv_cache.cuh b/transformer_engine/pytorch/csrc/kv_cache.cuh new file mode 100644 index 0000000000..e79690d215 --- /dev/null +++ b/transformer_engine/pytorch/csrc/kv_cache.cuh @@ -0,0 +1,145 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_KV_CACHE_CUH_ + +namespace transformer_engine { +namespace fused_attn { +template +__global__ void convert_thd_to_bshd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, + int b, int max_seq_len, int h, int d) { + // tensor: thd; new_tensor: bshd + // cu_seqlens: [b + 1] + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int num_elts = (cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]) * h * d; + int thd_offset = cu_seqlens[batch_idx] * h * d; + int bshd_offset = batch_idx * max_seq_len * h * d; + scalar_t *thd_token = tensor + thd_offset; + scalar_t *bshd_token = new_tensor + bshd_offset; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(bshd_token + i) = *(thd_token + i); + } + } +} + +template +__global__ void convert_bshd_to_thd_kernel(scalar_t *tensor, scalar_t *new_tensor, int *cu_seqlens, + int b, int max_seq_len, int h, int d) { + // tensor: bshd; new_tensor: thd + // cu_seqlens: [b + 1] + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int seqlen = cu_seqlens[batch_idx + 1] - cu_seqlens[batch_idx]; + int num_elts = seqlen * h * d; + int bshd_offset = batch_idx * max_seq_len * h * d; + int thd_offset = cu_seqlens[batch_idx] * h * d; + scalar_t *bshd_token = tensor + bshd_offset; + scalar_t *thd_token = new_tensor + thd_offset; + for (int i = threadIdx.x; i < num_elts; i += blockDim.x) { + *(thd_token + i) = *(bshd_token + i); + } + } +} + +template +__global__ void reindex_kv_cache_kernel(scalar_t *k_cache, scalar_t *v_cache, int *batch_indices, + int *cu_new_lens, int *cu_cached_lens, int h_kv, int d_k, + int d_v, int b, int max_seq_len) { + // k_cache, v_cache: bshd + // batch_indices: [b]; cu_new_lens, cu_cached_lens: [b + 1] + int actual_b = b; + for (int i = 0; i < b - 1; i++) { + if (batch_indices[i + 1] < batch_indices[i]) { + actual_b = i + 1; + } + } + for (int batch_idx = 0; batch_idx < actual_b; batch_idx++) { + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; + for (int token_idx = blockIdx.x; token_idx < cached_len - new_len; token_idx += gridDim.x) { + int num_elts_k = h_kv * d_k; + int num_elts_v = h_kv * d_v; + int k_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_k; + int k_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_k; + int v_cache_src_offset = (batch_indices[batch_idx] * max_seq_len + token_idx) * h_kv * d_v; + int v_cache_des_offset = (batch_idx * max_seq_len + token_idx) * h_kv * d_v; + for (int i = threadIdx.x; i < num_elts_k; i += blockDim.x) { + *(k_cache + k_cache_des_offset + i) = *(k_cache + k_cache_src_offset + i); + } + for (int i = threadIdx.x; i < num_elts_v; i += blockDim.x) { + *(v_cache + v_cache_des_offset + i) = *(v_cache + v_cache_src_offset + i); + } + } + } +} + +template +__global__ void copy_to_kv_cache_kernel(scalar_t *new_k, scalar_t *new_v, scalar_t *k_cache, + scalar_t *v_cache, int *page_table, int *cu_new_lens, + int *cu_cached_lens, NVTE_QKV_Format qkv_format, int h_kv, + int d_k, int d_v, int b, int max_ctx_len, int max_seq_len, + int max_pages_per_seq, bool is_non_paged) { + // new_k, new_v: qkv_format; k_cache, v_cache: bshd + // cu_new_lens, cu_cached_lens: [b + 1] + // page_table: [b, max_pages_per_seq] + int page_size = max_seq_len / max_pages_per_seq; + if (qkv_format == NVTE_QKV_Format::NVTE_BSHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; + int new_token_offset = batch_idx * max_ctx_len; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; + for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = + *(new_k + (new_token_offset + i) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = + *(new_v + (new_token_offset + i) * h_kv * d_v + j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; + for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = *(new_k + (i * b + batch_idx) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = *(new_v + (i * b + batch_idx) * h_kv * d_v + j); + } + } + } + } else if (qkv_format == NVTE_QKV_Format::NVTE_THD) { + for (int batch_idx = blockIdx.x; batch_idx < b; batch_idx += gridDim.x) { + int *page_list = is_non_paged ? nullptr : page_table + batch_idx * max_pages_per_seq; + int cached_len = cu_cached_lens[batch_idx + 1] - cu_cached_lens[batch_idx]; + int new_len = cu_new_lens[batch_idx + 1] - cu_new_lens[batch_idx]; + for (int i = threadIdx.x; i < new_len; i += blockDim.x) { + int page_idx = is_non_paged ? batch_idx : page_list[(cached_len - new_len + i) / page_size]; + int token_idx = page_idx * page_size + (cached_len - new_len + i) % page_size; + for (int j = 0; j < h_kv * d_k; j++) { + *(k_cache + token_idx * h_kv * d_k + j) = + *(new_k + (cu_new_lens[batch_idx] + i) * h_kv * d_k + j); + } + for (int j = 0; j < h_kv * d_v; j++) { + *(v_cache + token_idx * h_kv * d_v + j) = + *(new_v + (cu_new_lens[batch_idx] + i) * h_kv * d_v + j); + } + } + } + } +} +} // namespace fused_attn +} // namespace transformer_engine +#endif diff --git a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh index e85ec3afc2..f7598da45a 100644 --- a/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh +++ b/transformer_engine/pytorch/csrc/multi_tensor_apply.cuh @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h new file mode 100644 index 0000000000..b0f55d7598 --- /dev/null +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -0,0 +1,80 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#define PYBIND11_DETAILED_ERROR_MESSAGES // TODO remove + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ +#include +#include +#include +#include + +#include "common.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine::pytorch { + +extern PyTypeObject *Float8TensorPythonClass; +extern PyTypeObject *Float8TensorBasePythonClass; +extern PyTypeObject *Float8QuantizerClass; +extern PyTypeObject *Float8CurrentScalingQuantizerClass; +extern PyTypeObject *MXFP8TensorPythonClass; +extern PyTypeObject *MXFP8TensorBasePythonClass; +extern PyTypeObject *MXFP8QuantizerClass; + +void init_extension(); + +void init_float8_extension(); + +void init_mxfp8_extension(); + +namespace detail { + +inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } + +inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { + return Py_TYPE(obj) == Float8CurrentScalingQuantizerClass; +} + +inline bool IsFloat8Tensor(PyObject *obj) { + return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; +} + +inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } + +inline bool IsMXFP8Tensor(PyObject *obj) { + return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; +} + +TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); + +template +std::unique_ptr CreateQuantizer(const py::handle quantizer) { + return std::make_unique(quantizer); +} + +TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantization_params); + +std::unique_ptr CreateMXFP8Params(const py::handle params); + +inline bool IsFloatingPointType(at::ScalarType type) { + return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; +} + +constexpr std::array custom_types_converters = { + std::make_tuple(IsFloat8Tensor, IsFloat8Quantizers, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor, + CreateQuantizer), + std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, + CreateQuantizer)}; + +} // namespace detail + +} // namespace transformer_engine::pytorch + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_PYBIND_H_ diff --git a/transformer_engine/pytorch/csrc/thd_utils.cuh b/transformer_engine/pytorch/csrc/thd_utils.cuh new file mode 100644 index 0000000000..1f1f0cfdfd --- /dev/null +++ b/transformer_engine/pytorch/csrc/thd_utils.cuh @@ -0,0 +1,302 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_ +#define TRANSFORMER_ENGINE_FUSED_ATTN_THD_UTILS_CUH_ + +#include +#include +#include + +struct LseCorrectionFunctor { + __forceinline__ __device__ static void run(double *lse, float *half_lse, size_t idx, + size_t half_idx) { + double val = lse[idx]; + float val_per_step = half_lse[half_idx]; + double max_scale = max(val, val_per_step); + double min_scale = min(val, val_per_step); + lse[idx] = max_scale + log(1.0 + exp(min_scale - max_scale)); + } +}; + +struct ReadLseFunctor { + __forceinline__ __device__ static void run(float *lse, float *half_lse, size_t idx, + size_t half_idx) { + half_lse[half_idx] = lse[idx]; + } +}; + +struct EmptyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) {} +}; + +struct CopyFunctor { + __forceinline__ __device__ static void run(void *token, void *token_per_step, int idx) { + reinterpret_cast(token)[idx] = reinterpret_cast(token_per_step)[idx]; + } +}; + +template +struct AddFunctor { + __forceinline__ __device__ static void run(dtype *token, dtype *token_per_step, int idx) { + float4 d_ = reinterpret_cast(token)[idx]; + dtype *p_ = reinterpret_cast(&d_); + + float4 d = reinterpret_cast(token_per_step)[idx]; + dtype *p = reinterpret_cast(&d); + +#pragma unroll + for (int i = 0; i < sizeof(float4) / sizeof(dtype); i++) { + p_[i] += p[i]; + } + + reinterpret_cast(token)[idx] = d_; + } +}; + +namespace transformer_engine { +namespace fused_attn { + +/*************************************************************************************************** + * Support THD format for Context Parallel: Binary search an array for a target value + **************************************************************************************************/ + +__forceinline__ __device__ int binary_search(int target, int *array, int len) { + int left = 1, right = len - 1; + while (left < right) { + int mid = (left + right) / 2; + if (array[mid] <= target) { + left = mid + 1; + } else { + right = mid; + } + } + return left - 1; +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Generate partitioned indices for input tokens + **************************************************************************************************/ +__global__ void thd_partition_indices_kernel(int *output, int *cu_seqlens, int batch, + int total_tokens, int world_size, int rank) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + int seqlen = cu_seqlens[i]; + // Currently we assume that each sequence length is divisible by (world_size*2) since we have + // to distribute each sequence evenly to different GPUs. + assert(seqlen % (world_size * 2) == 0); + cu_seqlens_s[i] = seqlen / world_size; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + + for (int token_id = tid; token_id < total_tokens / world_size; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + int index = token_id - cu_seqlens_s[seq_id]; + int offset = index < seq_len / 2 ? rank : (world_size - 1) * 2 - rank; + index += cu_seqlens_s[seq_id] * world_size + seq_len / 2 * offset; + output[token_id] = index; + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Read the half of a THD tensor + **************************************************************************************************/ +__global__ void thd_read_half_tensor_kernel(void *half, void *tensor, int *cu_seqlens, int batch, + int hidden_size_in_bytes, int half_idx, + int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int warpid = (blockIdx.x * blockDim.x + threadIdx.x) / 32; + int laneid = threadIdx.x % 32; + int num_warps = (blockDim.x * gridDim.x) / 32; + int num_total_tokens = cu_seqlens_s[batch]; + int num_float4s_per_token = hidden_size_in_bytes / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size_in_bytes; + half = reinterpret_cast(reinterpret_cast(half) + offset / 2 * blockIdx.y); + tensor = reinterpret_cast(reinterpret_cast(tensor) + offset * blockIdx.y); + + for (int token_id = warpid; token_id < num_total_tokens; token_id += num_warps) { + int seqid = binary_search(token_id, cu_seqlens_s, batch + 1); + + size_t offset_in_bytes = static_cast(token_id) * hidden_size_in_bytes; + float4 *cur_half_token = + reinterpret_cast(reinterpret_cast(half) + offset_in_bytes); + + offset_in_bytes = + (static_cast(token_id) + cu_seqlens_s[seqid + half_idx]) * hidden_size_in_bytes; + float4 *cur_token = + reinterpret_cast(reinterpret_cast(tensor) + offset_in_bytes); + + for (int idx = laneid; idx < num_float4s_per_token; idx += 32) { + cur_half_token[idx] = cur_token[idx]; + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: softmax_lse related operations + **************************************************************************************************/ + +template +__global__ void thd_lse_kernel(lse_dtype *lse, float *half_lse, int *cu_seqlens, int batch, + int num_heads, int lse_seqlen, int second_half_lse_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } + __syncthreads(); + + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int num_threads = blockDim.x * gridDim.x; + int num_total_tokens = cu_seqlens_s[batch]; + + for (int token_id = tid; token_id < num_total_tokens; token_id += num_threads) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, half_idx; + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1]; + half_idx = head_id * second_half_lse_seqlen + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + + idx = row * lse_seqlen + col + seq_len; + half_idx = row * second_half_lse_seqlen + col; + } + + Functor::run(lse, half_lse, idx, half_idx); + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Out correction in forward + **************************************************************************************************/ + +template +__global__ void thd_out_correction_kernel(dtype *out, dtype *out_per_step, float *lse, + float *lse_per_step, int *cu_seqlens, int batch, + int num_heads, int dim_per_head, int lse_seqlen, + int lse_per_step_seqlen) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + cu_seqlens_s[i] = cu_seqlens[i] / (only_second_half + 1); + } + __syncthreads(); + + int tile_id = (blockIdx.x * blockDim.x + threadIdx.x) / tile_size; + int lane_id = threadIdx.x % tile_size; + int num_tiles = (blockDim.x * gridDim.x) / tile_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_loops_per_head = dim_per_head * sizeof(dtype) / sizeof(float4); + + for (int token_id = tile_id; token_id < num_total_tokens; token_id += num_tiles) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + for (int head_id = blockIdx.y; head_id < num_heads; head_id += gridDim.y) { + size_t idx, idx_per_step; + + if constexpr (lse_packed) { + idx = head_id * lse_seqlen + token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx_per_step = head_id * lse_per_step_seqlen + token_id; + } else { + size_t row = static_cast(seq_id) * num_heads + head_id; + int col = token_id - cu_seqlens_s[seq_id]; + int seq_len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + idx = row * lse_seqlen + col + seq_len * only_second_half; + idx_per_step = row * lse_per_step_seqlen + col; + } + float lse_corrected_exp = exp(lse_per_step[idx_per_step] - lse[idx]); + + idx = token_id + cu_seqlens_s[seq_id + 1] * only_second_half; + idx = (idx * num_heads + head_id) * dim_per_head; + idx_per_step = (static_cast(token_id) * num_heads + head_id) * dim_per_head; + dtype *cur_out = out + idx; + dtype *cur_out_per_step = out_per_step + idx_per_step; + + for (int j = lane_id; j < num_loops_per_head; j += tile_size) { + float4 data_per_step = reinterpret_cast(cur_out_per_step)[j]; + float4 data = reinterpret_cast(cur_out)[j]; + dtype *p_per_step = reinterpret_cast(&data_per_step); + dtype *p = reinterpret_cast(&data); + for (int k = 0; k < sizeof(float4) / sizeof(dtype); k++) { + p[k] += (p_per_step[k] == 0 ? 0 : p_per_step[k] * lse_corrected_exp); + } + reinterpret_cast(cur_out)[j] = data; + } + } + } +} + +/*************************************************************************************************** + * Support THD format for Context Parallel: Gradients correction in backward + **************************************************************************************************/ + +template +__global__ void thd_grad_correction_kernel(dtype *grad, dtype *grad_per_step, int *cu_seqlens, + int batch, int hidden_size, int dim_size_of_token) { + extern __shared__ int cu_seqlens_s[]; + for (int i = threadIdx.x; i <= batch; i += blockDim.x) { + if constexpr (functor_idx < 2) { + cu_seqlens_s[i] = cu_seqlens[i] / 2; + } else { + cu_seqlens_s[i] = cu_seqlens[i]; + } + } + __syncthreads(); + + int group_id = (blockIdx.x * blockDim.x + threadIdx.x) / group_size; + int lane_id = threadIdx.x % group_size; + int num_groups = (blockDim.x * gridDim.x) / group_size; + int num_total_tokens = cu_seqlens_s[batch]; + int num_inner_loops = hidden_size * sizeof(dtype) / sizeof(float4); + + size_t offset = static_cast(dim_size_of_token) * hidden_size; + if constexpr (functor_idx < 2) { + grad_per_step = grad_per_step + offset / 2 * blockIdx.y; + } else { + grad_per_step = grad_per_step + offset * blockIdx.y; + } + grad = grad + offset * blockIdx.y; + + for (int token_id = group_id; token_id < num_total_tokens; token_id += num_groups) { + int seq_id = binary_search(token_id, cu_seqlens_s, batch + 1); + + int token_offset; + bool is_first_half; + if constexpr (functor_idx < 2) { + token_offset = cu_seqlens_s[seq_id + functor_idx]; + is_first_half = (functor_idx == 0); + } else { + token_offset = 0; + int len = cu_seqlens_s[seq_id + 1] - cu_seqlens_s[seq_id]; + is_first_half = (token_id - cu_seqlens_s[seq_id]) < (len / 2); + } + + dtype *token = &grad[(token_id + token_offset) * static_cast(hidden_size)]; + dtype *token_per_step = &grad_per_step[token_id * static_cast(hidden_size)]; + for (int idx = lane_id; idx < num_inner_loops; idx += group_size) { + if (is_first_half) { + Functor_0::run(token, token_per_step, idx); + } else { + Functor_1::run(token, token_per_step, idx); + } + } + } +} + +} // namespace fused_attn +} // namespace transformer_engine +#endif diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp deleted file mode 100644 index 9f31dba669..0000000000 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ /dev/null @@ -1,414 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include "common/util/cuda_runtime.h" -#include "common/util/system.h" -#include "extensions.h" - -namespace { -transformer_engine::DType reverse_map_dtype(int64_t dtype) { - if (dtype >= 0 && dtype < static_cast(transformer_engine::DType::kNumTypes)) { - return static_cast(dtype); - } else { - NVTE_ERROR("Type not supported."); - } -} -} // namespace - -at::Tensor cast_to_fp8_ts(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax, - at::Tensor scale_inv, int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = - cast_to_fp8(input, scale, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, fp8_tensor); - return output; -} - -at::Tensor cast_to_fp8_noalloc_ts(const at::Tensor &input, const at::Tensor &scale, - at::Tensor output, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - cast_to_fp8_noalloc(input, scale, output, amax, scale_inv, otype_arg, fp8_tensor, fp8_tensor, - fp8_tensor); - return output; -} - -at::Tensor cast_from_fp8_ts(const at::Tensor &input, const at::Tensor &scale_inv, - int64_t fp8_tensor, int64_t itype, int64_t otype) { - transformer_engine::DType itype_arg = reverse_map_dtype(itype); - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - at::Tensor output = cast_from_fp8(input, scale_inv, itype_arg, otype_arg, fp8_tensor); - return output; -} - -at::Tensor gelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = gelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor relu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = relu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor reglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = reglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor geglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = geglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor swiglu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = swiglu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor qgelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = qgelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor srelu_ts(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - - at::Tensor s, a, s_inv; - if (scale.numel()) { - s = scale[fp8_tensor]; - } else { - s = scale; - } - - if (amax.numel()) { - a = amax[0][fp8_tensor]; - } else { - a = amax; - } - - if (scale_inv.numel()) { - s_inv = scale_inv[fp8_tensor]; - } else { - s_inv = scale_inv; - } - - at::Tensor output = srelu(input, s, a, s_inv, otype_arg); - return output; -} - -at::Tensor te_gemm_ts(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, - int64_t A_type, int64_t transa, at::Tensor B, at::Tensor B_scale_inverse, - int64_t B_fp8_tensor, int64_t B_type, int64_t transb, at::Tensor D, - at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, at::Tensor bias, - int64_t bias_type, at::Tensor pre_gelu_out, int64_t grad, - at::Tensor workspace, int64_t workspaceSize, int64_t accumulate, - int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; - - if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor]; - - te_gemm(A, A_scale_inverse, A_type_arg, transa_arg, B, B_scale_inverse, B_type_arg, transb_arg, D, - D_scale, D_type_arg, D_amax, bias, bias_type_arg, pre_gelu_out, grad_arg, workspace, - workspaceSize_arg, accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -std::vector te_grouped_gemm_ts( - std::vector A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type, - int64_t transa, std::vector B, at::Tensor B_scale_inverse, int64_t B_offset, - int64_t B_type, int64_t transb, std::vector D, int64_t D_offset, at::Tensor D_scale, - int64_t D_type, at::Tensor D_amax, std::vector bias, int64_t bias_type, - std::vector pre_gelu_out, int64_t grad, std::vector workspace, - int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - te_grouped_gemm(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, B_scale_inverse, - B_offset, B_type_arg, transb_arg, D, D_offset, D_scale, D_type_arg, D_amax, bias, - bias_type_arg, pre_gelu_out, grad_arg, workspace, workspaceSize_arg, - accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -at::Tensor te_grouped_gemm_single_output_ts( - std::vector A, std::vector A_scale_inverse, int64_t A_offset, - int64_t A_type, int64_t transa, std::vector B, at::Tensor B_scale_inverse, - int64_t B_offset, int64_t B_type, int64_t transb, std::vector m_splits, at::Tensor D, - int64_t D_offset, at::Tensor D_scale, int64_t D_type, at::Tensor D_amax, - std::vector bias, int64_t bias_type, std::vector pre_gelu_out, - int64_t grad, std::vector workspace, int64_t workspaceSize, int64_t accumulate, - int64_t use_split_accumulator) { - // cast inputs to types accepted by te_gemm - transformer_engine::DType A_type_arg = reverse_map_dtype(A_type); - bool transa_arg = static_cast(transa); - transformer_engine::DType B_type_arg = reverse_map_dtype(B_type); - bool transb_arg = static_cast(transb); - transformer_engine::DType D_type_arg = reverse_map_dtype(D_type); - transformer_engine::DType bias_type_arg = reverse_map_dtype(bias_type); - bool grad_arg = static_cast(grad); - size_t workspaceSize_arg = static_cast(workspaceSize); - bool accumulate_arg = static_cast(accumulate); - bool use_split_accumulator_arg = static_cast(use_split_accumulator); - - // Set an external SM Margin to all the GEMMs. - // This comes in handy when DP is overlapped with GEMMs - - const int device_id = at::cuda::current_device(); - const int sm_count = transformer_engine::cuda::sm_count(device_id); - int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); - - te_grouped_gemm_single_output(A, A_scale_inverse, A_offset, A_type_arg, transa_arg, B, - B_scale_inverse, B_offset, B_type_arg, transb_arg, m_splits, D, - D_offset, D_scale, D_type_arg, D_amax, bias, bias_type_arg, - pre_gelu_out, grad_arg, workspace, workspaceSize_arg, - accumulate_arg, use_split_accumulator_arg, num_math_sms); - return D; -} - -at::Tensor layernorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, double eps, at::Tensor scale, - at::Tensor amax, at::Tensor scale_inv, int64_t fp8_tensor, - int64_t otype, const int64_t sm_margin, - const bool zero_centered_gamma) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - float eps_float = static_cast(eps); - - at::Tensor output = layernorm_fwd_fp8_inf(input, weight, bias, eps_float, scale, amax, scale_inv, - otype_arg, sm_margin, zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset - fp8_tensor); // scale_inv_offset - - return output; -} - -at::Tensor layernorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, - const at::Tensor &bias, double eps, const int64_t sm_margin, - const bool zero_centered_gamma) { - float eps_float = static_cast(eps); - - at::Tensor output = - layernorm_fwd_inf(input, weight, bias, eps_float, sm_margin, zero_centered_gamma); - - return output; -} - -at::Tensor rmsnorm_fwd_fp8_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, - at::Tensor scale, at::Tensor amax, at::Tensor scale_inv, - int64_t fp8_tensor, int64_t otype, const int64_t sm_margin, - const bool zero_centered_gamma) { - transformer_engine::DType otype_arg = reverse_map_dtype(otype); - float eps_float = static_cast(eps); - - at::Tensor output = rmsnorm_fwd_fp8_inf(input, weight, eps_float, scale, amax, scale_inv, - otype_arg, sm_margin, zero_centered_gamma, - fp8_tensor, // scale_offset - fp8_tensor, // amax_offset - fp8_tensor); // scale_inv_offset - - return output; -} - -at::Tensor rmsnorm_fwd_inf_ts(const at::Tensor &input, const at::Tensor &weight, double eps, - const int64_t sm_margin, const bool zero_centered_gamma) { - float eps_float = static_cast(eps); - - at::Tensor output = rmsnorm_fwd_inf(input, weight, eps_float, sm_margin, zero_centered_gamma); - - return output; -} - -TORCH_LIBRARY(tex_ts, m) { - m.def("cast_to_fp8_ts", &cast_to_fp8_ts); - m.def("cast_to_fp8_noalloc_ts", &cast_to_fp8_noalloc_ts); - m.def("cast_from_fp8_ts", &cast_from_fp8_ts); - m.def("gelu_ts", &gelu_ts); - m.def("relu_ts", &relu_ts); - m.def("geglu_ts", &geglu_ts); - m.def("reglu_ts", ®lu_ts); - m.def("swiglu_ts", &swiglu_ts); - m.def("qgelu_ts", &qgelu_ts); - m.def("srelu_ts", &srelu_ts); - m.def("te_gemm_ts", &te_gemm_ts); - m.def("te_grouped_gemm_ts", &te_grouped_gemm_ts); - m.def("te_grouped_gemm_single_output_ts", &te_grouped_gemm_single_output_ts); - m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts); - m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts); - m.def("rmsnorm_fwd_fp8_inf_ts", &rmsnorm_fwd_fp8_inf_ts); - m.def("rmsnorm_fwd_inf_ts", &rmsnorm_fwd_inf_ts); -} diff --git a/transformer_engine/pytorch/csrc/type_shim.h b/transformer_engine/pytorch/csrc/type_shim.h index 5d5a91f9eb..8100f0e4a2 100644 --- a/transformer_engine/pytorch/csrc/type_shim.h +++ b/transformer_engine/pytorch/csrc/type_shim.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h new file mode 100644 index 0000000000..cbdf0833ed --- /dev/null +++ b/transformer_engine/pytorch/csrc/util.h @@ -0,0 +1,12 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ + +bool non_tn_fp8_gemm_supported(); + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 490ac3b160..2a614f67d7 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -7,6 +7,7 @@ from contextlib import contextmanager, AbstractContextManager, ContextDecorator from functools import lru_cache +import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import warnings @@ -20,7 +21,11 @@ from .utils import safely_set_viewless_tensor_data from .constants import dist_group_type from .fp8 import FP8GlobalStateManager -from .float8_tensor import Float8Tensor +from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer +from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.quantized_tensor import QuantizedTensor, Quantizer +from .tensor._internal.float8_tensor_base import Float8TensorBase +from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = ["checkpoint", "CudaRNGStatesTracker"] @@ -814,89 +819,313 @@ def fork(self, name: str = "model-parallel-rng"): def reduce_scatter_along_first_dim( - input_: torch.Tensor, tp_group: dist_group_type, async_op: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: + inp: torch.Tensor, tp_group: dist_group_type, async_op: bool = False +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """Reduce-scatter the input tensor across model parallel group.""" world_size = get_distributed_world_size(tp_group) # Bypass the function if we are using only 1 GPU. if world_size == 1: - return input_, None + return inp, None - dim_size = list(input_.size()) + dim_size = list(inp.size()) assert ( dim_size[0] % world_size == 0 ), "First dimension of the tensor should be divisible by tensor parallel size" dim_size[0] = dim_size[0] // world_size - output = torch.empty(dim_size, dtype=input_.dtype, device=torch.cuda.current_device()) + output = torch.empty(dim_size, dtype=inp.dtype, device=torch.cuda.current_device()) handle = torch.distributed.reduce_scatter_tensor( - output, input_.contiguous(), group=tp_group, async_op=async_op + output, inp.contiguous(), group=tp_group, async_op=async_op ) return output, handle +def _all_gather_fp8( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: Optional[Float8Quantizer] = None, + out_shape: Optional[list[int]] = None, +) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: + """All-gather FP8 tensor along first dimension.""" + world_size = get_distributed_world_size(process_group) + + # Output tensor dims + if out_shape is None: + out_shape = list(inp.size()) + out_shape[0] *= world_size + + # Quantize input tensor if needed + if not isinstance(inp, Float8TensorBase): + assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) + # we cannot directly gather the transposed fp8 tensor + # so we need to disable columnwise usage for the quantizer + # and then set it back to the original value after quantizing + init_columnwise_usage = quantizer.columnwise_usage + quantizer.set_usage(columnwise=False) + inp = quantizer(inp) + quantizer.set_usage(columnwise=init_columnwise_usage) + + # Construct output tensor + out: Float8TensorBase + if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + dtype = torch.float32 + device = "cuda" + if isinstance(inp, Float8Tensor): + dtype = inp.dtype + device = inp.device + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + elif isinstance(inp, Float8Tensor): + out = inp.make_like(inp, shape=out_shape) + out._data = torch.empty_like( + out_shape, + dtype=torch.uint8, + device=inp.device, + ) + out._transpose = None + out._transpose_invalid = True + else: + raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + # For delayed scaling, scale_inv is from history, so we can pass it from inp to out + # For current scaling, scale_inv is from doing amax reduction in C++ code, so each rank should have same scale_inv, + # so we can just pass it from inp to out + out._scale_inv = inp._scale_inv + + # Perform communication + handle = torch.distributed.all_gather_into_tensor( + out._data, + inp._data.contiguous(), + group=process_group, + async_op=async_op, + ) + + # Make sure FP8 transpose is populated if needed + if out._transpose is not None: + if handle is not None: + handle.wait() + handle = None + if not isinstance(out, Float8Tensor): + raise RuntimeError("FP8TensorBase does not support FP8 transpose yet") + out._create_transpose() + + return out, handle + + +def _all_gather_mxfp8( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: MXFP8Quantizer, + out_shape: Optional[list[int]] = None, +) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: + """All-gather MXFP8 tensor along first dimension.""" + + # Tensor dims + world_size = get_distributed_world_size(process_group) + in_shape = list(inp.size()) + if out_shape is None: + out_shape = [in_shape[0] * world_size] + in_shape[1:] + + # For cases where inp has dimensions that cannot be quantized, + # we gather in high precision followed by a cast to FP8. + if ( + not isinstance(inp, MXFP8TensorBase) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): + out = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None + + inp_dtype = inp.dtype + inp_device = inp.device + + # Cast input tensor to MXFP8 with required data + if not isinstance(inp, MXFP8TensorBase): + inp = quantizer(inp) + elif ( + inp.rowwise_data is None + and quantizer.rowwise_usage + or inp.columnwise_data is None + and quantizer.columnwise_usage + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to MXFP8." + ) + inp = quantizer(inp.dequantize()) + + # Construct MXFP8 output tensor + out = quantizer.make_empty(out_shape, dtype=inp_dtype, device=inp_device) + + # Async op handle + handle = None + + # Gather MXFP8 data for row-wise usage + if quantizer.rowwise_usage: + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = inp._rowwise_scale_inv + out_scale_inv = out._rowwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = torch.distributed.all_gather_into_tensor( + out._rowwise_data, + inp._rowwise_data, + group=process_group, + async_op=async_op, + ) + + # Gather MXFP8 data for column-wise usage + if quantizer.columnwise_usage: + + # Remove padding from MXFP8 scale-inverses + in_scale_inv = inp._columnwise_scale_inv + out_scale_inv = out._columnwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv[flattened_in_shape0 * world_size :].zero_() + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + if handle is not None: + handle.wait() + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + handle = torch.distributed.all_gather_into_tensor( + out._columnwise_data, + inp._columnwise_data, + group=process_group, + async_op=async_op, + ) + + return out, handle + + def gather_along_first_dim( - input_: torch.Tensor, + inp: torch.Tensor, process_group: dist_group_type, async_op: bool = False, -) -> tuple[torch.Tensor, Any]: + quantizer: Optional[Quantizer] = None, +) -> tuple[torch.Tensor, Optional[torch.distributed.Work]]: """All-gather tensors and concatenate along first dimension.""" # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - return input_, None - - # Allocate output tensor - output_shape = list(input_.size()) - output_shape[0] *= world_size - if isinstance(input_, Float8Tensor): - output = Float8Tensor.make_like( - input_, - data=torch.empty( - output_shape, - dtype=torch.uint8, - device=input_.device, - ), + if quantizer is not None and not isinstance(inp, QuantizedTensor): + inp = quantizer(inp) + return inp, None + + # Output tensor dims + out_shape = list(inp.size()) + out_shape[0] *= world_size + + # FP8 case: delayed scaling or current scaling + if isinstance(inp, Float8TensorBase) or isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + return _all_gather_fp8( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, ) - src = input_._data.contiguous() - dst = output._data - else: - output = torch.empty( - output_shape, - dtype=input_.dtype, - device=input_.device, + + # MXFP8 case + if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): + assert isinstance(quantizer, MXFP8Quantizer) + return _all_gather_mxfp8( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + + # High-precision communication for quantized tensors + if quantizer is not None: + warnings.warn( + "Attempting to all-gather an unsupported quantized tensor. " + "Falling back to high-precision all-gather." + ) + if isinstance(inp, QuantizedTensor): + inp = inp.dequantize() + out = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) - src = input_.contiguous() - dst = output + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None - # Launch all-gather + # Dequantize quantized tensor if not supported + if isinstance(inp, QuantizedTensor): + warnings.warn( + "Attempting to all-gather an unsupported quantized tensor. " + "Falling back to high-precision all-gather." + ) + inp = inp.dequantize() + + # Communication for plain PyTorch tensors + out = torch.empty( + out_shape, + dtype=inp.dtype, + device=inp.device, + memory_format=torch.contiguous_format, + ) handle = torch.distributed.all_gather_into_tensor( - dst, - src, + out, + inp.contiguous(), group=process_group, async_op=async_op, ) - return output, handle + return out, handle def allreduce( - input_: torch.Tensor, + inp: torch.Tensor, tp_group: Optional[dist_group_type] = None, async_op: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, Optional[torch.distributed.Work]]: """All-reduce the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. if get_distributed_world_size(tp_group) == 1: - return input_, None + return inp, None # All-reduce. - handle = torch.distributed.all_reduce(input_, group=tp_group, async_op=async_op) + handle = torch.distributed.all_reduce(inp, group=tp_group, async_op=async_op) - return input_, handle + return inp, handle def _fsdp_scatter_tensors( @@ -907,12 +1136,13 @@ def _fsdp_scatter_tensors( if fsdp_group is not None: for t in tensors: if isinstance(t, torch.Tensor): - target = t._data if isinstance(t, Float8Tensor) else t - shapes.append(target.data.shape) - safely_set_viewless_tensor_data( - target, - split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), - ) + targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] + for target in targets: + shapes.append(target.data.shape) + safely_set_viewless_tensor_data( + target, + split_tensor_into_1d_equal_chunks(target.data, fsdp_group, new_buffer=True), + ) else: shapes.append(None) return shapes @@ -928,10 +1158,11 @@ def _fsdp_gather_tensors( for s, t in zip(shapes, tensors): if isinstance(t, torch.Tensor): assert s is not None, "Internal TE error." - target = t._data if isinstance(t, Float8Tensor) else t - safely_set_viewless_tensor_data( - target, gather_split_1d_tensor(target.data, fsdp_group).view(s) - ) + targets = t.get_data_tensors() if isinstance(t, QuantizedTensor) else [t] + for target in targets: + safely_set_viewless_tensor_data( + target, gather_split_1d_tensor(target.data, fsdp_group).view(s) + ) def _is_te_module(module): diff --git a/transformer_engine/pytorch/dot_product_attention/__init__.py b/transformer_engine/pytorch/dot_product_attention/__init__.py new file mode 100644 index 0000000000..6a4c84f47d --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Python interface for dot product attention""" diff --git a/transformer_engine/pytorch/dot_product_attention/inference.py b/transformer_engine/pytorch/dot_product_attention/inference.py new file mode 100644 index 0000000000..ae220225e8 --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/inference.py @@ -0,0 +1,798 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Inference""" +import logging +from collections import OrderedDict, defaultdict +from typing import Optional, List +from einops import rearrange + +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat + +__all__ = ["InferenceParams", "KVCacheManager", "NonPagedKVCacheManager", "PagedKVCacheManager"] + + +class KVCacheManager: + """Base KV cache manager""" + + def __init__(self): + """Initialize cache manager""" + self.cache = {} + self.sequences = OrderedDict() + + def reset(self): + """Reset cache manager state""" + self.sequences = OrderedDict() + + def allocate_memory(self, layer_number: int): + """Allocate memory for the cache""" + self.cache[layer_number] = (None, None) + + def pre_step( + self, + step_dict: OrderedDict, # pylint: disable=unused-argument + ): + """Update tracked sequences and prepare for step()""" + return self.sequences + + def step( + self, + layer_number: int, + new_k: torch.Tensor, # pylint: disable=unused-argument + new_v: torch.Tensor, # pylint: disable=unused-argument + cu_new_seqlens: torch.Tensor, # pylint: disable=unused-argument + cu_cached_seqlens: torch.Tensor, # pylint: disable=unused-argument + qkv_format: str, # pylint: disable=unused-argument + ): + """Copy the new tokens to KV cache""" + return self.cache[layer_number] + + +class InferenceParams: + """ + KV caching for inference. The memory allocation of the caches and the copying of new tokens + to the cache take place at the following locations.:: + + class TransformerLayer: + class MultiHeadAttention: + if self.layer_number not in inference_params.cache_manager.cache: + inference_params.allocate_memory(self.layer_number) + class DotProductAttention: + if inference_params is not None: + k_cache, v_cache, new_qkv_format = inference_params.step( + new_k, new_v, qkv_format) + output = attention(new_q, k_cache, v_cache, new_qkv_format) + + allocate_memory() can be called outside the model, independently. step() can take three formats, + qkv_format = {'bshd', 'sbhd', 'thd'}. It converts new_k and new_v to 'bshd' in both + NonPagedKVCacheManager and PagedKVCacheManager. The format of new_q may change depending on the + backend. If it is unchanged, we would have new_qkv_format = {'bshd', 'sbhd_2bshd', 'thd_2bshd'}. + A standard KV caching workflow for inference is as follows.:: + + model = [TransformerLayer() for _ in range(num_layers)] + # initialize InferenceParams, e.g. with PagedKVCacheManager + inference_params = InferenceParams(..., is_paged=True) + # inference loop + for i in range(num_iters): + # get info for iteration i, e.g. seq_ids = [0, 2, 3], step_lens = [10, 1, 1] + step_dict = OrderedDict(zip(seq_ids, step_lens)) + # update inference_params' state + inference_params.pre_step(step_dict) + # run iteration + output = model( + ..., + attn_mask_type="padding_causal", + cu_seqlens_q=cu_seqlens_new_q, + cu_seqlens_kv=cu_seqlens_new_kv, + inference_params=inference_params, + ) + # get output tokens based on qkv_format + # 'bshd': output = output[:,step_dict.values()-1] + # 'sbhd': output = output[step_dict.values()-1,:] + # 'thd' : output = output[cu_seqlens_new_q[j+1]-1], j=0,...b-1 + + + Parameters + ---------- + max_batch_size: int + Maximum batch size in inference + max_seqlen_kv: int + Maximum sequence length in inference + num_heads_kv: int + Number of attention heads in keys and values + head_dim_k: int + Head size for keys + dtype: torch.dtype + Data type of the KV cache + head_dim_v: int, default = None + Head size for values. If None, initialized as head_dim_k. + is_paged: bool, default = False + Whether the KV cache is paged (True) or non-paged (False) + total_num_pages: int, default = None + Total number of pages in the KV cache. Required for is_paged = True. + page_size: int, default = None + Page size of the KV cache. Required for is_paged = True. + max_ctx_len: int, default = None + Maximum context length in inference. 1 <= max_ctx_len <= max_seqlen_kv. + qkv_format: str, default = "bshd" + Format of the incoming query/key/value tensors in current iteration + custom_cache_manager: KVCacheManager, default = None + Custom cache manager, with KVCacheManager as the base class. + """ + + def __init__( + self, + max_batch_size: int, + max_seqlen_kv: int, + num_heads_kv: int = 16, + head_dim_k: int = 64, + dtype: torch.dtype = torch.bfloat16, + head_dim_v: int = None, + is_paged: bool = False, + total_num_pages: int = None, + page_size: int = None, + max_ctx_len: int = None, + qkv_format: str = "bshd", + custom_cache_manager: KVCacheManager = None, + ): + self.max_batch_size = max_batch_size + self.max_seqlen_kv = max_seqlen_kv + self.num_heads_kv = num_heads_kv + self.head_dim_k = head_dim_k + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + self.is_paged = is_paged + + if not self.is_paged: + cache_manager = ( + custom_cache_manager if custom_cache_manager is not None else NonPagedKVCacheManager + ) + self.cache_manager = cache_manager( + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + head_dim_v=self.head_dim_v, + ) + else: + assert page_size is not None, "Paged KV cache requires page_size is not None." + self.page_size = page_size + assert ( + max_seqlen_kv % page_size == 0 + ), "Paged KV cache requires max_seqlen_kv % page_size = 0." + max_pages_per_seq = max_seqlen_kv // page_size + assert ( + total_num_pages == self.max_batch_size * max_pages_per_seq + ), "Paged KV cache requires total_num_pages = max_batch_size * max_pages_per_seq." + self.total_num_pages = total_num_pages + + cache_manager = ( + custom_cache_manager if custom_cache_manager is not None else PagedKVCacheManager + ) + self.cache_manager = cache_manager( + total_num_pages=self.total_num_pages, + page_size=self.page_size, + num_heads=self.num_heads_kv, + head_dim_k=self.head_dim_k, + dtype=self.dtype, + max_batch_size=self.max_batch_size, + max_seqlen=self.max_seqlen_kv, + head_dim_v=self.head_dim_v, + ) + + if qkv_format == "thd": + assert max_ctx_len is not None, "max_ctx_len is required when qkv_format=thd!" + self.max_ctx_len = max_ctx_len + + self.cache_qkv_format = "bshd" + self.input_qkv_format = qkv_format + if self.input_qkv_format == self.cache_qkv_format: + self.output_qkv_format = self.cache_qkv_format + else: + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format + + self.sequences_pre_step = OrderedDict() + self.sequences = OrderedDict() + self.batch_size = 0 + + self.cu_seqlens_q = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + self.cu_seqlens_kv = torch.zeros( + self.max_batch_size + 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + + def reset(self): + """Reset InferenceParams state""" + self.sequences = OrderedDict() + self.cache_manager.reset() + + def __repr__(self) -> str: + if self.is_paged: + return ( + f"dtype={self.dtype}, " + f"is_paged={self.is_paged}, " + f"total_pages={self.total_num_pages}, " + f"page_size={self.page_size}, " + f"num_heads={self.num_heads_kv}, " + f"head_dim_k={self.head_dim_k}, " + f"head_dim_v={self.head_dim_v}" + ) + return ( + f"dtype={self.dtype}, " + f"is_paged={self.is_paged}, " + f"max_batch_size={self.max_batch_size}, " + f"max_seqlen={self.max_seqlen_kv}, " + f"num_heads={self.num_heads_kv}, " + f"head_dim_k={self.head_dim_k}, " + f"head_dim_v={self.head_dim_v}" + ) + + def allocate_memory(self, layer_number: int): + """ + Allocate memory for the cache. For layer layer_number, + - NonPagedKVCacheManager: + - K cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_k] + - V cache: [max_batch_size, max_seqlen_kv, num_heads_kv, head_dim_v] + - PagedKVCacheManager: + - K cache: [total_num_pages, page_size, num_heads_kv, head_dim_k] + - V cache: [total_num_pages, page_size, num_heads_kv, head_dim_v] + """ + self.cache_manager.allocate_memory(layer_number) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + self.batch_size = len(step_dict) + + self.sequences = self.cache_manager.pre_step(step_dict) + # track the pre-step seqlens for the next layer in the model + self.sequences_pre_step = OrderedDict() + for k, v in self.sequences.items(): + self.sequences_pre_step[k] = v - step_dict[k] + + seqlens_q = list(step_dict.values()) + cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] + cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) + self.cu_seqlens_q.copy_(torch.Tensor(cu_seqlens_q).to(dtype=torch.int32, device="cpu")) + + seqlens_kv = list(self.sequences.values()) + cu_seqlens_kv = [0] + [sum(seqlens_kv[:i]) for i in range(1, self.batch_size + 1)] + cu_seqlens_kv = cu_seqlens_kv + [cu_seqlens_kv[-1]] * ( + self.max_batch_size - self.batch_size + ) + self.cu_seqlens_kv.copy_(torch.Tensor(cu_seqlens_kv).to(dtype=torch.int32, device="cpu")) + + def get_seqlens_pre_step(self): + """Get cached sequence lengths before the stepping""" + return torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) + + def convert_paged_to_nonpaged(self, layer_number: int): + """ + Convert k_cache and v_cache from paged to non-paged format. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + + Returns + ------- + k_cache: torch.Tensor + Non-paged key cache tensor + v_cache: torch.Tensor + Non-paged value cache tensor + """ + k_cache, v_cache = self.cache_manager.cache[layer_number] + page_table = self.cache_manager.page_table + batch_size = page_table.shape[0] + new_k_cache = rearrange( + k_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, + ) + new_v_cache = rearrange( + v_cache[page_table.flatten()], + "(b npages) page_size ... -> b (npages page_size) ...", + b=batch_size, + ) + + new_k_cache = new_k_cache[: self.batch_size].contiguous() + new_v_cache = new_v_cache[: self.batch_size].contiguous() + + return new_k_cache, new_v_cache + + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + qkv_format: str, + ): + """ + Copy new KV tokens to the cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + qkv_format: str + Format of new_q, new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + cu_seqlens_q: torch.Tensor + Updated cumulative sequence lengths for query, [batch_size + 1] + cu_seqlens_kv: torch.Tensor + Updated cumulative sequence lengths for key and value, [batch_size + 1] + max_seqlen_q: int + Update maximum sequence length for query + max_seqlen_kv: int + Update maximum sequence length for key and value + qkv_format: str + Updated qkv_format, e.g. 'thd' format becomes 'thd_2bshd' after step() + """ + self.input_qkv_format = qkv_format + if self.input_qkv_format == self.cache_qkv_format: + self.output_qkv_format = self.cache_qkv_format + else: + self.output_qkv_format = self.input_qkv_format + "_2" + self.cache_qkv_format + + k_cache, v_cache = self.cache_manager.step( + layer_number, + new_k, + new_v, + self.cu_seqlens_q, + self.cu_seqlens_kv, + qkv_format, + ) + + return ( + k_cache, + v_cache, + self.cu_seqlens_q, + self.cu_seqlens_kv, + self.max_seqlen_kv, + self.output_qkv_format, + ) + + +class NonPagedKVCacheManager(KVCacheManager): + """Non-paged KV cache manager""" + + def __init__( + self, + max_batch_size: int, + max_seqlen: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + head_dim_v: Optional[int] = None, + ): + super().__init__() + """Initialize cache manager""" + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + + # track sequences in the cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # cache tensors, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # track sequence indices in the batch in order to re-index k_cache and v_cache + self.batch_indices = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + # after re-indexing, batch indices are always [0, ..., b-1] + self.batch_indices_post_step = torch.range( + 0, + self.max_batch_size - 1, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + + def allocate_memory(self, layer_number): + """Allocate memory for the cache""" + k_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + v_cache = torch.zeros( + self.max_batch_size, + self.max_seqlen, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cache[layer_number] = (k_cache, v_cache) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + # Track unfinished sequences' indices in the batch, e.g. + # at t-1, seq_ids = [0, 1, 2, 3]; at t, seq_ids = [0, 2, 3] since seq_id 1 is finished + # step() re-indexes k_cache and v_cache using batch_indices = [0, 2, 3, 1] so that + # they are contiguous and match the indexing in q + prev_batch_size = len(self.sequences) + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] + finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] + self.batch_indices.copy_( + torch.Tensor( + ( + unfinished_indices + + finished_indices + + list(range(prev_batch_size, self.max_batch_size)) + ) + ).to(dtype=torch.int32, device="cpu") + ) + + # Advance unfinished sequences + for i in unfinished_seqs: + self.sequences[i] += 1 + + # Remove finished sequences + for i in finished_seqs: + self.sequences.pop(i) + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for i in new_seqs: + self.sequences[i] = step_dict[i] + + return self.sequences + + def step( + self, + layer_number, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, + qkv_format: str, + ): + """ + Copy the new tokens to the non-paged KV cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + qkv_format: str + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + """ + k_cache, v_cache = self.cache[layer_number] + + batch_size = self.max_batch_size + ctx_len = 1 + if qkv_format == "bshd": + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] + if qkv_format == "sbhd": + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + + tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + self.batch_indices, + cu_new_seqlens, + cu_cached_seqlens, + QKVFormat[qkv_format], + batch_size, + ctx_len, + self.max_seqlen, + 1, + True, + ) + + k_cache = k_cache[:batch_size] + v_cache = v_cache[:batch_size] + + return k_cache, v_cache + + +class Page: + """A single page""" + + def __init__(self, page_id: int): + """Initialize a page""" + self.page_id = page_id + self.allocated = 0 + + def allocate_page(self): + """Allocate a page""" + self.allocated = True + + def deallocate_page(self): + """Deallocate a page""" + self.allocated = False + + +class PagedKVCacheManager(KVCacheManager): + """Paged KV cache manager""" + + def __init__( + self, + total_num_pages: int, + page_size: int, + num_heads: int, + head_dim_k: int, + dtype: torch.dtype, + max_batch_size: int, + max_seqlen: int, + head_dim_v: Optional[int] = None, + ): + super().__init__() + """Initialize cache manager""" + self.total_num_pages = total_num_pages + self.page_size = page_size + self.num_heads = num_heads + self.head_dim_k = head_dim_k + self.dtype = dtype + self.max_batch_size = max_batch_size + self.max_seqlen = max_seqlen + self.max_pages_per_seq = max_seqlen // self.page_size + self.head_dim_v = head_dim_v if head_dim_v is not None else head_dim_k + + # track sequences in the cache, {seq_id: seq_len} + self.sequences = OrderedDict() + # cache tensors, cache[layer_number] = (k_cache, v_cache) + self.cache = {} + # available pages, [Page(),...] + self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + # allocated pages, {seq_id: [page_id,...]} + self.allocated_pages = defaultdict(list) + # page table, [batch_size, max_pages_per_seq] + self.page_table = torch.zeros( + self.max_batch_size, self.max_pages_per_seq, dtype=torch.int32, device="cuda" + ) + + def reset(self): + """Reset cache manager state""" + self.sequences = OrderedDict() + self.free_pages = [] + for i in range(self.total_num_pages): + self.free_pages.append(Page(i)) + self.allocated_pages = defaultdict(list) + self.page_table.fill_(0) + + def allocate_memory(self, layer_number): + """Allocate memory for the cache""" + k_cache = torch.zeros( + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_k, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + v_cache = torch.zeros( + self.total_num_pages, + self.page_size, + self.num_heads, + self.head_dim_v, + dtype=self.dtype, + device=torch.cuda.current_device(), + ) + self.cache[layer_number] = (k_cache, v_cache) + + def print_cache(self): + """Print KV cache status""" + used_pages = [self.get_page_count(seq) for seq in self.sequences] + logger = logging.getLogger("PagedKVCacheManager") + logger.debug("Cache status:") + logger.debug( + " total pages: %s (used %s, free %s)", + self.total_num_pages, + sum(used_pages), + len(self.free_pages), + ) + logger.debug(" total sequences: %s", self.get_sequence_count()) + for i, seq in enumerate(self.sequences): + logger.debug( + " >> batch index %s: seq_id %s, num_tokens %s, num_pages %s, page_list %s", + i, + seq, + self.get_sequence_lengths()[i], + self.get_page_count(seq), + self.get_page_list(seq), + ) + + def get_sequence_count(self): + """Get the total number of sequences in the KV cache""" + return len(self.sequences) + + def get_sequence_lengths(self): + """Get the list of sequence lengths in the KV cache""" + return list(self.sequences.values()) + + def has_free_page(self) -> bool: + """Whether the page pool has any free pages left""" + return len(self.free_pages) > 0 + + def get_page_count(self, seq: int): + """Get the number of pages allocated to a sequence""" + return len(self.allocated_pages[seq]) + + def get_page_list(self, seq: int): + """Get the list of pages allocated to a sequence""" + return [x.page_id for x in self.allocated_pages[seq]] + + def get_page_table(self, sequences: List[int]): + """Get the page table, in shape [batch_size, max_pages_per_seq]""" + page_table = torch.Tensor( + [ + self.get_page_list(seq) + [0] * (self.max_pages_per_seq - self.get_page_count(seq)) + for seq in sequences + ] + ).to(dtype=torch.int32, device="cpu") + self.page_table[: self.get_sequence_count()].copy_(page_table) + return self.page_table + + def allocate_page(self, seq: int): + """Allocate a new page to a sequence""" + if not self.has_free_page(): + raise RuntimeError("KV cache is full!") + page = self.free_pages.pop(0) + page.allocate_page() + self.allocated_pages[seq].append(page) + + def allocate_sequence(self, seq: int, context_len: int): + """Add a new sequence to the cache""" + num_pages = context_len // self.page_size + if context_len % self.page_size > 0: + num_pages = num_pages + 1 + for _ in range(num_pages): + self.allocate_page(seq) + + def deallocate_sequence(self, seq: int): + """Deallocate all the pages for a sequence""" + for page in self.allocated_pages[seq]: + page.deallocate_page() + if not page.allocated: + self.free_pages.append(page) + self.allocated_pages.pop(seq) + + def pre_step( + self, + step_dict: OrderedDict, + ): + """Update tracked sequences and prepare for step()""" + # Remove finished sequences and advance unfinished sequences + unfinished_seqs = self.sequences.keys() & step_dict.keys() + finished_seqs = self.sequences.keys() - unfinished_seqs + for seq in finished_seqs: + self.sequences.pop(seq) + self.deallocate_sequence(seq) + for seq in unfinished_seqs: + if self.sequences[seq] % self.page_size == 0 and self.sequences[seq] < self.max_seqlen: + self.allocate_page(seq) + self.sequences[seq] += 1 + + # Add new sequences + new_seqs = step_dict.keys() - self.sequences.keys() + for seq in new_seqs: + self.sequences[seq] = step_dict[seq] + self.allocate_sequence(seq, step_dict[seq]) + + # Get page table + self.page_table = self.get_page_table(list(self.sequences.keys())) + + return self.sequences + + def step( + self, + layer_number: int, + new_k: torch.Tensor, + new_v: torch.Tensor, + cu_new_seqlens, + cu_cached_seqlens, + qkv_format: str, + ): + """ + Copy the new tokens to the paged KV cache. + + Parameters + ---------- + layer_number: int + Layer number of attention in the model + new_k: torch.Tensor + New key tokens for layer_number in current inference iteration + new_v: torch.Tensor + New value tokens for layer_number in current inference iteration + cu_new_seqlens: torch.Tensor + Cumulative sequence lengths for new_k and new_v, in shape [batch_size + 1] + cu_cached_seqlens: torch.Tensor + Cumulative sequence lengths for k_cache and v_cache (after new tokens are copied in), in shape [batch_size + 1] + qkv_format: str + Format of new_k and new_v tensors, {'bshd', 'sbhd', 'thd'} + + Returns + ------- + k_cache: torch.Tensor + Full key tensor containing both previous and current key tokens + v_cache: torch.Tensor + Full value tensor containing both previous and current value tokens + """ + k_cache, v_cache = self.cache[layer_number] + + batch_size = self.max_batch_size + ctx_len = 1 + if qkv_format == "bshd": + batch_size = new_k.shape[0] + ctx_len = new_k.shape[1] + if qkv_format == "sbhd": + batch_size = new_k.shape[1] + ctx_len = new_k.shape[0] + + tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + self.page_table, + cu_new_seqlens, + cu_cached_seqlens, + QKVFormat[qkv_format], + batch_size, + ctx_len, + self.max_seqlen, + self.max_pages_per_seq, + False, + ) + + return k_cache, v_cache diff --git a/transformer_engine/pytorch/dot_product_attention/rope.py b/transformer_engine/pytorch/dot_product_attention/rope.py new file mode 100644 index 0000000000..83698c7bc6 --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/rope.py @@ -0,0 +1,225 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Rotary Position Embedding implementation of different types along with helper functions +""" +from typing import Optional, Tuple, Union +import torch +import transformer_engine_torch as tex + + +class RotaryPositionEmbedding(torch.nn.Module): + """ + Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. + """ + + def __init__( + self, + dim: int, + rotary_percent: float = 1.0, + seq_len_interpolation_factor: Optional[int] = None, + pretrained_max_position_embeddings: Optional[int] = None, + rotary_base: float = 10000.0, + ): + """ + Parameters + ---------- + dim: int + rotary embedding dimension + rotary_percent: float + Percent of rotary dimension to use for rotary position embeddings. + seq_len_interpolation_factor: int + if not None, discrete positions will be interpolated by this factor via the trick in + https://arxiv.org/abs/2306.15595 + pretrained_max_position_embeddings: int + pre-trained max_position_embeddings before position interpolation + """ + super().__init__() + if rotary_percent < 1.0: + dim = int(dim * rotary_percent) + self.seq_len_interpolation_factor = seq_len_interpolation_factor + self.rotary_base = rotary_base + inv_freq = 1.0 / ( + self.rotary_base + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device()) + / dim + ) + ) + self.register_buffer("inv_freq", inv_freq) + self.pretrained_max_position_embeddings = pretrained_max_position_embeddings + + def forward(self, max_seq_len: int, offset: int = 0): + """ + Create rotary position embedding frequencies + + Parameters + ---------- + max_seq_len: int + sequence length of a sample + offset: int, default = 0 + fixed offset for freqencies + """ + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) + + if ( + self.pretrained_max_position_embeddings is not None + and self.seq_len_interpolation_factor is not None + ): + if ( + max_seq_len + > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor + ): + # dynamic linear scaling (length > position we have learned) + seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) + else: + # fixed linear scaling + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) + # first part even vector components, second part odd vector components, + # 2 * dim in dimension size + emb = torch.cat((freqs, freqs), dim=-1) + # emb [seq_length, .., dim] + return emb.reshape(emb.size(0), 1, 1, emb.size(1)) + + +class FusedRoPEFunc(torch.autograd.Function): + """ + Function for FusedRoPE + + This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and + the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid + the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern. + """ + + @staticmethod + def forward( + ctx, + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + if freqs.dtype != torch.float32: + freqs = freqs.float() + if tensor_format == "sbhd": + output = tex.fused_rope_forward(t, freqs, False) + elif tensor_format == "bshd": + output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) + elif tensor_format == "thd": + output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) + else: + raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + ctx.save_for_backward(freqs, cu_seqlens) + ctx.tensor_format = tensor_format + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank + + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: + # pylint: disable=missing-function-docstring + freqs, cu_seqlens = ctx.saved_tensors + if ctx.tensor_format == "sbhd": + grad_input = tex.fused_rope_backward(grad_output, freqs, False) + elif ctx.tensor_format == "bshd": + grad_input = tex.fused_rope_backward( + grad_output.transpose(0, 1), freqs, True + ).transpose(0, 1) + elif ctx.tensor_format == "thd": + grad_input = tex.fused_rope_thd_backward( + grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank + ) + else: + raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.") + + return grad_input, None, None, None, None, None + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """ + change sign so the last dimension becomes [-odd, +even] + """ + x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + t: torch.Tensor, + freqs: torch.Tensor, + tensor_format: str = "sbhd", + fused: bool = False, + cu_seqlens: Union[torch.Tensor, None] = None, + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input tensor. + + Parameters + ---------- + t: torch.Tensor + Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which + rotary positional embedding will be applied. + freqs: torch.Tensor + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + fused: bool, default = False + Whether to use a fused applying RoPE implementation. + tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' + is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is + of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. + cu_seqlens: torch.Tensor, default = None. + Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and + dtype torch.int32. Only valid when `tensor_format` is 'thd'. + Should be `cu_seqlens_padded` when cp_size > 1. + cp_size: int, default = 1. + Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True. + cp_rank: int, default = 0. + Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. + """ + if fused: + assert ( + tensor_format != "thd" or cu_seqlens is not None + ), "cu_seqlens must not be None when tensor_format is 'thd'." + return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank) + + assert tensor_format in ("sbhd", "bshd"), ( + "Only formats `sbhd` or `bshd` are supported for input tensor `t` " + f"when fused is False, got {tensor_format}." + ) + + max_seq_len = freqs.shape[0] + cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] + + # Only apply the rotary embeddings up to the sequence length of the running + # input. + assert ( + cur_seq_len <= max_seq_len + ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" + freqs = freqs[:cur_seq_len] + if tensor_format == "bshd": + freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] + # cos/sin first then dtype conversion for better precision + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + rot_dim = freqs.shape[-1] + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + t = (t * cos_) + (_rotate_half(t) * sin_) + return torch.cat((t, t_pass), dim=-1) diff --git a/transformer_engine/pytorch/dot_product_attention/utils.py b/transformer_engine/pytorch/dot_product_attention/utils.py new file mode 100644 index 0000000000..bae237c592 --- /dev/null +++ b/transformer_engine/pytorch/dot_product_attention/utils.py @@ -0,0 +1,1786 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utils/Helper classes and methods for attention +""" +import math +import os +from typing import Any, Dict, List, Optional, Tuple, Union +import warnings +import logging +import functools + +from dataclasses import dataclass, fields +import numpy as np +from packaging.version import Version as PkgVersion + +import torch +import torch.nn.functional as F +import transformer_engine_torch as tex +import transformer_engine as te +from transformer_engine.pytorch.cpp_extensions.fused_attn import ( + QKVLayout, + AttnBiasType, + AttnMaskType, + FusedAttnBackend, + META_QKV, + META_DQKV, + META_O, + META_DO, + META_S, + META_DP, + META_O_CP, + META_DQKV_CP, +) +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.constants import TE_DType + + +from transformer_engine.pytorch.utils import ( + get_device_compute_capability, + get_cudnn_version, +) + +from transformer_engine.pytorch.jit import jit_fuser + +# NVTE_DEBUG = 0/1 # disables/enables debug mode, default = 0 +_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0")) +# NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 +_NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) +_NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) + + +class AttentionLogging: + """ + Manage logging for attention module + """ + + _log_level = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL + _formatter = logging.Formatter("[%(levelname)-8s | %(name)-19s]: %(message)s") + _stream_handler = logging.StreamHandler() + fa_logger = logging.getLogger(__name__) + + @staticmethod + def setup_logging(): + """ + Set up log levels, logger and handlers + """ + _log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + AttentionLogging._log_level = _log_levels[ + AttentionLogging._log_level if AttentionLogging._log_level in [0, 1, 2] else 2 + ] + AttentionLogging._stream_handler.setFormatter(AttentionLogging._formatter) + AttentionLogging.fa_logger.setLevel(AttentionLogging._log_level) + if not AttentionLogging.fa_logger.hasHandlers(): + AttentionLogging.fa_logger.addHandler(AttentionLogging._stream_handler) + + +@functools.lru_cache(maxsize=None) +def _get_supported_versions(version_min, version_max): + """ + Calculate version info based on min and max numbers + """ + return ">= " + str(version_min) + ", " + "<= " + str(version_max) + + +class FlashAttentionUtils: + """ + Manage Flash Attention versioning information + """ + + is_installed = False + version = PkgVersion("0") + version_required = PkgVersion("2.1.1") + version_required_blackwell = PkgVersion("2.7.3") + max_version = PkgVersion("2.7.4.post1") + v2_plus = False + v2_1_plus = False + v2_3_plus = False + v2_4_plus = False + v2_4_1_plus = False + v2_5_plus = False + v2_5_7_plus = False + v2_6_0_plus = False + v2_7_0_plus = False + warning_printed = False + + v3_is_installed = False + fa3_version = PkgVersion("0") + v3_0_0_beta = False + use_v3 = False + # FA3 from FA 2.7.3+/hopper has different APIs than FA3 from 2.7.2/hopper + # Please follow these instructions to install FA3 + v3_installation_steps = """\ +(1) git clone https://github.com/Dao-AILab/flash-attention.git +(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install +(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` +(4) mkdir -p $python_path/flash_attn_3 +(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py""" + v3_warning_printed = False + + @staticmethod + def set_flash_attention_version(): + """ + Setup version info for FA v2.x + """ + FlashAttentionUtils.is_installed = True + FlashAttentionUtils.v2_plus = FlashAttentionUtils.version >= PkgVersion("2") + FlashAttentionUtils.v2_1_plus = FlashAttentionUtils.version >= PkgVersion("2.1") + FlashAttentionUtils.v2_3_plus = FlashAttentionUtils.version >= PkgVersion("2.3") + FlashAttentionUtils.v2_4_plus = FlashAttentionUtils.version >= PkgVersion("2.4") + FlashAttentionUtils.v2_4_1_plus = FlashAttentionUtils.version >= PkgVersion("2.4.1") + FlashAttentionUtils.v2_5_plus = FlashAttentionUtils.version >= PkgVersion("2.5.0") + FlashAttentionUtils.v2_5_7_plus = FlashAttentionUtils.version >= PkgVersion("2.5.7") + FlashAttentionUtils.v2_6_0_plus = FlashAttentionUtils.version >= PkgVersion("2.6.0") + FlashAttentionUtils.v2_7_0_plus = FlashAttentionUtils.version >= PkgVersion("2.7.0") + + @staticmethod + def set_flash_attention_3_params(): + """ + Setup version info for FA v3.x + """ + FlashAttentionUtils.v3_is_installed = True + FlashAttentionUtils.v3_0_0_beta = ( + PkgVersion("3.0.0b") < FlashAttentionUtils.fa3_version < PkgVersion("3.0.0") + ) + + +@dataclass(eq=True) +class AttentionParams: + """ + Attention parameters used to determine which backend to be used. + + Parameters + ---------- + qkv_type: Union[torch.Tensor, Float8Tensor], default = `torch.Tensor` + Type of query/key/value tensors, {`torch.Tensor`, `Float8Tensor`}. + qkv_dtype: torch.dtype, default = `torch.bfloat16` + Data type of query/key/value tensors. + qkv_layout: str, default = "sbh3d" + Query/key/value tensor memory layout. + batch_size: int, default = 1 + Batch size. + num_heads: int, default = 16 + Number of attention heads in the query tensor. + num_gqa_groups: int, default = 16 + Number of attention heads in key and value tensors. + max_seqlen_q: int, default = 128 + Maximum sequence length of the query tensor. + max_seqlen_kv: int, default = 128 + Maximum sequence length of the key and value tensors. + head_dim_qk: int, default = 64 + The size of each attention head in query and key tensors. + head_dim_v: int, default = 64 + The size of each attention head in the value tensor. + attn_mask_type: str, default = `no_mask` + Attention mask type, {`no_mask`, `padding`, `causal`, `padding_causal`, + `causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`} + window_size: Tuple[int, int], default = None + Sliding window attention size. + alibi_slopes_shape: Optional[Union[torch.Size, List]], default = `None` + Tensor shape of :attr:`alibi_slopes` in `DotProductAttention`. + core_attention_bias_type: str, default = `no_bias` + Attention bias type, {`no_bias`, `pre_scale_bias`, `post_scale_bias`, `alibi`}. + core_attention_bias_shape: str, default = `1hss` + Attention bias shape, {`1hss`, `b1ss`, `bhss`}. + core_attention_bias_requires_grad: bool, default = `True` + Whether attention bias requires gradient. + pad_between_seqs: bool, default = `False` + Whether there is padding between sequences in a batch. + This only applies to `qkv_format=thd`. + attention_dropout: float, default = 0.0 + Attention dropout. + context_parallel: bool, default = `False` + Whether context parallelism is used or not. + deterministic: bool, default = `False` + Whether to run `DotProductAttention` with determinism or not. + is_training: bool, default = `True` + Whether in training mode (`True`) or inference mode (`False`) + fp8: bool, default = `False` + Whether `DotProductAttention` is in an `fp8_autocast` region. + fp8_meta: Optional[Dict[str Any]], default = `None` + The FP8 metadata tensor of `DotProductAttention`. + inference_params: Optional[InferenceParams], default = `None` + Inference-related parameters. See InferenceParams for details. + """ + + qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor + qkv_dtype: torch.dtype = torch.bfloat16 + qkv_layout: str = "sbh3d" + batch_size: int = 1 + num_heads: int = 16 + num_gqa_groups: int = 16 + max_seqlen_q: int = 128 + max_seqlen_kv: int = 128 + head_dim_qk: int = 64 + head_dim_v: int = 64 + attn_mask_type: str = "no_mask" + window_size: Union[Tuple[int, int], None] = None + alibi_slopes_shape: Union[torch.Size, List, None] = None + core_attention_bias_type: str = "no_bias" + core_attention_bias_shape: str = "1hss" + core_attention_bias_requires_grad: bool = True + pad_between_seqs: bool = False + attention_dropout: float = 0.0 + context_parallel: bool = False + deterministic: bool = False + is_training: bool = True + fp8: bool = False + fp8_meta: Union[Dict[str, Any], None] = None + inference_params: Optional[InferenceParams] = None + + def __eq__(self, other): + """ + Overwrite dataclass.__eq__ so that only fp8_meta["recipe"] is compared, + since all other entries of fp8_meta are unused in get_attention_backend. + """ + if not isinstance(other, self.__class__): + return NotImplemented + for field in fields(self): + fname = field.name + sf = getattr(self, fname) + of = getattr(other, fname) + if fname != "fp8_meta": + if sf != of: + return False + elif sf.get("recipe", None) != of.get("recipe", None): + return False + return True + + +def get_attention_backend( + attention_params: AttentionParams = None, +): + """ + Select the appropriate attention backend/sub-backend based on user input and runtime environment. + + Parameters + ---------- + See `AttentionParams`. + + Returns + ---------- + use_flash_attention: bool + Whether the `FlashAttention` backend has been selected. + use_fused_attention: bool + Whether the `FusedAttention` backend has been selected. + fused_attention_backend: tex.NVTE_Fused_Attn_Backend + If `use_fused_attention = True`, one of `FusedAttention` three sub-backends, else `None`. + use_unfused_attention: bool + Whether the `UnfusedDotProductAttention` backend has been selected. + available_backends: List[bool] + All available backends that could support the provided input. A list of Booleans + in the form of [use_flash_attention, use_fused_attention, use_unfused_attention]. + """ + # NOTE: As part of refactoring attention.py, populating the _attention_backends cache in attention + # is no longer performed at the end of get_attention_backend(), but the responsibility of doing so + # is shifted over to the caller of this function + qkv_type = attention_params.qkv_type + qkv_dtype = attention_params.qkv_dtype + qkv_layout = attention_params.qkv_layout + batch_size = attention_params.batch_size + num_heads = attention_params.num_heads + num_gqa_groups = attention_params.num_gqa_groups + max_seqlen_q = attention_params.max_seqlen_q + max_seqlen_kv = attention_params.max_seqlen_kv + head_dim_qk = attention_params.head_dim_qk + head_dim_v = attention_params.head_dim_v + attn_mask_type = attention_params.attn_mask_type + window_size = attention_params.window_size + alibi_slopes_shape = attention_params.alibi_slopes_shape + core_attention_bias_type = attention_params.core_attention_bias_type + core_attention_bias_shape = attention_params.core_attention_bias_shape + core_attention_bias_requires_grad = attention_params.core_attention_bias_requires_grad + pad_between_seqs = attention_params.pad_between_seqs + attention_dropout = attention_params.attention_dropout + context_parallel = attention_params.context_parallel + deterministic = attention_params.deterministic + is_training = attention_params.is_training + fp8 = attention_params.fp8 + fp8_meta = attention_params.fp8_meta + inference_params = attention_params.inference_params + + # Run config + logger = logging.getLogger("DotProductAttention") + logger.setLevel(AttentionLogging._log_level) + if not logger.hasHandlers(): + logger.addHandler(AttentionLogging._stream_handler) + device_compute_capability = get_device_compute_capability() + cudnn_version = get_cudnn_version() + run_config = { + "transformer_engine_version": te.__version__, + "compute_capability": "sm" + + str(10 * device_compute_capability[0] + device_compute_capability[1]), + "flash_attn_version": ( + str(FlashAttentionUtils.version) + if FlashAttentionUtils.is_installed + else "not installed" + ), + "flash_attn_3_version": ( + str(FlashAttentionUtils.fa3_version) + if FlashAttentionUtils.v3_is_installed + else "not installed" + ), + "cudnn_version": ".".join([str(i) for i in cudnn_version]), + } + attention_params_dict = { + field.name: getattr(attention_params, field.name) for field in fields(attention_params) + } + run_config.update(attention_params_dict) + if fp8: + run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + logger.debug("Running with config=%s", run_config) + + # The following sections check if `FlashAttention` supports the provided attention params, + # regardless of whether FA2 or FA3 is installed. If FA2 or FA3 is not installed but is + # necessary for performance/functionality, a warning will be issued to prompt users to + # install an appropriate FA version. + qkv_format, q_format, _ = get_qkv_format(qkv_layout, inference_params) + + # Filter: Environment variables + use_flash_attention = int(os.getenv("NVTE_FLASH_ATTN", "1")) + use_flash_attention_2 = use_flash_attention + use_flash_attention_3 = use_flash_attention + flash_attention_backend = None + use_fused_attention = int(os.getenv("NVTE_FUSED_ATTN", "1")) + use_unfused_attention = int(os.getenv("NVTE_UNFUSED_ATTN", "1")) + if not use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 due to NVTE_FLASH_ATTN=0") + if not use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 due to NVTE_FLASH_ATTN=0") + if not use_fused_attention: + logger.debug("Disabling FusedAttention due to NVTE_FUSED_ATTN=0") + if not use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0") + + # Filter: Compute capability + if device_compute_capability < (8, 0): + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 for compute capability < sm80") + use_flash_attention_2 = False + if use_fused_attention: + logger.debug("Disabling FusedAttention for compute capability < sm80") + use_fused_attention = False + if device_compute_capability != (9, 0): + if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for compute capability != sm90") + use_flash_attention_3 = False + + # Filter: Data type + if qkv_dtype not in [torch.bfloat16, torch.float16]: + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 for unsupported qkv_dtype = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16}. ", + qkv_dtype, + ) + use_flash_attention_2 = False + if qkv_dtype not in [torch.bfloat16, torch.float16, torch.float8_e4m3fn] or qkv_type not in [ + torch.Tensor, + Float8Tensor, + ]: + if use_flash_attention_3 and FlashAttentionUtils.v3_is_installed: + logger.debug( + "Disabling FlashAttention 3 for unsupported qkv_dtype = %s, qkv_type = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " + "qkv_type = {torch.Tensor, Float8Tensor}. ", + qkv_dtype, + qkv_type, + ) + use_flash_attention_3 = False + if use_fused_attention: + logger.debug( + "Disabling FusedAttention for unsupported qkv_dtype = %s, qkv_type = %s. " + "Supported: qkv_dtype = {torch.bfloat16, torch.float16, torch.float8_e4m3fn}, " + "qkv_type = {torch.Tensor, Float8Tensor}. ", + qkv_dtype, + qkv_type, + ) + use_fused_attention = False + + # Filter: Execution type + if fp8 and fp8_meta["recipe"].fp8_dpa: + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 for FP8 attention") + use_flash_attention_2 = False + if use_flash_attention_3 and is_training: + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for FP8 training") + use_flash_attention_3 = False + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") + use_unfused_attention = False + + # Filter: KV cache + # backend | precision | KV cache | architecture | qkv_format | page_size + # --------------------------------------------------------------------------------------- + # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 + # | FP8 | non-paged/paged | sm90 | thd | >= 1 + # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 + if inference_params is not None: + if context_parallel: + logger.debug("Disabling all backends for KV caching with context parallelism") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + if fp8_meta["recipe"].fp8_mha: + logger.debug("Disabling all backends for KV caching with FP8 MHA") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if use_flash_attention_3 and q_format != "thd": + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for FP8 KV caching and non-THD") + use_flash_attention_3 = False + if use_fused_attention: + logger.debug("Disabling FusedAttention for FP8 KV caching") + use_fused_attention = False + else: + if q_format == "thd" and pad_between_seqs: + logger.debug("Disabling all backends for pad_between_seqs = True and KV caching") + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + if inference_params.is_paged: + if use_flash_attention_2 and inference_params.page_size < 256: + if FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 for page size < 256") + use_flash_attention_2 = False + if use_flash_attention_2: + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.5") + elif not FlashAttentionUtils.v2_5_plus: + logger.debug( + "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" + ) + use_flash_attention_2 = False + + # Filter: Head dimension + if head_dim_qk != head_dim_v: + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): + logger.debug("Disabling FlashAttention as it does not support MLA.") + use_flash_attention = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") + if use_fused_attention and qkv_layout_group != "hd_hd_hd": + logger.debug( + "Disabling FusedAttention as MLA is not supported with qkv_layout = %s", + qkv_layout, + ) + use_fused_attention = False + if use_flash_attention_2 and ( + head_dim_qk > 256 + or head_dim_qk % 8 != 0 + or ( + head_dim_qk > 192 + and device_compute_capability not in ((8, 0), (9, 0), (10, 0), (12, 0)) + ) + ): + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 due to unsupported head_dim_qk and head_dim_v. " + "Supported: head_dim_qk = head_dim_v, head_dim_qk %%8 = 0, " + "head_dim_qk <= 256 (>192 requires sm80/90/100+). " + "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", + head_dim_qk, + head_dim_v, + ".".join([str(i) for i in device_compute_capability]), + ) + use_flash_attention_2 = False + if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for head_dim > 128") + use_flash_attention_3 = False + + # Filter: QKV layout + if qkv_format == "thd": + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for qkv_format = thd") + use_unfused_attention = False + if pad_between_seqs: + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): + logger.debug( + "Disabling FlashAttention for qkv_format = thd when there is " + "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" + ) + use_flash_attention = False + + # Filter: Dropout + if attention_dropout != 0.0 and use_flash_attention_3: + logger.debug("Disabling FlashAttention 3 for dropout") + use_flash_attention_3 = False + + # Filter: Context parallelism + # qkv_format | attn_mask_type | attn_bias_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # bshd, sbhd | self-attention: | no_bias, post_scale_bias | FlashAttention, FusedAttention + # | no_mask, causal | | + # | cross-attention: | | + # | no_mask | | + # thd | self-attention: | no_bias | FlashAttention, FusedAttention + # | padding, padding_causal | | if no padding between sequences, + # | cross-attention: | | FusedAttention + # | padding | | if there is padding between sequences + # Note: context parallelism requires seq_len % (cp_size * 2) == 0 for each sequence in q, k, v. + if context_parallel and use_unfused_attention: + logger.debug( + "Disabling UnfusedDotProductAttention as it does not support context parallelism" + ) + use_unfused_attention = False + if context_parallel and (use_flash_attention_2 or use_flash_attention_3): + if FlashAttentionUtils.is_installed or FlashAttentionUtils.v3_is_installed: + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with FP8" + ) + use_flash_attention = False + if "bottom_right" in attn_mask_type: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_flash_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " causal masking for cross-attention" + ) + use_flash_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with bias" + " type of %s", + core_attention_bias_type, + ) + use_flash_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FlashAttention as it does not support context parallelism with" + " attention bias for THD format" + ) + use_flash_attention = False + + if context_parallel and use_fused_attention: + if "bottom_right" in attn_mask_type: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with" + " causal_bottom_right masking" + ) + use_fused_attention = False + elif "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with causal" + " masking for cross-attention" + ) + use_fused_attention = False + elif core_attention_bias_type not in ["no_bias", "post_scale_bias"]: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with bias type" + " of %s", + core_attention_bias_type, + ) + use_fused_attention = False + elif qkv_format == "thd" and core_attention_bias_type != "no_bias": + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with attention" + " bias for THD format" + ) + use_fused_attention = False + elif head_dim_qk != head_dim_v: + logger.debug( + "Disabling FusedAttention as it does not support context parallelism with MLA" + ) + use_fused_attention = False + + # Filter: Attention mask + # attn_mask_type | attention_mask | supported backends + # ---------------------------------------------------------------------------------------- + # no_mask | None | All + # padding | | All + # self-attention | One tensor in shape [b, 1, 1, sq] | + # cross-attention | Tuple of two tensors in shapes | + # | [b, 1, 1, sq] and [b, 1, 1, skv] | + # causal | None | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # padding_causal | Same as "padding" | + # self-attention | | All + # cross-attention | | FusedAttention, UnfusedDotProductAttention + # causal_bottom_right | None | All + # padding_causal_bottom_right | Same as "padding" | All + # arbitrary | One tensor in shape broadcastable to | UnfusedDotProductAttention + # | [b, h, sq, skv] | + if attn_mask_type == "arbitrary": + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): + logger.debug("Disabling FlashAttention for arbitrary mask") + use_flash_attention = False + if use_fused_attention: + logger.debug("Disabling FusedAttention for arbitrary mask") + use_fused_attention = False + if ( + (use_flash_attention_2 or use_flash_attention_3) + and attn_mask_type in ["causal", "padding_causal"] + and max_seqlen_q != max_seqlen_kv + ): + logger.warning( + "Disabling FlashAttention as it only supports bottom-right-diagonal " + "causal mask since flash-attn 2.1 (our minimum supported version). See " + "https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag" + ) + use_flash_attention = False + + # Filter: Sliding window attention + # backend | window_size | diagonal alignment + # --------------------------------------------------------------------------------- + # FlashAttention | (-1, -1) or (>=0, >=0) | bottom right + # FusedAttention | (-1, 0) or (>=0, 0) | top left + # UnfusedDotProductAttention | (-1, -1) or (>=0, >=0) | both; + # | | converts window_size to an 'arbitrary' mask + if window_size is None: + window_size = check_set_window_size(attn_mask_type, window_size) + else: + if use_fused_attention and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if fp8 and (fp8_meta["recipe"].fp8_dpa or fp8_meta["recipe"].fp8_mha): + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention" + " for FP8" + ) + use_fused_attention = False + elif window_size[1] != 0 or attention_dropout != 0.0: + logger.debug( + "Disabling FusedAttention as it only supports sliding window attention " + "with (left, 0) and no dropout" + ) + use_fused_attention = False + elif max_seqlen_q > max_seqlen_kv: + logger.debug( + "Disabling FusedAttention as it does not support sliding window attention " + "with s_q > s_kv for cross-attention" + ) + use_fused_attention = False + if use_flash_attention_2 and (window_size[0] != -1 or window_size[1] not in [-1, 0]): + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.3") + elif not FlashAttentionUtils.v2_3_plus: + logger.debug( + "Disabling FlashAttention as sliding window attention requires flash-attn 2.3+" + ) + use_flash_attention_2 = False + + # Filter: Attention bias + # backend | bias types | ALiBi diagonal alignment + # --------------------------------------------------------------------------------- + # FlashAttention | no_bias, alibi/alibi_slopes | bottom right + # FusedAttention | no_bias, post_scale_bias | + # | alibi/alibi_slopes | top left, + # | | bottom_right (converts to a 'post_scale_bias' bias) + # UnfusedDotProductAttention | no_bias, pre/post_scale_bias | + # | alibi/alibi_slopes | both; converts to a 'post_scale_bias' bias + if core_attention_bias_type == "alibi": + if use_flash_attention_3: + if FlashAttentionUtils.v3_is_installed: + logger.debug("Disabling FlashAttention 3 for ALiBi") + use_flash_attention_3 = False + if use_flash_attention_2: + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.4") + elif not FlashAttentionUtils.v2_4_plus: + logger.debug("Disabling FlashAttention as ALiBi requires flash-attn 2.4+") + use_flash_attention_2 = False + + if ( + core_attention_bias_type not in ["no_bias", "alibi"] + or core_attention_bias_shape is not None + ): + if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( + use_flash_attention_3 and FlashAttentionUtils.v3_is_installed + ): + logger.debug("Disabling FlashAttention for pre/post_scale_bias") + use_flash_attention = False + + fu_core_attention_bias_type = core_attention_bias_type + fu_core_attention_bias_shape = core_attention_bias_shape + fu_core_attention_bias_requires_grad = core_attention_bias_requires_grad + if ( + use_fused_attention + and core_attention_bias_type == "alibi" + and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv) + ): + fu_core_attention_bias_type = "post_scale_bias" + fu_core_attention_bias_requires_grad = False + if alibi_slopes_shape is None: + fu_core_attention_bias_shape = "1hss" + elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads: + fu_core_attention_bias_shape = "1hss" + elif ( + len(alibi_slopes_shape) == 2 + and alibi_slopes_shape[0] == batch_size + and alibi_slopes_shape[1] == num_heads + ): + fu_core_attention_bias_shape = "bhss" + + if ( + use_fused_attention + and fu_core_attention_bias_type == "post_scale_bias" + and fu_core_attention_bias_shape != "1hss" + ): + if fu_core_attention_bias_requires_grad: + # remove this line when cuDNN adds bwd support for + # [1, 1, s, s], [b, 1, s, s] and [b, h, s, s] + logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape") + use_fused_attention = False + else: + # max512 backend will only support [1, h, s, s] + os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" + + # Filter: cuDNN support + fused_attention_backend = None + if use_fused_attention: + q_type = TE_DType[qkv_dtype] + kv_type = q_type + if fp8 and fp8_meta["recipe"].fp8_dpa: + q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) + kv_type = q_type + fused_attention_backend = tex.get_fused_attn_backend( + q_type, + kv_type, + QKVLayout[qkv_layout], + AttnBiasType[fu_core_attention_bias_type], + AttnMaskType[attn_mask_type], + attention_dropout, + num_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size[0], + window_size[1], + ) + if fused_attention_backend == FusedAttnBackend["No_Backend"]: + logger.debug("Disabling FusedAttention as no backend supports the provided input") + use_fused_attention = False + fused_attention_backend = None + if ( + use_fused_attention + and window_size is not None + and window_size[0] != -1 + and fused_attention_backend != FusedAttnBackend["F16_arbitrary_seqlen"] + ): + logger.debug( + "Disabling FusedAttention as only sub-backend %s does not support " + "slidng window attention", + int(fused_attention_backend), + ) + use_fused_attention = False + fused_attention_backend = None + if ( + use_fused_attention + and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] + and fu_core_attention_bias_type == "post_scale_bias" + and fu_core_attention_bias_shape != "1hss" + ): + logger.debug( + "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in" + " [1, H, S, S] shape" + ) + use_fused_attention = False + fused_attention_backend = None + + # Filter: Determinism + # backend | deterministic + # --------------------------------------------- + # FlashAttention | + # flash-attn >=2.0, <2.4.1 | no + # flash-attn >=2.4.1 | yes + # FusedAttention | + # sub-backend 0 | yes + # sub-backend 1 | workspace optimization path and sm90+: yes; + # | otherwise: no + # sub-backend 2 | no + # UnfusedDotProductAttention | yes + if use_flash_attention_2 and deterministic: + if not FlashAttentionUtils.is_installed: + FlashAttentionUtils.version_required = PkgVersion("2.4.1") + elif not FlashAttentionUtils.v2_4_1_plus: + logger.warning( + "Disabling FlashAttention as version <2.4.1 does not support deterministic " + "execution. To use FlashAttention with deterministic behavior, " + "please install flash-attn >= 2.4.1." + ) + use_flash_attention_2 = False + if use_fused_attention and deterministic: + if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: + logger.debug("Disabling FusedAttention for determinism reasons") + use_fused_attention = False + if ( + fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + and is_training + and ( + device_compute_capability < (9, 0) + or core_attention_bias_requires_grad + or cudnn_version < (8, 9, 5) + ) + ): + logger.debug("Disabling FusedAttention for determinism reasons") + use_fused_attention = False + + # use_flash_attention may have been set above + use_flash_attention_2 = use_flash_attention and use_flash_attention_2 + use_flash_attention_3 = use_flash_attention and use_flash_attention_3 + + # `FusedAttention` and `FlashAttention` are faster backends than `UnfusedDotProductAttention`. + # When `FusedAttention` does not support the provided attention params, and `FlashAttention` + # does, we recommend users to install flash-attn if not installed already. + if not use_fused_attention and _NVTE_FLASH_ATTN: + if ( + use_flash_attention_3 + and not FlashAttentionUtils.v3_is_installed + and not FlashAttentionUtils.v3_warning_printed + and torch.cuda.current_device() == 0 + ): + logger.warning( + "flash-attn v3 may provide important feature support or performance improvement." + " Please install flash-attn v3 by \n%s", + FlashAttentionUtils.v3_installation_steps, + ) + FlashAttentionUtils.v3_warning_printed = True + elif ( + use_flash_attention_2 + and not FlashAttentionUtils.is_installed + and not FlashAttentionUtils.warning_printed + and torch.cuda.current_device() == 0 + ): + logger.warning( + "flash-attn may provide important feature support or performance improvement." + " Please install flash-attn %s by pip3 install flash-attn==.", + _get_supported_versions( + FlashAttentionUtils.version_required, + FlashAttentionUtils.max_version, + ), + ) + FlashAttentionUtils.warning_printed = True + # All available backends + if use_flash_attention_2 and not FlashAttentionUtils.is_installed: + use_flash_attention_2 = False + if use_flash_attention_3 and not FlashAttentionUtils.v3_is_installed: + use_flash_attention_3 = False + use_flash_attention = use_flash_attention_2 or use_flash_attention_3 + available_backends = [use_flash_attention, use_fused_attention, use_unfused_attention] + if use_flash_attention_2: + flash_attention_backend = FlashAttentionUtils.version + if use_flash_attention_3: + flash_attention_backend = FlashAttentionUtils.fa3_version + + logger.debug( + "Available backends = {FlashAttention=%s%s, FusedAttention=%s%s," + " UnfusedDotProductAttention=%s}", + bool(available_backends[0]), + (f" ({str(flash_attention_backend)})" if flash_attention_backend is not None else ""), + bool(available_backends[1]), + ( + f" (sub-backend {int(fused_attention_backend)})" + if fused_attention_backend is not None + else "" + ), + bool(available_backends[2]), + ) + + # Select FusedAttention for performance + if use_flash_attention and use_fused_attention and device_compute_capability >= (9, 0): + logger.debug( + "Disabling FlashAttention to give FusedAttention preference on Hopper+ " + "for performance reasons" + ) + use_flash_attention = False + + # Selected backend + if use_flash_attention: + use_fused_attention = False + use_unfused_attention = False + elif use_fused_attention: + use_unfused_attention = False + selected_backend = "NoBackend" + if use_flash_attention: + selected_backend = f"FlashAttention ({str(flash_attention_backend)})" + elif use_fused_attention: + selected_backend = f"FusedAttention (sub-backend {int(fused_attention_backend)})" + elif use_unfused_attention: + selected_backend = "UnfusedDotProductAttention" + logger.debug("Selected backend = %s", selected_backend) + + return ( + use_flash_attention, + flash_attention_backend, + use_fused_attention, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) + + +@torch.no_grad() +def get_padding_mask( + batch_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_q: int, + max_seqlen_kv: int, +): + """Convert cu_seqlens to attention_mask""" + seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) + attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) + for i in range(batch_size): + attention_mask_q = torch.cat( + [ + attention_mask_q, + torch.Tensor([False] * seqlens_q[i] + [True] * (max_seqlen_q - seqlens_q[i])) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask_kv = torch.cat( + [ + attention_mask_kv, + torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i])) + .to(dtype=torch.bool) + .unsqueeze(0) + .unsqueeze(0) + .unsqueeze(0), + ], + dim=0, + ) + attention_mask = ( + attention_mask_q.to(device="cuda"), + attention_mask_kv.to(device="cuda"), + ) + return attention_mask + + +@torch.no_grad() +def get_full_mask( + max_seqlen_q: int, + max_seqlen_kv: int, + attn_mask_type: str = "no_mask", + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, + window_size: Tuple[int, int] = None, + attention_type: str = "self", + bottom_right_alignment: bool = True, +) -> torch.Tensor: + """ + Get full attention mask in [..., max_seqlen_q, max_seqlen_kv] shape, based on `attn_mask_type`, + `attention_mask`, and `window_size`. For sliding window attention, the diagonal alignment depends + on both `attn_mask_type` and `bottom_right_alignment`, as detailed below.:: + + attn_mask_type output shape diagonal alignment + -------------------------------------------------------------------------------------------- + no_mask [1, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + causal [1, 1, max_seqlen_q, max_seqlen_kv] always top left + causal_bottom_right [1, 1, max_seqlen_q, max_seqlen_kv] always bottom right + padding [batch_size, 1, max_seqlen_q, max_seqlen_kv] follow bottom_right_alignment + padding_causal [batch_size, 1, max_seqlen_q, max_seqlen_kv] always top left + padding_causal_bottom_right [batch_size, 1, max_seqlen_q, max_seqlen_kv] always bottom right + arbitrary same as attention_mask follow bottom_right_alignment + + .. note:: + + For "padding_bottom_right" mask, or "padding" mask with `bottom_right_alignment` = True, the bottom right + diagonal comes from the bottom right corner of the [actual_seqlens_q[i], actual_seqlens_kv[i]] matrix, + i = 0,...,batch_size-1, not the [max_seqlen_q, max_seqlen_kv] matrix. For example, with max_seqlen_q = 4, + max_seqlen_kv = 4, attn_mask_type = "padding", attention_type = "cross", and attention_mask = ( + [[False, False, True, True], [False, False, False, False]], + [[False, False, False, True], [False, True, True, True]]), the returned full attention mask has [2, 4, 4] + shape and is,:: + + [[[False, False, False, True], + [False, False, False, True], + [ True, True, True, True], + [ True, True, True, True]], + [[False, True, True, True], + [False, True, True, True], + [False, True, True, True], + [False, True, True, True]]] + + Parameters + ---------- + max_seqlen_q: int + Maximum sequence length for queries. + max_seqlen_kv: int + Maximum sequence length for keys and values. + attn_mask_type: str, default = `no_mask` + Attention mask type, {"`no_mask`", "`padding`", "`causal`", "`padding_causal`", + "`causal_bottom_right`", "`padding_causal_bottom_right`", "`arbitrary`"} + attention_mask: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + default = `None` + Boolean tensor(s) used to mask out attention softmax input. Please see DotProductAttention + for the requirements of `attention_mask` for different `attn_mask_type`s. + window_size: Tuple[int, int], default = `None` + Sliding window size for local attention, where query at position i attends to keys + in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + + window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding + window and causal mask specifically. Both `causal` and `causal_bottom_right` masks + map to `window_size = (-1, 0)` and Transformer Engine distinguishes them based on + `attn_mask_type`. + attention_type: str, default = "self" + Attention type, {"self", "cross"} + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the sliding window attention to the bottom right (`True`) + or top left (`False`) corner of the softmax matrix. Ignored if `attn_mask_type` explicitly + specifies "causal" or "causal_bottom_right". + + Returns + ---------- + attn_mask_type: str + For sliding window attention (>=0, >0), "arbitrary"; otherwise, the same as input `attn_mask_type` + attention_mask: torch.Tensor + The full attention mask based on `attn_mask_type`, `attention_mask` and `window_size` + actual_seqlens_q: torch.Tensor + For padding masks, the actual sequence lengths for queries, in shape [batch_size]. + For other masks, `None`. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + For padding masks, the actual sequence lengths for keys and values, in shape [batch_size]. + For other masks, `None`. + """ + # perform basic checks + change_type = window_size is not None and ( + window_size[0] != -1 or window_size[1] not in [-1, 0] + ) + if window_size is None: + window_size = (-1, -1) + if "causal" in attn_mask_type: + window_size = (window_size[0], 0) + window_size = ( + max_seqlen_kv if window_size[0] == -1 else window_size[0], + max_seqlen_q if window_size[1] == -1 else window_size[1], + ) + + # apply padding mask + actual_seqlens_q = None + actual_seqlens_kv = None + if "padding" in attn_mask_type: + if attention_type == "self": + attention_mask = torch.logical_or( + attention_mask.squeeze(1).unsqueeze(3), attention_mask + ) + else: + attention_mask = torch.logical_or( + attention_mask[0].squeeze(1).unsqueeze(3), attention_mask[1] + ) + m = attention_mask.logical_not() + actual_seqlens_q = m[:, 0, :, 0].sum(dim=1) + actual_seqlens_kv = m[:, 0, 0, :].sum(dim=1) + + # apply SWA mask + mask = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view(1, 1, 1, max_seqlen_kv) + swa_left = None + swa_right = None + if attn_mask_type == "causal_bottom_right" or ( + attn_mask_type in ["no_mask", "arbitrary"] and bottom_right_alignment + ): + swa_left = mask + max_seqlen_kv - max_seqlen_q - window_size[0] + swa_right = mask + max_seqlen_kv - max_seqlen_q + window_size[1] + elif attn_mask_type in ["causal", "padding_causal"] or ( + attn_mask_type in ["no_mask", "padding", "arbitrary"] and not bottom_right_alignment + ): + swa_left = mask - window_size[0] + swa_right = mask + window_size[1] + elif attn_mask_type == "padding_causal_bottom_right" or ( + attn_mask_type == "padding" and bottom_right_alignment + ): + batch_size = attention_mask.shape[0] + swa_left = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q - window_size[0] + ).view(batch_size, 1, 1, 1) + swa_right = mask.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + ( + actual_seqlens_kv - actual_seqlens_q + window_size[1] + ).view(batch_size, 1, 1, 1) + swa_mask = torch.logical_not( + torch.where(swa_left <= 0, 1, 0) - torch.where(swa_right < 0, 1, 0) + ) + if attention_mask is not None: + attention_mask = torch.logical_or(swa_mask, attention_mask) + else: + attention_mask = swa_mask + + # change mask type + if change_type: + attn_mask_type = "arbitrary" + + return attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv + + +@torch.no_grad() +def get_alibi( + _alibi_cache: Dict[str, Any], + num_heads: int, + max_seqlen_q: int, + max_seqlen_kv: int, + actual_seqlens_q: Optional[torch.Tensor] = None, + actual_seqlens_kv: Optional[torch.Tensor] = None, + alibi_slopes: Optional[torch.Tensor] = None, + bias_dtype: Optional[torch.dtype] = None, + bottom_right_alignment: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Parameters + ---------- + num_heads: int + Number of heads. + max_seqlen_q: int + Maximum sequence length for queries. + max_seqlen_kv: int + Maximum sequence length for keys and values. + actual_seqlens_q: Optional[torch.Tensor], default = `None` + Actual sequence lengths for queries, in shape [batch_size]. + actual_seqlens_kv: Optional[torch.Tensor], default = `None` + Actual sequence lengths for keys and values, in shape [batch_size]. + alibi_slopes: Optional[torch.Tensor], default = `None` + Custom ALiBi slopes, FP32, CUDA tensor, in shape [num_heads] or [batch_size, num_heads]. + bias_dtype: Optional[torch.dtype], default = `None` + Dtype of the generated ALiBi bias. If None, use torch.float32. + bottom_right_alignment: bool, default = `True` + Whether to align the diagonal of the ALiBi bias to the bottom right corner of + the matrix (`True`) or top left (`False`). + + Returns + ---------- + alibi_slopes: torch.Tensor + ALiBi slopes in FP32 and shape [num_heads] or [batch_size, num_heads]. + alibi_bias: torch.Tensor + ALiBi bias in FP32 or `bias_dtype`. Its shape is + (1) [1, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in [num_heads] shape, + and `actual_seqlens_q` and `actual_seqlens_kv` are `None`; or + (2) [batch_size, num_heads, max_seqlen_q, max_seqlen_kv] if `alibi_slopes` is in + [batch_size, num_heads] shape, or, if `alibi_slopes` is in [num_heads] shape and + `actual_seqlens_q` and `actual_seqlens_kv` are not `None`. + """ + # NOTE: As part of refactoring attention.py, get_alibi() now receives the alibi cache from the caller + # as an additional input arg + if _alibi_cache["_alibi_slopes_require_update"]: + if alibi_slopes is not None: + _alibi_cache["_alibi_slopes"] = alibi_slopes + else: + n = 2 ** math.floor(math.log2(num_heads)) + m_0 = 2.0 ** (-8.0 / n) + m = torch.pow(m_0, torch.arange(1, 1 + n)) + + if n < num_heads: + m_hat_0 = 2.0 ** (-4.0 / n) + m_hat = torch.pow(m_hat_0, torch.arange(1, 1 + 2 * (num_heads - n), 2)) + m = torch.cat([m, m_hat]) + + _alibi_cache["_alibi_slopes"] = m.to(dtype=torch.float32, device="cuda") + _alibi_cache["_num_heads"] = num_heads + _alibi_cache["_alibi_slopes_require_update"] = False + + if _alibi_cache["_alibi_bias_require_update"]: + assert _alibi_cache["_alibi_slopes"] is not None, "ALiBi slopes can not be None!" + if _alibi_cache["_alibi_slopes"].dim() == 1: + slopes_shape = torch.Size([1, _alibi_cache["_alibi_slopes"].shape[0], 1, 1]) + elif _alibi_cache["_alibi_slopes"].dim() == 2: + slopes_shape = torch.Size([*_alibi_cache["_alibi_slopes"].shape[:], 1, 1]) + else: + raise ValueError("ALiBi slopes cannot exceed 2 dimensions.") + + bias = torch.arange(max_seqlen_q, dtype=torch.int32, device="cuda").view( + 1, 1, max_seqlen_q, 1 + ) - torch.arange(max_seqlen_kv, dtype=torch.int32, device="cuda").view( + 1, 1, 1, max_seqlen_kv + ) + if actual_seqlens_q is None and actual_seqlens_kv is None: + if bottom_right_alignment: + bias = bias + max_seqlen_kv - max_seqlen_q + elif actual_seqlens_q is not None and actual_seqlens_kv is not None: + batch_size = actual_seqlens_q.shape[0] + bias = bias.expand(batch_size, 1, max_seqlen_q, max_seqlen_kv) + if bottom_right_alignment: + bias = bias + (actual_seqlens_kv - actual_seqlens_q).view(batch_size, 1, 1, 1) + else: + assert ( + False + ), "actual_seqlens_q and actual_seqlens_kv need to be both None or torch.Tensors!" + bias = bias.abs().mul(-1) + bias = bias * _alibi_cache["_alibi_slopes"].view(slopes_shape) + _alibi_cache["_max_seqlen_q"], _alibi_cache["_max_seqlen_kv"] = max_seqlen_q, max_seqlen_kv + _alibi_cache["_bottom_right_alignment"] = bottom_right_alignment + bias_dtype = torch.float32 if bias_dtype is None else bias_dtype + _alibi_cache["_alibi_bias"] = bias.contiguous().to(dtype=bias_dtype, device="cuda") + _alibi_cache["_alibi_bias_require_update"] = False + + return _alibi_cache["_alibi_slopes"], _alibi_cache["_alibi_bias"] + + +def get_cu_seqlens(mask: torch.Tensor) -> torch.Tensor: + """ + Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 + tensor of shape [batch_size + 1] containing the cumulative sequence lengths of + the samples in a batch. + """ + mask = mask.squeeze(1).squeeze(1) + reduced_mask = mask.logical_not().sum(dim=1) + cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + return cu_seqlens + + +def get_cu_seqlens_and_indices(mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Given a padding mask of shape [batch_size, 1, 1, max_seqlen], returns an int32 + tensor of shape [batch_size + 1] containing the cumulative sequence lengths of + the samples in a batch, and another int32 tensor of shape [batch_size * max_seqlen, 1, 1] + containing the indices for the valid tokens. + """ + mask = mask.squeeze(1).squeeze(1) + bs, seqlen = mask.shape + + reduced_mask = mask.logical_not().sum(dim=1) + cu_seqlens = reduced_mask.cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + mask = mask.reshape(-1) + indices = mask.logical_not().nonzero() + indices = indices.unsqueeze(-1) + + num_nonzeros = indices.shape[0] + pad_amount = bs * seqlen - num_nonzeros + indices = F.pad( + input=indices, pad=(0, 0, 0, 0, 0, pad_amount), mode="constant", value=float(bs * seqlen) + ) + + return cu_seqlens, indices + + +def get_indices(max_seqlen: int, cu_seqlens: torch.Tensor) -> torch.Tensor: + """ + Given max_seqlen and cu_seqlens of shape [batch_size + 1], returns an int32 + tensor of shape [batch_size * max_seqlen, 1, 1] containing the indices for + the valid tokens in a batch. + """ + bs = len(cu_seqlens) - 1 + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + indices = [i * max_seqlen + ii for i, j in enumerate(seqlens) for ii in range(j)] + indices = torch.Tensor(indices).unsqueeze(1).unsqueeze(1).to(dtype=torch.int64, device="cuda") + + num_nonzeros = indices.shape[0] + pad_amount = bs * max_seqlen - num_nonzeros + indices = F.pad( + input=indices, + pad=(0, 0, 0, 0, 0, pad_amount), + mode="constant", + value=float(bs * max_seqlen), + ) + + return indices + + +_cu_seqlens_cache = {} + + +def get_full_cu_seqlens( + batch_size: int, + max_seqlen: int, + device: torch.device, +) -> torch.Tensor: + """Cumulative sequence lengths in full data batch + + All sequences in batch have the maximum sequence length. + + """ + global _cu_seqlens_cache + if (batch_size, max_seqlen) not in _cu_seqlens_cache: + _cu_seqlens_cache[(batch_size, max_seqlen)] = torch.arange( + 0, + (batch_size + 1) * max_seqlen, + step=max_seqlen, + dtype=torch.int32, + device=device, + ) + return _cu_seqlens_cache[(batch_size, max_seqlen)] + + +@jit_fuser +def _pack_tensor( + indices: torch.Tensor, + tensor: torch.Tensor, +) -> torch.Tensor: + """ + Packs the given tensor using the `indices`. + """ + padding_indice = torch.zeros( + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device + ) + indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) + if isinstance(tensor, Float8Tensor): + tensor_data = torch.cat((tensor._data, padding_indice), dim=0) + gathered_data = torch.gather(tensor_data, 0, indices) + + packed = Float8Tensor.make_like(tensor, data=gathered_data, shape=gathered_data.shape) + else: + tensor = torch.cat((tensor, padding_indice), dim=0) + + packed = torch.gather(tensor, 0, indices) + return packed + + +@jit_fuser +def _pack_2_tensors( + indices: torch.Tensor, + t1: torch.Tensor, + t2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Packs the given 2 tensors using the `indices`. + """ + t1_packed = _pack_tensor(indices, t1) + t2_packed = _pack_tensor(indices, t2) + return t1_packed, t2_packed + + +@jit_fuser +def _pack_3_tensors( + indices: torch.Tensor, + t1: torch.Tensor, + t2: torch.Tensor, + t3: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Packs the given 3 tensors using the `indices`. + """ + t1_packed = _pack_tensor(indices, t1) + t2_packed = _pack_tensor(indices, t2) + t3_packed = _pack_tensor(indices, t3) + return t1_packed, t2_packed, t3_packed + + +@jit_fuser +def _unpack_tensor( + indices: torch.Tensor, + dim0: int, + tensor: torch.Tensor, +) -> torch.Tensor: + """ + Inverse of `_pack_tensor`. + """ + indices = indices.repeat(1, tensor.shape[1], tensor.shape[2]) + unpacked = torch.zeros( + dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device + ) + if isinstance(tensor, Float8Tensor): + unpacked.scatter_(0, indices, tensor._data) + unpacked_data = unpacked[0:-1, :, :] + unpacked = Float8Tensor.make_like(tensor, data=unpacked_data, shape=unpacked_data.shape) + else: + unpacked.scatter_(0, indices, tensor) + unpacked = unpacked[0:-1, :, :] + return unpacked + + +@jit_fuser +def _unpack_2_tensors( + indices: torch.Tensor, + dim0: int, + t1: torch.Tensor, + t2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Inverse of `_pack_2_tensors`. + """ + t1_unpacked = _unpack_tensor(indices, dim0, t1) + t2_unpacked = _unpack_tensor(indices, dim0, t2) + return t1_unpacked, t2_unpacked + + +@jit_fuser +def _unpack_3_tensors( + indices: torch.Tensor, + dim0: int, + t1: torch.Tensor, + t2: torch.Tensor, + t3: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Inverse of `_pack_3_tensors`. + """ + t1_unpacked = _unpack_tensor(indices, dim0, t1) + t2_unpacked = _unpack_tensor(indices, dim0, t2) + t3_unpacked = _unpack_tensor(indices, dim0, t3) + return t1_unpacked, t2_unpacked, t3_unpacked + + +class PackTensors(torch.autograd.Function): + """ + Autograd function to pack a tensor. + """ + + @staticmethod + def forward( + ctx, indices: torch.Tensor, *tensors: Tuple[torch.Tensor, ...] + ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: + # pylint: disable=missing-function-docstring + assert 1 <= len(tensors) <= 3, f"Packing {len(tensors)} tensors not supported." + ctx.save_for_backward(indices) + ctx.dim0 = tensors[0].shape[0] + if len(tensors) == 1: + return _pack_tensor(indices, *tensors) + if len(tensors) == 2: + return _pack_2_tensors(indices, *tensors) + return _pack_3_tensors(indices, *tensors) + + @staticmethod + def backward(ctx, *grad_outputs: Tuple[torch.Tensor, ...]): + # pylint: disable=missing-function-docstring + (indices,) = ctx.saved_tensors + if len(grad_outputs) == 1: + return None, _unpack_tensor(indices, ctx.dim0, *grad_outputs) + if len(grad_outputs) == 2: + return None, *_unpack_2_tensors(indices, ctx.dim0, *grad_outputs) + return None, *_unpack_3_tensors(indices, ctx.dim0, *grad_outputs) + + +class UnpackTensor(torch.autograd.Function): + """ + Autograd function to unpack a tensor. + """ + + @staticmethod + def forward( + ctx, + indices: torch.Tensor, + dim0: int, + tensor: torch.Tensor, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + ctx.save_for_backward(indices) + return _unpack_tensor(indices, dim0, tensor) + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + (indices,) = ctx.saved_tensors + return None, None, _pack_tensor(indices, grad_output) + + +def get_qkv_format( + qkv_layout: str = "bshd_bshd_bshd", + inference_params: InferenceParams = None, +) -> str: + """Get qkv format. + + Parameters + ---------- + qkv_layout: str + Memory layout of `q`, `k` and `v`. See get_qkv_layout() for more details. + inference_params: InferenceParams, default = `None` + InferenceParams related to KV caching. + + Returns + ---------- + qkv_format: str, default = `sbhd` + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. + q_format: str + Format of the `q` tensor, {`bshd`, `sbhd`, `thd`}. + kv_format: str + Format of the `k` and `v` tensors, {`bshd`, `sbhd`, `thd`}. + """ + splited = qkv_layout.replace("paged_kv_", "").split("_") + if inference_params is not None: + q_format = "".join([i for i in splited[0] if i.isalpha()]) + kv_format = "".join([i for i in splited[1] if i.isalpha()]) + qkv_format = q_format + "_2" + kv_format if q_format != kv_format else q_format + else: + qkv_format = "".join([i for i in splited[0] if i.isalpha()]) + q_format = qkv_format + kv_format = qkv_format + return qkv_format, q_format, kv_format + + +def get_qkv_layout( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str = "sbhd", + inference_params: InferenceParams = None, +) -> str: + """Get qkv layout. + + Parameters + ---------- + q: torch.Tensor + Query tensor. + k: torch.Tensor + Key tensor. + v: torch.Tensor + Value tensor. + qkv_format: str, default = `sbhd` + Dimension format for `q`, `k` and `v`, {`sbhd`, `bshd`, `thd`}. `s` stands for + the sequence length dimension, `b` batch size, `h` the number of attention heads, + `d` head size, and `t` the total number of tokens in a batch, i.e. + `t = sum(s_i) for i = 0...b-1`. + inference_params: InferenceParams, default = `None` + InferenceParams related to KV caching. + + Returns + ---------- + qkv_layout: str + Memory layout of `q`, `k` and `v`. Each `qkv_layout` maps to a pair of `q_format` and + `kv_format` in {`bshd`, `sbhd`, `thd`}. The `paged_kv_` prefix is used to indicate that + paged KV caching is in play. A few examples of the layouts are as follows. + + (1) `sb3hd` means `q`, `k`, `v` are created as one chunk of memory and that they are + interleaved in the `2`nd dimension. (2) `sbhd_sbh2d` means `q` and `kv` are created in + two chunks and that `q` itself is contiguous and `k`, `v` are interleaved with each other + in the `3`rd dimension, `k = kv[:,:,:,0,:]` and `v = kv[:,:,:,1,:]`. `q_format` and + `kv_format` in this case are still both `sbhd`. (3) `paged_kv_thd_bshd_bshd` means `q` is + created in `thd` and `k`, `v` are in `sbhd`. This is likely due to the cache format in + paged KV caching. + + Mapping: + `sbhd`: {`sb3hd`, `sbh3d`, `sbhd_sb2hd`, `sbhd_sbh2d`, `sbhd_sbhd_sbhd`, `paged_kv_sbhd_sbhd_sbhd`} + `bshd`: {`bs3hd`, `bsh3d`, `bshd_bs2hd`, `bshd_bsh2d`, `bshd_bshd_bshd`, `paged_kv_bshd_bshd_bshd`} + `thd` : {`t3hd`, `th3d`, `thd_t2hd`, `thd_th2d`, `thd_thd_thd`} + `sbhd_2bshd`: {`sbhd_bshd_bshd`, `paged_kv_sbhd_bshd_bshd`} + `bshd_2sbhd`: {`bshd_sbhd_sbhd`, `paged_kv_bshd_sbhd_sbhd`} + `thd_2bshd`: {`thd_bshd_bshd`, `paged_kv_thd_bshd_bshd`} + `thd_2sbhd`: {`thd_sbhd_sbhd`, `paged_kv_thd_sbhd_sbhd`} + + q: torch.Tensor + Query tensor. It may be different from input `q` as we try to fit tensors to + a supported layout. + k: torch.Tensor + Key tensor. It may be different from input `k` as we try to fit tensors to + a supported layout. + v: torch.Tensor + Value tensor. It may be different from input `v` as we try to fit tensors to + a supported layout. + q_format: str + Format of the query tensor, {`bshd`, `sbhd`, `thd`}. + kv_format: str + Format of the key and value tensors, {`bshd`, `sbhd`, `thd`}. + """ + + check_last_dim_contiguous = all(x.stride(-1) == 1 for x in [q, k, v]) + assert check_last_dim_contiguous, "q, k and v must have stride 1 in their last dimension!" + if "_2" in qkv_format: + q_format, kv_format = qkv_format.split("_2") + is_same_q_kv_format = False + else: + q_format = qkv_format + kv_format = qkv_format + is_same_q_kv_format = True + + def run_iteratively(q, k, v): + # check data pointers + data_ptr = q.untyped_storage().data_ptr() + check_ptrs_qkv = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v]) + check_ptrs_qk = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k]) + data_ptr = k.untyped_storage().data_ptr() + check_ptrs_kv = all(x.untyped_storage().data_ptr() == data_ptr for x in [k, v]) + + # check tensor shapes + shape = q.shape + check_shapes_qkv = all(shape == x.shape for x in [q, k, v]) + shape = k.shape + check_shapes_kv = shape[:-1] == v.shape[:-1] + + # check tensor strides + stride = q.stride() + check_strides_qkv = all(stride == x.stride() for x in [q, k, v]) + check_strides_kv = tuple(sk / k.shape[-1] for sk in k.stride()[:-1]) == tuple( + sv / v.shape[-1] for sv in v.stride()[:-1] + ) + + # check tensor offsets for h3d and 3hd layouts + prod_h_d = q.shape[-1] * q.shape[-2] + check_3hd_offsets = all(x.storage_offset() == i * prod_h_d for i, x in enumerate([q, k, v])) + check_h3d_offsets = all( + x.storage_offset() == i * q.shape[-1] for i, x in enumerate([q, k, v]) + ) + + # check tensor offsets for hd_h2d and hd_2hd layouts + prod_all_dims = [np.prod(x.shape) for x in [q, k]] + offset = prod_all_dims[0] if check_ptrs_qkv else 0 + prod_h_d = k.shape[-1] * k.shape[-2] + check_2hd_offsets = all( + x.storage_offset() == (offset + i * prod_h_d) for i, x in enumerate([k, v]) + ) + check_h2d_offsets = all( + x.storage_offset() == (offset + i * k.shape[-1]) for i, x in enumerate([k, v]) + ) + + # check tensor offsets for hd_hd_hd layouts + check_hd_offsets_qkv = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k, v])) + if check_ptrs_qkv + else all(x.storage_offset() == 0 for i, x in enumerate([q, k, v])) + ) + check_hd_offsets_qk = ( + all(x.storage_offset() == sum(prod_all_dims[:i]) for i, x in enumerate([q, k])) + if not check_ptrs_qkv and check_ptrs_qk + else all(x.storage_offset() == 0 for i, x in enumerate([q, k])) + ) + check_hd_offsets_kv = ( + all(x.storage_offset() == sum(prod_all_dims[1 : i + 1]) for i, x in enumerate([k, v])) + if not check_ptrs_qkv and check_ptrs_kv + else all(x.storage_offset() == 0 for i, x in enumerate([k, v])) + ) + + if check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_3hd_offsets: + # sb3hd, bs3hd, t3hd + # one chunk of memory, qkv, with q, k, v interleaved at dim=-3 in qkv + qkv_layout = qkv_format[:-2] + "3" + qkv_format[-2:] + elif check_ptrs_qkv and check_strides_qkv and check_shapes_qkv and check_h3d_offsets: + # sbh3d, bsh3d, th3d + # one chunk of memory, qkv, with q, k, v interleaved at dim=-2 in qkv + qkv_layout = qkv_format[:-1] + "3" + qkv_format[-1:] + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_2hd_offsets: + # sbhd_sb2hd, bshd_bs2hd, thd_t2hd + # two chunks of memory, q and kv, with k, v interleaved at dim=-3 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True + qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] + elif check_ptrs_kv and check_strides_kv and check_shapes_kv and check_h2d_offsets: + # sbhd_sbh2d, bshd_bsh2d, thd_th2d + # two chunks of memory, q and kv, with k, v interleaved at dim=-2 in kv + # q and kv may be disjoint or consecutive in memory, and when consecutive, they may + # have the same data pointer, i.e. check_ptrs_qkv=True + qkv_layout = qkv_format + "_" + qkv_format[:-1] + "2" + qkv_format[-1:] + elif ( + check_strides_kv + and check_shapes_kv + and (check_hd_offsets_qkv or check_hd_offsets_kv or check_hd_offsets_qk) + ): + # sbhd_sbhd_sbhd, bshd_bshd_bshd, thd_thd_thd + # three chunks of memory, q, k and v, which may be disjoint or consecutive, and + # when consecutive, they may have the same data pointer, i.e. check_ptrs_qkv=True or + # check_ptrs_qk=True or check_ptrs_kv=True + if is_same_q_kv_format: + qkv_layout = "_".join(list([qkv_format]) * 3) + else: + qkv_layout = q_format + "_" + kv_format + "_" + kv_format + else: + qkv_layout = "not_supported" + + return qkv_layout + + qkv_layout = run_iteratively(q, k, v) + if qkv_layout == "not_supported": + # force q,k,v to be contiguous and run get_layout again + q, k, v = [x.contiguous() for x in [q, k, v]] + qkv_layout = run_iteratively(q, k, v) + if qkv_layout == "not_supported": + raise RuntimeError("The provided qkv memory layout is not supported!") + + if inference_params is not None and inference_params.is_paged: + qkv_layout = "paged_kv_" + qkv_layout + + return qkv_layout, q, k, v, q_format, kv_format + + +def check_set_window_size( + attn_mask_type: str, + window_size: Tuple[int, int] = None, +): + """Check if sliding window size is compliant with attention mask type. + If not, set it to the appropriate size. + + attn_mask_type | window_size + ------------------------------------------------------------------------- + no_mask, padding, arbitrary | (-1, -1) or (>=0, >=0) + causal, padding_causal | (-1, 0) or (>=0, 0) + causal_bottom_right, padding_causal_bottom_right | (-1, 0) or (>=0, 0) + """ + orig_window_size = window_size + if "causal" in attn_mask_type: + if orig_window_size is None: + window_size = (-1, 0) + elif orig_window_size == (-1, -1) or ( + orig_window_size[0] >= 0 and orig_window_size[1] != 0 + ): + window_size = (orig_window_size[0], 0) + warnings.warn( + "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type + ) + elif orig_window_size != (-1, 0) and (orig_window_size[0] < 0 or orig_window_size[1] != 0): + assert False, ( + "window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + attn_mask_type + ) + elif attn_mask_type in ["no_mask", "padding", "arbitrary"]: + if orig_window_size is None: + window_size = (-1, -1) + elif orig_window_size == (-1, 0): + window_size = (-1, -1) + warnings.warn( + "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type + ) + elif orig_window_size != (-1, -1) and (orig_window_size[0] < 0 or orig_window_size[1] < 0): + assert False, ( + "window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + attn_mask_type + ) + else: + assert False, "Invalid attn_mask_type: " + attn_mask_type + return window_size + + +def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): + """Get the list of quantizers used in attention from the quantizers list.""" + if not fp8: + num_of_nones = 8 if cp_specific_quantizers else 6 + return [None] * num_of_nones + QKV_quantizer = quantizers["scaling_fwd"][META_QKV] + QKV_quantizer.internal = True + QKV_quantizer.set_usage(rowwise=True, columnwise=False) + O_quantizer = quantizers["scaling_fwd"][META_O] + O_quantizer.set_usage(rowwise=True, columnwise=False) + S_quantizer = quantizers["scaling_fwd"][META_S] + S_quantizer.internal = True + S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] + dQKV_quantizer.interal = True + dQKV_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer = quantizers["scaling_bwd"][META_DO] + dO_quantizer.set_usage(rowwise=True, columnwise=False) + dO_quantizer.internal = True + dP_quantizer = quantizers["scaling_bwd"][META_DP] + dP_quantizer.set_usage(rowwise=True, columnwise=False) + dP_quantizer.interal = True + dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] + dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_CP_quantizer.internal = True + O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] + O_CP_quantizer.set_usage(rowwise=True, columnwise=False) + + if cp_specific_quantizers: + return ( + QKV_quantizer, + O_quantizer, + O_CP_quantizer, + S_quantizer, + dQKV_quantizer, + dQKV_CP_quantizer, + dO_quantizer, + dP_quantizer, + ) + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/export.py b/transformer_engine/pytorch/export.py deleted file mode 100755 index 5bc079711a..0000000000 --- a/transformer_engine/pytorch/export.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Export utilities for TransformerEngine""" -from contextlib import contextmanager - -_IN_ONNX_EXPORT_MODE = False - - -@contextmanager -def onnx_export( - enabled: bool = False, -) -> None: - """ - Context manager for exporting to ONNX. - - .. code-block:: python - - with onnx_export(enabled=True): - torch.onnx.export(model) - - Parameters - ---------- - enabled: bool, default = `False` - whether or not to enable export - """ - - global _IN_ONNX_EXPORT_MODE - onnx_export_state = _IN_ONNX_EXPORT_MODE - try: - _IN_ONNX_EXPORT_MODE = enabled - yield - finally: - _IN_ONNX_EXPORT_MODE = onnx_export_state - - -def is_in_onnx_export_mode() -> bool: - """Returns True if onnx export mode is enabled, False otherwise.""" - return _IN_ONNX_EXPORT_MODE diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index c3d8709925..a771e3bb75 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -1,9 +1,9 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tensor class with FP8 data""" -from .tensor import Float8Tensor +from .tensor.float8_tensor import Float8Tensor __all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index f95ba515cb..87298c2ec7 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,8 +1,11 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """FP8 utilities for TransformerEngine""" +from __future__ import annotations + +import abc import os from contextlib import contextmanager from collections import deque @@ -10,7 +13,13 @@ import torch import transformer_engine_torch as tex -from transformer_engine.common.recipe import DelayedScaling, Format +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + Format, + MXFP8BlockScaling, + Float8CurrentScaling, +) from .constants import dist_group_type from .utils import get_device_compute_capability @@ -33,21 +42,30 @@ def check_fp8_support() -> Tuple[bool, str]: return True, "" -def get_default_fp8_recipe() -> DelayedScaling: +def check_mxfp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for MXFP8 execution." + + +def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return MXFP8BlockScaling() return DelayedScaling() -def get_fp8_torch_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> torch.dtype: +def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): return torch.float8_e4m3fn - return torch.float8_e5m2fn + return torch.float8_e5m2 -def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get fp8 data type according to recipe and tensor""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -56,7 +74,7 @@ def get_fp8_te_dtype(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> t return tex.DType.kFloat8E5M2 -def get_fp8_max(fp8_recipe: DelayedScaling, fprop_tensor: bool = True) -> tex.DType: +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor @@ -81,7 +99,6 @@ class FP8GlobalStateManager: global_amax_buffer = {} global_amax_history_buffer = {} global_scale_buffer = {} - global_scale_inv_buffer = {} fp8_tensors_recompute_buffer = [] fp8_available = None reason_for_no_fp8 = "" @@ -89,6 +106,8 @@ class FP8GlobalStateManager: autocast_to_fp8_params = {} fp8_param_to_autocast = {} skip_fp8_weight_update_tensor = None + mxfp8_available = None + reason_for_no_mxfp8 = "" @classmethod def reset(cls) -> None: @@ -104,7 +123,6 @@ def reset(cls) -> None: cls.global_amax_buffer = {} cls.global_amax_history_buffer = {} cls.global_scale_buffer = {} - cls.global_scale_inv_buffer = {} cls.fp8_tensors_recompute_buffer = [] cls.fp8_available = None cls.reason_for_no_fp8 = "" @@ -112,6 +130,8 @@ def reset(cls) -> None: cls.autocast_to_fp8_params = {} cls.fp8_param_to_autocast = {} cls.skip_fp8_weight_update_tensor = None + cls.mxfp8_available = None + cls.reason_for_no_mxfp8 = "" @classmethod def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: @@ -132,6 +152,13 @@ def is_fp8_available(cls) -> Tuple[bool, str]: cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() return cls.fp8_available, cls.reason_for_no_fp8 + @classmethod + def is_mxfp8_available(cls) -> Tuple[bool, str]: + """Return if MXFP8/current scaling support is available.""" + if cls.mxfp8_available is None: + cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() + return cls.mxfp8_available, cls.reason_for_no_mxfp8 + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -156,30 +183,29 @@ def get_buffer_info(cls) -> str: def get_key_in_buffer( cls, forward: bool, - fp8_weights: bool, - fp8_recipe: DelayedScaling, + fp8_recipe: Recipe, fp8_group: dist_group_type, ) -> str: """Returns a key into the global FP8 buffers.""" autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{fp8_weights}_{autocast_key}" + return f"{fwd_bwd_key}_{autocast_key}" @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, bool, str]: + def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: """Splits buffer key into relevant parts.""" - forward, fp8_weights, autocast_key = key.split("_", 2) + forward, autocast_key = key.split("_", 1) forward = forward == "forward" - fp8_weights = fp8_weights == "True" - return forward, fp8_weights, autocast_key + return forward, autocast_key @classmethod def add_fp8_tensors_to_global_buffer( cls, fp8_meta: Dict[str, Any], - fp8_weights: Optional[List[torch.Tensor]] = None, ) -> None: """ + Delayed scaling only. + The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is to call this function in order to append it's FP8 tensor into a global @@ -193,6 +219,10 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + # Every module must call this function exactly once since # the amax tensors are static. Ensures that compatibility # with non-graphed modules is maintained. @@ -202,46 +232,23 @@ def add_fp8_tensors_to_global_buffer( fp8_meta[index_in_buffer] = [] for forward in (True, False): - # This algorithm creates a two-way map with `autocast_to_fp8_params` and - # `fp8_param_to_autocast`. This is used for keeping track of FP8 weights - # in an autocasted region and cross reference them in `float8_tensor.py` - # to perform the forward amax reduction. fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) if fp8_meta_tensor_key not in fp8_meta: # Handles non-parameter FP8 modules, e.g. DPA. continue - if forward and fp8_weights is not None: - autocast_key = cls.get_unique_autocast_key( - fp8_meta["recipe"], fp8_meta["fp8_group"] - ) - fp8_weight_set = {id(w._data) for w in fp8_weights} - if autocast_key not in cls.autocast_to_fp8_params: - cls.autocast_to_fp8_params[autocast_key] = fp8_weight_set - else: - cls.autocast_to_fp8_params[autocast_key] = cls.autocast_to_fp8_params[ - autocast_key - ].union(fp8_weight_set) - # Identify correct autocast key for a given param. - for w in fp8_weight_set: - cls.fp8_param_to_autocast[w] = autocast_key - - key = cls.get_key_in_buffer( - forward, fp8_weights is not None, fp8_meta["recipe"], fp8_meta["fp8_group"] - ) + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in cls.global_amax_buffer: cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - cls.global_scale_inv_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale_inv] else: cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) cls.global_amax_history_buffer[key].append( fp8_meta[fp8_meta_tensor_key].amax_history ) cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - cls.global_scale_inv_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale_inv) fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) fp8_meta[index_in_buffer].append(key) @@ -275,7 +282,7 @@ def is_first_fp8_module(cls): return tmp @classmethod - def get_fp8_recipe(cls) -> DelayedScaling: + def get_fp8_recipe(cls) -> Recipe: """Return the fp8 recipe""" if cls.FP8_RECIPE is not None: return cls.FP8_RECIPE @@ -287,7 +294,7 @@ def get_fp8_group(cls) -> Union[dist_group_type, None]: return cls.FP8_DISTRIBUTED_GROUP @classmethod - def get_fp8_autocast_state(cls) -> Tuple[bool, bool, DelayedScaling, dist_group_type, bool]: + def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: """FP8 autocast state getter""" return ( cls.FP8_ENABLED, @@ -327,20 +334,14 @@ def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_ty def reduce_and_update_fp8_tensors( cls, forward: bool = True, - fp8_weights: bool = False, ) -> None: - """Concatenate, reduce, and split amaxes in the global buffer.""" + """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" + # global_amax_buffer should only be non-empty for fp8 delayed scaling for buffer_key, amax_buffer in cls.global_amax_buffer.items(): # Check for forward or backward reduction. - fwd_update, fp8_weights_update, autocast_key = cls.split_key_in_buffer(buffer_key) + fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) if fwd_update != forward: continue - # Only skip a forward update when `fp8_weights` is explicitly set to `True` - # (inside optimizer) and the current key is not an `fp8_weight_update` key. - # For other cases, we need to reduce because of activation tensors. - # TODO(ksivaman) consider separate weight and activation fp8_tensors. - if fwd_update and fp8_weights and not fp8_weights_update: - continue if len(amax_buffer) == 0: continue @@ -368,7 +369,6 @@ def reduce_and_update_fp8_tensors( contiguous_amax, cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], - cls.global_scale_inv_buffer[buffer_key], recipe.amax_compute_algo, get_fp8_te_dtype(recipe, forward), recipe.margin, @@ -376,19 +376,18 @@ def reduce_and_update_fp8_tensors( else: split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - for amax_history, scale, scale_inv in zip( + for amax_history, scale in zip( cls.global_amax_history_buffer[buffer_key], cls.global_scale_buffer[buffer_key], - cls.global_scale_inv_buffer[buffer_key], ): _amax_and_scale_update( - amax_history, scale, scale_inv, get_fp8_max(recipe, forward), recipe + amax_history, scale, get_fp8_max(recipe, forward), recipe ) @classmethod def get_unique_autocast_key( cls, - recipe: Optional[DelayedScaling] = None, + recipe: Optional[Recipe] = None, group: Optional[dist_group_type] = None, ): """ @@ -402,7 +401,7 @@ def fp8_autocast_enter( cls, enabled: bool = False, calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: @@ -425,6 +424,9 @@ def fp8_autocast_enter( if enabled: fp8_available, reason_for_no_fp8 = cls.is_fp8_available() assert fp8_available, reason_for_no_fp8 + if isinstance(fp8_recipe, MXFP8BlockScaling): + mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() + assert mxfp8_available, reason_for_no_mxfp8 @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -434,19 +436,25 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: # FP8 weight modules are reduced at the end of the optimizer # step after the weight amax is populated. if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - cls.reduce_and_update_fp8_tensors(forward=True, fp8_weights=False) + # delayed scaling only function, for other recipes (current scaling with any granularity), + # this is noop for other recipes because cls.global_amax_buffer is empty list + cls.reduce_and_update_fp8_tensors(forward=True) @classmethod def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: """Copy the scaling factors and amaxes for recompute forward phase to ensure both forward steps are numerically same. """ + + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" to_copy = [ fp8_meta["scaling_fwd"].amax_history.clone(), fp8_meta["scaling_fwd"].scale.clone(), - fp8_meta["scaling_fwd"].scale_inv.clone(), ] if buffer_position_key in fp8_meta: @@ -464,31 +472,35 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non """Switch to the copied scaling factors and amaxes from phase 1 forward for indentical numerical outputs. """ + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return # Store updated amaxes and scales from phase 1 post forward. fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale - fp8_meta["updated_scale_inv_fwd"] = fp8_meta["scaling_fwd"].scale_inv # Retrieve stashed amaxes and scales from phase 1 pre forward. buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history = stashed_fp8_meta[0] - fp8_meta["scaling_fwd"].scale = stashed_fp8_meta[1] - fp8_meta["scaling_fwd"].scale_inv = stashed_fp8_meta[2] + fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) + fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) @staticmethod def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" - fp8_meta["scaling_fwd"].amax_history = fp8_meta["updated_amax_history_fwd"] - fp8_meta["scaling_fwd"].scale = fp8_meta["updated_scale_fwd"] - fp8_meta["scaling_fwd"].scale_inv = fp8_meta["updated_scale_inv_fwd"] + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) + fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) @contextmanager -def fp8_model_init(enabled: bool = True) -> None: +def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None: """ Context manager for FP8 initialization of parameters. @@ -512,22 +524,27 @@ def fp8_model_init(enabled: bool = True) -> None: precision copies of weights are already present in the optimizer. * inference, where only the FP8 copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. + recipe: transformer_engine.common.recipe.Recipe, default = `None` + Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE FP8GlobalStateManager.FP8_PARAMETERS = enabled + FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe @contextmanager def fp8_autocast( enabled: bool = True, calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, ) -> None: @@ -562,7 +579,7 @@ def fp8_autocast( data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` + fp8_recipe: recipe.Recipe, default = `None` recipe used for FP8 training. fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` distributed group over which amaxes for the fp8 tensors @@ -672,7 +689,6 @@ def _compute_scaling_factor( def _amax_and_scale_update( amax_history: torch.Tensor, scale: torch.Tensor, - scale_inv: torch.Tensor, fp8_max: float, recipe: DelayedScaling, ) -> None: @@ -683,7 +699,6 @@ def _amax_and_scale_update( ) new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) scale.copy_(new_scale) - scale_inv.copy_(1.0 / new_scale) amax_history.copy_(new_amax_history) @@ -695,3 +710,193 @@ def split_and_copy( """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" splits = buffer.split(chunk_sizes) torch._foreach_copy_(outputs, splits) + + +class RecipeState(abc.ABC): + """Configuration and state for a quantization recipe. + + This is a builder class for quantizers, which are in turn builder + classes for quantized tensors. + + This class may pack together the state for multiple quantizers, + which is helpful for applying fused kernels with less overhead. + + """ + + @staticmethod + def create( + recipe: Recipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> RecipeState: + """Factory method to create the state for a quantization recipe + + Parameters + ---------- + recipe: Recipe + Quantization recipe. + mode: {"forward", "backward"} + Training stage where quantization will be performed. + num_quantizers: int, default = 1 + Number of quantizers to create state for. + device: torch.device, default = default CUDA device + Device for quantized tensors. + + Returns + ------- + RecipeState: + Quantization recipe state. + + """ + + cls = None + if recipe.delayed(): + cls = DelayedScalingRecipeState + elif recipe.mxfp8(): + cls = MXFP8BlockScalingRecipeState + elif recipe.float8_current_scaling(): + cls = Float8CurrentScalingRecipeState + else: + raise ValueError("{recipe.__class__.__name__} is not supported") + return cls( + recipe, + mode=mode, + num_quantizers=num_quantizers, + device=device, + ) + + @abc.abstractmethod + def make_quantizers(self) -> list: + """Convert recipe state to quantizers. + + Quantizers are builder classes for quantized tensors. They are + typically used to convert a high-precision tensor (e.g. in + FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + +class DelayedScalingRecipeState(RecipeState): + """State for FP8 quantization with per-tensor delayed scaling. + + Delayed scaling recipe requires a scaling factor (applied when + casting to FP8) and a history of max-abs values ("amax") from + recent FP8 casts for updating the scaling factor. The scale update + is handled externally by `FP8GlobalStateManager`. + + """ + + recipe: DelayedScaling + mode: str + dtype: tex.DType + scale: torch.Tensor + amax_history: torch.Tensor + + def __init__( + self, + recipe: DelayedScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) + self.amax_history = torch.zeros( + recipe.amax_history_len, + num_quantizers, + dtype=torch.float32, + device=device, + ) + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_tensor import Float8Quantizer + + return [ + Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) + for i in range(self.num_quantizers) + ] + + +class Float8CurrentScalingRecipeState(RecipeState): + """Configuration for Per-tensor current scaling quantization. + + Per-tensor current quantization does not require state. + + """ + + recipe: Float8CurrentScaling + mode: str + dtype: tex.DType + device: torch.device + + def __init__( + self, + recipe: Float8CurrentScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + from .tensor.float8_tensor import Float8CurrentScalingQuantizer + + return [ + Float8CurrentScalingQuantizer(self.dtype, device=self.device) + for i in range(self.num_quantizers) + ] + + +class MXFP8BlockScalingRecipeState(RecipeState): + """Configuration for MXFP8 quantization. + + MXFP8 quantization does not require state. + + """ + + recipe: MXFP8BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: MXFP8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.mxfp8_tensor import MXFP8Quantizer + + return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index ed0ed1c008..0479aebb4d 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,7 +11,8 @@ from torch.utils._pytree import tree_unflatten as _tree_unflatten from torch._C import _graph_pool_handle -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.common.recipe import DelayedScaling, Recipe +from transformer_engine.pytorch.constants import dist_group_type from .fp8 import ( fp8_autocast, FP8GlobalStateManager, @@ -63,6 +64,7 @@ def _make_graphed_callables( sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -89,6 +91,14 @@ def _make_graphed_callables( sample_args = (sample_args,) sample_kwargs = (sample_kwargs,) + # Check training/inference + is_training = all(c.training for c in callables) + if not is_training and any(c.training for c in callables): + assert False, ( + "make_graphed_callables only supports when modules are all in training or all in" + " inference mode." + ) + # Check sizes of args if _order is None: assert len(sample_args) == len(callables) @@ -173,11 +183,14 @@ def _make_graphed_callables( ] else: per_callable_module_params = [] - for c in callables: - for i in range(num_microbatches): - per_callable_module_params.append( - tuple(c.parameters()) if isinstance(c, torch.nn.Module) else () - ) + for m_chunk in range(num_model_chunks): + for _ in range(num_microbatches): + for l_no in range(num_layers): + per_callable_module_params.append( + tuple(callables[m_chunk * num_layers + l_no].parameters()) + if isinstance(callables[m_chunk * num_layers + l_no], torch.nn.Module) + else () + ) assert len(per_callable_module_params) == len(flatten_sample_args) per_callable_static_input_surfaces = [ flatten_sample_args[i] + per_callable_module_params[i] @@ -201,21 +214,71 @@ def _make_graphed_callables( # Hopefully prevents cudnn benchmarking and other lazy-initialization cuda work # from ending up in any captures. torch.cuda.synchronize() - with torch.cuda.stream(torch.cuda.Stream()): + + # Get warmup func and func_idx. + warmup_func_idx = [] + warmup_func = [] + if _order is None: for func_idx, func in enumerate(callables): + warmup_func_idx.append(func_idx) + warmup_func.append(func) + else: + fwd_idx = [0] * num_model_chunks + for c_id in _order: + if c_id > 0: + m_chunk = c_id - 1 + for l_no in range(num_layers): + func = callables[m_chunk * num_layers + l_no] + func_idx = (m_chunk * num_microbatches * num_layers) + ( + fwd_idx[m_chunk] * num_layers + l_no + ) + warmup_func_idx.append(func_idx) + warmup_func.append(func) + fwd_idx[m_chunk] += 1 + assert len(warmup_func) == len( + sample_args + ), f"Warmup runs {len(warmup_func)} don't match args {len(sample_args)}." + assert len(warmup_func_idx) == len( + set(warmup_func_idx) + ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." + + # Filter the TE modules that cudagraph can access. + visited_te_modules = set() + + def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument + if isinstance(module, TransformerEngineBaseModule): + visited_te_modules.add(module) + + # Run warmup and do the above filtering. + with torch.cuda.stream(torch.cuda.Stream()): + for func_idx, func in zip(warmup_func_idx, warmup_func): args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] for _ in range(num_warmup_iters): + hooks = [] + for module in func.modules(): + hook = module.register_forward_hook(hook_fn) + hooks.append(hook) outputs, _ = _tree_flatten(func(*args, **kwargs)) - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), - only_inputs=True, - allow_unused=allow_unused_input, - ) + for hook in hooks: + hook.remove() + if is_training: + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(torch.empty_like(o) for o in outputs if o.requires_grad), + only_inputs=True, + allow_unused=allow_unused_input, + ) + else: + grad_inputs = None del outputs, grad_inputs + # The following code is added specifically for MCore's special requirements, + # aimed at preventing warmup from altering the control flow. + for module in func.modules(): + if hasattr(module, "is_first_microbatch"): + module.is_first_microbatch = True torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -262,21 +325,23 @@ def _make_graphed_callables( static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) - with torch.cuda.graph(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, - ) + if is_training: + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, + ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if arg.requires_grad: + if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad: static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: @@ -313,21 +378,23 @@ def _make_graphed_callables( static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) - with torch.cuda.graph(bwd_graph, pool=mempool): - grad_inputs = torch.autograd.grad( - outputs=tuple(o for o in static_outputs if o.requires_grad), - inputs=tuple(i for i in static_input_surface if i.requires_grad), - grad_outputs=tuple(o for o in static_grad_outputs if o is not None), - only_inputs=True, - allow_unused=allow_unused_input, - ) + if is_training: + with torch.cuda.graph(bwd_graph, pool=mempool): + grad_inputs = torch.autograd.grad( + outputs=tuple(o for o in static_outputs if o.requires_grad), + inputs=tuple(i for i in static_input_surface if i.requires_grad), + grad_outputs=tuple(o for o in static_grad_outputs if o is not None), + only_inputs=True, + allow_unused=allow_unused_input, + retain_graph=retain_graph_in_backward, + ) # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that # don't require grad. I couldn't think of a slick one-liner for this pattern. static_grad_inputs = [] grad_idx = 0 for arg in static_input_surface: - if arg.requires_grad: + if is_training and isinstance(arg, torch.Tensor) and arg.requires_grad: static_grad_inputs.append(grad_inputs[grad_idx]) grad_idx += 1 else: @@ -368,7 +435,10 @@ def forward(ctx, skip_fp8_weight_update, *inputs): # Copy values from new tensors into static tensors for i in range(len_user_args): - if static_input_surface[i].data_ptr() != inputs[i].data_ptr(): + if ( + isinstance(static_input_surface[i], torch.Tensor) + and static_input_surface[i].data_ptr() != inputs[i].data_ptr() + ): static_input_surface[i].copy_(inputs[i]) # Replay forward graph @@ -462,10 +532,23 @@ def new_fwd(*user_args, **user_kwargs): isinstance(m, TransformerEngineBaseModule) and FP8GlobalStateManager.is_fp8_enabled() ): + if m not in visited_te_modules: + # Only Set the FP8 meta for the modules included by forward + continue + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + from transformer_engine.pytorch.attention import DotProductAttention + + if ( + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + ): + # Don't need to update FP8 meta for non-FP8 DPA + continue m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, fp8_weights=m._get_fp8_params() + m.fp8_meta, ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) @@ -489,12 +572,16 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_recipe: DelayedScaling, -) -> List[Any]: + fp8_recipe: Recipe, +) -> Optional[List[Any]]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ + + if not isinstance(fp8_recipe, DelayedScaling): + return None + fp8_tensors = [] for module in modules: for m in module.modules(): @@ -512,9 +599,13 @@ def save_fp8_tensors( def restore_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_tensors: List[Any], + fp8_tensors: Optional[List[Any]], ) -> None: """Restore FP8 tensors.""" + + if fp8_tensors is None: + return + for module in modules: for m in module.modules(): module_tensors = fp8_tensors.pop(0) @@ -538,9 +629,11 @@ def make_graphed_callables( fp8_enabled: bool = False, fp8_calibrating: bool = False, fp8_recipe: Optional[DelayedScaling] = None, + fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, + retain_graph_in_backward: bool = False, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -567,6 +660,8 @@ def make_graphed_callables( pool: (tuple of) int, default = `None`, optional An instance returned from function `torch.cuda.graph_pool_handle` that hints this graph may share memory with the indicated pool. + retain_graph_in_backward: bool, default = `False` + Whether to set retain_graph=True in backward graph capture. FP8-related parameters ---------------------- @@ -579,6 +674,9 @@ def make_graphed_callables( using a higher precision. fp8_recipe: recipe.DelayedScaling, default = `None` recipe used for FP8 training. + fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the fp8 tensors + are reduced at the end of each training step. fp8_weight_caching: bool, default = `False` Whether or not to cache FP8 weights across microbatches. if set to `True`, the `is_first_microbatch` boolean argument must be passed into the forward @@ -607,7 +705,11 @@ def wrap_autocast(block): def forward_func(*args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, _graph=True + enabled=fp8_enabled, + calibrating=fp8_calibrating, + fp8_recipe=fp8_recipe, + fp8_group=fp8_group, + _graph=True, ): outputs = old_forward(*args, **kwargs) return outputs @@ -644,6 +746,7 @@ def forward_func(*args, **kwargs): sample_kwargs=sample_kwargs, _order=_order, pool=pool, + retain_graph_in_backward=retain_graph_in_backward, ) # Ensures warmup does not affect numerics for ops such as dropout. diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index ed08627e95..aae35ded68 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,28 +10,20 @@ # pylint: disable=unnecessary-lambda-assignment -jit_fuser = torch.jit.script +jit_fuser = lambda func: func if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): jit_fuser = torch.compile + # See: https://github.com/NVIDIA/TransformerEngine/issues/597 dropout_fuser = torch.jit.script if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): dropout_fuser = torch.compile + # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func -if torch.__version__ >= "2": - import torch._dynamo - - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( - f, recursive=recursive - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable +no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) def set_jit_fusion_options() -> None: diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index ba4755efe3..5074d32aa2 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 21365398f3..cd18808465 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -1,32 +1,30 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Internal function used by multiple modules.""" -from typing import Any, Dict, List, Optional, Tuple, Union, Callable +import os +from typing import Any, List, Optional, Tuple, Union, Callable from dataclasses import dataclass +from functools import reduce +from operator import mul as multiply_op import torch from .. import cpp_extensions as tex -from ..export import is_in_onnx_export_mode -from ..fp8 import get_fp8_te_dtype +from ..constants import TE_DType from ..utils import get_default_init_method +from ..tensor.float8_tensor import Float8Tensor +from ..tensor.mxfp8_tensor import MXFP8Quantizer +_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "0"))) -def _get_normalization_func( - normalization: str, fp8_output: bool, is_grad_enabled: bool, forward: bool -): + +def _get_normalization_func(normalization: str, forward: bool): fwd_normalization_funcs = { - ("LayerNorm", True, True): tex.layernorm_fwd_fp8, - ("LayerNorm", True, False): tex.layernorm_fwd_fp8_inf, - ("LayerNorm", False, True): tex.layernorm_fwd_noalloc, - ("LayerNorm", False, False): tex.layernorm_fwd_inf, - ("RMSNorm", True, True): tex.rmsnorm_fwd_fp8, - ("RMSNorm", True, False): tex.rmsnorm_fwd_fp8_inf, - ("RMSNorm", False, True): tex.rmsnorm_fwd_noalloc, - ("RMSNorm", False, False): tex.rmsnorm_fwd_inf, + "LayerNorm": tex.layernorm_fwd, + "RMSNorm": tex.rmsnorm_fwd, } bwd_normalization_funcs = { "LayerNorm": tex.layernorm_bwd, @@ -34,81 +32,79 @@ def _get_normalization_func( } if forward: - return fwd_normalization_funcs[(normalization, fp8_output, is_grad_enabled)] - assert not fp8_output, "FP8 output is not supported in backward normalization!" - assert is_grad_enabled, "Gradient has to be enabled to call backward normalization!" + return fwd_normalization_funcs[normalization] return bwd_normalization_funcs[normalization] -def _apply_normalization( +def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor: + """Reorder FP8 transposes after Userbuffers gather. + + The all-gather is performed in-place in the Float8Tensor's + row-wise data, and afterwards we need to do a transpose to get the + correct ordering. This misuses data fields in Float8Tensor and + should be considered an evil hack. It would be best to move + transpose logic into CommOverlap::get_buffer. + + Responsibility for fixing: adener, tmoon + + """ + assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor" + assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1" + assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data" + assert ( + fp8_tensor._data.shape[0] % tp_size == 0 + ), "Leading dimension of data is not divisble by TP size" + + data = fp8_tensor._data + batched_size = reduce(multiply_op, data.shape[1:]) + interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size] + transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size] + fp8_tensor._transpose = ( + data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape) + ) + + fp8_tensor._transpose_invalid = False + fp8_tensor._data = None + + return fp8_tensor + + +def apply_normalization( inputmat: torch.Tensor, ln_out: torch.Tensor, ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], eps: float, - fp8_out: bool, - fp8_meta: Dict[str, Any], + output_quantizer, + output_dtype, normalization: str, fwd_ln_sm_margin: int, zero_centered_gamma: bool, - is_grad_enabled: bool, - fp8_scale: Optional[torch.Tensor] = None, - fp8_amax: Optional[torch.Tensor] = None, - fp8_scale_inv: Optional[torch.Tensor] = None, ): - normalization_func = _get_normalization_func(normalization, fp8_out, is_grad_enabled, True) + """Apply normalization to input.""" + normalization_func = _get_normalization_func(normalization, True) inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias) - if fp8_out: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - - if is_grad_enabled: - output_key = "ln_out" if normalization == "LayerNorm" else "rmsnorm_out" - output_kwarg = {output_key: ln_out} - output = normalization_func( - *inputs, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - scale=fp8_scale, - amax=fp8_amax, - scale_inv=fp8_scale_inv, - **output_kwarg, - ) - else: - return ( - normalization_func( - *inputs, - eps, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - fwd_ln_sm_margin, - zero_centered_gamma, - scale=fp8_scale, - amax=fp8_amax, - scale_inv=fp8_scale_inv, - ), - None, - None, - ) - else: - if is_grad_enabled: - output = normalization_func(*inputs, ln_out, eps, fwd_ln_sm_margin, zero_centered_gamma) - else: - return ( - normalization_func(*inputs, eps, fwd_ln_sm_margin, zero_centered_gamma), - None, - None, - ) - if normalization == "RMSNorm": - output = (ln_out, None, output[1]) - elif normalization == "LayerNorm": - output = (ln_out, output[1], output[2]) - return output + + split_mxfp8_cast = False + if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer): + split_mxfp8_cast = True + + output = normalization_func( + *inputs, + eps, + None if split_mxfp8_cast else ln_out, + None if split_mxfp8_cast else output_quantizer, + TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype, + fwd_ln_sm_margin, + zero_centered_gamma, + ) + + return ( + (output_quantizer.quantize(output[0], out=ln_out), *output[1:]) + if split_mxfp8_cast + else output + ) class _NoopCatFunc(torch.autograd.Function): @@ -202,7 +198,7 @@ def backward( return None, *grad_inputs -def _noop_cat( +def noop_cat( tensors: List[torch.Tensor], dim: int = 0, ) -> torch.Tensor: @@ -217,8 +213,6 @@ def _noop_cat( raise ValueError("Attempted to concatenate 0 tensors") if len(tensors) == 1: return tensors[0] - if is_in_onnx_export_mode(): - return torch.cat(tensors, dim=dim) return _NoopCatFunc.apply(dim, *tensors) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 534174380f..4b82054fec 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -7,9 +7,6 @@ import os import pickle import warnings -import socket -import fcntl -import struct from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager @@ -18,12 +15,15 @@ import torch.nn.functional as F import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe + from ._common import _ParameterInitMeta -from ..export import is_in_onnx_export_mode from ..fp8 import ( - get_default_fp8_recipe, - get_fp8_te_dtype, + MXFP8BlockScalingRecipeState, + DelayedScalingRecipeState, + Float8CurrentScalingRecipeState, FP8GlobalStateManager, + RecipeState, ) from ..distributed import ( gather_along_first_dim, @@ -31,13 +31,10 @@ in_fp8_activation_recompute_phase, _fsdp_gather_tensors, ) -from ..cpp_extensions import ( - fp8_cast_transpose_fused, - fp8_cast_transpose_bgrad_fused, - cast_to_fp8, -) from ..constants import dist_group_type -from ..float8_tensor import Float8Tensor +from ..tensor import QuantizedTensor, Quantizer +from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase __all__ = ["initialize_ub", "destroy_ub"] @@ -48,6 +45,7 @@ _cublas_workspace = None _ub_communicators = None _NUM_MAX_UB_STREAMS = 3 +_MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = None, None layers_atomic_ring_exchange = [] @@ -177,85 +175,32 @@ def initialize_ub( world_rank = torch.distributed.get_rank(world_group) world_size = torch.distributed.get_world_size(world_group) - # We have single-node NVLink so we can color based on physical node hostnames. - # NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and - # otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on - # the chosen bootstrap backend. - mydomain = socket.gethostname() - ifname = os.getenv( - "NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME") - ) - if ifname is not None: - # Make sure the ifname found in the environment is a valid network interface - if ifname in [name for _, name in socket.if_nameindex()]: - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - mydomain = socket.inet_ntoa( - fcntl.ioctl( - s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8")) - )[20:24] - ) - except OSError as err: - raise OSError(f"Invalid network interface: {ifname}") from err - finally: - s.close() - else: - ifname_warning = ( - f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will" - + " attempt to detect ranks on the same node by matching " - + "'socket.gethostname()', which is known to fail on virtual clusters like " - + "Kubernetes. If Userbuffers initialization fails, please set the " - + "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network " - + "interface." - ) - warnings.warn(ifname_warning, UserWarning) - - # Allgather the domain colors across ranks and reduce to a list of unique domains - domain_per_rank_list = [None for _ in range(world_size)] - torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group) - unique_domains = [] - for domain in domain_per_rank_list: - if domain not in unique_domains: - unique_domains.append(domain) - num_domains = len(unique_domains) - + num_domains = world_size // tp_size + mydomain_idx = world_rank // tp_size if num_domains > 1: - # DP/TP model replicated on multiple NVLink domains - ranks_per_domain_list = [[] for _ in range(num_domains)] - mydomain_idx = -1 - for i, domain in enumerate(domain_per_rank_list): - domain_idx = unique_domains.index(domain) - ranks_per_domain_list[domain_idx].append(i) - if domain == mydomain: - mydomain_idx = domain_idx - assert mydomain_idx >= 0, "Internal TE error!" - - intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( + ranks_per_domain_list = [ + [i * tp_size + t for t in range(tp_size)] for i in range(num_domains) + ] + tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( ranks_per_domain_list, backend=bootstrap_backend ) - local_rank = torch.distributed.get_rank(intra_domain_group) - intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) - - inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration( - [list(ranks) for ranks in zip(*ranks_per_domain_list)], - backend=bootstrap_backend, - ) - - helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group) + local_rank = torch.distributed.get_rank(tp_domain_group) + tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group) + helper = tex.CommOverlapHelper(world_group, tp_domain_group) else: # TP model on single NVLink domain, no replication, no data-parallelism mydomain_idx = 0 local_rank = world_rank - intra_domain_ranks = list(range(world_size)) + tp_domain_ranks = list(range(world_size)) helper = tex.CommOverlapHelper(world_group) if world_rank == 0: - print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) + print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True) if local_rank == 0: print( - f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", + f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n", end="", flush=True, ) @@ -295,34 +240,43 @@ def get_method(name): raise KeyError(f"Given layer name {name} does not exist.") def get_default_config(name): + global _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY method = get_method(name) is_reduce_scatter = name in layers_reduce_scatter_overlap + if _MIN_STREAM_PRIORITY is None or _MAX_STREAM_PRIORITY is None: + _MIN_STREAM_PRIORITY, _MAX_STREAM_PRIORITY = tex.get_stream_priority_range() default_cfg = { "method": method, "is_reduce_scatter": is_reduce_scatter, "num_sm": 1 if method == "ring_exchange" else 16, "cga_size": 1 if method == "ring_exchange" else 2, - "set_sm_margin": False, - "num_splits": 4 if method == "pipeline" else tp_size, + "set_sm_margin": not method == "ring_exchange", + "num_splits": tp_size if method == "ring_exchange" else 4, "aggregate": False, "atomic_gemm": False, "use_ce": True, "fp8_buf": name in layers_all_gather_overlap, + "comm_priority": _MAX_STREAM_PRIORITY, + "gemm_priority": _MIN_STREAM_PRIORITY, + "pipeline_rs_overlap_first_gemm": False, } return default_cfg def add_ub( name: str, method: str, - is_reduce_scatter: int, + is_reduce_scatter: bool, num_sm: int = 16, cga_size: int = 2, - set_sm_margin: int = 0, + set_sm_margin: bool = False, num_splits: int = 0, - aggregate: int = 0, - atomic_gemm: int = 0, + aggregate: bool = False, + atomic_gemm: bool = False, use_ce: bool = True, fp8_buf: bool = False, + comm_priority: int = 0, + gemm_priority: int = 0, + pipeline_rs_overlap_first_gemm: bool = False, ) -> None: if atomic_gemm: warnings.warn( @@ -373,6 +327,8 @@ def add_ub( atomic_gemm=atomic_gemm, use_ce=use_ce, aggregate=aggregate, + gemm_priority=gemm_priority, + comm_priority=comm_priority, ) else: ub_obj = tex.CommOverlap( @@ -386,6 +342,9 @@ def add_ub( num_comm_sm=num_sm, set_sm_margin=set_sm_margin, atomic_gemm=atomic_gemm, + gemm_priority=gemm_priority, + comm_priority=comm_priority, + rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) _ub_communicators[name] = ub_obj @@ -439,8 +398,8 @@ def __init__(self) -> None: self.fp8_meta = {} self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None - self.fp8_meta["recipe"] = get_default_fp8_recipe() self.fp8_meta_tensors_initialized = False + self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} self.tp_group = None self.tp_size = 1 self.sequence_parallel = False @@ -448,7 +407,7 @@ def __init__(self) -> None: self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.fsdp_wrapped = False self.fsdp_group = None - self._fp8_workspaces: Dict[str, Float8Tensor] = {} + self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None # Names of attributes that can be set quickly (see __setattr__ @@ -472,7 +431,10 @@ def __setattr__(self, name: str, value: Any) -> None: super().__setattr__(name, value) def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> None: - """Increase or decrease size of amax history based on given `length`. + """ + Delayed scaling only. + + Increase or decrease size of amax history based on given `length`. .. warning:: This changes the underlying amax memory location. @@ -499,6 +461,9 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[meta_key].amax_history, pad=(0, 0, 0, extra_rows) ) + # Update quantizers with new amax pointers. + self.quantizers[meta_key] = self.fp8_meta[meta_key].make_quantizers() + # Update the global buffers with new amax and history pointers. if FP8GlobalStateManager.get_buffer_info() in self.fp8_meta: fwd_pos, fwd_key, bwd_pos, bwd_key = self.fp8_meta[ @@ -516,37 +481,42 @@ def adjust_amax_history_length(self, length: int, fwd: Optional[bool] = None) -> self.fp8_meta[meta_key].amax_history ) - def set_meta_tensor(self, fwd: bool) -> None: + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + # Return early if recipe state matches recipe if self.fp8_meta_tensors_initialized: - # Handle changed amax history size. - self.adjust_amax_history_length(self.fp8_meta["recipe"].amax_history_len, fwd=fwd) - return + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd) + return + if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return + if recipe.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 - self.fp8_meta[fp8_meta_tensor_key] = tex.FP8TensorMeta() - self.fp8_meta[fp8_meta_tensor_key].scale = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].scale_inv = torch.ones( - num_fp8_tensors, dtype=torch.float32, device="cuda" - ) - self.fp8_meta[fp8_meta_tensor_key].amax_history = torch.zeros( - self.fp8_meta["recipe"].amax_history_len, - num_fp8_tensors, - dtype=torch.float32, - device="cuda", + # Initialize recipe state and quantizers + recipe_state = RecipeState.create( + recipe, + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors, ) - def init_fp8_meta_tensors(self) -> None: + self.fp8_meta[fp8_meta_tensor_key] = recipe_state + self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + + def init_fp8_meta_tensors(self, recipe: Recipe) -> None: """Init scales and amaxes.""" - self.set_meta_tensor(True) - self.set_meta_tensor(False) + self.set_meta_tensor(True, recipe) + self.set_meta_tensor(False, recipe) + self.fp8_meta_tensors_initialized = True def get_fp8_meta_tensors(self) -> None: @@ -559,7 +529,6 @@ def get_fp8_meta_tensors(self) -> None: with torch.no_grad(): for key in (fwd_key, bwd_key): fp8_meta_tensors[key].append(self.fp8_meta[key].scale.clone()) - fp8_meta_tensors[key].append(self.fp8_meta[key].scale_inv.clone()) fp8_meta_tensors[key].append(self.fp8_meta[key].amax_history.clone()) return fp8_meta_tensors @@ -570,17 +539,13 @@ def reset(key): if key in self.fp8_meta: if fp8_meta_tensors is None: self.fp8_meta[key].scale.copy_(torch.ones_like(self.fp8_meta[key].scale)) - self.fp8_meta[key].scale_inv.copy_( - torch.ones_like(self.fp8_meta[key].scale_inv) - ) self.fp8_meta[key].amax_history.copy_( torch.zeros_like(self.fp8_meta[key].amax_history) ) else: assert key in fp8_meta_tensors, "Cannot reset fp8 tensors." self.fp8_meta[key].scale.copy_(fp8_meta_tensors[key][0]) - self.fp8_meta[key].scale_inv.copy_(fp8_meta_tensors[key][1]) - self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][2]) + self.fp8_meta[key].amax_history.copy_(fp8_meta_tensors[key][1]) with torch.no_grad(): reset("scaling_fwd") @@ -588,20 +553,50 @@ def reset(key): def get_extra_state(self) -> torch.Tensor: """Save before checkpointing.""" - state = None + # This implementation is working around a few issues: + # + # (1) PyTorch's "extra state" infrastructure might be able to + # support any picklable type, but they make no guarantees. + # We have experienced problems (e.g. in ONNX export) with + # non-tensor extra state. + # (2) PyTorch's checkpointing infrastructure does not remap + # devices for "extra state" like it does for "state dict". + # Thus, we want to avoid putting extra state on the GPU + # since it may be loaded on the wrong device. + # (3) The extra state consists of many small tensors. If we + # want to copy them all to CPU, then we need to avoid the + # overhead of many GPU-CPU memory transfers. + # + # See: https://github.com/NVIDIA/TransformerEngine/pull/351 + # See: https://github.com/NVIDIA/TransformerEngine/pull/363 + + def to_cpu(src: torch.Tensor) -> torch.Tensor: + """Helper function to make CPU copy of tensor + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst = torch.empty_like(src, device="cpu") + dst.copy_(src, non_blocking=True) + return dst + + # Store FP8 state if needed + state = None fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration - if fp8_checkpoint: + + # Copy tensors to CPU and store state = {} - state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale - state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv - state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history - state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale - state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv - state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history - - # Store other pickelable values. + state["recipe"] = self.fp8_meta["recipe"] + if state["recipe"].delayed(): + state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) + state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) + state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) + state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) + + # Store other pickelable values extra = {} for k, v in self.fp8_meta.items(): if k != "buffer_index_and_autocast_key" and isinstance( @@ -610,12 +605,10 @@ def get_extra_state(self) -> torch.Tensor: extra[k] = v state["extra_fp8_variables"] = extra - if is_in_onnx_export_mode(): - state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8) - else: - state_serialized = io.BytesIO() - torch.save(state, state_serialized) - + # Serialize state into byte tensor + torch.cuda.synchronize() + state_serialized = bytearray(pickle.dumps(state)) + state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) return state_serialized def set_extra_state(self, state: torch.Tensor) -> None: @@ -623,9 +616,12 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return + # Load state if isinstance(state, torch.Tensor): + # Default format: byte tensor with pickled data state = pickle.loads(state.detach().cpu().numpy().tobytes()) elif isinstance(state, io.BytesIO): + # Deprecated format with io.BytesIO state.seek(0) state = torch.load(state, map_location="cuda") else: @@ -634,20 +630,31 @@ def set_extra_state(self, state: torch.Tensor) -> None: if state is None: return - # Load extra items. + # Load extra items self.fp8_meta.update(state["extra_fp8_variables"]) - self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0] + self.fp8_meta["recipe"] = state["recipe"] if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta: del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"] - # Initialize before loading. - self.init_fp8_meta_tensors() - self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"]) - self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"]) - self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"]) - self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"]) - self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"]) - self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"]) + # Initialize before loading + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + + def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: + """Helper function to copy tensor from CPU + + Memory transfer is asynchronous w.r.t. host, so GPU should + be synchronized before using result. + + """ + dst.copy_(src, non_blocking=True) + + # Load tensors + if self.fp8_meta["recipe"].delayed(): + copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) + copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) + copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) + copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history) + torch.cuda.synchronize() def set_activation_dtype(self, inp: torch.Tensor) -> None: """Get activation data type for AMP.""" @@ -686,7 +693,7 @@ def _get_fp8_params(self) -> Union[List[torch.Tensor], None]: """returns the FP8 weights.""" fp8_params = [] for param in self.parameters(recurse=False): - if isinstance(param, Float8Tensor) and param.requires_grad: + if isinstance(param, QuantizedTensor) and param.requires_grad: fp8_params.append(param) if len(fp8_params) == 0: return None @@ -699,22 +706,28 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() self.fp8 = FP8GlobalStateManager.is_fp8_enabled() self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration - if self.fp8_parameters and not self.fp8_initialized: - self.fp8_meta["num_gemms"] = num_gemms - self.init_fp8_meta_tensors() - - if self.fp8 or self.fp8_calibration: - # FP8 init has already been run and recipe is the same, don't do anything. + if self.fp8_parameters or fp8_enabled: if ( self.fp8_initialized and FP8GlobalStateManager.get_fp8_recipe() == self.fp8_meta["recipe"] ): + # FP8 init has already been run and recipe is the same, don't do anything. return - - # Set FP8, recipe, and other FP8 metadata self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + + if fp8_enabled: + # Set FP8 and other FP8 metadata self.fp8_meta["num_gemms"] = num_gemms self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() @@ -723,17 +736,15 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes - self.init_fp8_meta_tensors() + self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) self.fp8_initialized = True - else: - # If fp8 isn't enabled, turn off and return. - self.fp8_initialized = False + + self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() @contextmanager def prepare_forward( self, inp: torch.Tensor, - is_first_microbatch: Union[bool, None], # pylint: disable=unused-argument num_gemms: int = 1, allow_non_contiguous: bool = False, ) -> Generator[torch.Tensor, None, None]: @@ -755,16 +766,14 @@ def prepare_forward( self.set_activation_dtype(inp) self.init_fp8_metadata(num_gemms=num_gemms) - if self.fp8 and self.sequence_parallel: + if self.fp8 and self.sequence_parallel and self.fp8_meta["recipe"].delayed(): assert self.fp8_meta["recipe"].reduce_amax, ( "Amax reduction across tensor parallel group is " "necessary when using sequence parallelism with FP8." ) if self.fp8 and not FP8GlobalStateManager.fp8_graph_capturing(): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.fp8_meta, fp8_weights=self._get_fp8_params() - ) + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) # Activation recomputation is used and this is the first forward phase. if self.fp8 and self.training and is_fp8_activation_recompute_enabled(): @@ -797,110 +806,64 @@ def set_nccl_overlap_warning_if_tp(self) -> None: @staticmethod def grad_output_preprocess( - ctx, grad_output: torch.Tensor, row_parallel_mode: bool + ctx, + grad_output: torch.Tensor, + row_parallel_mode: bool, + quantizer: Optional[Quantizer], ) -> Tuple[Union[torch.Tensor, None], ...]: """Utility function for backward. Returns tuple in order (all optional/None based on training precion/recipe): - R1: gathered `grad_output` in higher precision. - R2: gathered `grad_output` in FP8. - R3: R2 transposed. - R4: bias gradient on R1. + R1: gathered `grad_output`. + R2: bias gradient on R1. """ - if isinstance(grad_output, Float8Tensor): - grad_output._data = grad_output._data.contiguous() - else: - grad_output = grad_output.contiguous() - grad_output_mat = grad_output.view(-1, grad_output.shape[-1]) + grad_output = grad_output.reshape((-1, grad_output.shape[-1])) + grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - # No-FP8 case: bgrad is fused with wgrad for this case. + # Non-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: if not ctx.ub_overlap_ag: - grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) + grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) else: - ctx.ub_obj_gradout.copy_input_to_ubuf(grad_output, True) - grad_output_mat = ctx.ub_obj_gradout.get_ubuf_output(1) - return grad_output_mat, None, None, None - - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - # FP8 case with non-FP8 wgrad - if gather_grad_output and ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - assert ( - not ctx.ub_overlap_ag - ), "override_linear_precision.wgrad not supported with UB AG overlap" - grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) - # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather - elif gather_grad_output: + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) + return grad_output, None + + # FP8 with all-gather: unfused bgrad, fused cast + transpose + if gather_grad_output: + grad_bias = None if ctx.use_bias: - grad_bias = grad_output_mat.sum(dim=0) - else: - grad_bias = None + grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) if ctx.ub_overlap_ag: - grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) + # Quantize the gradient if needed + if not isinstance( + grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase) + ): + grad_output = quantizer(grad_output) + + # Copy into communication buffer, and replace original gradient with it + ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) + grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) else: - grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) - if not isinstance(grad_output_mat, Float8Tensor): - cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - out=grad_output_c, + grad_output, _ = gather_along_first_dim( + grad_output, + ctx.tp_group, + quantizer=quantizer, ) - else: - grad_output_c = grad_output_mat - if not ctx.ub_overlap_ag: - grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) - if not isinstance(grad_output_c, Float8Tensor): - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - else: - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(1) - grad_output_t = None - - return grad_output_mat, grad_output_c, grad_output_t, grad_bias + return grad_output, grad_bias - # FP8 case without gather: cast, transpose, bgrad fused + # FP8 without all-gather: fused bgrad + cast + transpose + grad_bias = None if ctx.use_bias: - grad_output_mat_no_fp8 = grad_output_mat - if isinstance(grad_output_mat, Float8Tensor): - grad_output_mat_no_fp8 = grad_output_mat.from_float8(grad_output_mat.dtype) - grad_bias, grad_output_c, grad_output_t = fp8_cast_transpose_bgrad_fused( - grad_output_mat_no_fp8, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if isinstance(grad_output_mat, Float8Tensor): - grad_output_c = grad_output_mat - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_c, grad_output_t = fp8_cast_transpose_fused( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) + if isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - grad_output_t = None - if not isinstance(grad_output_mat, Float8Tensor): - grad_output_c = cast_to_fp8( - grad_output_mat, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ) - else: - grad_output_c = grad_output_mat - grad_bias = None - - return grad_output_mat, grad_output_c, grad_output_t, grad_bias + grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) + if not isinstance(grad_output, (QuantizedTensor, Float8TensorBase, MXFP8TensorBase)): + grad_output = quantizer(grad_output) + return grad_output, grad_bias def register_parameter(self, name, param, **kwargs): """ @@ -937,21 +900,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: with get_rng_state_tracker().fork(): init_fn(param) - # If primary weights are in fp8, wrap the parameter as Float8Tensor + # If primary weights are in fp8, wrap the parameter as FP8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: - dummy_amax = torch.empty( - (1, 1), - dtype=torch.float32, - device=param.device, - ) # Dummy buffer to avoid overwriting amax history - param = Float8Tensor.to_float8( - param, - fp8_meta=self.fp8_meta, - fp8_meta_index=fp8_meta_index, - amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), - ) + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + assert ( + quantizer is not None + ) # to use primary fp8 weight one needs to use FP8 autocast with specific recipe. + quantizer.internal = False + param = quantizer(param) # Redo parameter wrap in case we broke it above # NOTE: Currently this can only be broken when primary weights are in Fp8 but @@ -963,17 +920,16 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: def forward(self): """Needs override.""" - def get_fp8_workspace( + def get_weight_workspace( self, *, tensor: Optional[torch.Tensor] = None, - fp8_meta_forward: Optional[bool] = None, - fp8_meta_index: Optional[int] = None, + quantizer: Optional[Quantizer] = None, cache_name: Optional[str] = None, update_workspace: bool = True, skip_update_flag: Optional[torch.Tensor] = None, - fsdp_group: dist_group_type = None, - ) -> Float8Tensor: + fsdp_group: Optional[dist_group_type] = None, + ) -> QuantizedTensor: """Get FP8 workspace buffer and maybe update its values The workspace buffer may be cached for future function calls. @@ -983,13 +939,9 @@ def get_fp8_workspace( tensor : torch.Tensor, optional Values to copy into workspace. Required if the workspace is being constructed or updated. - fp8_meta_forward: bool, optional - Whether to access FP8 meta tensors for the forward pass or - backward pass. Required if the workspace is being - constructed. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if the - workspace is being constructed. + quantizer: Quantizer, optional + Quantizer used to cast the weights. Required if the + workspace is being constructed or updated. cache_name: str, optional Key for caching. update_workspace: bool, default = `True` @@ -1011,61 +963,24 @@ def get_fp8_workspace( # for models initialized with Fp8 primary weights. if ( out is not None - and not isinstance(out, Float8Tensor) + and tensor is not None and fsdp_group is not None - and out._data.shape != tensor.data.shape + and out.data.shape != tensor.data.shape ): _fsdp_gather_tensors(fsdp_group, [tensor.data.shape], out) # Construct workspace if needed if out is None: - - # FP8 data - if tensor is None or fp8_meta_forward is None or fp8_meta_index is None: + if tensor is None or quantizer is None: raise ValueError( - "tensor, fp8_meta_forward, and fp8_meta_index kwargs " - "must be provided to construct FP8 workspace" - ) - fp8_dtype = get_fp8_te_dtype( - self.fp8_meta["recipe"], - fprop_tensor=fp8_meta_forward, - ) - data = torch.empty_like(tensor, dtype=torch.uint8) - scale_inv = torch.empty([1], dtype=torch.float32, device=tensor.device) - - # Transpose cache - with_transpose_cache = torch.is_grad_enabled() - if ( - not with_transpose_cache - and is_fp8_activation_recompute_enabled() - and not in_fp8_activation_recompute_phase() - ): - with_transpose_cache = True - data_transpose = None - if with_transpose_cache: - data_transpose = torch.empty( - (tensor.size(-1), tensor.numel() // tensor.size(-1)), - dtype=torch.uint8, - device=tensor.device, + "tensor and quantizer kwargs must be provided to construct FP8 workspace" ) - - # Construct FP8 tensor - out = Float8Tensor( - data=data, - fp8_meta=self.fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, - dtype=tensor.dtype, - data_transpose=data_transpose, - ) + out = quantizer(tensor) # Update cache if cache_name is not None: self._fp8_workspaces[cache_name] = out - update_workspace = True - skip_update_flag = None + return out # Update workspace if needed if skip_update_flag is not None: @@ -1073,17 +988,10 @@ def get_fp8_workspace( if update_workspace: if tensor is None: raise ValueError("tensor kwarg must be provided to update FP8 workspace") - if is_in_onnx_export_mode(): - # ONNX export does not support fused cast-transpose - # kernel and requires that FP8 scales can be - # represented with constant ops. - transpose_cache = out._transpose - out._transpose = None - out.quantize_(tensor) - out._scale_inv.fill_(out._scale_inv.item()) - out._transpose = transpose_cache - else: + if hasattr(out, "quantize_"): out.quantize_(tensor, noop_flag=skip_update_flag) + else: + tex.quantize(tensor, quantizer, out, skip_update_flag) return out diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 16d40cf401..2549d45728 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,9 +8,8 @@ import torch -from ..cpp_extensions import ( - multi_padding_fused, -) +import transformer_engine_torch as tex + from ..jit import no_torch_dynamo @@ -36,7 +35,7 @@ def forward( total_row = sum(padded_m_splits) out = torch.empty([total_row, in_features], dtype=inp.dtype, device=inp.device) - multi_padding_fused(inp.view(-1, in_features), m_splits, padded_m_splits, out) + tex.fused_multi_row_padding(inp.view(-1, in_features), out, m_splits, padded_m_splits) if is_grad_enabled: ctx.m_splits = m_splits diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index d45abe0668..479b91d396 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,9 +8,8 @@ import torch -from ..cpp_extensions import ( - multi_padding_fused, -) +import transformer_engine_torch as tex + from ..jit import no_torch_dynamo @@ -56,8 +55,8 @@ def backward(ctx, grad_output: torch.Tensor): [total_row, in_features], dtype=grad_output.dtype, device=grad_output.device ) # FP8 pad input for forward, FP8 input transpose for backward wgrad - multi_padding_fused( - grad_output.view(-1, in_features), ctx.m_splits, ctx.padded_m_splits, grad_input + tex.fused_multi_row_padding( + grad_output.view(-1, in_features), grad_input, ctx.m_splits, ctx.padded_m_splits ) return (grad_input, None, None, None) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 08c5addcfc..8bf420ab0e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -1,9 +1,9 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """GroupedLinear API""" -from typing import Union, Optional, Callable, Tuple, List, Dict, Any +from typing import Union, Optional, Callable, Tuple, List import torch @@ -16,7 +16,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, @@ -28,21 +28,26 @@ from ..distributed import ( set_tensor_model_parallel_attributes, get_distributed_world_size, + is_fp8_activation_recompute_enabled, + in_fp8_activation_recompute_phase, ) from ..cpp_extensions import ( - cast_to_fp8, - fp8_cast_transpose_bgrad_fused, - fp8_multi_cast_transpose_fused, - fp8_grouped_gemm, - grouped_gemm, + general_grouped_gemm, ) -from ..constants import GemmParallelModes, dist_group_type +from ..constants import GemmParallelModes, dist_group_type, TE_DType from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..tensor import Float8Tensor, QuantizedTensor -from ..export import is_in_onnx_export_mode +from ..tensor.float8_tensor import Float8Tensor from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) + + __all__ = ["GroupedLinear"] @@ -60,202 +65,145 @@ def forward( is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], + input_quantizers: List[Quantizer], + weight_quantizers: List[Quantizer], + output_quantizers: List[Quantizer], + grad_output_quantizers: List[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, sequence_parallel: bool, activation_dtype: torch.dtype, - fp8_meta_offsets: Dict[str, int], is_grad_enabled: bool, - weights_fp8: List[Union[Float8Tensor, None]], - *weights_and_biases: Union[Float8Tensor, torch.Tensor, None], + module, + skip_fp8_weight_update, + *weights_and_biases, ) -> torch.Tensor: + # pylint: disable=missing-function-docstring num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] biases = weights_and_biases[num_gemms:] + device = inp.device + + # TODO Support MXFP8 # pylint: disable=fixme + if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8(): + raise NotImplementedError("GroupedLinear does not yet support MXFP8") + # TODO Support Float8 Current Scaling # pylint: disable=fixme + if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): + raise NotImplementedError("GroupedLinear does not yet support Float8 Current Scaling") # Make sure input dimensions are compatible in_features = weights[0].shape[-1] assert inp.shape[-1] == in_features, "GEMM not possible" inputmats = torch.split(inp.view(-1, in_features), m_splits) if fp8: - for i in range(num_gemms): - assert_dim_for_fp8_exec(inputmats[i]) - assert_dim_for_fp8_exec(weights[i]) + assert_dim_for_fp8_exec(*inputmats, *weights) # Cast input to expected dtype inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats] inputmats = [] - inputmats_t = [] - inputmat_scale_inv = None - if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - inputmat_scale_inv = torch.empty([num_gemms], dtype=torch.float32, device=inp.device) - if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weights[0].requires_grad - and not sequence_parallel - ): - # FP8 input for forward, FP8 input transpose for backward wgrad - indices = list( - range(fp8_meta_offsets["input"], fp8_meta_offsets["input"] + num_gemms) + weight_requires_grad = weights[0].requires_grad + + if input_quantizers[0] is not None: + for input_quantizer in input_quantizers: + input_quantizer.set_usage( + rowwise=True, + columnwise=(is_grad_enabled and weight_requires_grad), ) - inputmats, inputmats_t = fp8_multi_cast_transpose_fused( - inputmats_no_fp8, - fp8_meta["scaling_fwd"], - indices, # scale_indices - indices, # amax_indices - indices, # scale_inv_indices - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() ) - else: - # FP8 input for forward - inputmats = [ - cast_to_fp8( - inputmats_no_fp8[i], - fp8_meta["scaling_fwd"], - fp8_meta_offsets["input"] + i, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, + if weight_quantizers[0] is not None: + for weight_quantizer in weight_quantizers: + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + if output_quantizers[0] is not None: + for output_quantizer in output_quantizers: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + if fp8: + inputmats = tex.fused_multi_quantize( + inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype] + ) + weights_fp8 = [] + bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype + if not isinstance(weights[0], QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + for i in range(num_gemms): + weight_fp8 = module.get_weight_workspace( + tensor=weights[i], + quantizer=weight_quantizers[i], + cache_name=(None if is_first_microbatch is None else f"weight{i}"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, ) - for i in range(num_gemms) - ] + weights_fp8.append(weight_fp8) + else: + weights_fp8 = weights - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - inputmat_scale_inv.fill_(inputmat_scale_inv.item()) else: inputmats = inputmats_no_fp8 + bias_dtype = activation_dtype + weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights] - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases - - # Use FP8 weights - if weights_fp8[0] is None: - weights_fp8 = weights - assert all(isinstance(w, Float8Tensor) for w in weights_fp8) + biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases - out = torch.empty( - [sum(m_splits), weights_fp8[0].size(0)], - dtype=activation_dtype, - device=inputmats[0].device, - ) + out = torch.empty( + [sum(m_splits), weights_fp8[0].size(0)], + dtype=activation_dtype, + device=device, + ) - _ = fp8_grouped_gemm( - [w._data for w in weights_fp8], - [w._scale_inv for w in weights_fp8], - 0, # weight offset is 0 for the newly created _scale_inv - fp8_dtype_forward, - inputmats, - inputmat_scale_inv, - 0, - fp8_dtype_forward, - [out], - activation_dtype, - get_multi_stream_cublas_workspace(), - m_splits=m_splits, - bias=biases, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ) - else: - # Cast for native AMP - weights = [cast_if_needed(w, activation_dtype) for w in weights] - biases = ( - [cast_if_needed(bias, activation_dtype) for bias in biases] if use_bias else biases - ) + _ = general_grouped_gemm( + weights_fp8, + inputmats, + [out], + activation_dtype, + get_multi_stream_cublas_workspace(), + single_output=True, + m_splits=m_splits, + bias=biases, + use_bias=use_bias, + use_split_accumulator=_2X_ACC_FPROP, + ) - if fp8_calibration: + if fp8_calibration: + for i in range(num_gemms): + # amax of input for i in range(num_gemms): - # amax of input - amin, amax = inputmats[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["input"] + i] = ( - torch.max(-amin, amax).float() - ) - # amax of weight - amin, amax = weights[i].aminmax() - fp8_meta["scaling_fwd"].amax_history[0][fp8_meta_offsets["weight"] + i] = ( - torch.max(-amin, amax).float() - ) + input_quantizers[i].calibrate(inputmats[i]) + for i in range(num_gemms): + weight_quantizers[i].calibrate(weights[i]) - out = torch.empty( - [sum(m_splits), weights[0].size(0)], - dtype=activation_dtype, - device=inputmats[0].device, - ) + if is_grad_enabled: - _ = grouped_gemm( - weights, - inputmats, - torch.split(out, m_splits), - activation_dtype, - get_multi_stream_cublas_workspace(), - bias=biases, - use_bias=use_bias, - ) + ctx.weights_shape_1 = weights[0].shape[1] - if is_grad_enabled: - saved_inputmats = [None] * num_gemms - saved_inputmats_t = [None] * num_gemms - if weights[0].requires_grad: - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: - if not inputmats_t: - saved_inputmats = inputmats - else: - saved_inputmats_t = inputmats_t - if cpu_offloading: - for t in saved_inputmats_t: - t.activation_offloading = True - else: - saved_inputmats = inputmats_no_fp8 - - if cpu_offloading: - if fp8: - for w in weights_fp8: - if w is not None: - w.weight_offloading = True - for w in weights: - w.weight_offloading = True - for t in saved_inputmats: - if t is not None: - t.activation_offloading = True - - ctx.save_for_backward( - inputmat_scale_inv, - *saved_inputmats, - *saved_inputmats_t, - *weights, - *weights_fp8, - *[ - w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None - for w in weights - ], - ) + tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.weights_requires_grad = weights[0].requires_grad + if fuse_wgrad_accumulation and ctx.weights_requires_grad: + ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] + else: + ctx.main_grads = [None] * num_gemms + ctx.device = device + ctx.grad_output_quantizers = grad_output_quantizers ctx.m_splits = m_splits ctx.num_gemms = num_gemms ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel ctx.inp_shape = inp.shape - ctx.fp8_meta_offsets = fp8_meta_offsets ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): @@ -271,66 +219,42 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_GroupedLinear_backward"): - ( - inputmat_scale_inv, - *saved_tensors, - ) = ctx.saved_tensors - inputmats = saved_tensors[: ctx.num_gemms] - inputmats_t = saved_tensors[ctx.num_gemms : 2 * ctx.num_gemms] - weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms] - weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] - main_grads = saved_tensors[4 * ctx.num_gemms :] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) + N = ctx.num_gemms + inputmats = saved_tensors[:N] + weights = saved_tensors[N : 2 * N] + biases = saved_tensors[2 * N : 3 * N] + main_grads = ctx.main_grads + + if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO for i in ctx.num_gemms: w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w.main_grad = main_grads[i] weights[i] = w # preprocess grad_output + grad_output = grad_output.contiguous() grad_output_mats = torch.split( grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits ) - grad_output_c = [None] * ctx.num_gemms - grad_output_t = [None] * ctx.num_gemms + grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) if ctx.use_bias: for i in range(ctx.num_gemms): - grad_biases[i], grad_output_c[i], grad_output_t[i] = ( - fp8_cast_transpose_bgrad_fused( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - ctx.fp8_meta_offsets["grad_output"] + i, - fp8_dtype_backward, - ) + grad_biases[i], grad_output[i] = tex.bgrad_quantize( + grad_output_mats[i], ctx.grad_output_quantizers[i] ) else: - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - indices = list( - range( - ctx.fp8_meta_offsets["grad_output"], - ctx.fp8_meta_offsets["grad_output"] + ctx.num_gemms, - ) - ) - grad_output_c, grad_output_t = fp8_multi_cast_transpose_fused( - grad_output_mats, - ctx.fp8_meta["scaling_bwd"], - indices, # scale_indices - indices, # amax_indices - indices, # scale_inv_indices - fp8_dtype_backward, - ) - else: - for i in range(ctx.num_gemms): - grad_output_c[i] = cast_to_fp8( - grad_output_mats[i], - ctx.fp8_meta["scaling_bwd"], - ctx.fp8_meta_offsets["grad_output"] + i, - fp8_dtype_backward, - ) + grad_output = tex.fused_multi_quantize( + grad_output_mats, + None, + ctx.grad_output_quantizers, + TE_DType[ctx.activation_dtype], + ) + else: + grad_output = grad_output_mats if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -340,111 +264,58 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation if ctx.requires_dgrad: - if ctx.fp8: - dgrad = torch.empty( - (sum(ctx.m_splits), weights_fp8[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - fp8_grouped_gemm( - [w.transpose_2d() for w in weights_fp8], - [w._scale_inv for w in weights_fp8], - 0, # weight offset is 0 for the newly created _scale_inv - weights_fp8[0]._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - ctx.fp8_meta_offsets["grad_output"], - fp8_dtype_backward, - [dgrad], - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - m_splits=ctx.m_splits, - use_split_accumulator=_2X_ACC_DGRAD, - ) - else: - dgrad = torch.empty( - (sum(ctx.m_splits), weights[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - grouped_gemm( - weights, - grad_output_mats, - torch.split(dgrad, ctx.m_splits), - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NN", - grad=True, - ) + dgrad = torch.empty( + (sum(ctx.m_splits), ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) - if weights[0].requires_grad: + general_grouped_gemm( + weights, + grad_output, + [dgrad], + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + single_output=True, + layout="NN", + m_splits=ctx.m_splits, + grad=True, + use_split_accumulator=_2X_ACC_DGRAD, + ) + + if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation: - wgrad_list = [w.main_grad for w in weights] + wgrad_list = main_grads else: wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) + torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights ] - if ctx.fp8: - # WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmats_t[0] is None: - for i in range(ctx.num_gemms): - if isinstance(inputmats[i], Float8Tensor): - inputmats_t[i] = inputmats[i].transpose_2d() - else: - inputmats_t[i] = tex.fp8_transpose( - inputmats[i], fp8_dtype_backward - ) - fp8_grouped_gemm( - [ - inp._data if isinstance(inp, Float8Tensor) else inp - for inp in inputmats_t - ], - [inputmat_scale_inv], - 0, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - ctx.fp8_meta_offsets["grad_output"], - fp8_dtype_backward, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - use_split_accumulator=_2X_ACC_WGRAD, - ) - else: - grouped_gemm( - inputmats, - grad_output_mats, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - ) - else: - # WGRAD - _, grad_biases, _ = grouped_gemm( - inputmats, - grad_output_mats, - wgrad_list, - ctx.activation_dtype, - get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - ) + # WGRAD + _, grad_biases_, _ = general_grouped_gemm( + inputmats, + grad_output, + wgrad_list, + ctx.activation_dtype, + get_multi_stream_cublas_workspace(), + layout="NT", + grad=True, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=_2X_ACC_WGRAD, + accumulate=accumulate_wgrad_into_param_main_grad, + ) + for i in range(ctx.num_gemms): + if grad_biases[i] is None: + grad_biases[i] = grad_biases_[i] + del grad_biases_ # Deallocate input tensor clear_tensor_data(*inputmats) - clear_tensor_data(*inputmats_t) def handle_custom_ddp_from_mcore(w, wgrad): - if w.requires_grad: + if ctx.weights_requires_grad: if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): w.grad_added_to_main_grad = True if getattr(w, "zero_out_wgrad", False): @@ -478,22 +349,24 @@ def handle_custom_ddp_from_mcore(w, wgrad): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) - return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, - None, # m_splits - None, # use_bias - None, # is_first_microbatch - None, # fp8 - None, # fp8_calibration - None, # fp8_meta - None, # fuse_wgrad_accumulation - None, # cpu_offloading - None, # sequence_parallel - None, # activation_dtype - None, # fp8_meta_offsets + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, # is_grad_enabled None, # is_grad_enabled - None, # weights_fp8 *wgrad_list, *grad_biases, ) @@ -718,7 +591,7 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, num_gemms=self.num_gemms) as inp: + with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] @@ -727,29 +600,32 @@ def forward( w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors ] - weight_tensors_fp8 = [None] * self.num_gemms + input_quantizers, weight_quantizers, output_quantizers = ( + [None] * self.num_gemms, + [None] * self.num_gemms, + [None] * self.num_gemms, + ) + grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms if self.fp8: + input_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] for i in range(self.num_gemms): - if isinstance(weight_tensors[i], Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensors[i]._transpose is not None: - weight_tensors[i].transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_tensors_fp8[i] = self.get_fp8_workspace( - tensor=weight_tensors[i], - fp8_meta_forward=True, - fp8_meta_index=self._offsets["weight"] + i, - cache_name=(None if is_first_microbatch is None else f"weight{i}"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) + input_quantizers[i].internal = True + weight_quantizers = [ + self.quantizers["scaling_fwd"][self._offsets["weight"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + weight_quantizers[i].internal = True + if torch.is_grad_enabled(): + grad_output_quantizers = [ + self.quantizers["scaling_bwd"][self._offsets["input"] + i] + for i in range(self.num_gemms) + ] + for i in range(self.num_gemms): + grad_output_quantizers[i].internal = True if torch.is_grad_enabled(): linear_fn = _GroupedLinear.apply @@ -764,14 +640,17 @@ def forward( is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, + input_quantizers, + weight_quantizers, + output_quantizers, + grad_output_quantizers, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.sequence_parallel, self.activation_dtype, - self._offsets, torch.is_grad_enabled(), - weight_tensors_fp8, + self, + skip_fp8_weight_update, *weight_tensors, *bias_tensors, ) diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index 32142cf48c..61aa69818a 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -61,15 +61,32 @@ class LayerNorm(_LayerNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( @@ -87,6 +104,9 @@ def __init__( # Flag for sequence parallelism (custom Megatron-LM integration) self.sequence_parallel: Optional[bool] = sequence_parallel + if sequence_parallel is not None: + self.weight.sequence_parallel = sequence_parallel + self.bias.sequence_parallel = sequence_parallel def reset_layer_norm_parameters(self) -> None: """Init LN params""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fbf1b97704..4022924861 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1,17 +1,20 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """LayerNormLinear API""" import os import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn import init -from .. import cpp_extensions as tex +import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, get_ub, @@ -20,14 +23,16 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..utils import ( + assert_dim_for_fp8_exec, + cast_if_needed, + clear_tensor_data, divide, get_default_init_method, init_method_constant, - cast_if_needed, - assert_dim_for_fp8_exec, - clear_tensor_data, + nvtx_range_pop, + nvtx_range_push, requires_grad, ) from ..distributed import ( @@ -40,14 +45,23 @@ _fsdp_scatter_tensors, _fsdp_gather_tensors, ) -from ..constants import GemmParallelModes, dist_group_type, TE_DType +from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import _apply_normalization, _noop_cat -from ..float8_tensor import Float8Tensor -from ..export import is_in_onnx_export_mode -from ..tensor import QuantizedTensor -from ..cpu_offload import is_cpu_offload_enabled +from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..cpp_extensions import ( + general_gemm, +) __all__ = ["LayerNormLinear"] @@ -64,15 +78,18 @@ def forward( ln_weight: torch.Tensor, ln_bias: Union[torch.Tensor, None], weight: torch.Tensor, - weight_fp8: Optional[torch.Tensor], bias: torch.Tensor, use_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -87,257 +104,321 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, normalization: str, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_ag: bool, ub_name: str, - fp8_output: bool, fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring + + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(weight) + assert_dim_for_fp8_exec(inputmat, weight) # Cast for native AMP + nvtx_range_push(f"{nvtx_label}.norm_input_cast") inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + nvtx_range_pop(f"{nvtx_label}.norm_input_cast") - if ub_overlap_ag: - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1 or (not is_grad_enabled): - ub_overlap_ag = False - if ub_overlap_ag: - dim_size = list(inputmat.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ub_name + "_fprop") - if return_layernorm_output: - # First prepare LN output in higher precision, - # which will be later copied to a FP8 UB - ln_out = torch.empty_like(inputmat, memory_format=torch.contiguous_format) + tp_world_size = get_distributed_world_size(tp_group) + ub_overlap_ag_fprop = ( + ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output + ) + + weight_requires_grad = weight.requires_grad + backward_needs_input = is_grad_enabled and weight_requires_grad + with_input_all_gather = parallel_mode == "column" and sequence_parallel + + if fp8: + if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" + ) + + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + + # Configure quantizer for normalization output + with_quantized_norm = fp8 and not return_layernorm_output + # for Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer + # so we need to set with_quantized_norm to False + if isinstance(input_quantizer, Float8CurrentScalingQuantizer): + with_quantized_norm = False + + if with_quantized_norm: + if with_input_all_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(input_quantizer, MXFP8Quantizer): + with_quantized_norm = False else: - ln_out = ub_obj_lnout.get_ubuf_output(0) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, + ) + + # Reduce duplicated transpose in `_fix_gathered_fp8_transpose` + if ( + fp8 + and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + and ub_bulk_dgrad + ): + input_quantizer.set_usage(rowwise=True, columnwise=False) + + ub_obj_fprop = None + ln_out = None + # For DelayScaling, output of normalization will be in fp8. + # For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8. + if ub_overlap_ag_fprop and not isinstance(input_quantizer, Float8CurrentScalingQuantizer): + ub_obj_fprop = get_ub(ub_name + "_fprop") + ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) + elif with_quantized_norm: + if with_input_all_gather: + input_quantizer.set_usage(rowwise=True, columnwise=False) + ln_out = input_quantizer.make_empty(inputmat.shape, dtype=inputmat.dtype, device="cuda") else: - ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like( - inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" ) - # Objects for FP8 cast - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - ln_out_scale_inv = None - if fp8: - ln_out_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) - - # Launch normalization kernel - ln_out, mu, rsigma = _apply_normalization( + # Apply normalization + nvtx_range_push(f"{nvtx_label}.norm") + ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, ln_weight, ln_bias, eps, - fp8 and not return_layernorm_output, - fp8_meta, + input_quantizer if with_quantized_norm else None, + inp.dtype, normalization, fwd_ln_sm_margin, zero_centered_gamma, - is_grad_enabled, - fp8_scale_inv=ln_out_scale_inv, ) - - # Column Parallel Linear - ln_out_gathered = False - ub_algo = None - if ub_overlap_ag: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - if not return_layernorm_output: - ln_out = torch.empty_like(ln_out) - if ub_obj_lnout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + ln_out_return = ln_out if return_layernorm_output else None + nvtx_range_pop(f"{nvtx_label}.norm") + + # For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer. + # So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer. + if ub_overlap_ag_fprop and isinstance(input_quantizer, Float8CurrentScalingQuantizer): + ub_obj_fprop = get_ub(ub_name + "_fprop") + ln_out_local = ln_out + ln_out = ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True) + input_quantizer.quantize(ln_out_local, out=ln_out) + + # Prepare GEMM input + # Note: Cast to expected dtype and perform tensor-parallel communication + nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") + if with_input_all_gather and not ub_overlap_ag_fprop: + with_quantized_all_gather = fp8 + if return_layernorm_output and return_layernorm_output_gathered: + with_quantized_all_gather = False + if fp8: + input_quantizer.set_usage(rowwise=True, columnwise=False) + # ln_out in this has two possibilities: + # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel + # 2. in high precision, then we need to cast it and then gather in FP8 + # the output ln_out_total will be in FP8, and it's a full tensor + ln_out_total, _ = gather_along_first_dim( + ln_out, + tp_group, + quantizer=(input_quantizer if with_quantized_all_gather else None), + ) + if return_layernorm_output and return_layernorm_output_gathered: + ln_out_return = ln_out_total + if fp8 and not with_quantized_all_gather: + ln_out_total = input_quantizer(ln_out_total) + else: + if ub_overlap_ag_fprop: + ln_out_total = ub_obj_fprop.get_buffer(input_quantizer) else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - elif parallel_mode == "column" and sequence_parallel: - ln_out_gathered = True - ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) + if fp8: + if not isinstance(ln_out, QuantizedTensor): + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + ln_out = input_quantizer(ln_out) + elif backward_needs_input: + ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) + ln_out_total = ln_out + nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") + + # Cast weight to expected dtype + weightmat = weight + quantized_weight = False + if not fp8: + weightmat = cast_if_needed(weightmat, activation_dtype) else: - ln_out_total = ln_out + if not isinstance(weight, QuantizedTensor): + quantized_weight = True + + # Configure quantizer + if weight_quantizer is not None: + weight_quantizer.set_usage(rowwise=True, columnwise=True) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) - # If residual connection is after LN, we need `ln_out_return` - # tensor in higher precision, this comes at the cost - # of an extra fp8 cast. - if return_layernorm_output: - ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out + # Cast bias to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(ln_out_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + ub_obj = None + ub_type = None + rs_out = None + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG if fp8: - if ub_overlap_ag: - ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0) - tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - out=ln_out_fp8, - scale_inv=ln_out_scale_inv, - ) - ln_out = torch.empty_like(ln_out_fp8) - else: - ln_out_total = tex.cast_to_fp8( - ln_out_total, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=ln_out_scale_inv, - ) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[slice_start:slice_end, ...] - else: - ln_out = ln_out_total + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ln_out_total = ub_obj.get_buffer(input_quantizer) + nvtx_range_push(f"{nvtx_label}.gemm") + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - - # Use FP8 weights - if weight_fp8 is None: - weight_fp8 = weight - - assert isinstance(weight_fp8, Float8Tensor) - - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - ln_out_scale_inv.fill_(ln_out_scale_inv.item()) - - if fp8_output: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - torch.uint8, - ) - else: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, - None, - None, - activation_dtype, - ) - out, _ = tex.fp8_gemm( - weight_fp8._data, - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ln_out_total, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - output_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=output_te_dtype, - ) - if output_dtype == torch.uint8: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, - ) - else: - # Cast for native AMP - weight = cast_if_needed(weight, activation_dtype) - bias = cast_if_needed(bias, activation_dtype) if use_bias else bias - - if fp8_calibration: - # amax of input - amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - out, _, _ = tex.gemm( - weight, - ln_out_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - ) + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + + out, *_, rs_out = general_gemm( + weightmat, + ln_out_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=activation_dtype, + bias=bias, + use_split_accumulator=fprop_gemm_use_split_accumulator, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, + ) + nvtx_range_pop(f"{nvtx_label}.gemm") + + if not weight.requires_grad: + if not return_layernorm_output: + ln_out = ln_out_total = None + clear_tensor_data(ln_out, ln_out_total) if is_grad_enabled: + ctx.ln_out_needs_gather = ( + weight.requires_grad and parallel_mode == "column" and sequence_parallel + ) + + # Input with column-wise usage is needed for dgrad GEMM. + if backward_needs_input: + if isinstance(ln_out, QuantizedTensor): + # For sequence parallel in vanilla FP8, rowwise data is + # to gather the input. For MXFP8, columnwise only data + # can be allgathered. + if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: + ln_out.update_usage(rowwise_usage=False) + if cpu_offloading: - if fp8 and weight_fp8 is not None: - weight_fp8.weight_offloading = True - ln_weight.weight_offloading = True - weight.weight_offloading = True + if fp8 and weightmat is not None: + set_offloading_param(weightmat, "weight_offloading", True) + set_offloading_param(ln_weight, "weight_offloading", True) + set_offloading_param(weight, "weight_offloading", True) - inputmat.activation_offloading = True - if normalization == "LayerNorm": - mu.activation_offloading = True - rsigma.activation_offloading = True - ln_out.activation_offloading = True + set_offloading_param(inputmat, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(rsigma, "activation_offloading", True) + set_offloading_param(ln_out, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, mu, rsigma, - weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, + weightmat if quantized_weight else None, ln_out if weight.requires_grad else None, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - ctx.save_for_backward( + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_object = weight + + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, + weightmat, + weight, + bias, ln_weight, + ln_out.clone() if ub_overlap_ag_fprop else ln_out, # avoid saving a UB buffer mu, rsigma, - weight, - weight_fp8, - weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, - ln_out if weight.requires_grad else None, - ln_out_scale_inv, ) - + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.requires_dgrad = inp.requires_grad + ctx.requires_wgrad = weight.requires_grad + ctx.quantized_weight = quantized_weight + if fuse_wgrad_accumulation and weight.requires_grad: + ctx.main_grad = weight.main_grad + ctx.grad_input_quantizer = grad_input_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.input_quantizer = input_quantizer + ctx.owns_input = inputmat is not inp + ctx.weight = weight ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -349,14 +430,13 @@ def forward( ctx.tp_group = tp_group ctx.tp_size = tp_size ctx.return_layernorm_output = return_layernorm_output - ctx.return_layernorm_output_gathered = ( - return_layernorm_output_gathered and ln_out_gathered - ) + ctx.return_layernorm_output_gathered = return_layernorm_output_gathered ctx.bwd_ln_sm_margin = bwd_ln_sm_margin ctx.zero_centered_gamma = zero_centered_gamma + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_name = ub_name ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization @@ -368,10 +448,15 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) + if ub_overlap_rs_fprop: + out = rs_out + elif parallel_mode == "row": + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp_shape[1:-1], out_features) @@ -389,82 +474,156 @@ def backward( ctx, *grad_outputs: Tuple[torch.Tensor, ...] ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring - if isinstance(grad_outputs[0], Float8Tensor): - ctx.fp8_meta["scaling_bwd"].scale_inv[tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_outputs[ - 0 - ]._scale_inv + + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" with torch.cuda.nvtx.range("_LayerNormLinear_backward"): - ( + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and (ctx.fp8_recipe is not None) + ): + if not ctx.fp8_recipe.float8_per_tensor_scaling(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" + ) + + saved_tensors = ctx.saved_tensors + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, + weight, + origin_weight, + bias, ln_weight, + ln_out, mu, rsigma, - weight, - weight_fp8, - main_grad, - ln_out, - ln_out_scale_inv, - ) = ctx.saved_tensors + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( ctx.fsdp_group, ctx.fsdp_shapes, mu, rsigma, - weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + weight if ctx.fp8 and ctx.quantized_weight else None, ln_out, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + + # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, + # we need to connect them into one. + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + origin_weight = ctx.weight_object + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + origin_weight.main_grad = main_grad + + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None + if ctx.ub_overlap_ag: + # Overlap grad_output all-gather with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, weight.requires_grad) - weight.main_grad = main_grad + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device + ) - if ctx.ub_overlap_rs_dgrad: - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub(ctx.ub_name + "_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, local_chunk=True) + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) + + if ctx.grad_output_quantizer is not None: + # Reduce duplicated transpose, which is performed in grad_output.update_usage + if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False) + else: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( grad_output, - grad_output_c, - grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_outputs[0], ctx.parallel_mode == "row" + ctx, + grad_outputs[0], + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, ) - - if ctx.ub_bulk_wgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not weight.requires_grad: - ctx.ub_bulk_wgrad = False - - # Column Parallel Linear - # Overlap input AG with dgrad - if ( - weight.requires_grad - and (not ctx.ub_bulk_dgrad) - and ctx.parallel_mode == "column" - and ctx.sequence_parallel - ): - ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") + + # Prepare GEMM input + # Note: Perform tensor-parallel communication if needed + ln_out_total = None + ln_out_total_work = None + if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: + quantizer = None + if ctx.fp8: + quantizer = ctx.input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + ln_out_total, ln_out_total_work = gather_along_first_dim( + ln_out, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: ln_out_total = ln_out - handle = None + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch @@ -472,220 +631,152 @@ def backward( else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - dgrad_size = list(grad_output.size()) - dgrad_size[1] = weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub(ctx.ub_name + "_wgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub(ctx.ub_name + "_dgrad") - dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - else: - dgrad = torch.empty(dgrad_size, dtype=ctx.activation_dtype, device=weight.device) - - rs_out = None - if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = weight.size(1) - rs_out = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=grad_output.device - ) - if ub_obj_dgrad.is_p2p_overlap(): - if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ctx.fp8 and ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None + # dgrad GEMM + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - out_index, meta_tensor, out_te_type, out_type = ( - None, - None, - None, - ctx.activation_dtype, - ) - if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT1 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) - - # DGRAD: Evaluated unconditionally to feed into Linear backward - _ = tex.fp8_gemm( - weight_fp8.transpose_2d(), - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ( - grad_output_c._data - if isinstance(grad_output_c, Float8Tensor) - else grad_output_c - ), - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - out_type, - get_workspace(), - out=dgrad, - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_te_type, - ) - clear_tensor_data(grad_output_c) - else: - # DGRAD: Evaluated unconditionally to feed into Linear backward - _, _, _ = tex.gemm( - weight, - grad_output, - ctx.activation_dtype, - get_workspace(), - out=dgrad, - layout="NN", - grad=True, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - ) - if ctx.ub_bulk_dgrad: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - - # Overlap dgrad-RS/AR with wgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - if not ctx.ub_bulk_dgrad and handle is not None: - handle.wait() - if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator + + dgrad, *_ = general_gemm( + weight, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, + out_dtype=ctx.activation_dtype, + use_split_accumulator=dgrad_gemm_use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # Launch tensor-parallel communication + dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) - dgrad, handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + # Compute grad weight tensor wgrad = None - if weight.requires_grad: - if ctx.fp8: - # WGRAD - extra_output_tensor = None - if ctx.ub_bulk_wgrad: - if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output - extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=dgrad.device - ) - dgrad = extra_output_tensor + if ctx.requires_wgrad: + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) + if ctx.fp8: + # FP8 GEMM on Hopper only supports TN layout so the gathered input must have + # a valid transpose. + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size) else: - dgrad = ub_obj_dgrad.get_ubuf_output(0) - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) - wgrad, _ = tex.fp8_gemm( - ln_out_total_t, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - ( - grad_output_t._data - if isinstance(grad_output_t, Float8Tensor) - else grad_output_t - ), - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_t, grad_output_t) - else: - ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( - ln_out_total, - ln_out_scale_inv, - 0, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - wgrad, _, _ = tex.gemm( - ln_out_total_c, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_c) + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + ln_out_total._create_transpose() + else: - # WGRAD - wgrad, grad_bias, _ = tex.gemm( - ln_out_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None + + if isinstance(grad_output, QuantizedTensor): + # This is a no-op if platform supports non-TN FP8 GEMM or the transpose + # already exists. + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device ) + + # wgrad GEMM + # Note: Fuse with bgrad computation if needed + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD + if ctx.fp8: + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator + ) + + wgrad, grad_bias_, *_, rs_out = general_gemm( + ln_out_total, + grad_output, + get_workspace(), + layout="NT", + grad=True, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + bias=(bias if (grad_bias is None and not ctx.fp8) else None), + out=main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=wgrad_gemm_use_split_accumulator, + accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") + + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out + else: + dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) + + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ + + # Deallocate input tensor + if not ctx.return_layernorm_output: + # TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme clear_tensor_data(ln_out_total) - if ctx.ub_bulk_wgrad: - dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output - # Column Parallel Linear - if ( - (not ctx.ub_bulk_wgrad) - and ctx.parallel_mode == "column" - and ctx.tensor_parallel - and handle is not None - ): - handle.wait() + # Don't return grad bias if not needed + if not ctx.use_bias: + grad_bias = None - # LayerNorm gradient - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out.view(inputmat.shape) - else: - dgrad = dgrad.view(inputmat.shape) + # Synchronize tensor parallel communication + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None # Residual gradient + dgrad = dgrad.view(inputmat.shape) if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + # Norm gradient dgamma = None dbeta = None + nvtx_range_push(f"{nvtx_label}.norm") if ctx.normalization == "LayerNorm": dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, @@ -696,6 +787,7 @@ def backward( ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, ) + dgrad = dgrad.reshape(inputmat.size()) elif ctx.normalization == "RMSNorm": dgrad, dgamma = tex.rmsnorm_bwd( dgrad, @@ -705,28 +797,27 @@ def backward( ctx.bwd_ln_sm_margin, ctx.zero_centered_gamma, ) + dgrad = dgrad.reshape(inputmat.size()) dbeta = None + nvtx_range_pop(f"{nvtx_label}.norm") clear_tensor_data(mu) clear_tensor_data(rsigma) - if not ctx.use_bias: - grad_bias = None - - if weight.requires_grad: + if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): - weight.grad_added_to_main_grad = True - if getattr(weight, "zero_out_wgrad", False): + if ctx.fuse_wgrad_accumulation and hasattr(origin_weight, "grad_added_to_main_grad"): + origin_weight.grad_added_to_main_grad = True + if getattr(origin_weight, "zero_out_wgrad", False): wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, + origin_weight.main_grad.shape, + dtype=origin_weight.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, + origin_weight.main_grad.shape, + dtype=origin_weight.dtype, device=torch.cuda.current_device(), requires_grad=False, ) @@ -736,26 +827,31 @@ def backward( wgrad = None if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, Float8Tensor): - _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + # if ctx.fp8 and not isinstance(weight, QuantizedTensor): + # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, wgrad, - None, # weight_fp8 grad_bias, None, # use_bias None, # eps None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta None, # fuse_wgrad_accumulation + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_output_quantizer + None, # grad_input_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -770,13 +866,16 @@ def backward( None, # bwd_ln_sm_margin None, # zero_centered_gamma None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad None, # ub_overlap_rs_dgrad - None, # ub_overlap_ag + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name - None, # fp8_output None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -887,10 +986,11 @@ def __init__( parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, ub_overlap_rs_dgrad: bool = False, + ub_bulk_wgrad: bool = False, + ub_bulk_dgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -907,13 +1007,6 @@ def __init__( self.return_layernorm_output = return_layernorm_output self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_ag = ub_overlap_ag - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag, ub_overlap_rs_dgrad]): - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name if tp_group is None: self.tp_size = tp_size @@ -939,9 +1032,49 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column-parallel overlaps + self.ub_overlap_ag_fprop = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_overlap_rs_dgrad = ( + ub_overlap_rs_dgrad and self.sequence_parallel and self.parallel_mode == "column" + ) + self.ub_bulk_wgrad = ( + ub_bulk_wgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad + and self.sequence_parallel + and self.parallel_mode == "column" + and not self.ub_overlap_rs_dgrad + ) + + # Row-parallel overlaps + self.ub_overlap_rs_fprop = ( + ub_overlap_rs and self.sequence_parallel and self.parallel_mode == "row" + ) + self.ub_overlap_ag_dgrad = ( + ub_overlap_ag and self.sequence_parallel and self.parallel_mode == "row" + ) + if any( + [ + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + ] + ): + assert ub_name is not None, "Userbuffer name [string] is not set." + self.ub_name = ub_name + self.eps = eps layer_norm_weight = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_weight", @@ -950,7 +1083,7 @@ def __init__( ) if self.normalization != "RMSNorm": layer_norm_bias = torch.nn.Parameter( - torch.empty(in_features, device=device, dtype=params_dtype) + torch.empty(self.in_features, device=device, dtype=params_dtype) ) self.register_parameter( "layer_norm_bias", layer_norm_bias, init_fn=init_method_constant(0.0) @@ -1034,7 +1167,9 @@ def __init__( # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) # Construct weight parameter self.register_parameter( @@ -1082,6 +1217,16 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif other recipes (mxfp8, etc) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1152,11 +1297,16 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch) as inp: + with self.prepare_forward( + inp, allow_non_contiguous=False # removed .contiguous from inside the layer + ) as inp: # Get concatenated weight and bias tensors unfused_weights = [getattr(self, name) for name in self.weight_names] @@ -1168,35 +1318,20 @@ def forward( ) else: unfused_weights = [w.dequantize() for w in unfused_weights] - weight_tensor = _noop_cat(unfused_weights) + + weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) else: bias_tensor = getattr(self, self.bias_names[0]) # Unused - # Initialize FP8 weights if needed - weight_fp8 = None - if self.fp8: - if isinstance(weight_tensor, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensor._transpose is not None: - weight_tensor.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_fp8 = self.get_fp8_workspace( - tensor=weight_tensor, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers(fp8_output) if torch.is_grad_enabled(): fwd_fn = _LayerNormLinear.apply @@ -1209,15 +1344,18 @@ def forward( self.layer_norm_weight, self.layer_norm_bias, weight_tensor, - weight_fp8, bias_tensor, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, self.fuse_wgrad_accumulation, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1232,13 +1370,16 @@ def forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.normalization, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_rs_dgrad, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_ag, self.ub_name, - fp8_output, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = fwd_fn(*args) @@ -1255,3 +1396,68 @@ def forward( if self.return_layernorm_output: return out, ln_out return out + + def _get_quantizers(self, fp8_output): + if not self.fp8: + return [None] * 5 + grad_input_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = False + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 64e8c9ce36..633690ba6a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1,16 +1,21 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """LayerNormMLP API""" import os import warnings -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch from torch.nn.parameter import Parameter from torch.nn import init +import transformer_engine_torch as tex + +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, _ub_communicators, @@ -20,7 +25,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ..fp8 import FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -35,6 +40,7 @@ assert_dim_for_fp8_exec, clear_tensor_data, requires_grad, + non_tn_fp8_gemm_supported, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -45,31 +51,76 @@ use_reentrant_activation_recompute, in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, - _fsdp_gather_tensors, ) -from .. import cpp_extensions as tex - -from ..constants import dist_group_type, TE_DType +from ..constants import dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor -from ._common import _apply_normalization -from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.float8_tensor import Float8Tensor +from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ._common import apply_normalization, _fix_gathered_fp8_transpose +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..cpp_extensions import ( + general_gemm, +) __all__ = ["LayerNormMLP"] -def _act_func(activation: str): - funcs = { - "gelu": (tex.gelu, tex.dgelu), - "relu": (tex.relu, tex.drelu), - "geglu": (tex.geglu, tex.dgeglu), - "reglu": (tex.reglu, tex.dreglu), - "swiglu": (tex.swiglu, tex.dswiglu), - "qgelu": (tex.qgelu, tex.dqgelu), - "srelu": (tex.srelu, tex.dsrelu), +def _get_act_func_supported_list(recipe: Optional[Recipe] = None): + if recipe is None: + # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + return { + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + } + if recipe.delayed() or recipe.mxfp8(): + # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + return { + "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + } + # no activation fusion written yet + # Per-tensor current scaling: [] + return { + "gelu": (tex.gelu, tex.dgelu, None), + "relu": (tex.relu, tex.drelu, None), + "geglu": (tex.geglu, tex.dgeglu, None), + "reglu": (tex.reglu, tex.dreglu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), + "qgelu": (tex.qgelu, tex.dqgelu, None), + "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "srelu": (tex.srelu, tex.dsrelu, None), } + + +def _act_func(activation: str, recipe: Optional[Recipe] = None): + # based on each quantization mode, we have different kernel fusion supported: + # bf16 (recipe is None): [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] + # Per-tensor current scaling: [] + funcs = _get_act_func_supported_list(recipe) if activation not in funcs: raise NotImplementedError("Activation type " + activation + " is not supported!") return funcs[activation] @@ -87,19 +138,24 @@ def forward( ln_weight: torch.Tensor, ln_bias: torch.Tensor, fc1_weight: torch.Tensor, - fc1_weight_fp8: Optional[torch.Tensor], fc1_bias: torch.Tensor, use_fc1_bias: bool, fc2_weight: torch.Tensor, - fc2_weight_fp8: Optional[torch.Tensor], fc2_bias: torch.Tensor, use_fc2_bias: bool, eps: float, is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], fuse_wgrad_accumulation: bool, + fc1_input_quantizer: Optional[Quantizer], + fc1_weight_quantizer: Optional[Quantizer], + fc2_input_quantizer: Optional[Quantizer], + fc2_weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_fc2_output_quantizer: Optional[Quantizer], + grad_fc1_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], cpu_offloading: bool, tp_group: Union[dist_group_type, None], tp_size: int, @@ -108,7 +164,7 @@ def forward( activation_dtype: torch.dtype, return_layernorm_output: bool, return_layernorm_output_gathered: bool, - bias_gelu_nvfusion: bool, + bias_gelu_fusion: bool, set_parallel_mode: bool, is_grad_enabled: bool, fwd_ln_sm_margin: int, @@ -116,26 +172,36 @@ def forward( zero_centered_gamma: bool, activation: str, normalization: str, + ub_overlap_ag: bool, + ub_overlap_rs: bool, + ub_overlap_rs_dgrad: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_overlap_rs_dgrad: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, gemm_gelu_fusion: bool, fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring + + in_features, inp_shape = ln_weight.numel(), inp.shape # Make sure input dimensions are compatible - in_features = ln_weight.numel() - inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" inputmat = inp.view((-1, in_features)) if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(fc1_weight) - assert_dim_for_fp8_exec(fc2_weight) + assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + if any([ub_overlap_ag, ub_overlap_rs]) and not ( + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" + ) - activation_func = _act_func(activation)[0] + activation_func = _act_func( + activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + )[0] + device = inp.device # Cast for native AMP inputmat = cast_if_needed(inputmat, activation_dtype) @@ -143,314 +209,269 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + # for fp8 DelayedScaling: layernorm output = FP8 + # only output of the linear is returned + # for return_layernorm_output: layernorm output = High precision, then cast to FP8 + # high precision layernorm output and output of the linear are returned + with_quantized_norm = fp8 and not return_layernorm_output + if isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): + with_quantized_norm = False + tp_world_size = get_distributed_world_size(tp_group) - if ub_overlap_ag: - if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_overlap_ag = False - if ub_overlap_ag: + ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output + ub_overlap_rs = ub_overlap_rs and is_grad_enabled + with_input_all_gather_nccl = sequence_parallel and not ub_overlap_ag + backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad + + # Configure quantizer for normalization output + if fp8 and fc1_input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if with_quantized_norm: + if with_input_all_gather_nccl: + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + if isinstance(fc1_input_quantizer, MXFP8Quantizer): + with_quantized_norm = False + else: + fc1_input_quantizer.set_usage( + rowwise=True, + columnwise=backwards_needs_fc1_input, + ) + + # Reduce duplicated transpose in `_fix_gathered_fp8_transpose` + if ( + fp8 + and FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + and ub_bulk_dgrad + ): + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + + ub_obj_lnout = None + ln_out = None + # For DelayScaling, output of normalization will be in fp8. + # For Float8CurrentScaling, we want the output of normalization in high precision, then quantize to fp8. + if ub_overlap_ag and not isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): ub_obj_lnout = get_ub("fc1_fprop") - ln_out = ub_obj_lnout.get_ubuf_output(0) - else: - ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype + ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True) + elif not with_quantized_norm: ln_out = torch.empty_like( - inputmat, dtype=ln_out_dtype, memory_format=torch.contiguous_format + inputmat, dtype=inputmat.dtype, memory_format=torch.contiguous_format, device="cuda" ) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - - ln_out, mu, rsigma = _apply_normalization( + # Apply normalization + ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, ln_weight, ln_bias, eps, - fp8 and not return_layernorm_output, - fp8_meta, + fc1_input_quantizer if with_quantized_norm else None, + inp.dtype, normalization, fwd_ln_sm_margin, zero_centered_gamma, - is_grad_enabled, ) - # Column Parallel Linear + ln_out_return = ln_out if return_layernorm_output else None + + # For Float8CurrentScalingQuantizer, layernorm/rmsnorm has not been fused with quantizer. + # So the output of normalization is in high precision, and we need to quantize it to FP8 and put in the buffer. + if ub_overlap_ag and isinstance(fc1_input_quantizer, Float8CurrentScalingQuantizer): + ub_obj_lnout = get_ub("fc1_fprop") + ln_out_local = ln_out + ln_out = ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True) + fc1_input_quantizer.quantize(ln_out_local, out=ln_out) + + # Prepare GEMM input + # Note: Cast to expected dtype and perform tensor-parallel communication ln_out_gathered = False - ub_algo_ag = None - if ub_overlap_ag: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) - ln_out = torch.empty_like(ln_out) - if ub_obj_lnout.is_atomic_gemm(): - ub_algo_ag = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo_ag = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - elif set_parallel_mode and sequence_parallel: + with_quantized_all_gather = fp8 + if with_input_all_gather_nccl: + if return_layernorm_output and return_layernorm_output_gathered: + with_quantized_all_gather = False + if fp8: + fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) + # ln_out in this has two possibilities: + # 1. in FP8 low precision, the cast was done by fusing quantization into layernorm kernel + # 2. in high precision, then we need to cast it and then gather in FP8 + # the output ln_out_total will be in FP8, and it's a full tensor + ln_out_total, _ = gather_along_first_dim( + ln_out, + tp_group, + quantizer=(fc1_input_quantizer if with_quantized_all_gather else None), + ) ln_out_gathered = True - ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) else: - ln_out_total = ln_out - - # If residual connection is after LN, we need `ln_out` - # tensor in higher precision, this comes at the cost - # of an extra fp8 cast. - if return_layernorm_output: - ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out - if fp8: - if ub_overlap_ag: - ln_out = tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - else: - ln_out_total = tex.cast_to_fp8( - ln_out_total, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) - if ln_out_gathered: - rank = torch.distributed.get_rank(tp_group) - slice_start = rank * ln_out.size(0) - slice_end = (rank + 1) * ln_out.size(0) - ln_out = ln_out_total[slice_start:slice_end, ...] - else: - ln_out = ln_out_total + with_quantized_all_gather = False + if ub_overlap_ag: + ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer, False) + else: + if fp8: + if not isinstance(ln_out, QuantizedTensor): + fc1_input_quantizer.set_usage( + rowwise=True, columnwise=backwards_needs_fc1_input + ) + ln_out = fc1_input_quantizer(ln_out) + elif backwards_needs_fc1_input: + ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) + # here ln_out is in FP8 low precision, the cast was either done by fc1_input_quantizer + # or fused into the layernorm kernel + # ln_out_total represents the full fp8 tensor, in this case, it's the same as ln_out + ln_out_total = ln_out - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - fc1_bias = cast_if_needed(fc1_bias, bias_dtype) if use_fc1_bias else fc1_bias - fc2_bias = cast_if_needed(fc2_bias, bias_dtype) if use_fc2_bias else fc2_bias - - # Use FP8 weights - if fc1_weight_fp8 is None: - fc1_weight_fp8 = fc1_weight - if fc2_weight_fp8 is None: - fc2_weight_fp8 = fc2_weight - - assert isinstance(fc1_weight_fp8, Float8Tensor) - assert isinstance(fc2_weight_fp8, Float8Tensor) - - # Perform FP8 GEMM - fp8_gemm_args = [ - fc1_weight_fp8._data, - fc1_weight_fp8._scale_inv, - 0, - fc1_weight_fp8._fp8_dtype, - ln_out_total, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - activation_dtype, - get_workspace(), - ] - fp8_gemm_kwargs = { - "bias": fc1_bias, - "use_bias": use_fc1_bias, - "use_split_accumulator": _2X_ACC_FPROP, - "ub_algo": ub_algo_ag if ub_overlap_ag else None, - "ub": ub_obj_lnout if ub_overlap_ag else None, - "extra_output_tensor": ln_out if ub_overlap_ag else None, - } - if gemm_gelu_fusion: - fp8_gemm_args[8] = torch.uint8 # out_dtype - fp8_gemm_kwargs.update( - { - "gelu": True, - "out_index": tex.FP8FwdTensors.GEMM2_INPUT, - "fp8_meta_tensor": fp8_meta["scaling_fwd"], - "D_dtype": fp8_dtype_forward, - } + # Cast weights to expected dtype + fc1_weight_final = fc1_weight + fc2_weight_final = fc2_weight + if not fp8: + fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) + fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) + else: + # If weights are not quantized, we call get_weight_workspace, + # which handles weight caching etc. + if not isinstance(fc1_weight, QuantizedTensor): + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + fc1_weight_final = module.get_weight_workspace( + tensor=fc1_weight, + quantizer=fc1_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc1_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, ) - fp8_gemm_out = tex.fp8_gemm(*fp8_gemm_args, **fp8_gemm_kwargs) - if not is_grad_enabled: - clear_tensor_data(ln_out_total) - - # Perform activation - if gemm_gelu_fusion: - gelu_out, fc1_out = fp8_gemm_out - else: - fc1_out, _ = fp8_gemm_out - gelu_out = activation_func( - fc1_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, + if not isinstance(fc2_weight, QuantizedTensor): + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True) + fc2_weight_final = module.get_weight_workspace( + tensor=fc2_weight, + quantizer=fc2_weight_quantizer, + cache_name=(None if is_first_microbatch is None else "fc2_weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, ) - if not is_grad_enabled: - clear_tensor_data(fc1_out) - - fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( - None, - None, - None, - activation_dtype, - ) - rs_out = None - ub_algo_rs = None - if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") - fc2_out = ub_obj_fc2out.get_ubuf_output(1) - dim_size = list(gelu_out.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc2_weight_fp8.size(0) - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - if ub_obj_fc2out.is_p2p_overlap(): - if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_fc2out.is_atomic_gemm(): - ub_algo_rs = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - - if ub_obj_fc2out.is_fp8_ubuf(): - fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT - fc2_meta_tensor = fp8_meta["scaling_fwd"] - fc2_te_type = fp8_dtype_forward - out_type = torch.uint8 - ub_obj_fc2out.set_ubuf_scale_inv(fc2_meta_tensor.scale_inv[fc2_out_index]) + # Cast biases to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + if fc1_bias is not None: + fc1_bias = cast_if_needed(fc1_bias, bias_dtype) + if fc2_bias is not None: + fc2_bias = cast_if_needed(fc2_bias, bias_dtype) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if fc1_input_quantizer is not None: + fc1_input_quantizer.calibrate(ln_out_total) + if fc1_weight_quantizer is not None: + fc1_weight_quantizer.calibrate(fc1_weight) + + # FC1 GEMM + + # There are 2 fussions possible: + # - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion, + # - bias_gelu_fusion - only for full precision. + # If both gemm_gelu_fusion and bias_gelu_fusion are enabled, only bias_gelu_fusion will be performer + if activation != "gelu": + gemm_gelu_fusion = bias_gelu_fusion = False + else: + if fp8: + assert not bias_gelu_fusion, "Bias gelu fusion is supported only for full precision" else: - dim_size = list(gelu_out.size()) - dim_size[1] = fc2_weight_fp8.size(0) - fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - - _ = tex.fp8_gemm( - fc2_weight_fp8._data, - fc2_weight_fp8._scale_inv, - 0, - fc2_weight_fp8._fp8_dtype, - gelu_out, - fp8_meta["scaling_fwd"].scale_inv, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - out_type, - get_workspace(), - bias=fc2_bias, - use_bias=use_fc2_bias, - use_split_accumulator=_2X_ACC_FPROP, - out=fc2_out, - ub_algo=ub_algo_rs if ub_overlap_rs else None, - ub=ub_obj_fc2out if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=fc2_out_index, - fp8_meta_tensor=fc2_meta_tensor, - D_dtype=fc2_te_type, - ) - if not is_grad_enabled: - clear_tensor_data(gelu_out) + gemm_gelu_fusion = True + if gemm_gelu_fusion and bias_gelu_fusion: + gemm_gelu_fusion = False + + fc1_outputs = general_gemm( + fc1_weight_final, + ln_out_total, + get_workspace(), + quantization_params=( + fc2_input_quantizer if gemm_gelu_fusion else None # fused gelu output is in fp8 + ), + out_dtype=activation_dtype, + bias=( + fc1_bias if not bias_gelu_fusion else None + ), # otherwise bias is added later (fused with gelu) + gelu=gemm_gelu_fusion, + accumulate=_2X_ACC_FPROP, + ub=ub_obj_lnout, + ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, + ) + if not is_grad_enabled and (ln_out_total is not ln_out_return): + clear_tensor_data(ln_out_total) + + # ACTIVATION - sometimes activation is fused with the GEMM above. + + fc1_out_without_bias = None + + if bias_gelu_fusion: + fc1_out = None + fc1_out_without_bias, *_ = fc1_outputs + act_out = bias_gelu_fused(fc1_out_without_bias, fc1_bias) + elif gemm_gelu_fusion: + act_out, _, fc1_out, _ = fc1_outputs else: - # Cast for native AMP - fc1_weight = cast_if_needed(fc1_weight, activation_dtype) - fc2_weight = cast_if_needed(fc2_weight, activation_dtype) - fc1_bias = cast_if_needed(fc1_bias, activation_dtype) if use_fc1_bias else fc1_bias - fc2_bias = cast_if_needed(fc2_bias, activation_dtype) if use_fc2_bias else fc2_bias - - if fp8_calibration: - # amax of fc1 input - amin, amax = ln_out_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of fc1 weight - amin, amax = fc1_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - fc1_outputs = tex.gemm( - fc1_weight, - ln_out_total, - activation_dtype, - get_workspace(), - bias=fc1_bias, - use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, - gelu=not bias_gelu_nvfusion and (activation == "gelu"), - ub_algo=tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, - ub=ub_obj_lnout if ub_overlap_ag else None, - extra_output_tensor=ln_out if ub_overlap_ag else None, - ) - if not is_grad_enabled: - clear_tensor_data(ln_out_total) + fc1_out, *_ = fc1_outputs + act_out = activation_func(fc1_out, fc2_input_quantizer) - if bias_gelu_nvfusion: - fc1_out, _, _ = fc1_outputs - gelu_out = bias_gelu_fused(fc1_out, fc1_bias) - else: - if activation == "gelu": - gelu_out, _, fc1_out = fc1_outputs - else: - fc1_out, _, _ = fc1_outputs - gelu_out = activation_func( - fc1_out, None, tex.FP8FwdTensors.GEMM2_INPUT, TE_DType[fc1_out.dtype] - ) - if not is_grad_enabled: - clear_tensor_data(fc1_out) - - if fp8_calibration: - # amax of fc2 input - amin, amax = gelu_out.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_INPUT] = torch.max( - -amin, amax - ).float() - # amax of fc2 weight - amin, amax = fc2_weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = torch.max( - -amin, amax - ).float() - - if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") - fc2_out = ub_obj_fc2out.get_ubuf_output(1) - dim_size = list(gelu_out.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc2_weight.size(0) - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - if ub_obj_fc2out.is_p2p_overlap(): - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo_rs = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - else: - dim_size = list(gelu_out.size()) - dim_size[1] = fc2_weight.size(0) - fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - _ = tex.gemm( - fc2_weight, - gelu_out, - activation_dtype, - get_workspace(), - bias=fc2_bias, - use_bias=use_fc2_bias, - out=fc2_out, - ub_algo=ub_algo_rs if ub_overlap_rs else None, - ub=ub_obj_fc2out if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - ) - if not is_grad_enabled: - clear_tensor_data(gelu_out) + if not is_grad_enabled: + clear_tensor_data(fc1_out) + + if fp8_calibration: + fc2_input_quantizer.calibrate(act_out) + fc2_weight_quantizer.calibrate(fc2_weight) + + ub_obj_fc2out = None + rs_out = None + fc2_out = None + if ub_overlap_rs: + ub_obj_fc2out = get_ub("fc2_fprop") + dim_size = list(act_out.size()) + dim_size[0] = dim_size[0] // tp_world_size + dim_size[1] = fc2_weight.size(0) + rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) + fc2_out = ub_obj_fc2out.get_buffer(output_quantizer) + else: + dim_size = list(act_out.size()) + dim_size[1] = fc2_weight.size(0) + fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device) + + # FC2 GEMM + _ = general_gemm( + fc2_weight_final, + act_out, + get_workspace(), + out_dtype=activation_dtype, + bias=fc2_bias, + quantization_params=output_quantizer, + out=fc2_out, + use_split_accumulator=_2X_ACC_FPROP, + ub=ub_obj_fc2out, + ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, + extra_output=rs_out, + ) + if not is_grad_enabled: + clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) if is_grad_enabled: if cpu_offloading: - if fp8 and fc1_weight_fp8 is not None: - fc1_weight_fp8.weight_offloading = True - if fp8 and fc2_weight_fp8 is not None: - fc2_weight_fp8.weight_offloading = True - ln_weight.weight_offloading = True - fc1_weight.weight_offloading = True - fc2_weight.weight_offloading = True - if fc1_bias is not None: - fc1_bias.weight_offloading = True - - inputmat.activation_offloading = True - if normalization == "LayerNorm": - mu.activation_offloading = True - rsigma.activation_offloading = True - ln_out.activation_offloading = True - fc1_out.activation_offloading = True - gelu_out.activation_offloading = True + if fp8 and fc1_weight_final is not None: + set_offloading_param(fc1_weight_final, "weight_offloading", True) + if fp8 and fc2_weight_final is not None: + set_offloading_param(fc2_weight_final, "weight_offloading", True) + set_offloading_param(ln_weight, "weight_offloading", True) + set_offloading_param(fc1_weight, "weight_offloading", True) + set_offloading_param(fc2_weight, "weight_offloading", True) + set_offloading_param(fc1_bias, "weight_offloading", True) + + set_offloading_param(inputmat, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(rsigma, "activation_offloading", True) + set_offloading_param(mu, "activation_offloading", True) + set_offloading_param(ln_out, "activation_offloading", True) + set_offloading_param(fc1_out, "activation_offloading", True) + set_offloading_param(fc1_out_without_bias, "activation_offloading", True) + set_offloading_param(act_out, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -461,45 +482,69 @@ def forward( mu, rsigma, ln_out, - fc1_out, - gelu_out, - fc1_weight_fp8 if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + fc1_out_without_bias if bias_gelu_fusion else fc1_out, + act_out, + fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None, ) - ctx.save_for_backward( + if not fc1_weight.requires_grad: + if not return_layernorm_output: + clear_tensor_data(ln_out) + ln_out = None + if not fc2_weight.requires_grad: + clear_tensor_data(act_out) + act_out = None + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, + ln_out.clone() if ub_overlap_ag else ln_out, # avoid saving a UB buffer + fc1_weight_final, + fc1_bias, + fc1_out, + fc1_out_without_bias, + act_out, + fc2_weight_final, + fc2_bias, mu, rsigma, - ln_out if fc1_weight.requires_grad else None, - fc1_out, - gelu_out if fc2_weight.requires_grad else None, - fc1_weight, - fc1_weight_fp8, - fc1_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, - fc2_weight, - fc2_weight_fp8, - fc2_weight.main_grad if (cpu_offloading and fuse_wgrad_accumulation) else None, - fc1_bias, - fp8_meta["scaling_fwd"].scale_inv.clone() if fp8 else None, ) + if fuse_wgrad_accumulation: + ctx.fc1_main_grad = fc1_weight.main_grad if fc1_weight.requires_grad else None + ctx.fc2_main_grad = fc2_weight.main_grad if fc2_weight.requires_grad else None + + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.grad_fc1_output_quantizer = grad_fc1_output_quantizer + ctx.grad_fc2_output_quantizer = grad_fc2_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer + ctx.fc2_input_quantizer = fc2_input_quantizer + ctx.fc1_input_quantizer = fc1_input_quantizer + + ctx.fc1_weight_requires_grad = fc1_weight.requires_grad + ctx.fc2_weight_requires_grad = fc2_weight.requires_grad + ctx.fc1_weight = fc1_weight + ctx.fc2_weight = fc2_weight + + ctx.device = device ctx.activation_dtype = activation_dtype ctx.activation = activation ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_fc1_bias = use_fc1_bias ctx.use_fc2_bias = use_fc2_bias + ctx.use_bias = ctx.use_fc1_bias ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp_shape ctx.tp_group = tp_group ctx.tp_size = tp_size - ctx.bias_gelu_nvfusion = bias_gelu_nvfusion + ctx.bias_gelu_fusion = bias_gelu_fusion ctx.return_layernorm_output = return_layernorm_output ctx.return_layernorm_output_gathered = ( return_layernorm_output_gathered and ln_out_gathered @@ -511,7 +556,10 @@ def forward( ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad ctx.ub_overlap_ag = ub_overlap_ag - ctx.requires_dgrad = inp.requires_grad + + ctx.requires_dgrad = ( + inp.requires_grad or ln_weight.requires_grad or ln_bias.requires_grad + ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( @@ -547,499 +595,380 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with torch.cuda.nvtx.range("_LayerNormMLP_backward"): - ( + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and (ctx.fp8_recipe is not None) + ): + if not ctx.fp8_recipe.float8_per_tensor_scaling(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" + ) + + saved_tensors = ctx.saved_tensors + ( # pylint: disable=unbalanced-tuple-unpacking inputmat, ln_weight, - mu, - rsigma, ln_out, - fc1_out, - gelu_out, fc1_weight, - fc1_weight_fp8, - fc1_weight_main_grad, - fc2_weight, - fc2_weight_fp8, - fc2_weight_main_grad, fc1_bias, - fwd_scale_inverses, - ) = ctx.saved_tensors - - # Gather saved autograd context tensors when running with FSDP - # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already - # shards/unshards the base weights so we don't do it ourselves - _fsdp_gather_tensors( - ctx.fsdp_group, - ctx.fsdp_shapes, + fc1_out, + fc1_out_without_bias, + act_out, + fc2_weight, + fc2_bias, mu, rsigma, - ln_out, - fc1_out, - gelu_out, - fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + fc1_weight_main_grad = ( + ctx.fc1_main_grad + if fc1_weight is not None + and ctx.fuse_wgrad_accumulation + and ctx.fc1_weight_requires_grad + else None + ) + fc2_weight_main_grad = ( + ctx.fc2_main_grad + if fc2_weight is not None + and ctx.fuse_wgrad_accumulation + and ctx.fc2_weight_requires_grad + else None ) - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - fc1_weight = Parameter(fc1_weight, fc1_weight.requires_grad) - fc2_weight = Parameter(fc2_weight, fc2_weight.requires_grad) - + # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, + # we need to connect them into one. + if ctx.fuse_wgrad_accumulation: fc1_weight.main_grad = fc1_weight_main_grad fc2_weight.main_grad = fc2_weight_main_grad - activation_func = _act_func(ctx.activation)[1] - - if ctx.ub_overlap_rs_dgrad: - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_rs_dgrad = False - if ctx.ub_bulk_dgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not fc1_weight.requires_grad: - ctx.ub_bulk_dgrad = False - if ctx.ub_bulk_dgrad: - dim_size = list(ln_out.size()) - dim_size[0] = dim_size[0] * tp_world_size - ub_obj_lnout = get_ub("fc1_dgrad") - ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - if ctx.ub_overlap_ag: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_overlap_ag = False - - ub_algo = None - if ctx.ub_overlap_ag: - dim_size = list(grad_outputs[0].size()) - dim_size[0] = dim_size[0] * tp_world_size - ctx.ub_obj_gradout = get_ub("fc2_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P + # TODO: Fix this # pylint: disable=fixme + # Gather saved autograd context tensors when running with FSDP + # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # shards/unshards the base weights so we don't do it ourselves + # _fsdp_gather_tensors( + # ctx.fsdp_group, + # ctx.fsdp_shapes, + # mu, + # rsigma, + # ln_out, + # fc1_out_without_bias if bias_gelu_nvfusion else fc1_out,, + # gelu_out, + # fc1_weight_fp8 if ctx.fp8 and not isinstance(fc1_weight, Float8Tensor) else None, + # fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, + # ) + + # No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required + ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad + ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + if ctx.grad_fc2_output_quantizer is not None: + # Reduce duplicated transpose, which is performed in grad_output.update_usage + if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): + ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=False) else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ctx.grad_fc2_output_quantizer.set_usage(rowwise=True, columnwise=True) - ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess + ub_obj_fc2_dgrad = None + if ctx.ub_overlap_ag: + ub_obj_fc2_dgrad = get_ub("fc2_dgrad") + ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, - grad_output_c, - grad_output_t, fc2_bias_grad, - ) = TransformerEngineBaseModule.grad_output_preprocess(ctx, grad_outputs[0], True) - - if ctx.ub_bulk_wgrad: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1 or not fc1_weight.requires_grad: - ctx.ub_bulk_wgrad = False - # Column Parallel Linear - # Overlap input AG with dgrad + ) = TransformerEngineBaseModule.grad_output_preprocess( + ctx, grad_outputs[0], True, ctx.grad_fc2_output_quantizer + ) + + # Prepare FC1 GEMM input + # Note: Perform tensor-parallel communication if needed + ln_out_total = None + ln_out_total_work = None if ( - fc1_weight.requires_grad - and (not ctx.ub_bulk_dgrad) - and ctx.set_parallel_mode + ctx.fc1_weight_requires_grad + and ctx.tensor_parallel and ctx.sequence_parallel + and not ctx.ub_bulk_dgrad ): - ln_out_total, handle = gather_along_first_dim(ln_out, ctx.tp_group, async_op=True) + quantizer = None + if ctx.fp8: + quantizer = ctx.fc1_input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + ln_out_total, ln_out_total_work = gather_along_first_dim( + ln_out, + ctx.tp_group, + async_op=True, + quantizer=quantizer, + ) else: ln_out_total = ln_out - handle = None + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch ) else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + # There are 5 possible fusion paths + # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, + # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize + # 3 fp8 activation+bias+quantize fusion: gemm, activation + FC1_bias + quantize + # 4 fp8 bias+quantize fusion: gemm, activation, FC1_bias + quantize + # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm + fc2_dgrad_gemm_gelu_fusion = ( + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) + ) - fc2_wgrad = None - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - - # FC2 DGRAD; Unconditional - fc2_dgrad, _ = tex.fp8_gemm( - fc2_weight_fp8.transpose_2d(), - fc2_weight_fp8._scale_inv, - 0, - fc2_weight_fp8._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, - ) - if ctx.ub_overlap_ag: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - clear_tensor_data(grad_output_c) - - # FC2 WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if fc2_weight.requires_grad: - gelu_out_t = tex.fp8_transpose(gelu_out, fp8_dtype_forward) - clear_tensor_data(gelu_out) - fc2_wgrad, _ = tex.fp8_gemm( - gelu_out_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ) - clear_tensor_data(gelu_out_t, grad_output_t) - - if ctx.activation == "gelu": - fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_dgelu_fused( - fc2_dgrad, - fc1_out, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - else: - dgelu = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - fc1_bias_grad, dgelu, dgelu_t = tex.fp8_cast_transpose_bgrad_fused( - dgelu, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - clear_tensor_data(fc1_out) - else: - if fc2_weight.requires_grad: - gelu_out_c = torch.ops.tex_ts.cast_from_fp8_ts( - gelu_out, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM2_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - clear_tensor_data(gelu_out) - fc2_wgrad, _, _ = tex.gemm( - gelu_out_c, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=False, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) - clear_tensor_data(gelu_out_c) - - if ctx.activation == "gelu": - fc1_bias_grad, dgelu_no_fp8 = bgrad_dgelu_fused( - fc2_dgrad, fc1_out, fc1_bias - ) - else: - dgelu_no_fp8 = activation_func( - fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype] - ) - fc1_bias_grad = dgelu_no_fp8.sum(dim=0) - clear_tensor_data(fc1_out) - - dgelu = tex.cast_to_fp8( - dgelu_no_fp8, - ctx.fp8_meta["scaling_bwd"], - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ) - dgelu_t = None + # FC2 DGRAD; Unconditional + gemm_output, *_ = general_gemm( + fc2_weight, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=( + ctx.grad_fc1_output_quantizer if fc2_dgrad_gemm_gelu_fusion else None + ), # high precision to activation + out_dtype=ctx.activation_dtype, + gelu=fc2_dgrad_gemm_gelu_fusion, + gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None, + use_split_accumulator=_2X_ACC_DGRAD, + ub=ub_obj_fc2_dgrad, + ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None, + ) + if fc2_dgrad_gemm_gelu_fusion: + dact = gemm_output + fc2_dgrad = None + else: + fc2_dgrad = gemm_output - out_index, meta_tensor, out_te_type, out_type = ( - None, - None, - None, - ctx.activation_dtype, - ) - fc1_dgrad_size = list(dgelu.size()) - fc1_dgrad_size[1] = fc1_weight.size(1) - # Get/alloc fc1_dgrad - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - else: - fc1_dgrad = torch.empty( - fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device - ) + # FC2 WGRAD + if ctx.fc2_weight_requires_grad: + if isinstance(act_out, QuantizedTensor): + act_out.update_usage(rowwise_usage=True, columnwise_usage=True) - # FP8 RS - if (ctx.ub_bulk_wgrad or ctx.ub_overlap_rs_dgrad) and ub_obj_dgrad.is_fp8_ubuf(): - out_index = tex.FP8BwdTensors.GRAD_INPUT2 - meta_tensor = ctx.fp8_meta["scaling_bwd"] - out_te_type = fp8_dtype_backward - out_type = torch.uint8 - ub_obj_dgrad.set_ubuf_scale_inv(meta_tensor.scale_inv[out_index]) + if isinstance(grad_output, QuantizedTensor): + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) - # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap - rs_out = None - if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(dgelu.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc1_weight_fp8.size(1) - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) - if ub_obj_dgrad.is_p2p_overlap(): - if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_dgrad.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - # FC1 DGRAD: Unconditional - _ = tex.fp8_gemm( - fc1_weight_fp8.transpose_2d(), - fc1_weight_fp8._scale_inv, - 0, - fc1_weight_fp8._fp8_dtype, - dgelu, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - out_type, - get_workspace(), - out=fc1_dgrad, - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=out_te_type, - ) - else: - # FC2 DGRAD; Unconditional - fc2_dgrad, _, _ = tex.gemm( - fc2_weight, + fc2_wgrad, fc2_bias_grad_, *_ = general_gemm( + act_out, grad_output, - ctx.activation_dtype, get_workspace(), - layout="NN", - gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == "gelu"), + out_dtype=ctx.activation_dtype, + quantization_params=None, # wgrad in high precision + layout="NT", grad=True, - gelu_input=fc1_out, - ub_algo=( - tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ctx.ub_overlap_ag else None - ), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, + bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, + accumulate=accumulate_wgrad_into_param_main_grad, + use_split_accumulator=_2X_ACC_WGRAD, + out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, ) + if fc2_bias_grad is None: + fc2_bias_grad = fc2_bias_grad_ + del fc2_bias_grad_ + clear_tensor_data(act_out) + + # bias computation + fc1_bias_grad = None + fuse_gemm_and_bias_fc1_wgrad = False + if ctx.grad_fc1_output_quantizer is not None: + ctx.grad_fc1_output_quantizer.set_usage(rowwise=True, columnwise=True) + if ctx.bias_gelu_fusion: + # Fusion: gemm, bias + gelu + assert ctx.activation == "gelu" + assert not ctx.fp8 + fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) + if ctx.grad_fc1_output_quantizer is not None: + dact = ctx.grad_fc1_output_quantizer(dact) + elif ( + _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None + and ctx.fp8 + ): + # Fusion: gemm, bias + gelu + quantize + dbias_dact_quantize_func = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[2] + fc1_bias_grad, dact = dbias_dact_quantize_func( + fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.grad_fc1_output_quantizer + ) # quantize bgrad gelu fused + else: + # Fusion: gemm + gelu, + if not fc2_dgrad_gemm_gelu_fusion: + activation_func_bwd = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[1] + dact = activation_func_bwd( + fc2_dgrad, fc1_out.to(ctx.activation_dtype), None + ) # activation in high precision - # FC2 WGRAD - if fc2_weight.requires_grad: - fc2_wgrad, fc2_bias_grad, _ = tex.gemm( - gelu_out, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_fc2_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) - clear_tensor_data(gelu_out) - - if ctx.bias_gelu_nvfusion and ctx.activation == "gelu": - fc1_bias_grad, fc2_dgrad = bgrad_dgelu_fused(fc2_dgrad, fc1_out, fc1_bias) - else: - if ctx.activation != "gelu": - fc2_dgrad = activation_func(fc2_dgrad, fc1_out, TE_DType[fc2_dgrad.dtype]) - - # For non-fp8 execution, FC1 bias gradient is fused with FC1 wgrad GEMM - # and will not be calculated in case wgrad is not required. - if not fc1_weight.requires_grad: - fc1_bias_grad = fc2_dgrad.sum(dim=0) - - # Overwrite data. Deleting the tensor does not release underlying memory. - clear_tensor_data(fc1_out) - dgelu = fc2_dgrad - - fc1_dgrad_size = list(dgelu.size()) - fc1_dgrad_size[1] = fc1_weight.size(1) - if ctx.ub_bulk_wgrad: # allocate dgrad output - ub_obj_dgrad = get_ub("fc1_wgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output - elif ctx.ub_overlap_rs_dgrad: - ub_obj_dgrad = get_ub("fc1_dgrad") - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(1) # AllGather output + if ctx.fp8: + fc1_bias_grad, dact = tex.bgrad_quantize(dact, ctx.grad_fc1_output_quantizer) else: - fc1_dgrad = torch.empty( - fc1_dgrad_size, dtype=ctx.activation_dtype, device=fc1_weight.device + fuse_gemm_and_bias_fc1_wgrad = ( + True # fc1_bias_grad is computed later, fused with wgrad gemm for the FC1 ) + # it may not be calculated in case wgrad is not required. + if fc1_bias is not None: + if not ctx.fc1_weight_requires_grad and fc1_bias.requires_grad: + fc1_bias_grad = dact.sum(dim=0) + + # Overwrite data. Deleting the tensor does not release underlying memory. + clear_tensor_data(fc1_out, fc1_out_without_bias) + + # Set UB algo and UB obj for fc1_dgrad/wgrad bulk/pipelined overlap + ub_obj_fc1_dgrad = None + ub_obj_fc1_wgrad = None + ub_type_fc1_dgrad = None + fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] + fc1_dgrad_rs_out = None + fc1_dgrad_bulk = None + if ctx.ub_overlap_rs_dgrad: + # Overlap DGRAD+RS + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.RS + fc1_dgrad_rs_out = torch.empty( + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" + ) - # Set UB algo and UB obj for fc1_dgrad bulk/pipelined overlap + else: if ctx.ub_bulk_dgrad: - ub_algo = tex.CommOverlapAlgo.BULK_OVERLAP_AG - ub_obj = ub_obj_lnout - elif ctx.ub_overlap_rs_dgrad: - dim_size = list(dgelu.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = fc1_weight.size(1) - rs_out = torch.empty(dim_size, dtype=ctx.activation_dtype, device=dgelu.device) - if ub_obj_dgrad.is_p2p_overlap(): - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - ub_obj = ub_obj_dgrad - else: - ub_algo = None - ub_obj = None - # FC1 DGRAD: Unconditional - _ = tex.gemm( - fc1_weight, - dgelu, - ctx.activation_dtype, - get_workspace(), - out=fc1_dgrad, - layout="NN", - grad=True, - ub_algo=ub_algo, - ub=ub_obj, - extra_output_tensor=rs_out if ctx.ub_overlap_rs_dgrad else None, - ) + # Overlap ln_out all-gather with DGRAD compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_type_fc1_dgrad = tex.CommOverlapType.AG + ub_obj_fc1_dgrad.copy_into_buffer( + ln_out, ctx.fc1_input_quantizer, local_chunk=True + ) + + if ctx.ub_bulk_wgrad: + # Overlap FC1 DGRAD reduce-scatter with WGRAD compute + ub_obj_fc1_wgrad = get_ub("fc1_wgrad") + fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) + + # FC1 DGRAD: Unconditional + fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( + fc1_weight, + dact, + get_workspace(), + out=fc1_dgrad_bulk, + out_dtype=ctx.activation_dtype, + layout="NN", + grad=True, + ub=ub_obj_fc1_dgrad, + ub_type=ub_type_fc1_dgrad, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) - if ctx.ub_bulk_dgrad: - ln_out_total = ub_obj_lnout.get_ubuf_output(1) # Overlap dgrad-RS/AR with wgrad - if ctx.set_parallel_mode and ctx.sequence_parallel: - if not ctx.ub_bulk_dgrad and handle is not None: - handle.wait() - if not ctx.ub_bulk_wgrad and not ctx.ub_overlap_rs_dgrad: + fc1_dgrad_work = None + if ctx.ub_overlap_rs_dgrad: + fc1_dgrad = fc1_dgrad_rs_out + elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad: + if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) - fc1_dgrad, handle = reduce_scatter_along_first_dim( - fc1_dgrad, ctx.tp_group, async_op=True + fc1_dgrad, fc1_dgrad_work = reduce_scatter_along_first_dim( + fc1_dgrad, + ctx.tp_group, + async_op=True, ) - elif ctx.set_parallel_mode and ctx.tensor_parallel: - fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + elif ctx.tensor_parallel: + fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) + # FC1 WGRAD fc1_wgrad = None - if fc1_weight.requires_grad: - if ctx.fp8: - # FC1 WGRAD - extra_output_tensor = None - if ctx.ub_bulk_wgrad: - if ub_obj_dgrad.is_fp8_ubuf(): - dim_size = list(ub_obj_dgrad.get_ubuf_output(0).size()) # RS output - extra_output_tensor = torch.empty( - dim_size, dtype=ctx.activation_dtype, device=fc1_dgrad.device - ) - fc1_dgrad = extra_output_tensor - else: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - ln_out_total_t = tex.fp8_transpose(ln_out_total, fp8_dtype_forward) - fc1_wgrad, _ = tex.fp8_gemm( - ln_out_total_t, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - dgelu_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT2, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_t, dgelu_t) - else: - ln_out_total_c = torch.ops.tex_ts.cast_from_fp8_ts( - ln_out_total, - fwd_scale_inverses, - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - TE_DType[ctx.activation_dtype], - ) - fc1_wgrad, _, _ = tex.gemm( - ln_out_total_c, - dgelu_no_fp8, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=( - tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None - ), - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, - extra_output_tensor=extra_output_tensor, - ) - clear_tensor_data(ln_out_total_c, dgelu_no_fp8) + if ctx.fc1_weight_requires_grad: + if ctx.ub_bulk_dgrad: + ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer) + if ctx.fp8: + if ln_out._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size) + elif not non_tn_fp8_gemm_supported(): + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + ln_out_total._create_transpose() + else: - # FC1 WGRAD - fc1_wgrad_outputs = tex.gemm( - ln_out_total, - dgelu, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=not ctx.bias_gelu_nvfusion, - accumulate=accumulate_wgrad_into_param_main_grad, - out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ub_algo=tex.CommOverlapAlgo.BULK_OVERLAP_RS if ctx.ub_bulk_wgrad else None, - ub=ub_obj_dgrad if ctx.ub_bulk_wgrad else None, + if ln_out_total_work is not None: + # Synchronize tensor-parallel communication + ln_out_total_work.wait() + ln_out_total_work = None + + # Make sure GEMM inputs have expected data + if isinstance(ln_out_total, QuantizedTensor): + ln_out_total.update_usage(rowwise_usage=True, columnwise_usage=True) + if isinstance(dact, QuantizedTensor): + dact.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad_rs_out = torch.empty( + fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" ) - clear_tensor_data(ln_out_total, dgelu) - if ctx.bias_gelu_nvfusion: - fc1_wgrad, _, _ = fc1_wgrad_outputs - else: - fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs - if ctx.ub_bulk_wgrad: - fc1_dgrad = ub_obj_dgrad.get_ubuf_output(0) # Reduce-scatter output + fc1_wgrad_outputs = general_gemm( + ln_out_total, + dact, + get_workspace(), + out_dtype=ctx.activation_dtype, + layout="NT", + grad=fuse_gemm_and_bias_fc1_wgrad, + bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, + accumulate=accumulate_wgrad_into_param_main_grad, + out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + ub=ub_obj_fc1_wgrad, + ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, + extra_output=fc1_dgrad_rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) - # Column Parallel Linear - if ( - (not ctx.ub_bulk_wgrad) - and ctx.set_parallel_mode - and ctx.tensor_parallel - and handle is not None - ): - handle.wait() + clear_tensor_data(ln_out_total, dact) - # LayerNorm gradient - if ctx.ub_overlap_rs_dgrad: - dgrad = rs_out.view(inputmat.shape) - else: - dgrad = fc1_dgrad.view(inputmat.shape) + if fuse_gemm_and_bias_fc1_wgrad: + fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs + else: + fc1_wgrad, *_ = fc1_wgrad_outputs + + if ctx.ub_bulk_wgrad: + if ub_obj_fc1_wgrad.is_fp8_ubuf(): + fc1_dgrad = fc1_dgrad_rs_out + else: + fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True) + + # Synchronize tensor parallel communication + if ln_out_total_work is not None: + ln_out_total_work.wait() + ln_out_total_work = None + if fc1_dgrad_work is not None: + fc1_dgrad_work.wait() + fc1_dgrad_work = None # Residual gradient + dgrad = fc1_dgrad.view(inputmat.shape) if ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) + # Norm gradient dgamma = None dbeta = None if ctx.normalization == "LayerNorm": @@ -1062,10 +991,9 @@ def backward( ctx.zero_centered_gamma, ) dbeta = None - clear_tensor_data(mu) - clear_tensor_data(rsigma) + clear_tensor_data(mu, rsigma) - if fc1_weight.requires_grad: + if ctx.fc1_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): fc1_weight.grad_added_to_main_grad = True @@ -1077,18 +1005,13 @@ def backward( requires_grad=False, ) else: - fc1_wgrad = torch.empty( - fc1_weight.main_grad.shape, - dtype=fc1_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + fc1_wgrad = None elif ctx.fuse_wgrad_accumulation: fc1_wgrad = None else: fc1_wgrad = None - if fc2_weight.requires_grad: + if ctx.fc2_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"): fc2_weight.grad_added_to_main_grad = True @@ -1100,12 +1023,7 @@ def backward( requires_grad=False, ) else: - fc2_wgrad = torch.empty( - fc2_weight.main_grad.shape, - dtype=fc2_weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, - ) + fc2_wgrad = None elif ctx.fuse_wgrad_accumulation: fc2_wgrad = None else: @@ -1114,34 +1032,37 @@ def backward( if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + # FIX THIS # Scatter Fp8 tranposed-weight buffers - if ctx.fp8: - _fsdp_scatter_tensors( - ctx.fsdp_group, - fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, - fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, - ) - + # if ctx.fp8: + # _fsdp_scatter_tensors( + # ctx.fsdp_group, + # fc1_weight_fp8 if not isinstance(fc1_weight, Float8Tensor) else None, + # fc2_weight_fp8 if not isinstance(fc2_weight, Float8Tensor) else None, + # ) return ( dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, dgamma, dbeta, fc1_wgrad, - None, # fc1_weight_fp8 - # Due to bias gelu nvfusion available in the bf16 case, fc1_bias_grad is calculated at - # different paths and this confused the linter. - fc1_bias_grad if ctx.use_fc1_bias else None, # pylint: disable=used-before-assignment + fc1_bias_grad if ctx.use_fc1_bias else None, None, # use_fc1_bias - fc2_wgrad, - None, # fc2_weight_fp8 + fc2_wgrad, # pylint: disable=possibly-used-before-assignment fc2_bias_grad if ctx.use_fc2_bias else None, None, # use_fc2_bias None, # eps None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta None, # fuse_wgrad_accumulation + None, # fc1_input_quantizer + None, # fc1_weight_quantizer + None, # fc2_input_quantizer + None, # fc2_weight_quantizer + None, # output_quantizer + None, # grad_fc2_output_quantizer + None, # grad_fc1_output_quantizer + None, # grad_input_quantizer None, # cpu_offloading None, # tp_group None, # tp_size @@ -1150,7 +1071,7 @@ def backward( None, # activation_dtype None, # return_layernorm_output None, # return_layernorm_output_gathered - None, # bias_gelu_nvfusion + None, # bias_gelu_fusion None, # set_parallel_mode None, # is_grad_enabled None, # fwd_ln_sm_margin @@ -1158,13 +1079,15 @@ def backward( None, # zero_centered_gamma None, # activation None, # normalization - None, # ub_bulk_wgrad - None, # ub_bulk_dgrad - None, # ub_overlap_rs_dgrad - None, # ub_overlap_rs None, # ub_overlap_ag + None, # ub_overlap_rs + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # gemm_gelu_fusion None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -1285,11 +1208,11 @@ def __init__( set_parallel_mode: bool = False, zero_centered_gamma: bool = False, device: Union[torch.device, str] = "cuda", - ub_bulk_wgrad: bool = False, - ub_bulk_dgrad: bool = False, - ub_overlap_rs_dgrad: bool = False, - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ) -> None: super().__init__() @@ -1308,11 +1231,7 @@ def __init__( ) self.set_parallel_mode = set_parallel_mode self.zero_centered_gamma = zero_centered_gamma - self.ub_bulk_wgrad = ub_bulk_wgrad - self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag + # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) @@ -1337,6 +1256,16 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel self.size_per_partition = divide(ffn_hidden_size, self.tp_size) + self.ub_overlap_ag = ub_overlap_ag and self.sequence_parallel + self.ub_overlap_rs = ub_overlap_rs and self.sequence_parallel + self.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad and self.sequence_parallel + self.ub_bulk_wgrad = ( + ub_bulk_wgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + ub_bulk_dgrad and self.sequence_parallel and not self.ub_overlap_rs_dgrad + ) + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -1357,7 +1286,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["reglu", "geglu", "swiglu"]: + if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1426,6 +1355,15 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + if FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif for other recipes (mxfp8, etc.) + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1484,65 +1422,37 @@ def forward( produced) """ - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, is_first_microbatch, num_gemms=2) as inp: + with self.prepare_forward(inp, num_gemms=2) as inp: + # Get quantizers + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers() # Get weight tensors fc1_weight = self.fc1_weight - fc1_bias = self.fc1_bias + fc1_bias = self.fc1_bias if self.use_bias else None fc2_weight = self.fc2_weight - fc2_bias = self.fc2_bias + fc2_bias = self.fc2_bias if self.use_bias else None if not self.fp8: if isinstance(fc1_weight, Float8Tensor): fc1_weight = fc1_weight.from_float8() if isinstance(fc2_weight, Float8Tensor): fc2_weight = fc2_weight.from_float8() - # Cast weights to FP8 if needed - fc1_weight_fp8 = None - fc2_weight_fp8 = None - if self.fp8: - update_workspace = is_first_microbatch is None or is_first_microbatch - if isinstance(fc1_weight, Float8Tensor): - if fc1_weight._transpose is not None: - fc1_weight.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - cache_name = None - if is_first_microbatch is not None: - cache_name = "fc1_weight" - fc1_weight_fp8 = self.get_fp8_workspace( - tensor=fc1_weight, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=cache_name, - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - if isinstance(fc2_weight, Float8Tensor): - if fc2_weight._transpose is not None: - fc2_weight.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - cache_name = None - if is_first_microbatch is not None: - cache_name = "fc2_weight" - fc2_weight_fp8 = self.get_fp8_workspace( - tensor=fc2_weight, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM2_WEIGHT, - cache_name=cache_name, - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - ) - # Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute(): self.bias_gelu_nvfusion = False @@ -1558,19 +1468,24 @@ def forward( self.layer_norm_weight, self.layer_norm_bias, fc1_weight, - fc1_weight_fp8, fc1_bias, self.use_bias, fc2_weight, - fc2_weight_fp8, fc2_bias, self.apply_bias and not self.gemm_bias_unfused_add, self.eps, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, self.fuse_wgrad_accumulation, + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, is_cpu_offload_enabled(), self.tp_group, self.tp_size, @@ -1579,7 +1494,7 @@ def forward( self.activation_dtype, self.return_layernorm_output, self.return_layernorm_output_gathered, - self.bias_gelu_nvfusion, + self.bias_gelu_nvfusion and not self.fp8, self.set_parallel_mode, torch.is_grad_enabled(), self.fwd_ln_sm_margin if torch.is_grad_enabled() else self.inf_ln_sm_margin, @@ -1587,13 +1502,15 @@ def forward( self.zero_centered_gamma, self.activation, self.normalization, - self.ub_bulk_wgrad, - self.ub_bulk_dgrad, - self.ub_overlap_rs_dgrad, - self.ub_overlap_rs, self.ub_overlap_ag, + self.ub_overlap_rs, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.gemm_gelu_fusion, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = fwd_fn(*args) @@ -1610,3 +1527,121 @@ def forward( if self.return_layernorm_output: return out, ln_out return out + + def _get_quantizers(self): + ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) = [None] * 8 + if self.fp8: + fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + fc1_input_quantizer.internal = False # temporary + fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + fc1_weight_quantizer.internal = True + fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] + fc2_input_quantizer.set_usage( + rowwise=True, columnwise=isinstance(fc2_input_quantizer, MXFP8Quantizer) + ) + fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] + fc2_weight_quantizer.internal = True + if torch.is_grad_enabled(): + grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ] + grad_fc2_output_quantizer.internal = True + grad_fc1_output_quantizer = self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ] + grad_fc1_output_quantizer.internal = True + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT2] + grad_input_quantizer.internal = True + + return ( + fc1_input_quantizer, + fc1_weight_quantizer, + fc2_input_quantizer, + fc2_weight_quantizer, + output_quantizer, + grad_fc1_output_quantizer, + grad_fc2_output_quantizer, + grad_input_quantizer, + ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_mlp.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # fc1_input_quantizer: set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # fc2_input_quantizer + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # fc1_weight_quantizer: also set numerical configs about weight + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # fc2_weight_quantizer + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM2_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # parallel related + if self.sequence_parallel and self.set_parallel_mode: + # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # grad_fc1_output_quantizer: also set numerical configs for grad_fc1_output_quantizer + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_INPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + if self.sequence_parallel and self.set_parallel_mode: + # grad_fc2_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_size = self.tp_size diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fed467210..f96355a678 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1,14 +1,17 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Linear API""" -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, Optional, Tuple, Union +from functools import reduce +from operator import mul as multiply_op import torch import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe from .base import ( get_workspace, get_ub, @@ -17,14 +20,17 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import _noop_cat -from ..fp8 import get_fp8_te_dtype, FP8GlobalStateManager +from ._common import noop_cat, _fix_gathered_fp8_transpose +from ..fp8 import FP8GlobalStateManager from ..utils import ( - divide, cast_if_needed, - assert_dim_for_fp8_exec, clear_tensor_data, + divide, init_method_constant, + non_tn_fp8_gemm_supported, + assert_dim_for_fp8_exec, + nvtx_range_pop, + nvtx_range_push, requires_grad, ) from ..distributed import ( @@ -33,23 +39,26 @@ allreduce, reduce_scatter_along_first_dim, gather_along_first_dim, + is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, ) from ..cpp_extensions import ( - fp8_gemm, - gemm, - fp8_cast_transpose_fused, - cast_to_fp8, + general_gemm, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..float8_tensor import Float8Tensor -from ..export import is_in_onnx_export_mode -from ..tensor import QuantizedTensor -from ..cpu_offload import is_cpu_offload_enabled +from ..tensor.quantized_tensor import ( + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase + +from ..cpu_offload import is_cpu_offload_enabled, set_offloading_param __all__ = ["Linear"] @@ -62,15 +71,17 @@ class _Linear(torch.autograd.Function): @staticmethod def forward( ctx, - weight: Union[Float8Tensor, torch.Tensor], - weight_fp8: Optional[Float8Tensor], + weight: torch.Tensor, inp: torch.Tensor, - bias: torch.Tensor, - use_bias: bool, + bias: Optional[torch.Tensor], is_first_microbatch: Union[bool, None], fp8: bool, fp8_calibration: bool, - fp8_meta: Dict[str, Any], + input_quantizer: Optional[Quantizer], + weight_quantizer: Optional[Quantizer], + output_quantizer: Optional[Quantizer], + grad_output_quantizer: Optional[Quantizer], + grad_input_quantizer: Optional[Quantizer], fuse_wgrad_accumulation: bool, cpu_offloading: bool, tp_group: Union[dist_group_type, None], @@ -80,275 +91,255 @@ def forward( activation_dtype: torch.dtype, parallel_mode: Union[str, None], is_grad_enabled: bool, - ub_overlap_rs: bool, - ub_overlap_ag: bool, + ub_overlap_rs_fprop: bool, + ub_overlap_ag_dgrad: bool, + ub_overlap_ag_fprop: bool, + ub_overlap_rs_dgrad: bool, + ub_bulk_dgrad: bool, + ub_bulk_wgrad: bool, ub_name: str, - fp8_output: bool, + fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], + module: torch.nn.Module, + skip_fp8_weight_update: bool, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - is_input_fp8 = isinstance(inp, Float8Tensor) + + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape assert inp_shape[-1] == in_features, "GEMM not possible" - inputmat = inp.view(-1, in_features) - if fp8: - assert_dim_for_fp8_exec(inputmat) - assert_dim_for_fp8_exec(weight) tp_world_size = get_distributed_world_size(tp_group) - ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs - - # Cast input to expected dtype - inputmat = cast_if_needed(inputmat, activation_dtype) - inputmat_t = None - inputmat_no_fp8 = inputmat - inputmat_scale_inv = None + backward_needs_input = is_grad_enabled and weight.requires_grad + # Prepare input tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + nvtx_range_push(f"{nvtx_label}.input_cast_comm") + inputmat = inp.view(-1, in_features) + inputmat_total = None + with_input_all_gather_nccl = ( + parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop + ) + own_quantized_input = False if fp8: - fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - if isinstance(inputmat, Float8Tensor): - inputmat_scale_inv = inputmat._scale_inv + assert_dim_for_fp8_exec(inputmat, weight) + if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + ): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" + ) + + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + if with_input_all_gather_nccl: + assert not isinstance( + inputmat, QuantizedTensor + ), "All gather of fp8 input is not supported" + input_quantizer.set_usage(rowwise=True, columnwise=False) + inputmat_total, _ = gather_along_first_dim( + inputmat, + tp_group, + quantizer=input_quantizer, + ) else: - inputmat_scale_inv = torch.empty([1], dtype=torch.float32, device=inputmat.device) if ( - not fp8_meta["recipe"].override_linear_precision.wgrad - and is_grad_enabled - and weight.requires_grad - and not sequence_parallel + FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() + and ub_bulk_dgrad ): - # FP8 input for forward, FP8 input transpose for backward wgrad - inputmat, inputmat_t = fp8_cast_transpose_fused( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, - ) + # reduce duplicated transpose in `_fix_gathered_fp8_transpose` + input_quantizer.set_usage(rowwise=True, columnwise=False) else: - # FP8 input for forward - inputmat = cast_to_fp8( - inputmat, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - scale_inv=inputmat_scale_inv, + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input, ) - - # Hack for ONNX export - # Note: ONNX models are represented as a graph of tensor - # operations, so the in-place scale-inv update doesn't fit - # very well. We work around this by making it look like - # the scale-inv tensor is initialized with a copy. - # Note: ONNX export expects FP8 scales can be represented - # with constant ops. However, copying into a buffer - # involves an expand op for array broadcasting. We work - # around this by filling the buffer instead. - if is_in_onnx_export_mode(): - inputmat_scale_inv.fill_(inputmat_scale_inv.item()) - - # Column Parallel Linear - if parallel_mode == "column" and sequence_parallel: - inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) + if not isinstance(inputmat, QuantizedTensor): + inputmat = input_quantizer(inputmat) + own_quantized_input = True + elif backward_needs_input: + inputmat.update_usage(rowwise_usage=True, columnwise_usage=True) + inputmat_total = inputmat else: - inputmat_total = inputmat - if fp8: - bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype - bias = cast_if_needed(bias, bias_dtype) if use_bias else bias - - # Use FP8 weights - if weight_fp8 is None: - weight_fp8 = weight - - assert isinstance(weight_fp8, Float8Tensor) - - if fp8_output: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( - tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_meta["scaling_fwd"], - fp8_dtype_forward, - torch.uint8, - ) + inputmat = cast_if_needed(inp, activation_dtype) + if with_input_all_gather_nccl: + inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: - proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( - None, - None, - None, - activation_dtype, - ) + inputmat_total = inputmat + nvtx_range_pop(f"{nvtx_label}.input_cast_comm") - ub_algo = None - rs_out = None - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): - if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - if ub_obj_projout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_RS - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - if ub_obj_projout.is_fp8_ubuf(): - proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT - meta_tensor = fp8_meta["scaling_fwd"] - proj_out_tetype = fp8_dtype_forward - proj_out_pttype = torch.uint8 - ub_obj_projout.set_ubuf_scale_inv(meta_tensor.scale_inv[proj_out_index]) - else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=proj_out_pttype, device=inputmat_total.device) - - _ = fp8_gemm( - weight_fp8._data, - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - ( - inputmat_total._data - if isinstance(inputmat_total, Float8Tensor) - else inputmat_total - ), - inputmat_scale_inv, - 0, - fp8_dtype_forward, - proj_out_pttype, - get_workspace(), - bias=bias, - use_bias=use_bias, - use_split_accumulator=_2X_ACC_FPROP, - out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - out_index=proj_out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=proj_out_tetype, - ) - if fp8_output: - out = Float8Tensor( - data=out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_OUTPUT, - fp8_dtype=fp8_dtype_forward, - dtype=activation_dtype, - ) + # Cast weight to expected dtype + weightmat = weight + if not fp8: + weightmat = cast_if_needed(weightmat, activation_dtype) else: - # Cast for native AMP - weight = cast_if_needed(weight, activation_dtype) - bias = cast_if_needed(bias, activation_dtype) if use_bias else bias - - if fp8_calibration: - # amax of input - amin, amax = inputmat_total.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_INPUT] = torch.max( - -amin, amax - ).float() - # amax of weight - amin, amax = weight.aminmax() - fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = torch.max( - -amin, amax - ).float() - - if ub_overlap_rs: - ub_obj_projout = get_ub(ub_name + "_fprop") - out = ub_obj_projout.get_ubuf_output(1) - dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) - dim_size[1] = out_features - rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - if ub_obj_projout.is_p2p_overlap(): - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_RS - else: - dim_size = list(inputmat_total.size()) - dim_size[1] = out_features - out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + if not isinstance(weight, QuantizedTensor): + # Configure quantizer + if weight_quantizer is not None: + columnwise_usage = is_grad_enabled and inp.requires_grad + if not columnwise_usage: + columnwise_usage = ( + is_fp8_activation_recompute_enabled() + and not in_fp8_activation_recompute_phase() + ) + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + + # FP8 cast to workspace buffer + update_workspace = is_first_microbatch is None or is_first_microbatch + weightmat = module.get_weight_workspace( + tensor=weight, + quantizer=weight_quantizer, + cache_name=(None if is_first_microbatch is None else "weight"), + update_workspace=update_workspace, + skip_update_flag=skip_fp8_weight_update, + fsdp_group=fsdp_group, + ) - _ = gemm( - weight, - inputmat_total, - activation_dtype, - get_workspace(), - bias=bias, - use_bias=use_bias, - out=out, - ub_algo=ub_algo if ub_overlap_rs else None, - ub=ub_obj_projout if ub_overlap_rs else None, - extra_output_tensor=rs_out if ub_overlap_rs else None, - ) + # Cast bias to expected dtype + bias_dtype = activation_dtype + if fp8 and activation_dtype == torch.float32: + bias_dtype = torch.bfloat16 + bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias + + # Configure output quantizer + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) + + # Calibrate quantizers if needed + if not fp8 and fp8_calibration: + if input_quantizer is not None: + input_quantizer.calibrate(inputmat_total) + if weight_quantizer is not None: + weight_quantizer.calibrate(weight) + + ub_obj = None + ub_type = None + rs_out = None + out_dtype = activation_dtype + if ub_overlap_rs_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.RS + out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features] + rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device) + + elif ub_overlap_ag_fprop: + ub_obj = get_ub(ub_name + "_fprop") + ub_type = tex.CommOverlapType.AG + if fp8: + assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." + ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True) + inputmat_total = ub_obj.get_buffer(input_quantizer) + + nvtx_range_push(f"{nvtx_label}.gemm") + fprop_gemm_use_split_accumulator = _2X_ACC_FPROP + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if hasattr(recipe, "fp8_gemm_fprop"): + fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator + + out, *_, rs_out = general_gemm( + weightmat, + inputmat_total, + get_workspace(), + quantization_params=output_quantizer, + out_dtype=out_dtype, + bias=bias, + use_split_accumulator=fprop_gemm_use_split_accumulator, + ub=ub_obj, + ub_type=ub_type, + extra_output=rs_out, + ) + nvtx_range_pop(f"{nvtx_label}.gemm") if is_grad_enabled: saved_inputmat = None - saved_inputmat_t = None - if weight.requires_grad: - if fp8 and not fp8_meta["recipe"].override_linear_precision.wgrad: - if inputmat_t is None: - saved_inputmat = inputmat - else: - saved_inputmat_t = inputmat_t - if cpu_offloading: - saved_inputmat_t.activation_offloading = True - else: - saved_inputmat = inputmat_no_fp8 - if cpu_offloading: - if fp8 and weight_fp8 is not None: - weight_fp8.weight_offloading = True - weight.weight_offloading = True + ctx.backward_input_needs_gather = ( + weight.requires_grad and parallel_mode == "column" and sequence_parallel + ) - if saved_inputmat is not None: - saved_inputmat.activation_offloading = True + if backward_needs_input: + if own_quantized_input and isinstance(inputmat, QuantizedTensor): + # For sequence parallel in vanilla FP8, rowwise data is + # to gather the input. For MXFP8, columnwise only data + # can be allgathered. + if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: + inputmat.update_usage(rowwise_usage=False) + saved_inputmat = inputmat + + if cpu_offloading: + set_offloading_param(weight, "weight_offloading", True) + set_offloading_param(weightmat, "weight_offloading", True) + if saved_inputmat is not None: + set_offloading_param(saved_inputmat, "activation_offloading", True) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, - saved_inputmat, # None if fp8 == False - saved_inputmat_t, # None if fp8 == False AND not is_grad_enabled - weight_fp8 if fp8 and not isinstance(weight, Float8Tensor) else None, + saved_inputmat, + weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_object = weight - ctx.save_for_backward( + # TODO(ksivamani): Check memory usage + tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, - saved_inputmat_t, - inputmat_scale_inv, + weightmat, weight, - weight_fp8, - weight.main_grad if cpu_offloading and fuse_wgrad_accumulation else None, + bias, ) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects ctx.activation_dtype = activation_dtype + ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fp8 = fp8 - ctx.fp8_meta = fp8_meta + ctx.input_quantizer = input_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation + if fuse_wgrad_accumulation and weight.requires_grad: + ctx.main_grad = weight.main_grad + ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch - ctx.use_bias = use_bias + ctx.use_bias = bias is not None ctx.sequence_parallel = sequence_parallel ctx.tensor_parallel = tensor_parallel ctx.inp_shape = inp_shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group - ctx.ub_overlap_ag = ub_overlap_ag + ctx.ub_overlap_ag = ub_overlap_ag_dgrad + ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad + ctx.ub_bulk_dgrad = ub_bulk_dgrad + ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad - ctx.is_input_fp8 = is_input_fp8 + ctx.requires_wgrad = weight.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False + ctx.owns_input = saved_inputmat is not inp if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -356,83 +347,163 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module # Row Parallel Linear - if ub_overlap_rs: + if ub_overlap_rs_fprop: out = rs_out - elif parallel_mode == "row" and sequence_parallel: - out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif parallel_mode == "row" and tensor_parallel: - out, _ = allreduce(out, tp_group) - - # [*, in_features] -> [*, out_features] except first dimension changes for SP - return out.view(-1, *inp_shape[1:-1], out_features) + elif parallel_mode == "row": + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") + if sequence_parallel: + out, _ = reduce_scatter_along_first_dim(out, tp_group) + elif tensor_parallel: + out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") + + out = out.view(-1, *inp_shape[1:-1], out_features) + return out @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring - if isinstance(grad_output, Float8Tensor): - ctx.fp8_meta["scaling_bwd"].scale_inv[ - tex.FP8BwdTensors.GRAD_OUTPUT1 - ] = grad_output._scale_inv + + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" with torch.cuda.nvtx.range("_Linear_backward"): - ( - inputmat, - inputmat_t, - inputmat_scale_inv, - weight, - weight_fp8, - main_grad, - ) = ctx.saved_tensors + if ( + ctx.fp8 + and any( + [ + ctx.ub_overlap_ag, + ctx.ub_overlap_rs_dgrad, + ctx.ub_bulk_dgrad, + ctx.ub_bulk_wgrad, + ] + ) + and (ctx.fp8_recipe is not None) + ): + if not ctx.fp8_recipe.float8_per_tensor_scaling(): + raise NotImplementedError( + "Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" + " current scaling" + ) + + saved_tensors = ctx.saved_tensors + inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking + restore_from_saved(ctx.tensor_objects, saved_tensors) + ) + # Delete the references to tensor objects once they've been consumed + # by the `restore_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ( + ctx.main_grad + if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad + else None + ) + + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + weight = ctx.weight_object + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + weight.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( ctx.fsdp_group, ctx.fsdp_shapes, inputmat, - inputmat_t, - weight_fp8 if ctx.fp8 and not isinstance(weight, Float8Tensor) else None, + weight_fp8, ) - - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - weight = torch.nn.Parameter(weight, weight.requires_grad) - weight.main_grad = main_grad - - tp_world_size = get_distributed_world_size(ctx.tp_group) - ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag - ub_algo = None + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + + ctx.ub_obj_gradout = None + ub_obj_dgrad = None + ub_obj_wgrad = None + ub_type_dgrad = None + ub_type_wgrad = None + dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] + rs_out = None + dgrad_bulk = None if ctx.ub_overlap_ag: - dim_size = list(grad_output.size()) - dim_size[0] = dim_size[0] * tp_world_size + # Overlap grad_output all-gather with dgrad compute ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") - if ctx.ub_obj_gradout.is_atomic_gemm(): - ub_algo = tex.CommOverlapAlgo.ATOMIC_GEMM_AG_P2P - else: - ub_algo = tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + elif ctx.ub_overlap_rs_dgrad: + # Overlap dgrad reduce-scatter with dgrad compute + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.RS + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device + ) + + else: + if ctx.ub_bulk_dgrad: + # Overlap inputmat all-gather with dgrad compute + # NOTE: Copying into communication buffer will always prefer rowwise data, + # and will copy columnwise data if rowwise does not exist. In that case, + # the all-gather will apply to the leading dimension of the transpose, + # which then needs to be interleaved correctly before WGRAD. + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ub_obj_dgrad = ctx.ub_obj_gradout + ub_type_dgrad = tex.CommOverlapType.AG + ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True) + + if ctx.ub_bulk_wgrad: + # Overlap dgrad reduce-scatter with wgrad compute + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_type_wgrad = tex.CommOverlapType.RS + ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer) + dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) + + # Prepare grad output tensor + # Note: Cast to expected dtype and perform tensor-parallel communication + if ctx.grad_output_quantizer is not None: + # Reduce duplicated transpose, which is performed in grad_output.update_usage + if ctx.ub_overlap_ag and ctx.fp8_recipe.float8_per_tensor_scaling(): + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=False) + else: + ctx.grad_output_quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( grad_output, - grad_output_c, - grad_output_t, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, grad_output, ctx.parallel_mode == "row" + ctx, + grad_output, + ctx.parallel_mode == "row", + ctx.grad_output_quantizer, ) + nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") - # Column Parallel Linear - # Overlap input AG with dgrad + # Prepare input tensor + # Note: Perform tensor-parallel communication if needed inputmat_total = None - inputmat_t_total = None - handle = None - if weight.requires_grad and ctx.parallel_mode == "column" and ctx.sequence_parallel: - inputmat_total, handle = gather_along_first_dim( - inputmat, ctx.tp_group, async_op=ctx.requires_dgrad + inputmat_total_work = None + if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: + quantizer = None + if ctx.fp8: + quantizer = ctx.input_quantizer + quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") + inputmat_total, inputmat_total_work = gather_along_first_dim( + inputmat, + ctx.tp_group, + async_op=True, + quantizer=quantizer, ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: inputmat_total = inputmat - inputmat_t_total = inputmat_t + # Check whether to output wgrad GEMM directly into main grad if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch @@ -440,154 +511,154 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation - if ctx.fp8: - fp8_dtype_forward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True) - fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) - + # Compute grad input tensor + dgrad = None + dgrad_work = None if ctx.requires_dgrad: + + # Update quantizer + if ctx.grad_input_quantizer is not None: + ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + + # dgrad GEMM + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD if ctx.fp8: - if ctx.is_input_fp8: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - tex.FP8BwdTensors.GRAD_INPUT1, - ctx.fp8_meta["scaling_bwd"], - fp8_dtype_backward, - torch.uint8, - ) - else: - out_index, meta_tensor, output_te_dtype, output_dtype = ( - None, - None, - None, - ctx.activation_dtype, + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_dgrad"): + dgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_dgrad.use_split_accumulator ) - dgrad, _ = fp8_gemm( - weight_fp8.transpose_2d(), - weight_fp8._scale_inv, - 0, - weight_fp8._fp8_dtype, - grad_output_c, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - output_dtype, - get_workspace(), - use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo if ctx.ub_overlap_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, - out_index=out_index, - fp8_meta_tensor=meta_tensor, - D_dtype=output_te_dtype, - ) - if output_dtype == torch.uint8: - dgrad = Float8Tensor( - data=dgrad, - fp8_meta=ctx.fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=tex.FP8BwdTensors.GRAD_INPUT1, - fp8_dtype=fp8_dtype_backward, - dtype=ctx.activation_dtype, + + dgrad, *_, rs_out = general_gemm( + weight_fp8, + grad_output, + get_workspace(), + layout="NN", + grad=True, + quantization_params=ctx.grad_input_quantizer, + out=dgrad_bulk, + out_dtype=ctx.activation_dtype, + use_split_accumulator=dgrad_gemm_use_split_accumulator, + ub=ub_obj_dgrad, + ub_type=ub_type_dgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_dgrad, + ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") + + # Launch tensor-parallel communication + if ctx.ub_overlap_rs_dgrad: + dgrad = rs_out + elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") + if ctx.sequence_parallel: + dgrad, dgrad_work = reduce_scatter_along_first_dim( + dgrad, + ctx.tp_group, + async_op=True, ) - else: - dgrad, _, _ = gemm( - weight, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NN", - grad=True, - ub_algo=( - tex.CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P - if ctx.ub_overlap_ag - else None - ), - ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, - ) + else: + dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") + + # Compute grad weight tensor + wgrad = None + if ctx.requires_wgrad: + if ctx.ub_bulk_dgrad: + inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) + if ctx.fp8: + if inputmat._data is None: + # All-gather executed on columnwise data and result is in rowwise data, + # so we need to fix the interleaving before WGRAD. + inputmat_total = _fix_gathered_fp8_transpose( + inputmat_total, ctx.tp_size + ) + elif not non_tn_fp8_gemm_supported(): + # FP8 GEMM on Hopper only supports TN layout so the gathered input must + # have a valid transpose. + inputmat_total._create_transpose() - # Overlap dgrad-RS/AR with wgrad - if ctx.parallel_mode == "column" and ctx.sequence_parallel: - if handle is not None: - handle.wait() - dgrad, handle = reduce_scatter_along_first_dim( - dgrad, ctx.tp_group, async_op=True + else: + if inputmat_total_work is not None: + # Synchronize tensor-parallel communication + inputmat_total_work.wait() + inputmat_total_work = None + + if isinstance(grad_output, QuantizedTensor): + # This is a no-op if platform supports non-TN FP8 GEMM or the transpose + # already exists. + grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) + + if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + rs_out = torch.empty( + dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device ) - elif ctx.parallel_mode == "column" and ctx.tensor_parallel: - dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) - wgrad = None - if weight.requires_grad: + # wgrad GEMM + # Note: Fuse with bgrad computation if needed + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") + wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD if ctx.fp8: - # WGRAD - if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if ctx.ub_overlap_ag: - if isinstance(grad_output_c, Float8Tensor): - grad_output_t = grad_output_c.transpose_2d() - else: - grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) - if inputmat_t_total is None: - if isinstance(inputmat_total, Float8Tensor): - inputmat_t_total = inputmat_total.transpose_2d() - else: - inputmat_t_total = tex.fp8_transpose( - inputmat_total, fp8_dtype_backward - ) - wgrad, _ = fp8_gemm( - ( - inputmat_t_total._data - if isinstance(inputmat_t_total, Float8Tensor) - else inputmat_t_total - ), - inputmat_scale_inv, - 0, - fp8_dtype_forward, - grad_output_t, - ctx.fp8_meta["scaling_bwd"].scale_inv, - tex.FP8BwdTensors.GRAD_OUTPUT1, - fp8_dtype_backward, - ctx.activation_dtype, - get_workspace(), - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - use_split_accumulator=_2X_ACC_WGRAD, + recipe = ctx.fp8_recipe + if hasattr(recipe, "fp8_gemm_wgrad"): + wgrad_gemm_use_split_accumulator = ( + recipe.fp8_gemm_wgrad.use_split_accumulator ) + + wgrad, grad_bias_, _, rs_out = general_gemm( + inputmat_total, + grad_output, + get_workspace(), + layout="NT", + grad=True, + out_dtype=( + main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + ), + bias=(bias if (grad_bias is None and not ctx.fp8) else None), + out=main_grad if ctx.fuse_wgrad_accumulation else None, + use_split_accumulator=wgrad_gemm_use_split_accumulator, + accumulate=accumulate_wgrad_into_param_main_grad, + ub=ub_obj_wgrad, + ub_type=ub_type_wgrad, + extra_output=rs_out, + bulk_overlap=ctx.ub_bulk_wgrad, + ) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") + + if ctx.ub_bulk_wgrad: + if ub_obj_wgrad.is_fp8_ubuf(): + dgrad = rs_out else: - wgrad, _, _ = gemm( - inputmat_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) - else: - # WGRAD - wgrad, grad_bias, _ = gemm( - inputmat_total, - grad_output, - ctx.activation_dtype, - get_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - accumulate=accumulate_wgrad_into_param_main_grad, - out=weight.main_grad if ctx.fuse_wgrad_accumulation else None, - ) + dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) - # Deallocate input tensor - clear_tensor_data(inputmat_total) - clear_tensor_data(inputmat_t_total) + if grad_bias is None: + grad_bias = grad_bias_ + del grad_bias_ - # Column Parallel Linear - if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None: - handle.wait() + # Deallocate input tensor + if ctx.owns_input: + clear_tensor_data(inputmat_total) + # Don't return grad bias if not needed if not ctx.use_bias: grad_bias = None - if weight.requires_grad: + # Synchronize tensor parallel communication + if inputmat_total_work is not None: + inputmat_total_work.wait() + inputmat_total_work = None + if dgrad_work is not None: + dgrad_work.wait() + dgrad_work = None + + if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(weight, "grad_added_to_main_grad"): + if ( + ctx.fuse_wgrad_accumulation + and weight is not None + and hasattr(weight, "grad_added_to_main_grad") + ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): wgrad = torch.zeros( @@ -609,22 +680,25 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], wgrad = None if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) + nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, Float8Tensor): + if ctx.fp8 and not isinstance(weight, QuantizedTensor): _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) - return ( wgrad, - None, # weight_fp8 dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, grad_bias, - None, # use_bias None, # is_first_microbatch None, # fp8 None, # fp8_calibration - None, # fp8_meta + None, # input_quantizer + None, # weight_quantizer + None, # output_quantizer + None, # grad_output_quantizer + None, # grad_input_quantizer None, # fuse_wgrad_accumulation None, # cpu_offloading None, # tp_group @@ -634,11 +708,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], None, # activation_dtype None, # parallel_mode None, # is_grad_enabled - None, # ub_overlap_rs - None, # ub_overlap_ag + None, # ub_overlap_rs_fprop + None, # ub_overlap_ag_dgrad + None, # ub_overlap_ag_fprop + None, # ub_overlap_rs_dgrad + None, # ub_bulk_dgrad + None, # ub_bulk_wgrad None, # ub_name None, # fp8_output None, # fsdp_group + None, # module + None, # skip_fp8_weight_update ) @@ -729,8 +809,11 @@ def __init__( parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, device: Union[torch.device, str] = "cuda", - ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = False, + ub_bulk_wgrad: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -742,11 +825,6 @@ def __init__( self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias - self.ub_overlap_rs = ub_overlap_rs - self.ub_overlap_ag = ub_overlap_ag - if ub_overlap_rs or ub_overlap_ag: - assert ub_name is not None, "Userbuffer name [string] is not set." - self.ub_name = ub_name self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name @@ -773,6 +851,47 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel + # Column parallel TP overlap options + self.ub_overlap_ag_fprop = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_ag + ) + self.ub_overlap_rs_dgrad = ( + self.parallel_mode == "column" and self.sequence_parallel and ub_overlap_rs_dgrad + ) + self.ub_bulk_dgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_dgrad + and not self.ub_overlap_rs_dgrad + ) + self.ub_bulk_wgrad = ( + self.parallel_mode == "column" + and self.sequence_parallel + and ub_bulk_wgrad + and not self.ub_overlap_rs_dgrad + ) + + # Row parallel TP overlap options + self.ub_overlap_rs_fprop = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_rs + ) + self.ub_overlap_ag_dgrad = ( + self.parallel_mode == "row" and self.sequence_parallel and ub_overlap_ag + ) + + if any( + [ + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, + ] + ): + assert ub_name is not None, f"Comm+GEMM overlap layer '{ub_name}' is not initialized." + self.ub_name = ub_name + # Initialize params in FP8 with_fp8_params = FP8GlobalStateManager.with_fp8_parameters() @@ -849,7 +968,9 @@ def __init__( # Check if parameters are subviews of buffers is_subview = (split_start, split_end) != (0, self.out_features) if is_subview and with_fp8_params: - raise RuntimeError("Splitting Float8Tensor into multiple params is not supported") + raise RuntimeError( + "Splitting QuantizedTensor into multiple params is not supported" + ) # Construct weight parameter self.register_parameter( @@ -889,6 +1010,16 @@ def __init__( else: self.gemm_bias_unfused_add = False + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: + """Init scales and amaxes for fwd | bwd.""" + super().set_meta_tensor(fwd, recipe) + + # customize quantizers based on each recipe & layer configs + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + self._customize_quantizers_float8_current_scaling(fwd, recipe) + # elif for other recipes (mxfp8, etc.) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -916,6 +1047,7 @@ def forward( inp: torch.Tensor, is_first_microbatch: Optional[bool] = None, fp8_output: Optional[bool] = False, + fp8_grad: Optional[bool] = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ Apply the linear transformation to the input. @@ -938,14 +1070,15 @@ def forward( first microbatch (since it is the first gradient being produced) """ - - skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + if FP8GlobalStateManager.fp8_graph_capturing(): + skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() + else: + skip_fp8_weight_update = None if skip_fp8_weight_update is not None: is_first_microbatch = False with self.prepare_forward( inp, - is_first_microbatch, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: @@ -959,36 +1092,25 @@ def forward( ) else: unfused_weights = [w.dequantize() for w in unfused_weights] - weight_tensor = _noop_cat(unfused_weights) + weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = _noop_cat([getattr(self, name) for name in self.bias_names]) + bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused - - # Initialize FP8 weights if needed - weight_fp8 = None - if self.fp8: - if isinstance(weight_tensor, Float8Tensor): - # Make sure transpose cache is valid, if present - # Note: Transpose cache may have been invalidated - # externally, e.g. by optimizer. - if weight_tensor._transpose is not None: - weight_tensor.transpose_2d( - fill_cache=True, - noop_flag=skip_fp8_weight_update, - ) - else: - # FP8 cast to workspace buffer - update_workspace = is_first_microbatch is None or is_first_microbatch - weight_fp8 = self.get_fp8_workspace( - tensor=weight_tensor, - fp8_meta_forward=True, - fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT, - cache_name=(None if is_first_microbatch is None else "weight"), - update_workspace=update_workspace, - skip_update_flag=skip_fp8_weight_update, - fsdp_group=self.fsdp_group, - ) + bias_tensor = None + + ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) = self._get_quantizers(fp8_output, fp8_grad) + + # Make sure weight tensor has correct quantizer + # Note: Quantizer might have changed if quantization + # recipe changed + if weight_quantizer is not None and isinstance(weight_tensor, QuantizedTensor): + weight_tensor._quantizer = weight_quantizer if torch.is_grad_enabled(): linear_fn = _Linear.apply @@ -998,14 +1120,16 @@ def forward( args = [None] args += ( weight_tensor, - weight_fp8, inp, - bias_tensor, - self.apply_bias and not self.gemm_bias_unfused_add, + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, is_first_microbatch, self.fp8, self.fp8_calibration, - self.fp8_meta, + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, self.fuse_wgrad_accumulation, is_cpu_offload_enabled(), self.tp_group, @@ -1015,17 +1139,100 @@ def forward( self.activation_dtype, self.parallel_mode, torch.is_grad_enabled(), - self.ub_overlap_rs, - self.ub_overlap_ag, + self.ub_overlap_rs_fprop, + self.ub_overlap_ag_dgrad, + self.ub_overlap_ag_fprop, + self.ub_overlap_rs_dgrad, + self.ub_bulk_dgrad, + self.ub_bulk_wgrad, self.ub_name, fp8_output, self.fsdp_group, + self, + skip_fp8_weight_update, ) out = linear_fn(*args) - if self.gemm_bias_unfused_add: out = out + cast_if_needed(bias_tensor, self.activation_dtype) if self.return_bias: return out, cast_if_needed(bias_tensor, self.activation_dtype) return out + + def _get_quantizers(self, fp8_output, fp8_grad): + if not self.fp8: + return [None] * 5 + grad_input_quantizer = None + grad_output_quantizer = None + output_quantizer = None + input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + input_quantizer.internal = False + weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + weight_quantizer.internal = True + if fp8_output: + output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + if torch.is_grad_enabled(): + grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + grad_output_quantizer.internal = True + if fp8_grad: + grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_output_quantizer, + grad_input_quantizer, + ) + + def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert ( + recipe.float8_current_scaling() + ), "current scaling recipe quantizer customization here" + if fwd: + # set configs about amax epsilon and power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_epsilon = recipe.fp8_quant_fwd_inp.amax_epsilon + # also set weight quantizer with same amax_epsilon & power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_WEIGHT + ].amax_epsilon = recipe.fp8_quant_fwd_weight.amax_epsilon + # paralle related + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_size = self.tp_size + else: + # set grad_output_quantizer with amax epsilon and power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon + # parallel related + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_size = self.tp_size diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index f3651ecc19..bc826edc2a 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -65,15 +65,32 @@ class RMSNorm(_RMSNormOp): def __init__( self, - normalized_shape: Union[Iterable[int], int], + normalized_shape: Union[Iterable[int], int, None] = None, eps: float = 1e-5, sequence_parallel: Optional[bool] = None, # legacy params_dtype: Optional[torch.dtype] = None, # deprecated zero_centered_gamma: bool = False, + hidden_size: Optional[int] = None, # deprecated **kwargs, ) -> None: # Handle deprecated options + if normalized_shape is None: + if hidden_size is None: + raise RuntimeError( + "Neither `normalized_shape` nor `hidden_size` (deprecated) args are provided" + ) + warnings.warn( + "`hidden_size` arg has been renamed to `normalized_shape` " + "for compatibility with `torch.nn.LayerNorm`.", + DeprecationWarning, + stacklevel=2, + ) + normalized_shape = hidden_size + elif hidden_size is not None: + raise RuntimeError( + "Both `normalized_shape` and `hidden_size` (deprecated) args are provided" + ) if params_dtype is not None: if "dtype" in kwargs: raise RuntimeError( @@ -91,6 +108,8 @@ def __init__( # Flag for sequence parallelism (custom Megatron-LM integration) self.sequence_parallel: Optional[bool] = sequence_parallel + if sequence_parallel is not None: + self.weight.sequence_parallel = sequence_parallel def reset_rms_norm_parameters(self) -> None: """Deprecated""" diff --git a/transformer_engine/pytorch/numerics_debug.py b/transformer_engine/pytorch/numerics_debug.py index bc9a5f89e0..5a73f5b61b 100644 --- a/transformer_engine/pytorch/numerics_debug.py +++ b/transformer_engine/pytorch/numerics_debug.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/__init__.py b/transformer_engine/pytorch/ops/__init__.py index f65433398e..156c33210a 100644 --- a/transformer_engine/pytorch/ops/__init__.py +++ b/transformer_engine/pytorch/ops/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index b1654add98..20e63e0e63 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,12 +10,13 @@ import torch from transformer_engine_torch import FP8TensorMeta +from .. import torch_version from ..fp8 import FP8GlobalStateManager -from ..tensor import Float8Tensor +from ..tensor.float8_tensor import Float8Tensor from ..utils import ( - canonicalize_device, # pylint: disable=unused-import - canonicalize_dtype, # pylint: disable=unused-import - devices_match, # pylint: disable=unused-import + canonicalize_device, + canonicalize_dtype, + devices_match, ) @@ -61,12 +62,9 @@ def convert_tensor( # Note: torch.Tensor.to ignores memory_format kwarg (see # https://github.com/pytorch/pytorch/issues/132020). data = data.contiguous(memory_format=memory_format) - return Float8Tensor.make_like( - tensor, - data=data, - fp8_attrs=tensor._fp8_attrs, - dtype=dtype, - ) + out = Float8Tensor.make_like(tensor, dtype=dtype) + out.data = data + return out # Convert standard PyTorch tensor tensor = tensor.to(device=device, dtype=dtype) @@ -85,46 +83,14 @@ def reshape( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor | Float8Tensor: - """Reshape tensor, keeping same data if possible - - If the input is a Float8Tensor, this function attempts to preserve - the cached transpose if available and valid. If a cached transpose - is present, it is interpreted as the transpose of a 2D matrix - where the width matches the innermost tensor dimension. - - """ - - # Make sure tensor is in expected format + """Reshape tensor, keeping same data if possible""" tensor = convert_tensor( tensor, device=device, dtype=dtype, memory_format=torch.contiguous_format, ) - - # Return immediately if tensor already has desired shape - shape = list(shape) - if len(shape) == tensor.dim(): - if sum(1 for d in shape if d == -1) > 1: - raise ValueError( - "Attempted to reshape tensor with " - f"shape={tuple(tensor.size())} into shape={tuple(shape)}" - ) - if all(d1 == d2 for d1, d2 in zip(shape, tensor.size()) if d1 != -1): - return tensor - - # Reshape FP8 tensor - # Note: Preserve cached transpose if possible - if is_float8_tensor(tensor): - out = Float8Tensor.make_like( - tensor, - data=tensor._data.view(shape), - fp8_attrs=tensor._fp8_attrs, - ) - return out - - # Reshape standard PyTorch tensor - return tensor.view(shape) + return tensor.reshape(*shape) def maybe_autocast_dtype( @@ -133,8 +99,13 @@ def maybe_autocast_dtype( default_dtype: Optional[torch.dtype] = None, ) -> torch.dtype: """Get autocast dtype if enabled""" - if torch.is_autocast_enabled(device_type): - return torch.get_autocast_dtype(device_type) + + if torch_version() >= (2, 4, 3): + if torch.is_autocast_enabled(device_type): + return torch.get_autocast_dtype(device_type) + else: + if torch.is_autocast_enabled(): + return torch.get_autocast_gpu_dtype() return canonicalize_dtype(default_dtype) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 3dd8f64229..ae635c956a 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -1,9 +1,10 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Single tensor operations supported by the operation fuser.""" +from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU from .add_in_place import AddInPlace from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py new file mode 100644 index 0000000000..45c78bea87 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -0,0 +1,281 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operations for activation functions.""" + +from __future__ import annotations +import abc +from typing import Optional + +import torch + +import transformer_engine_torch as tex +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor +from ...utils import clear_tensor_data, devices_match +from ..op import BasicOperation, OperationContext +from .._common import reshape + + +class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): + r"""Apply activation function + + Activation functions are either element-wise unary functions or + variants of the gated linear unit (GLU). Recall that GLU is + computed by splitting the input tensor into chunks :math:`a` and + :math:`b` along the last dimension and computing + + .. math:: + \text{GLU}(a,b) = \sigma(a) * b + + .. warning:: + + Transformer Engine gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + @abc.abstractmethod + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + """Forward implementation + + Implementation from transformer_engine.pytorch.cpp_extensions. + + """ + + @abc.abstractmethod + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + """Backward implementation + + Implementation from transformer_engine_torch. + + """ + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op: Optional[BasicOperation] = None, + next_op: Optional[BasicOperation] = None, + ) -> torch.Tensor: + + # Compute dtype + dtype: torch.dtype + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = input_.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + raise RuntimeError(f"Unsupported dtype ({dtype})") + + # Check input tensor + x = input_ + if isinstance(x, QuantizedTensor): + x = x.dequantize() + if x.device.type != "cuda": + x = x.cuda() + if x.dtype != dtype: + x = x.to(dtype=dtype) + if not x.is_contiguous(): + x = x.contiguous() + + # Check if FP8 is enabled + fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() + if fp8_enabled and next_op is not None and next_op.num_quantizers("forward") > 0: + quantizer = next_op.get_quantizer("forward", 0) + else: + quantizer = None + + # Launch kernel + y = self._activation_forward_impl( + reshape(x, (-1, x.size(-1))), + quantizer, + ) + + # Check output tensor + if y.dim() != x.dim(): + y = y.reshape(list(x.shape[:-1]) + [-1]) + + # Save state for backward pass + ctx.save_for_backward(x.detach()) + ctx.fp8_enabled = fp8_enabled + ctx.prev_op = prev_op + + return y + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass + (x,) = ctx.saved_tensors + + # Check grad output tensor + dy = grad_output + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() + if not devices_match(dy.device, x.device) or dy.dtype != x.dtype: + dy = dy.to(device=x.device, dtype=x.dtype) + if not dy.is_contiguous(): + dy = dy.contiguous() + + # Launch kernel + dx = self._activation_backward_impl( + reshape(dy, (-1, dy.size(-1))), + reshape(x, (-1, x.size(-1))), + None, + ) + + # Check grad input tensor + if dx.size() != x.size(): + dx = dx.reshape(x.size()) + + # Clear input tensor if possible + if ctx.prev_op is not None: + clear_tensor_data(x) + + return dx, () + + +class GELU(_ActivationOperation): + r"""Gaussian Error Linear Unit + + This computes the "tanh" approximation to GELU: + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + See `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.gelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dgelu(*args, **kwargs) + + +class ReLU(_ActivationOperation): + r"""Rectified linear unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.relu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.drelu(*args, **kwargs) + + +class GEGLU(_ActivationOperation): + r"""Gaussian error gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.geglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dgeglu(*args, **kwargs) + + +class ReGLU(_ActivationOperation): + r"""Rectified gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{ReGLU}(a,b) = \max(a,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.reglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dreglu(*args, **kwargs) + + +class SwiGLU(_ActivationOperation): + r"""Swish gated linear unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GEGLU}(a,b) = \text{SiLU}(a) * b + + where + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + The Sigmoid Linear Unit (SiLU) gating function is also known as + the swish function. See + `GLU Variants Improve Transformer`__ + and `Gaussian Error Linear Units (GELUs)`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.swiglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dswiglu(*args, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_in_place.py index 041888f5d7..4ccbaef1c0 100644 --- a/transformer_engine/pytorch/ops/basic/add_in_place.py +++ b/transformer_engine/pytorch/ops/basic/add_in_place.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index b914d1dc6f..15b1f65d85 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,12 +9,9 @@ import torch -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from .._common import convert_tensor, is_float8_tensor +from ...distributed import gather_along_first_dim +from ...tensor import QuantizedTensor +from ..op import BasicOperation, OperationContext class AllGather(BasicOperation): @@ -45,47 +42,12 @@ def op_forward( prev_op: Optional[BasicOperation] = None, next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: - - # Trivial case + out: torch.Tensor if self.process_group_size == 1: - return input_ - - # Tensor dimensions - input_dims = input_.size() - if not input_dims: - raise RuntimeError( - "Attempted to all-gather a tensor " - f"with shape={list(input_dims)} " - f"over {self.process_group_size} processes" - ) - output_dims = list(input_dims) - output_dims[0] *= self.process_group_size - - # Perform all-gather - x = convert_tensor(input_, memory_format=torch.contiguous_format) - y = None - if is_float8_tensor(x): - y = Float8Tensor.make_like( - x, - data=torch.empty( - output_dims, - dtype=torch.uint8, - device=x.device, - ), - ) - torch.distributed.all_gather_into_tensor( - y._data, - x._data, - group=self.process_group, - ) + out = input_.detach() else: - y = torch.empty(output_dims, dtype=x.dtype, device=x.device) - torch.distributed.all_gather_into_tensor( - y, - x, - group=self.process_group, - ) - return y + out, _ = gather_along_first_dim(input_, self.process_group) + return out def op_backward( self, @@ -110,8 +72,8 @@ def op_backward( # Check output gradient tensor dy = grad_output - if is_float8_tensor(dy): - dy = dy.from_float8() + if isinstance(dy, QuantizedTensor): + dy = dy.dequantize() dy = dy.contiguous() # Perform reduce-scatter diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index f466ade3a3..8b4593b934 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index ad86861114..cb93eb5e6b 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,33 +12,24 @@ import torch -from transformer_engine.pytorch.cpp_extensions import ( - FP8TensorMeta, - fp8_gemm, - gemm, -) -from transformer_engine.pytorch.distributed import ( +from transformer_engine.pytorch.module.base import get_workspace +from ...cpp_extensions import general_gemm +from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - get_fp8_te_dtype, -) -from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) +from ...fp8 import FP8GlobalStateManager +from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD +from ...tensor import Quantizer, QuantizedTensor +from ...tensor.float8_tensor import Float8Quantizer +from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ..op import BasicOperation, OperationContext from .._common import ( canonicalize_device, canonicalize_dtype, - convert_tensor, devices_match, - is_float8_tensor, - reshape, ) from ...utils import clear_tensor_data @@ -110,17 +101,8 @@ def __init__( self.in_features: int = in_features self.out_features: int = out_features - # Weight tensor device - defer_param_init = False + # Weight tensor attributes device = canonicalize_device(device) - if device.type == "meta": - defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device - - # Weight tensor datatype dtype = canonicalize_dtype(dtype) if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") @@ -147,16 +129,14 @@ def __init__( out_features=out_features, ) - # Whether weight tensor is natively in FP8 - self._with_fp8_parameters: bool = FP8GlobalStateManager.with_fp8_parameters() - if self._with_fp8_parameters: - self._fp8_metas = self._make_fp8_metas() + # Whether weight tensor is natively quantized + self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() # Initialize parameters if needed weight = torch.empty( self.local_out_features, self.local_in_features, - device="meta", + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -164,7 +144,7 @@ def __init__( self.register_parameter("weight", weight) self._rng_state_tracker_function: Optional[Callable[[], CudaRNGStatesTracker]] self._rng_state_tracker_function = rng_state_tracker_function - if not defer_param_init: + if weight.device.type != "meta": self.reset_parameters() # Whether to accumulate weight gradient into main_grad @@ -273,43 +253,48 @@ def _canonicalize_tensor_parallelism( local_out_features, ) - def num_fp8_scales(self, mode: str) -> int: - if mode in ("input", "param", "grad_output"): + def num_quantizers(self, mode: str) -> int: + if mode == "forward": + return 2 + if mode == "backward": return 1 return 0 def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda" or is_float8_tensor(weight): - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Allocate buffer if needed + if isinstance(weight, QuantizedTensor): + weight = torch.empty( + weight.size(), + dtype=weight.dtype, + device=device, + ) + elif not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values - init_context = contextlib.nullcontext + init_context = contextlib.nullcontext() if self._rng_state_tracker_function is not None: - init_context = self._rng_state_tracker_function().fork - with init_context(): + init_context = self._rng_state_tracker_function().fork() + with init_context: torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5)) - # Cast to FP8 if needed - if self._with_fp8_parameters: - dummy_amax = torch.empty( - (1, 1), - dtype=torch.float32, - device=self.device, - ) # Dummy buffer to avoid overwriting amax history - weight = Float8Tensor.to_float8( - weight, - fp8_meta=self.get_fp8_meta("param"), - fp8_meta_forward=True, - fp8_meta_index=0, - amax=dummy_amax, - with_transpose_cache=torch.is_grad_enabled(), + # Quantize if needed + if self._with_quantized_weight: + quantizer = self.get_quantizer("forward", 1) + quantizer.set_usage( + rowwise=True, + columnwise=torch.is_grad_enabled(), ) + with torch.no_grad(): + weight = quantizer(weight) # Save updated parameter if not isinstance(weight, torch.nn.Parameter): @@ -318,8 +303,33 @@ def reset_parameters(self) -> None: def pre_forward(self, *args, **kwargs) -> None: super().pre_forward(*args, **kwargs) - if self.weight.device.type == "meta": + + # Initialize weights if needed + weight = self.weight + if weight.device.type == "meta": self.reset_parameters() + weight = self.weight + + # Configure quantizers + if FP8GlobalStateManager.is_fp8_enabled(): + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + + # Specify required tensor formats + is_grad_enabled = torch.is_grad_enabled() + weight_requires_grad = is_grad_enabled and weight.requires_grad + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + + # Make sure weight tensor has correct quantizer + # Note: Quantizer might have changed if quantization + # recipe changed + if isinstance(weight_quantizer, Float8Quantizer) and isinstance( + weight, Float8TensorBase + ): + weight._quantizer = weight_quantizer @staticmethod def _functional_forward( @@ -327,17 +337,17 @@ def _functional_forward( weight: torch.Tensor, *, bias: Optional[torch.Tensor] = None, - device: Optional[torch.device] = None, + device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, out: Optional[torch.Tensor] = None, accumulate_into_out: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, - with_fp8_compute: bool = False, - input_fp8_meta: Optional[dict[str, Any]] = None, - weight_fp8_meta: Optional[dict[str, Any]] = None, - output_fp8_meta: Optional[dict[str, Any]] = None, + with_quantized_compute: bool = False, + input_quantizer: Optional[Quantizer] = None, + weight_quantizer: Optional[Quantizer] = None, + output_quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Functional API for forward pass @@ -366,16 +376,14 @@ def _functional_forward( parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_fp8_compute: bool, default = `False` - Whether to perform compute in FP8 - input_fp8_meta: dict, optional - FP8 metadata for casting input tensor to FP8. Required for - FP8 compute if input is not already in FP8. - weight_fp8_meta: dict, optional - FP8 metadata for casting weight tensor to FP8. Required for - FP8 compute if weight is not already in FP8. - output_fp8_meta: dict, optional - FP8 metadata for casting output tensor to FP8 + with_quantized_compute: bool, default = `False` + Whether to perform compute with quantized data. + input_quantizer: Quantizer, optional + Builder class for quantized input tensor. + weight_quantizer: Quantizer, optional + Builder class for quantized weight tensor. + output_quantizer: Quantizer, optional + Builder class for quantized output tensor. Returns ------- @@ -390,17 +398,6 @@ def _functional_forward( """ - # Check device - if device is None: - device = weight.device if out is None else out.device - device = canonicalize_device(device) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - if out is not None and not devices_match(out.device, device): - raise ValueError( - f"Output tensor has invalid device (expected {device}, got {out.device})" - ) - # Check datatype if dtype is None: dtype = weight.dtype if out is None else out.dtype @@ -410,36 +407,88 @@ def _functional_forward( if out is not None and out.dtype != dtype: raise ValueError(f"Output tensor has invalid dtype (expected {dtype}, got {out.dtype})") - # Check input tensor dims - input_dims = tuple(input.size()) - weight_dims = tuple(weight.size()) - if len(weight_dims) != 2: - raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") - if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: - raise ValueError( - f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - - # Check output tensor dims - output_dims: list[int] - if out is None: - output_dims = list(input_dims) - output_dims[0] = -1 - output_dims[-1] = weight_dims[0] + # Check input tensor + x_local = input + x = None + x_async = None + with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + own_quantized_x_local = False + if with_quantized_compute: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage(rowwise=True) + if with_x_all_gather: + input_quantizer.set_usage(columnwise=False) + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + quantizer=input_quantizer, + ) + else: + if not isinstance(x_local, QuantizedTensor): + x_local = input_quantizer(x_local) + own_quantized_x_local = True + x = x_local else: - output_dims = list(out.size()) - if len(output_dims) == 0 or weight_dims[0] != output_dims[-1]: + if isinstance(x_local, QuantizedTensor): + x_local = x_local.dequantize() + if x_local.dtype != dtype: + x_local = x_local.to(dtype=dtype) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + else: + x = x_local + + # Check weight tensor + w = weight + w_is_quantized = isinstance(w, QuantizedTensor) + if with_quantized_compute and not w_is_quantized: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(rowwise=True) + w = weight_quantizer(w) + elif not with_quantized_compute and w_is_quantized: + w = w.dequantize() + if not with_quantized_compute and w.dtype != dtype: + w = w.to(dtype=dtype) + + # Check output tensor + y = out + if y is None: + if not with_quantized_compute: + output_quantizer = None + if tensor_parallel_mode == "row": + output_quantizer = None + elif isinstance(y, QuantizedTensor): + if not with_quantized_compute: + raise ValueError("Output tensor is quantized, but quantized compute is not enabled") + if tensor_parallel_mode == "row": raise ValueError( - f"Output tensor (shape={output_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" + "Output tensor is quantized, " + "but row tensor parallelism does not support quantized output" ) + if output_quantizer is None: + output_quantizer = getattr(y, "_quantizer", None) + if output_quantizer is None: + raise ValueError("Output tensor is quantized, but quantizer was not provided") + else: + output_quantizer = None + if isinstance(output_quantizer, MXFP8Quantizer): + raise RuntimeError( + "Attempting to generate MXFP8 output tensor, " + "but GEMM with MXFP8 output is not supported" + ) + if output_quantizer is not None: + output_quantizer.set_usage(rowwise=True, columnwise=False) # Check if accumulating into output tensor if accumulate_into_out: - if out is None: + if y is None: raise ValueError( "Attempted to accumulate into output tensor without providing output tensor" ) @@ -448,181 +497,22 @@ def _functional_forward( "Accumulating into output tensor is not supported with row tensor parallelism" ) - # Check if FP8 is enabled - if with_fp8_compute: - if input_fp8_meta is None and not is_float8_tensor(input): - raise ValueError("No FP8 metadata was provided for casting input to FP8") - if weight_fp8_meta is None and not is_float8_tensor(weight): - raise ValueError("No FP8 metadata was provided for casting weight to FP8") - else: - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - with_fp8_output = with_fp8_compute and tensor_parallel_mode != "row" - if out is None: - with_fp8_output = with_fp8_output and output_fp8_meta is not None - else: - if is_float8_tensor(out): - if not with_fp8_output: - raise ValueError( - "Output tensor is a Float8Tensor, but FP8 output is not supported" - ) - out._reset_caches() - else: - with_fp8_output = False - - # Check input tensor - x_local = reshape( - input, - (-1, input_dims[-1]), - device=device, - dtype=dtype, - ) - if with_fp8_compute and not is_float8_tensor(x_local): - fp8_dtype = get_fp8_te_dtype( - input_fp8_meta["recipe"], - fprop_tensor=True, - ) - with_transpose_cache = weight.requires_grad - if tensor_parallel_mode == "column" and sequence_parallel: - with_transpose_cache = False - x_local = Float8Tensor.to_float8( - x_local, - fp8_meta=input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=with_transpose_cache, - ) - elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.dequantize() - x = x_local + # Synchronize communication for input + _wait_async(x_async) x_async = None - if tensor_parallel_mode == "column" and sequence_parallel: - x, x_async = gather_along_first_dim( - x_local, - tensor_parallel_group, - async_op=True, - ) - - # Check weight tensor - w = convert_tensor( - weight, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if with_fp8_compute and not is_float8_tensor(w): - fp8_dtype = get_fp8_te_dtype( - weight_fp8_meta["recipe"], - fprop_tensor=True, - ) - w = Float8Tensor.to_float8( - w, - fp8_meta=weight_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) - elif not with_fp8_compute and is_float8_tensor(w): - w = w.dequantize() - - # Check bias tensor - b = None - if bias is not None: - b = convert_tensor( - bias, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - - # Construct output tensor - y = None - if out is not None: - y = reshape(out, (-1, output_dims[-1])) - elif with_fp8_output: - fp8_dtype = get_fp8_te_dtype( - output_fp8_meta["recipe"], - fprop_tensor=True, - ) - data = torch.empty( - (x.size(0), weight_dims[0]), - dtype=torch.uint8, - device=device, - ) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - y = torch.empty( - (x.size(0), weight_dims[0]), - dtype=dtype, - device=device, - ) # Perform GEMM - _wait_async(x_async) - x_async = None - if with_fp8_compute: - kwargs = { - "accumulate": accumulate_into_out, - "out": y, - "bias": b, - "use_bias": (b is not None), - } - if with_fp8_output: - if y._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = y._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty(1, 1, dtype=torch.float32, device=device) - fp8_meta.scale_inv = y._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=y._fp8_meta_forward, - ) - fp8_meta = y._fp8_meta[fp8_meta_key] - fp8_meta_index = y._fp8_meta_index - kwargs.update( - { - "out": y._data, - "out_index": fp8_meta_index, - "fp8_meta_tensor": fp8_meta, - "D_dtype": y._fp8_dtype, - } - ) - fp8_gemm( - w._data, - w._scale_inv, - 0, - w._fp8_dtype, - x._data, - x._scale_inv, - 0, - x._fp8_dtype, - y.dtype, - get_workspace(), - **kwargs, - ) - else: - gemm( - w, - x, - y.dtype, - get_workspace(), - accumulate=accumulate_into_out, - out=y, - bias=b, - use_bias=(b is not None), - ) + y, *_ = general_gemm( + w, + x, + get_workspace(), + out_dtype=dtype, + quantization_params=output_quantizer, + accumulate=accumulate_into_out, + out=y, + bias=bias, + use_split_accumulator=_2X_ACC_FPROP, + ) # Reduce tensor-parallel output if needed if tensor_parallel_mode == "row": @@ -631,23 +521,27 @@ def _functional_forward( else: torch.distributed.all_reduce(y, group=tensor_parallel_group) - # Reshape output tensor if needed - if out is None: - out = reshape(y, output_dims) + # Configure input tensor for backward pass + if own_quantized_x_local: + x_local.update_usage(rowwise_usage=False) + + # Detach input tensor if needed + # Note: PyTorch autograd produces esoteric errors if we save + # input tensor as context for backward pass. + if x_local is input: + x_local = x_local.detach() - return out, x_local, w + return y, x_local, w @staticmethod def _functional_backward( grad_output: torch.Tensor, input: Optional[torch.Tensor], # pylint: disable=redefined-builtin weight: Optional[torch.Tensor], - input_dims: Iterable[int], - weight_dims: Iterable[int], *, input_requires_grad: bool = True, weight_requires_grad: bool = True, - device: Optional[torch.device] = None, + device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, grad_weight: Optional[torch.Tensor] = None, accumulate_into_grad_weight: bool = False, @@ -656,11 +550,11 @@ def _functional_backward( tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, - with_fp8_compute: bool = False, - input_fp8_meta: Optional[dict[str, Any]] = None, - weight_fp8_meta: Optional[dict[str, Any]] = None, - grad_output_fp8_meta: Optional[dict[str, Any]] = None, - grad_input_fp8_meta: Optional[dict[str, Any]] = None, + with_quantized_compute: bool = False, + input_quantizer: Optional[Quantizer] = None, + weight_quantizer: Optional[Quantizer] = None, + grad_output_quantizer: Optional[Quantizer] = None, + grad_input_quantizer: Optional[Quantizer] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Functional API for backward pass @@ -674,10 +568,6 @@ def _functional_backward( weight: torch.Tensor, optional Weight tensor. Required to compute loss gradient w.r.t. input. - input_dims: iterable of int - Input tensor dimensions - weight_dims: iterable of int - Weight tensor dimensions input_requires_grad: bool Whether to compute loss gradient w.r.t. input tensor weight_requires_grad: bool @@ -703,21 +593,18 @@ def _functional_backward( parallelism, i.e. distributing input or output tensors along outer dimension (sequence or batch dim) when not distributing along inner dimension (embedding dim) - with_fp8_compute: bool, default = `False` - Whether to perform compute in FP8 - input_fp8_meta: dict, optional - FP8 metadata for casting input tensor to FP8. Required for - FP8 compute if input is not already in FP8. - weight_fp8_meta: dict, optional - FP8 metadata for casting weight tensor to FP8. Required for - FP8 compute if weight is not already in FP8. - grad_output_fp8_meta: dict, optional - FP8 metadata for casting loss gradient w.r.t. output - tensor to FP8. Required if output grad is not already in - FP8. - grad_input_fp8_meta: dict, optional - FP8 metadata for casting loss gradient w.r.t. input - tensor to FP8 + with_quantized_compute: bool, default = `False` + Whether to perform compute with quantized data. + input_quantizer: Quantizer, optional + Builder class for quantized input tensor. + weight_quantizer: Quantizer, optional + Builder class for quantized weight tensor. + grad_output_quantizer: Quantizer, optional + Builder class for quantized loss gradient w.r.t. output + tensor. + grad_input_quantizer: dict, optional + Builder class for quantized loss gradient w.r.t. input + tensor. Returns ------- @@ -728,13 +615,6 @@ def _functional_backward( """ - # Check device - if device is None: - device = weight.device - device = canonicalize_device(device) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - # Check datatype if dtype is None: dtype = weight.dtype @@ -742,109 +622,42 @@ def _functional_backward( if dtype not in (torch.float32, torch.float16, torch.bfloat16): raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") - # Check tensor dims - output_dims = tuple(grad_output.size()) - input_dims = tuple(input_dims) - weight_dims = tuple(weight_dims) - if len(weight_dims) != 2: - raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})") - if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]: - raise ValueError( - f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - if weight_dims[0] != output_dims[-1]: - raise ValueError( - f"Grad output tensor (shape={output_dims}) " - f"and weight tensor (shape={weight_dims}) " - "are not compatible" - ) - if grad_input is not None and tuple(grad_input.size()) != input_dims: - raise ValueError( - f"Grad input tensor (shape={tuple(grad_input.size())}) " - f"does not match expected shape ({input_dims})" - ) - - # Check grad input tensor - if not input_requires_grad: - grad_input = None - if grad_input is not None and not devices_match(grad_input.device, device): - raise ValueError( - f"Grad input tensor has invalid device (expected {device}, got {grad_input.device})" - ) - if grad_input is not None and grad_input.dtype != dtype: - raise ValueError( - f"Grad input tensor has invalid dtype (expected {dtype}, got {grad_input.dtype})" + # Check grad output tensor + dy_local = grad_output + dy = None + dy_async = None + with_dy_all_gather = tensor_parallel_mode == "row" and sequence_parallel + if with_quantized_compute: + if grad_output_quantizer is None: + raise ValueError("Missing quantizer for grad output tensor") + grad_output_quantizer.set_usage( + rowwise=input_requires_grad, + columnwise=weight_requires_grad, ) - if accumulate_into_grad_input: - if grad_input is None: - raise ValueError( - "Attempted to accumulate into grad input tensor " - "without providing grad input tensor" - ) - if tensor_parallel_mode == "column": - raise ValueError( - "Accumulating into grad input tensor " - "is not supported with column tensor parallelism" + if with_dy_all_gather: + dy, dy_async = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + quantizer=grad_output_quantizer, ) - - # Check if FP8 is enabled - if with_fp8_compute: - if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): - raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") - else: - input_fp8_meta = None - weight_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - with_fp8_grad_input = ( - with_fp8_compute and input_requires_grad and tensor_parallel_mode != "column" - ) - if grad_input is None: - with_fp8_grad_input = with_fp8_grad_input and grad_input_fp8_meta is not None + else: + if not isinstance(dy_local, QuantizedTensor): + dy_local = grad_output_quantizer(dy_local) + dy = dy_local else: - if is_float8_tensor(grad_input): - if not with_fp8_grad_input: - raise ValueError( - "Grad input tensor is a Float8Tensor, but FP8 output is not supported" - ) - grad_input._reset_caches() + if isinstance(dy_local, QuantizedTensor): + dy_local = dy_local.dequantize() + if dy_local.dtype != dtype: + dy_local = dy_local.to(dtype=dtype) + if with_dy_all_gather: + dy, dy_async = gather_along_first_dim( + dy_local, + tensor_parallel_group, + async_op=True, + ) else: - with_fp8_grad_input = False - - # Check grad output tensor - dy_async = None - dy = reshape( - grad_output, - (-1, output_dims[-1]), - device=device, - dtype=dtype, - ) - if with_fp8_compute and not is_float8_tensor(dy): - fp8_dtype = get_fp8_te_dtype( - grad_output_fp8_meta["recipe"], - fprop_tensor=False, - ) - with_transpose_cache = weight_requires_grad - if tensor_parallel_mode == "row" and sequence_parallel: - with_transpose_cache = False - dy = Float8Tensor.to_float8( - dy, - fp8_meta=grad_output_fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=with_transpose_cache, - ) - elif not with_fp8_compute and is_float8_tensor(dy): - dy = dy.dequantize() - if tensor_parallel_mode == "row" and sequence_parallel: - dy, dy_async = gather_along_first_dim( - dy, - tensor_parallel_group, - async_op=True, - ) + dy = dy_local # Check input tensor x = None @@ -852,35 +665,36 @@ def _functional_backward( if weight_requires_grad: if input is None: raise ValueError("Input tensor is required to compute weight grad") - x_local = reshape( - input, - (-1, input_dims[-1]), - device=device, - dtype=dtype, - ) - x_is_sharded = tensor_parallel_mode == "column" and sequence_parallel - if with_fp8_compute and not is_float8_tensor(x_local): - fp8_dtype = get_fp8_te_dtype( - input_fp8_meta["recipe"], - fprop_tensor=True, - ) - x_local = Float8Tensor.to_float8( - x_local, - fp8_meta=input_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=(not x_is_sharded), - ) - elif not with_fp8_compute and is_float8_tensor(x_local): - x_local = x_local.from_float8() - x = x_local - if x_is_sharded: - x, x_async = gather_along_first_dim( - x_local, - tensor_parallel_group, - async_op=True, - ) + x_local = input + with_x_all_gather = tensor_parallel_mode == "column" and sequence_parallel + if with_quantized_compute: + if input_quantizer is None: + raise ValueError("Missing quantizer for input tensor") + input_quantizer.set_usage(columnwise=True) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + quantizer=input_quantizer, + ) + else: + if not isinstance(x_local, QuantizedTensor): + x_local = input_quantizer(x_local) + x = x_local + else: + if isinstance(x_local, QuantizedTensor): + x_local = x_local.dequantize() + if x_local.dtype != dtype: + x_local = x_local.to(dtype=dtype) + if with_x_all_gather: + x, x_async = gather_along_first_dim( + x_local, + tensor_parallel_group, + async_op=True, + ) + else: + x = x_local # Compute grad input dx = None @@ -890,110 +704,80 @@ def _functional_backward( # Check weight tensor if weight is None: raise ValueError("Weight tensor is required to compute input grad") - w = convert_tensor( - weight, - device=device, - dtype=dtype, - memory_format=torch.contiguous_format, - ) - if with_fp8_compute and not is_float8_tensor(w): - fp8_dtype = get_fp8_te_dtype( - weight_fp8_meta["recipe"], - fprop_tensor=True, - ) - w = Float8Tensor.to_float8( - w, - fp8_meta=weight_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - with_transpose_cache=True, - ) - elif not with_fp8_compute and is_float8_tensor(w): + w = weight + w_is_quantized = isinstance(w, QuantizedTensor) + if with_quantized_compute and not w_is_quantized: + if weight_quantizer is None: + raise ValueError("Missing quantizer for weight tensor") + weight_quantizer.set_usage(columnwise=True) + w = weight_quantizer(w) + elif not with_quantized_compute and w_is_quantized: w = w.dequantize() + if not with_quantized_compute and w.dtype != dtype: + w = w.to(dtype=dtype) - # Construct grad input tensor - if grad_input is not None: - dx = reshape(grad_input, (-1, input_dims[-1])) - elif with_fp8_grad_input: - fp8_dtype = get_fp8_te_dtype( - grad_input_fp8_meta["recipe"], - fprop_tensor=False, - ) - data = torch.empty( - (dy.size(0), weight_dims[1]), - dtype=torch.uint8, - device=device, - ) - dx = Float8Tensor( - data=data, - fp8_meta=grad_input_fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - dx = torch.empty( - (dy.size(0), weight_dims[1]), - dtype=dtype, - device=device, - ) - - # Perform dgrad GEMM + # Synchronize tensor-parallel communication _wait_async(dy_async) dy_async = None - if with_fp8_compute: - kwargs = {"accumulate": accumulate_into_grad_input, "out": dx} - if with_fp8_grad_input: - if dx._fp8_meta is None: - # Hackily create FP8TensorMeta if needed - fp8_meta = FP8TensorMeta() - fp8_meta.scale = dx._scale_inv.reciprocal() - fp8_meta.amax_history = torch.empty( - 1, 1, dtype=torch.float32, device=device - ) - fp8_meta.scale_inv = dx._scale_inv - fp8_meta_index = 0 - else: - # Get FP8TensorMeta from Float8Tensor - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dx._fp8_meta_forward, - ) - fp8_meta = dx._fp8_meta[fp8_meta_key] - fp8_meta_index = dx._fp8_meta_index - kwargs.update( - { - "out": dx._data, - "out_index": fp8_meta_index, - "fp8_meta_tensor": fp8_meta, - "D_dtype": dx._fp8_dtype, - } + + # Check grad input tensor + dx = grad_input + if dx is None: + if not with_quantized_compute: + grad_input_quantizer = None + if tensor_parallel_mode == "column": + grad_input_quantizer = None + elif isinstance(dx, QuantizedTensor): + if not with_quantized_compute: + raise ValueError( + "Grad input tensor is quantized, but quantized compute is not enabled" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Grad input tensor is quantized, " + "but column tensor parallelism does not support quantized grad input" + ) + if grad_input_quantizer is None: + grad_input_quantizer = getattr(dx, "_quantizer", None) + if grad_input_quantizer is None: + raise ValueError( + "Grad input tensor is quantized, but quantizer was not provided" ) - fp8_gemm( - w.transpose_2d(), - w._scale_inv, - 0, - w._fp8_dtype, - dy._data, - dy._scale_inv, - 0, - dy._fp8_dtype, - dx.dtype, - get_workspace(), - **kwargs, - ) else: - gemm( - w, - dy, - dx.dtype, - get_workspace(), - accumulate=accumulate_into_grad_input, - layout="NN", - out=dx, + grad_input_quantizer = None + if isinstance(grad_input_quantizer, MXFP8Quantizer): + raise RuntimeError( + "Attempting to generate MXFP8 grad input tensor, " + "but GEMM with MXFP8 output is not supported" ) + # Check if accumulating into grad input tensor + if accumulate_into_grad_input: + if dx is None: + raise ValueError( + "Attempted to accumulate into grad input tensor " + "without providing grad input tensor" + ) + if tensor_parallel_mode == "column": + raise ValueError( + "Accumulating into grad input tensor " + "is not supported with column tensor parallelism" + ) + + # Perform dgrad GEMM + dx, *_ = general_gemm( + w, + dy, + get_workspace(), + out_dtype=dtype, + quantization_params=grad_input_quantizer, + accumulate=accumulate_into_grad_input, + layout="NN", + out=dx, + use_split_accumulator=_2X_ACC_DGRAD, + grad=True, + ) + # Reduce tensor-parallel grad input if needed if tensor_parallel_mode == "column": if sequence_parallel: @@ -1009,59 +793,46 @@ def _functional_backward( async_op=True, ) - # Perform wgrad GEMM - if not weight_requires_grad: - grad_weight = None - else: - if grad_weight is None: + # Compute grad weight + dw = None + if weight_requires_grad: + + # Synchronize tensor-parallel communication + _wait_async(x_async) + _wait_async(dy_async) + x_async = None + dy_async = None + + # Check grad input tensor + dw = grad_weight + dw_dtype = dtype + if dw is None: if accumulate_into_grad_weight: raise ValueError( - "Attempted to accumulate into grad weight buffer" - "without providing grad weight" + "Attempted to accumulate into grad weight tensor " + "without providing grad weight tensor" ) - grad_weight = torch.empty( - weight_dims, - dtype=dtype, - device=device, - memory_format=torch.contiguous_format, - ) - _wait_async(dy_async) - _wait_async(x_async) - dy_async = None - x_async = None - if with_fp8_compute: - fp8_gemm( - x.transpose_2d(), - x._scale_inv, - 0, - x._fp8_dtype, - dy.transpose_2d(), - dy._scale_inv, - 0, - dy._fp8_dtype, - grad_weight.dtype, - get_workspace(), - accumulate=accumulate_into_grad_weight, - out=grad_weight, - ) else: - gemm( - x, - dy, - x.dtype, - get_workspace(), - accumulate=accumulate_into_grad_weight, - layout="NT", - out=grad_weight, - ) + dw_dtype = dw.dtype + + # Perform wgrad GEMM + dw, *_ = general_gemm( + x, + dy, + get_workspace(), + out_dtype=dw_dtype, + accumulate=accumulate_into_grad_weight, + layout="NT", + out=dw, + use_split_accumulator=_2X_ACC_WGRAD, + grad=True, + ) # Clean up and return grads _wait_async(dy_async) _wait_async(x_async) _wait_async(dx_async) - if dx is not None and grad_input is None: - grad_input = reshape(dx, input_dims) - return grad_input, grad_weight + return dx, dw def op_forward( self, @@ -1071,21 +842,33 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: + # Check which grads are required + input_requires_grad = ctx.requires_grad and input_.requires_grad + weight_requires_grad = ctx.requires_grad and self.weight.requires_grad + # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = self.get_fp8_meta("input") - weight_fp8_meta = self.get_fp8_meta("param") - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") - grad_output_fp8_meta = self.get_fp8_meta("grad_output") - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + + # Get quantizers + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + if next_op is not None and next_op.num_quantizers("forward") > 0: + output_quantizer = next_op.get_quantizer("forward", 0) + grad_output_quantizer = self.get_quantizer("backward", 0) + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) + + # Configure quantizers + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + input_quantizer.set_usage(columnwise=weight_requires_grad) + weight_quantizer.set_usage(columnwise=False) # Get autocast dtype if needed dtype = None @@ -1096,27 +879,26 @@ def op_forward( output, x_local, _ = BasicLinear._functional_forward( input=input_, weight=self.weight, - device=self.device, dtype=dtype, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass ctx.save_for_backward(x_local) - ctx.with_fp8_compute = with_fp8_compute - ctx.weight_fp8_meta = weight_fp8_meta - ctx.grad_output_fp8_meta = grad_output_fp8_meta - ctx.grad_input_fp8_meta = grad_input_fp8_meta + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizer = input_quantizer + ctx.weight_quantizer = weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer ctx.dtype = dtype - ctx.input_dims = input_.size() - ctx.input_requires_grad = input_.requires_grad - ctx.weight_requires_grad = self.weight.requires_grad + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad ctx.has_prev_op = prev_op is not None return output @@ -1149,21 +931,19 @@ def op_backward( grad_output=grad_output, input=x_local, weight=self.weight, - input_dims=ctx.input_dims, - weight_dims=self.weight.size(), input_requires_grad=ctx.input_requires_grad, weight_requires_grad=ctx.weight_requires_grad, - device=self.device, dtype=ctx.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, - with_fp8_compute=ctx.with_fp8_compute, - weight_fp8_meta=ctx.weight_fp8_meta, - grad_output_fp8_meta=ctx.grad_output_fp8_meta, - grad_input_fp8_meta=ctx.grad_input_fp8_meta, + with_quantized_compute=ctx.with_quantized_compute, + input_quantizer=ctx.input_quantizer, + weight_quantizer=ctx.weight_quantizer, + grad_output_quantizer=ctx.grad_output_quantizer, + grad_input_quantizer=ctx.grad_input_quantizer, ) # Clear input tensor if possible diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index eac1865566..5a73ec6c25 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/identity.py b/transformer_engine/pytorch/ops/basic/identity.py index 73179c68a6..d0466be15e 100644 --- a/transformer_engine/pytorch/ops/basic/identity.py +++ b/transformer_engine/pytorch/ops/basic/identity.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 99c9c493db..c5897486e3 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -13,14 +13,15 @@ import torch from transformer_engine_torch import layernorm_bwd, layernorm_fwd -from ...cpp_extensions import ( - layernorm_fwd_fp8, - layernorm_fwd_fp8_inf, - layernorm_fwd_inf, +from ...fp8 import FP8GlobalStateManager +from ...constants import TE_DType +from ...tensor import QuantizedTensor +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, ) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -84,28 +85,23 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed dtype = canonicalize_dtype(dtype) weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) bias = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=dtype, ) weight = torch.nn.Parameter(weight) @@ -143,17 +139,18 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight bias = self.bias - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) - if bias.device.type != "cuda": - bias = torch.empty_like(bias, device=self.device) - else: - bias = bias.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) + if not devices_match(bias.device, device): + bias = torch.empty_like(bias, device=device) # Initialize values if self.zero_centered_gamma: @@ -184,17 +181,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype) @@ -208,64 +209,33 @@ def op_forward( # Check if backward pass is needed requires_grad = ctx.requires_grad - # Check if FP8 is enabled - with_fp8_output = ( + # Check if output is quantized + output_quantizer = None + if ( FP8GlobalStateManager.is_fp8_enabled() and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ) - output_fp8_meta = None - if with_fp8_output: - output_fp8_meta = next_op.get_fp8_meta("input") + and next_op.num_quantizers("forward") > 0 + ): + output_quantizer = next_op.get_quantizer("forward", 0) # Compute layer norm - y = None - means = None - rstdevs = None sm_margin = self._sm_margins["forward" if requires_grad else "inference"] - if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) - args = ( - x, - w, - b, - self.eps, - output_fp8_meta[fp8_meta_key], - 0, # fp8_meta_index - fp8_dtype, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - data, means, rstdevs = layernorm_fwd_fp8(*args) - else: - data = layernorm_fwd_fp8_inf(*args) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - args = ( - x, - w, - b, - self.eps, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - y, means, rstdevs = layernorm_fwd(*args) - else: - y = layernorm_fwd_inf(*args) + y, means, rstdevs = layernorm_fwd( + x, + w, + b, + self.eps, + None, + output_quantizer, + TE_DType[dtype], + sm_margin, + self.zero_centered_gamma, + ) # Save state for backward pass if requires_grad: ctx.save_for_backward(x, means, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -282,9 +252,12 @@ def op_backward( # Saved tensors from forward pass x, means, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -312,6 +285,6 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) - grad_bias = reshape(db, self._shape) + grad_weight = reshape(dw, weight_dims) + grad_bias = reshape(db, weight_dims) return grad_input, (grad_weight, grad_bias) diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index db1651c184..73d08b5c7f 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 313b6e5583..448954fc69 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,8 +9,8 @@ import torch -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor from ..op import BasicOperation, OperationContext @@ -38,10 +38,10 @@ def __init__( self._quantize_forward = forward self._quantize_backward = backward - def num_fp8_scales(self, mode: str) -> int: - if mode == "input" and self._quantize_forward: + def num_quantizers(self, mode: str) -> int: + if mode == "forward" and self._quantize_forward: return 1 - if mode == "grad_output" and self._quantize_backward: + if mode == "backward" and self._quantize_backward: return 1 return 0 @@ -61,15 +61,7 @@ def op_forward( # Quantize if needed out = input_ if quantize_forward and not isinstance(out, QuantizedTensor): - fp8_meta = self.get_fp8_meta("input") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) - out = Float8Tensor.to_float8( - out, - fp8_meta=fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) + out = self.get_quantizer("forward", 0)(out) ctx.quantize_backward = quantize_backward return out @@ -81,13 +73,5 @@ def op_backward( ) -> tuple[torch.Tensor, tuple[()]]: grad_input = grad_output if ctx.quantize_backward and not isinstance(grad_input, QuantizedTensor): - fp8_meta = self.get_fp8_meta("grad_output") - fp8_dtype = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=False) - grad_input = Float8Tensor.to_float8( - grad_input, - fp8_meta=fp8_meta, - fp8_meta_forward=False, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - ) + grad_input = self.get_quantizer("backward", 0)(grad_input) return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index c78dbc2877..adfd46641b 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,9 +9,9 @@ import torch -from ...tensor import Float8Tensor, QuantizedTensor +from ...distributed import gather_along_first_dim +from ...tensor import QuantizedTensor from ..op import BasicOperation, OperationContext -from .._common import convert_tensor class ReduceScatter(BasicOperation): @@ -45,7 +45,7 @@ def op_forward( # Trivial case if self.process_group_size == 1: - return input_ + return input_.detach() # Tensor dimensions input_dims = input_.size() @@ -74,47 +74,9 @@ def op_backward( ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: - - # Trivial case + grad_input: torch.Tensor if self.process_group_size == 1: - return grad_output, () - - # Tensor dimensions - output_dims = grad_output.size() - if not output_dims: - raise RuntimeError( - "Attempted to all-gather a tensor " - f"with shape={list(output_dims)} " - f"over {self.process_group_size} processes" - ) - input_dims = list(output_dims) - input_dims[0] *= self.process_group_size - - # Perform all-gather - dy = convert_tensor(grad_output, memory_format=torch.contiguous_format) - dx = None - if isinstance(dy, Float8Tensor): - dx = Float8Tensor.make_like( - dy, - data=torch.empty( - input_dims, - dtype=torch.uint8, - device=dy.device, - ), - ) - torch.distributed.all_gather_into_tensor( - dx._data, - dy._data, - group=self.process_group, - ) + grad_input = grad_output.detach() else: - if isinstance(dy, QuantizedTensor): - dy = dy.dequantize() - dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device) - torch.distributed.all_gather_into_tensor( - dx, - dy, - group=self.process_group, - ) - - return dx, () + grad_input, _ = gather_along_first_dim(grad_output, self.process_group) + return grad_input, () diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index c3b1816635..1e9095169c 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -14,7 +14,6 @@ BasicOperation, OperationContext, ) -from .._common import reshape class Reshape(BasicOperation): @@ -42,11 +41,11 @@ def op_forward( next_op: Optional[BasicOperation] = None, ) -> torch.Tensor: ctx.input_shape = input_.size() - return reshape(input_, self._shape) + return input_.reshape(*self._shape) def op_backward( self, ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: - return reshape(grad_output, ctx.input_shape), () + return grad_output.reshape(*ctx.input_shape), () diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 4f0e2ddc22..c1f32af93a 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -13,14 +13,15 @@ import torch from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd -from ...cpp_extensions import ( - rmsnorm_fwd_fp8, - rmsnorm_fwd_fp8_inf, - rmsnorm_fwd_inf, +from ...fp8 import FP8GlobalStateManager +from ...tensor import QuantizedTensor +from ...constants import TE_DType +from ...utils import ( + canonicalize_device, + canonicalize_dtype, + clear_tensor_data, + devices_match, ) -from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype -from ...tensor import Float8Tensor, QuantizedTensor -from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, reshape @@ -83,22 +84,17 @@ def __init__( normalized_shape = (normalized_shape,) else: normalized_shape = tuple(normalized_shape) - self._shape: tuple[int, ...] = normalized_shape # Parameter device defer_param_init = False device = canonicalize_device(device) if device.type == "meta": defer_param_init = True - device = canonicalize_device(None) - if device.type != "cuda": - raise ValueError(f"Only CUDA devices are supported (got {device})") - self.device: torch.device = device # Initialize parameters if needed weight = torch.empty( - self._shape, - device="meta", + normalized_shape, + device=device, dtype=canonicalize_dtype(dtype), ) weight = torch.nn.Parameter(weight) @@ -133,12 +129,15 @@ def getenv(name: str) -> int: def reset_parameters(self) -> None: """Initialize parameter buffers and values""" - # Make sure parameter is initialized + # Parameter device weight = self.weight - if weight.device.type != "cuda": - weight = torch.empty_like(weight, device=self.device) - else: - weight = weight.to(device=self.device) + device = weight.device + if device.type == "meta": + device = canonicalize_device(None) + + # Initialize param buffers + if not devices_match(weight.device, device): + weight = torch.empty_like(weight, device=device) # Initialize values if self.zero_centered_gamma: @@ -165,17 +164,21 @@ def op_forward( ) -> torch.Tensor: # Check tensor dims + weight = self.weight + weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) - if len(input_dims) < len(self._shape) or input_dims[-len(self._shape) :] != self._shape: + if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: raise ValueError( f"Input tensor (shape={input_dims}) " - f"and weight tensor (shape={self._shape}) are not compatible" + f"and weight tensor (shape={weight_dims}) are not compatible" ) # Check input tensors - inner_dim = math.prod(self._shape) - device = self.device - dtype = maybe_autocast_dtype(default_dtype=self.weight.dtype) + inner_dim = math.prod(weight_dims) + device = weight.device + if device.type != "cuda": + device = canonicalize_device(None) + dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) if isinstance(x, QuantizedTensor): @@ -186,61 +189,32 @@ def op_forward( # Check if backward pass is needed requires_grad = ctx.requires_grad - # Check if FP8 is enabled - with_fp8_output = ( + # Check if output is quantized + output_quantizer = None + if ( FP8GlobalStateManager.is_fp8_enabled() and next_op is not None - and next_op.num_fp8_scales("input") > 0 - ) - output_fp8_meta = None - if with_fp8_output: - output_fp8_meta = next_op.get_fp8_meta("input") + and next_op.num_quantizers("forward") > 0 + ): + output_quantizer = next_op.get_quantizer("forward", 0) # Compute RMSNorm - y = None - rstdevs = None sm_margin = self._sm_margins["forward" if requires_grad else "inference"] - if with_fp8_output: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) - fp8_dtype = get_fp8_te_dtype(output_fp8_meta["recipe"], fprop_tensor=True) - args = ( - x, - w, - self.eps, - output_fp8_meta[fp8_meta_key], - 0, # fp8_meta_index - fp8_dtype, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - data, rstdevs = rmsnorm_fwd_fp8(*args) - else: - data = rmsnorm_fwd_fp8_inf(*args) - y = Float8Tensor( - data=data, - fp8_meta=output_fp8_meta, - fp8_meta_forward=True, - fp8_meta_index=0, - fp8_dtype=fp8_dtype, - dtype=dtype, - ) - else: - args = ( - x, - w, - self.eps, - sm_margin, - self.zero_centered_gamma, - ) - if requires_grad: - y, rstdevs = rmsnorm_fwd(*args) - else: - y = rmsnorm_fwd_inf(*args) + y, _, rstdevs = rmsnorm_fwd( + x, + w, + self.eps, + None, + output_quantizer, + TE_DType[dtype], + sm_margin, + self.zero_centered_gamma, + ) # Save state for backward pass if requires_grad: ctx.save_for_backward(x, rstdevs) + ctx.device = device ctx.dtype = dtype ctx.has_prev_op = prev_op is not None @@ -257,9 +231,12 @@ def op_backward( # Saved tensors from forward pass x, rstdevs = ctx.saved_tensors + # Tensor dims + weight_dims = self.weight.size() + inner_dim = math.prod(weight_dims) + # Check input tensors - inner_dim = x.size(-1) - device = self.device + device = ctx.device dtype = ctx.dtype dy = reshape(grad_output, x.size(), device=device, dtype=dtype) w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype) @@ -285,5 +262,5 @@ def op_backward( # Reshape results grad_input = reshape(dx, grad_output.size()) - grad_weight = reshape(dw, self._shape) + grad_weight = reshape(dw, weight_dims) return grad_input, (grad_weight,) diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 08b9f06123..b9b5ec9508 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 123c560066..e295929e98 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -73,11 +73,8 @@ def fuser_backward( grad_output=grad_output, input=x_local, weight=linear_op.weight, - input_dims=linear_op_ctx.input_dims, - weight_dims=linear_op.weight.size(), input_requires_grad=linear_op_ctx.input_requires_grad, weight_requires_grad=linear_op_ctx.weight_requires_grad, - device=linear_op.device, dtype=grad_input.dtype, grad_weight=grad_weight, accumulate_into_grad_weight=accumulate_into_main_grad, @@ -86,10 +83,11 @@ def fuser_backward( tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=linear_op_ctx.with_fp8_compute, - weight_fp8_meta=linear_op_ctx.weight_fp8_meta, - grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, - grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, + with_quantized_compute=linear_op_ctx.with_quantized_compute, + input_quantizer=linear_op_ctx.input_quantizer, + weight_quantizer=linear_op_ctx.weight_quantizer, + grad_output_quantizer=linear_op_ctx.grad_output_quantizer, + grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) if accumulate_into_main_grad: grad_weight = None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 3afdc3a0c3..6088b3c0db 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -83,22 +83,22 @@ def fuser_forward( raise NotImplementedError("Activations are not yet supported") # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = linear_op.get_fp8_meta("input") - weight_fp8_meta = linear_op.get_fp8_meta("param") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) next_op = basic_op_next_ops[-1] - if next_op is not None and next_op.num_fp8_scales("input") > 0: - output_fp8_meta = next_op.get_fp8_meta("input") - grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + if next_op is not None and next_op.num_quantizers("forward") > 0: + output_quantizer = next_op.get_quantizer("forward", 0) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) # Get autocast dtype if needed dtype = None @@ -110,25 +110,24 @@ def fuser_forward( input=input_, weight=linear_op.weight, bias=bias, - device=linear_op.device, dtype=dtype, tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass linear_op_ctx.save_for_backward(x_local) - linear_op_ctx.with_fp8_compute = with_fp8_compute - linear_op_ctx.weight_fp8_meta = weight_fp8_meta - linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta - linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 3d994d80f0..69b0c3ba5a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -77,19 +77,19 @@ def fuser_forward( raise ValueError("Bias operation forward does not expect keyword arguments") # FP8 metadata - with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() - input_fp8_meta = None - weight_fp8_meta = None - output_fp8_meta = None - grad_output_fp8_meta = None - grad_input_fp8_meta = None - if with_fp8_compute: - input_fp8_meta = linear_op.get_fp8_meta("input") - weight_fp8_meta = linear_op.get_fp8_meta("param") - grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output") + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizer = None + weight_quantizer = None + output_quantizer = None + grad_output_quantizer = None + grad_input_quantizer = None + if with_quantized_compute: + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) prev_op = basic_op_prev_ops[0] - if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: - grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") + if prev_op is not None and prev_op.num_quantizers("backward") > 0: + grad_input_quantizer = prev_op.get_quantizer("backward", 0) # Get autocast dtype if needed dtype = None @@ -102,26 +102,25 @@ def fuser_forward( input=input_, weight=linear_op.weight, bias=bias, - device=linear_op.device, out=output, accumulate_into_out=True, tensor_parallel_mode=linear_op.tensor_parallel_mode, tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, - with_fp8_compute=with_fp8_compute, - input_fp8_meta=input_fp8_meta, - weight_fp8_meta=weight_fp8_meta, - output_fp8_meta=output_fp8_meta, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, ) # Save state for backward pass linear_op_ctx.save_for_backward(x_local) - linear_op_ctx.with_fp8_compute = with_fp8_compute - linear_op_ctx.weight_fp8_meta = weight_fp8_meta - linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta - linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.weight_requires_grad = linear_op.weight.requires_grad linear_op_ctx.has_prev_op = basic_op_prev_ops[0] is not None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 907cff1c81..bbb27f86e6 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -1,9 +1,11 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Linear layer backward with Userbuffers communication.""" +# pylint: skip-file ### TODO Debug Userbuffers support + from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional @@ -12,11 +14,7 @@ import torch from transformer_engine_torch import CommOverlapAlgo -from ...cpp_extensions import ( - fp8_cast_transpose_bgrad_fused, - fp8_gemm, - gemm, -) +from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -49,6 +47,9 @@ def __init__( reduce_scatter: Optional[ReduceScatter], ) -> None: + ### TODO Debug Userbuffers support + raise NotImplementedError("Userbuffers support has been broken by recent refactors") + # Basic operations that comprise this fused operation op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} ops = [] @@ -706,6 +707,8 @@ def fuse_userbuffers_backward_linear( """ + return ops ### TODO Debug Userbuffers support + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index a1b0ca6a9e..a08c0a6ef9 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -1,9 +1,11 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Linear layer forward with Userbuffers communication.""" +# pylint: skip-file ### TODO Debug Userbuffers support + from __future__ import annotations from collections.abc import Iterable from typing import Any, Optional @@ -11,7 +13,7 @@ import torch from transformer_engine_torch import CommOverlapAlgo -from ...cpp_extensions import fp8_gemm, gemm +from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype @@ -49,6 +51,9 @@ def __init__( reduce_scatter: Optional[ReduceScatter], ) -> None: + ### TODO Debug Userbuffers support + raise NotImplementedError("Userbuffers support has been broken by recent refactors") + # Basic operations that comprise this fused operation op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None} ops = [linear] @@ -524,6 +529,8 @@ def fuse_userbuffers_forward_linear( """ + return ops ### TODO Debug Userbuffers support + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 6fcb435e5c..7c638032f1 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -135,7 +135,11 @@ def forward( requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) for idx in basic_op_idxs: basic_op_ctxs[idx].requires_grad = requires_grad - x.requires_grad_(requires_grad=requires_grad) + if requires_grad != x.requires_grad: + if requires_grad: + x.requires_grad_() + else: + x = x.detach() # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] @@ -188,7 +192,7 @@ def forward( func_ctx.backward_ops = backward_ops func_ctx.basic_ops = basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.num_params = num_params + func_ctx.basic_op_num_params = [sum(1 for _ in op.parameters()) for op in basic_ops] func_ctx.num_extra_inputs = num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() @@ -254,14 +258,14 @@ def backward( # Flatten list of parameter gradients grad_params_flat = [] for idx, dparams in enumerate(grad_params): - params = list(basic_ops[idx].parameters()) + num_params = func_ctx.basic_op_num_params[idx] if dparams is None: - dparams = [None for _ in range(len(params))] + dparams = [None for _ in range(num_params)] else: dparams = list(dparams) - if len(dparams) != len(params): + if len(dparams) != num_params: raise RuntimeError( - f"Expected op {idx} to generate {len(params)} param grads, " + f"Expected op {idx} to generate {num_params} param grads, " f"but got {len(dparams)}" ) grad_params_flat.extend(dparams) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 68472f171a..8ed2702a72 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 0bb6f25db8..8346d31a40 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -13,13 +13,14 @@ import torch -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import ( - DelayedScaling, +from transformer_engine.common.recipe import Recipe +from ..fp8 import ( + MXFP8BlockScalingRecipeState, + DelayedScalingRecipeState, FP8GlobalStateManager, - get_default_fp8_recipe, + RecipeState, ) -from ._common import canonicalize_device, is_float8_tensor +from ..tensor import Quantizer @dataclasses.dataclass @@ -174,132 +175,148 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): def __init__(self) -> None: super().__init__() - # FP8 metadata objects + # Objects for quantization + self._quantizers: Optional[dict[str, list[Quantizer]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None @property def is_fused_op(self) -> bool: return False - def num_fp8_scales( + def num_quantizers( self, mode: str, # pylint: disable=unused-argument ) -> int: - """Number of FP8 scaling factors + """Number of quantizers + + Matches number of quantized tensors used in operation. Parameters ---------- - mode: {"input", "param", "grad_output"} - Type of FP8 scaling factor + mode: {"forward", "backward"} + Quantizer type """ return 0 - def _make_fp8_metas(self) -> dict[str, Optional[dict[str, Any]]]: - """Construct FP8 metadata""" - - # Shared objects for FP8 metadata - dtype = torch.float32 - device = canonicalize_device(None) - recipe = get_default_fp8_recipe() - - def _make_meta( - num_scales: int, - is_forward: bool, - ) -> Optional[dict[str, Any]]: - """Construct FP8 metadata for one tensor type""" - if num_scales == 0: - return None - key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - meta = tex.FP8TensorMeta() - meta.scale = torch.ones(num_scales, dtype=dtype, device=device) - meta.scale_inv = torch.ones(num_scales, dtype=dtype, device=device) - meta.amax_history = torch.zeros( - (recipe.amax_history_len, num_scales), - dtype=dtype, - device=device, + def _reset_quantization_recipe_state( + self, + *, + recipe: Optional[Recipe] = None, + ) -> None: + """Construct state for quantization recipe""" + + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() + + # Quantization recipe state for forward and backward pass + self._fp8_metas = {"forward": None, "backward": None} + self._quantizers = {"forward": [], "backward": []} + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: + continue + + # Construct quantization recipe state + recipe_state = RecipeState.create( + recipe, + mode=mode, + num_quantizers=num_quantizers, ) - return { - key: meta, + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + self._fp8_metas[mode] = { + fp8_meta_key: recipe_state, "recipe": recipe, - "fp8_group": None, + "fp8_group": FP8GlobalStateManager.get_fp8_group(), } - # Construct FP8 metadata for all tensor types - return { - "input": _make_meta(self.num_fp8_scales("input"), True), - "param": _make_meta(self.num_fp8_scales("param"), True), - "grad_output": _make_meta(self.num_fp8_scales("grad_output"), False), - } - - @classmethod - def _maybe_update_fp8_meta( - cls, - fp8_meta: Optional[dict[str, Any]], + # Construct builder class for quantized tensors + self._quantizers[mode] = recipe_state.make_quantizers() + + def _update_quantization_recipe_state( + self, *, - fp8_recipe: Optional[DelayedScaling] = None, + recipe: Optional[Recipe] = None, ) -> None: - if fp8_meta is None: - return + """Make sure quantizer state matches quantization recipe""" - # Update FP8 recipe - if fp8_recipe is None: - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - fp8_meta["recipe"] = fp8_recipe + # Quantization recipe + if recipe is None: + recipe = FP8GlobalStateManager.get_fp8_recipe() - # Update FP8 communication group - fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - - # Adjust amax history length if needed - amax_history_len = fp8_recipe.amax_history_len - for is_forward in (True, False): - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=is_forward) - if fp8_meta_key not in fp8_meta: + # Reset quantization state if needed + if self._fp8_metas is None or self._quantizers is None: + self._reset_quantization_recipe_state(recipe=recipe) + return + for mode in ("forward", "backward"): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: continue - meta = fp8_meta[fp8_meta_key] - curr_len = meta.amax_history.size(0) - - # Nothing to be done if amax history is already correct - if curr_len == amax_history_len: + recipe_state = self._fp8_metas[mode][fp8_meta_key] + need_to_reset_recipe_state = ( + recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState) + ) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) + if need_to_reset_recipe_state: + self._reset_quantization_recipe_state(recipe=recipe) + return + + # Quantization recipe state for forward and backward pass + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: continue - # Reallocate amax history - with torch.no_grad(): - if curr_len > amax_history_len: - meta.amax_history = meta.amax_history[:amax_history_len].clone() - else: - meta.amax_history = torch.nn.functional.pad( - meta.amax_history, - pad=(0, 0, 0, amax_history_len - curr_len), - ) + # Update FP8 metadata + fp8_meta = self._fp8_metas[mode] + fp8_meta["recipe"] = recipe + fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - # Update global buffers for amax reductions - buffer_info_key = FP8GlobalStateManager.get_buffer_info() - if buffer_info_key in fp8_meta: - fwd_pos, fwd_key, bwd_pos, bwd_key = fp8_meta[buffer_info_key] - for pos, buffer_key in zip((fwd_pos, bwd_pos), (fwd_key, bwd_key)): - assert ( - buffer_key in FP8GlobalStateManager.global_amax_history_buffer - ), "TE internal error during amax history change." - FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = fp8_meta[ - fp8_meta_key - ].amax_history[0] - FP8GlobalStateManager.global_amax_history_buffer[buffer_key][pos] = fp8_meta[ - fp8_meta_key - ].amax_history - - def get_fp8_meta(self, mode: str) -> Optional[dict[str, Any]]: - """FP8 metadata + # Get recipe state + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + recipe_state = fp8_meta[fp8_meta_key] + + # Reallocate amax history if needed + if recipe.mxfp8(): + continue + + current_length = recipe_state.amax_history.size(0) + target_length = recipe.amax_history_len + if current_length != target_length: + with torch.no_grad(): + if target_length < current_length: + recipe_state.amax_history = recipe_state.amax_history[ + :target_length + ].clone() + else: + recipe_state.amax_history = torch.nn.functional.pad( + recipe_state.amax_history, + pad=(0, 0, 0, target_length - current_length), + ) + self._quantizers[mode] = recipe_state.make_quantizers() + + def get_quantizer( + self, + mode: str, + index: int, + ) -> Quantizer: + """Get builder class for quantized tensor Parameters ---------- - mode: {"input", "param", "grad_output"} - Type of FP8 scaling factor + mode: {"forward", "backward"} + Quantizer type """ - if self._fp8_metas is None: - self._fp8_metas = self._make_fp8_metas() - return self._fp8_metas[mode] + if self._quantizers is None: + self._reset_quantization_recipe_state() + return self._quantizers[mode][index] @torch.no_grad() def _save_fp8_metas(self) -> Optional[dict[str, Any]]: @@ -321,7 +338,6 @@ def _save_fp8_metas(self) -> Optional[dict[str, Any]]: continue out[mode][fp8_meta_key] = ( fp8_meta[fp8_meta_key].scale.clone(), - fp8_meta[fp8_meta_key].scale_inv.clone(), fp8_meta[fp8_meta_key].amax_history.clone(), ) return out @@ -346,16 +362,15 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: assert ( fp8_meta_key in self._fp8_metas[mode] ), f"Found an unexpected key ({mode=}, {fp8_meta_key=}) in saved FP8 metadata" - scale, scale_inv, amax_history = tensors + scale, amax_history = tensors self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) - self._fp8_metas[mode][fp8_meta_key].scale_inv.copy_(scale_inv) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) def pre_forward( self, *, fp8_enabled: Optional[bool] = None, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, ) -> None: """Preprocessing before forward pass""" @@ -363,30 +378,15 @@ def pre_forward( if fp8_enabled is None: fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() if fp8_enabled: - - # Construct FP8 metadata if needed - if self._fp8_metas is None: - self._fp8_metas = self._make_fp8_metas() - - # Make sure FP8 metadata matches FP8 autocast context - for fp8_meta in self._fp8_metas.values(): - self._maybe_update_fp8_meta(fp8_meta, fp8_recipe=fp8_recipe) - - # Register FP8 metadata for amax and scale update + self._update_quantization_recipe_state(recipe=fp8_recipe) if not FP8GlobalStateManager.fp8_graph_capturing(): - if self.num_fp8_scales("input"): + if self.num_quantizers("forward"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("input"), + self._fp8_metas["forward"], ) - if self.num_fp8_scales("param"): - fp8_params = list(filter(is_float8_tensor, self.parameters())) + if self.num_quantizers("backward"): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("param"), - fp8_weights=(fp8_params if fp8_params else None), - ) - if self.num_fp8_scales("grad_output"): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self.get_fp8_meta("grad_output"), + self._fp8_metas["backward"], ) @abc.abstractmethod @@ -505,7 +505,7 @@ def forward( basic_op_kwargs=[kwargs], ) - def get_extra_state(self) -> Optional[torch.Tensor]: + def get_extra_state(self) -> torch.Tensor: """Serialize extra state Contains metadata for FP8 casting. @@ -516,7 +516,7 @@ def get_extra_state(self) -> Optional[torch.Tensor]: # # (1) PyTorch's "extra state" infrastructure might be able to # support any picklable type, but they make no guarantees. - # It seems that ONNX export experiences issues with + # We have experienced problems (e.g. in ONNX export) with # non-tensor extra state. # (2) PyTorch's checkpointing infrastructure does not remap # devices for "extra state" like it does for "state dict". @@ -529,13 +529,6 @@ def get_extra_state(self) -> Optional[torch.Tensor]: # See: https://github.com/NVIDIA/TransformerEngine/pull/351 # See: https://github.com/NVIDIA/TransformerEngine/pull/363 - # Return immediately if op has no FP8 state - has_fp8_state = any( - self.num_fp8_scales(mode) > 0 for mode in ("input", "param", "grad_output") - ) - if not has_fp8_state: - return None - def to_cpu(src: torch.Tensor) -> torch.Tensor: """Helper function to make CPU copy of tensor @@ -549,25 +542,20 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Store FP8 state state = {} - for mode in ("input", "param", "grad_output"): + for mode in ("forward", "backward"): # Get state for a given FP8 tensor - if self.num_fp8_scales(mode) == 0: - state[mode] = None + if self.num_quantizers(mode) == 0: continue fp8_meta = self.get_fp8_meta(mode) - if fp8_meta is None: - continue state[mode] = {} # Store tensors if "scaling_fwd" in fp8_meta: state[mode]["scale_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale) - state[mode]["scale_inv_fwd"] = to_cpu(fp8_meta["scaling_fwd"].scale_inv) state[mode]["amax_history_fwd"] = to_cpu(fp8_meta["scaling_fwd"].amax_history) if "scaling_bwd" in fp8_meta: state[mode]["scale_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale) - state[mode]["scale_inv_bwd"] = to_cpu(fp8_meta["scaling_bwd"].scale_inv) state[mode]["amax_history_bwd"] = to_cpu(fp8_meta["scaling_bwd"].amax_history) # Store other picklable items @@ -588,12 +576,12 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: def set_extra_state(self, state: Optional[torch.Tensor]) -> None: """Load extra state""" - if state is None: + if state is None or state.numel() == 0: return # Deserialize state from byte tensor state = pickle.loads(state.detach().numpy(force=True).tobytes()) - if state is None: + if state is None or len(state) == 0: return def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: @@ -608,12 +596,12 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load FP8 state - for mode in ("input", "param", "grad_output"): + for mode in ("forward", "backward"): # Get state for a given FP8 tensor if mode not in state: continue - if self.num_fp8_scales(mode) == 0: + if self.num_quantizers(mode) == 0: continue fp8_meta = self.get_fp8_meta(mode) if fp8_meta is None: @@ -633,12 +621,10 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: if "scaling_fwd" in fp8_meta: fp8_meta_fwd = fp8_meta["scaling_fwd"] copy_tensor(state[mode]["scale_fwd"], fp8_meta_fwd.scale) - copy_tensor(state[mode]["scale_inv_fwd"], fp8_meta_fwd.scale_inv) copy_tensor(state[mode]["amax_history_fwd"], fp8_meta_fwd.amax_history) if "scaling_bwd" in fp8_meta: fp8_meta_bwd = fp8_meta["scaling_bwd"] copy_tensor(state[mode]["scale_bwd"], fp8_meta_bwd.scale) - copy_tensor(state[mode]["scale_inv_bwd"], fp8_meta_bwd.scale_inv) copy_tensor(state[mode]["amax_history_bwd"], fp8_meta_bwd.amax_history) # Finish CPU-GPU memory transfers diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index 8d4fefb4c5..3240bd73d6 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/optimizers/__init__.py b/transformer_engine/pytorch/optimizers/__init__.py index fc9bdc304a..c76f75743d 100644 --- a/transformer_engine/pytorch/optimizers/__init__.py +++ b/transformer_engine/pytorch/optimizers/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 93f6191dfe..070f46e937 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -1,31 +1,31 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Fused Adam optimizer.""" +from __future__ import annotations +from collections.abc import Iterable from copy import deepcopy from itertools import chain +from typing import Optional +import warnings import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from .multi_tensor_apply import multi_tensor_applier -from ..float8_tensor import Float8Tensor def get_fp8_meta(fp8_tensor): """FP8 metadata getter.""" - if fp8_tensor._fp8_meta is None: - raise RuntimeError("FP8 meta data is not initialized.") + assert isinstance(fp8_tensor, Float8Tensor), "Fused optimizer supports only Float8Tensor class" + if fp8_tensor._quantizer is None: + raise RuntimeError("FP8 quantizer data is not initialized.") - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=fp8_tensor._fp8_meta_forward, - ) + quantizer = fp8_tensor._quantizer - fp8_meta_index = fp8_tensor._fp8_meta_index - scale = fp8_tensor._fp8_meta[fp8_meta_key].scale[fp8_meta_index] - amax = fp8_tensor._fp8_meta[fp8_meta_key].amax_history[0][fp8_meta_index] + scale = quantizer.scale + amax = quantizer.amax scale_inv = fp8_tensor._scale_inv return scale, amax, scale_inv @@ -56,8 +56,6 @@ class FusedAdam(torch.optim.Optimizer): params (iterable): iterable of parameters to optimize or dicts defining parameter groups. lr (float, optional): learning rate. (default: 1e-3) - bias_correction (bool, optional): apply correction factor to - moment estimates. (default: True) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve @@ -66,10 +64,10 @@ class FusedAdam(torch.optim.Optimizer): amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) NOT SUPPORTED in FusedAdam! + bias_correction (bool, optional): apply correction factor to + moment estimates. (default: True) adam_w_mode (boolean, optional): Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True) - set_grad_none (bool, optional): whether set grad to None when zero_grad() - method is called. (default: True) capturable (bool, optional): whether to use the version of the optimizer that can be used with CUDA Graphs. (default: False) master_weights (bool, optional): whether to maintain FP32 master weights @@ -94,6 +92,13 @@ class FusedAdam(torch.optim.Optimizer): instead of ".grad" for reading gradients. It's useful when the dtypes of grad and param are different. (default: False) + store_param_remainders (bool, optional): Whether to store entire FP32 master + params or just store the trailing 16 remainder bits. Whole FP32 master can be + reconstructed from BF16 params plus the trailing remainder bits. Works only + when param type is BF16 and master weight type is FP32, no effect otherwise. + Useful memory saving optimization. + (default: False) + .. _Adam - A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 @@ -103,21 +108,23 @@ class FusedAdam(torch.optim.Optimizer): def __init__( self, - params, - lr=1e-3, + params: Iterable[torch.nn.Parameter | dict], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False, + *, bias_correction=True, - betas=(0.9, 0.999), - eps=1e-8, adam_w_mode=True, - weight_decay=0.0, - amsgrad=False, - set_grad_none=True, capturable=False, master_weights=False, master_weight_dtype=torch.float32, exp_avg_dtype=torch.float32, exp_avg_sq_dtype=torch.float32, use_decoupled_grad=False, + store_param_remainders=False, + set_grad_none: Optional[bool] = None, # deprecated ): if amsgrad: @@ -142,6 +149,8 @@ def __init__( raise RuntimeError("Capturable mode only supports fp32 exp_avg.") if capturable and exp_avg_sq_dtype != torch.float32: raise RuntimeError("Capturable mode only supports fp32 exp_avg_sq") + if capturable and store_param_remainders: + raise RuntimeError("Capturable mode doesn't support storing param remainders") # If the optimizer is capturable then LR should be a tensor (on GPU) lr = torch.tensor(lr, dtype=torch.float32) if capturable else lr @@ -154,7 +163,6 @@ def __init__( } super().__init__(params, defaults) self.adam_w_mode = 1 if adam_w_mode else 0 - self.set_grad_none = set_grad_none self.capturable = capturable self.master_weights = master_weights @@ -172,6 +180,7 @@ def __init__( # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device="cuda") self.multi_tensor_adam = tex.multi_tensor_adam + self.multi_tensor_adam_param_remainder = tex.multi_tensor_adam_param_remainder self.multi_tensor_adam_fp8 = tex.multi_tensor_adam_fp8 self.multi_tensor_adam_capturable = tex.multi_tensor_adam_capturable self.multi_tensor_adam_capturable_master = tex.multi_tensor_adam_capturable_master @@ -192,20 +201,51 @@ def __init__( } self._scales = {} self.use_decoupled_grad = use_decoupled_grad + # Works only when master params is in FP32 + self.store_param_remainders = ( + store_param_remainders and master_weights and master_weight_dtype == torch.float32 + ) + + # Deprecated options + self.set_grad_none = set_grad_none + if self.set_grad_none is not None: + warnings.warn( + "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "Use set_to_none kwarg in zero_grad instead.", + DeprecationWarning, + ) + + def zero_grad(self, set_to_none: Optional[bool] = None) -> None: + """Reset parameter gradients. + + Arguments: + set_to_none (bool, optional): whether to set grads to `None` + instead of zeroing out buffers. (default: True) + + """ + + # Handle deprecated set_grad_none option + if self.set_grad_none is not None: + if set_to_none is not None and set_to_none != self.set_grad_none: + raise ValueError( + f"Called zero_grad with set_to_none={set_to_none}, " + f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + ) + set_to_none = self.set_grad_none + if set_to_none is None: + set_to_none = True - def zero_grad(self): - # pylint: disable=missing-function-docstring - if not self.use_decoupled_grad and not self.set_grad_none: - super().zero_grad() + if not self.use_decoupled_grad and not set_to_none: + super().zero_grad(set_to_none=set_to_none) return for group in self.param_groups: for p in group["params"]: - if self.use_decoupled_grad and self.set_grad_none: + if self.use_decoupled_grad and set_to_none: p.decoupled_grad = None - elif self.use_decoupled_grad and not self.set_grad_none: + elif self.use_decoupled_grad and not set_to_none: p.decoupled_grad.zero_() - elif not self.use_decoupled_grad and self.set_grad_none: + elif not self.use_decoupled_grad and set_to_none: p.grad = None def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): @@ -222,6 +262,10 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: assert isinstance(scaled_state, Float8Tensor) + assert len(scaled_state._quantizer.scale) == 1, ( + "Only scaling with one scaling factor per tensor is supported by the" + " FusedAdam." + ) else: assert scaled_state.dtype == dtype @@ -236,7 +280,7 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): absmax = absmax.to(dtype=torch.float32, device=unscaled_state.device) torch.div(absmax, max_range, out=scale) if isinstance(scaled_state, Float8Tensor): - scaled_state._scale_inv.copy_(scale) + scaled_state._quantizer.scale.copy_(1 / scale) scaled_state.copy_(unscaled_state) else: rscale = torch.where(scale > 0, scale.reciprocal(), 0.0) @@ -254,14 +298,20 @@ def get_unscaled_state(self, param, state_name): state = self.state[param] dtype = self.name_to_dtype_map[state_name] if dtype == torch.uint8: - assert isinstance(state[state_name], Float8Tensor) unscaled = state[state_name].float() elif dtype == torch.float16: assert state[state_name].dtype == torch.float16 unscaled = state[state_name].float() unscaled.mul_(self._scales[param][state_name]) elif dtype == torch.float32: - assert state[state_name].dtype == torch.float32 + if ( + self.store_param_remainders + and state_name == "master_param" + and param.dtype == torch.bfloat16 + ): + assert state[state_name].dtype == torch.int16 + else: + assert state[state_name].dtype == torch.float32 unscaled = state[state_name] else: raise RuntimeError(f"Dtype of {state_name} can only be fp8/fp16/fp32.") @@ -279,10 +329,19 @@ def set_scaled_state(self, param, state_name, unscaled_state): and 'master_param`. unscaled_state (torch.Tensor): The original high-precision(FP32) state. """ - assert unscaled_state.dtype == torch.float32 + store_param_remainders = ( + self.store_param_remainders + and state_name == "master_param" + and param.dtype == torch.bfloat16 + ) + + if store_param_remainders: + assert unscaled_state.dtype == torch.int16 + else: + assert unscaled_state.dtype == torch.float32 state = self.state[param] if state_name not in state: - self._initialize_state(param, state_name, False) + self._initialize_state(param, state_name, False, store_param_remainders) dtype = self.name_to_dtype_map[state_name] if dtype != torch.float32: @@ -291,7 +350,9 @@ def set_scaled_state(self, param, state_name, unscaled_state): else: state[state_name].copy_(unscaled_state) - def _initialize_state(self, param, state_name, zero_buffer: bool): + def _initialize_state( + self, param, state_name, zero_buffer: bool, store_param_remainders: bool = False + ): """Initialize one of the optimizer states according to `state_name`. Arguments: @@ -299,19 +360,26 @@ def _initialize_state(self, param, state_name, zero_buffer: bool): state_name (string): Name of optimizer states, can be one of 'exp_avg', 'exp_avg_sq', and 'master_param`. zero_buffer (bool): Whether to initialize the optimizer state with zeros. + store_param_remainders (bool): Store only trailing remainder bits. """ dtype = self.name_to_dtype_map[state_name] - data = torch.empty_like(param, dtype=dtype) + if store_param_remainders: + data = torch.zeros_like(param, dtype=torch.int16) + else: + data = torch.empty_like(param, dtype=dtype) if zero_buffer: data.zero_() if dtype == torch.uint8: - self.state[param][state_name] = Float8Tensor( - data=data, - dtype=torch.float32, - fp8_scale_inv=torch.ones([1], dtype=torch.float32, device=param.device), + quantizer = Float8Quantizer( + scale=torch.ones([1], dtype=torch.float32, device=param.device), + amax=torch.zeros([1], dtype=torch.float32, device=param.device), + fp8_dtype=tex.DType.kFloat8E4M3, ) + self.state[param][state_name] = quantizer.make_empty(param.shape) + self.state[param][state_name].quantize_(data.float()) else: + self.state[param][state_name] = data # Create scale if necessary. @@ -322,17 +390,24 @@ def _initialize_state(self, param, state_name, zero_buffer: bool): [1], dtype=torch.float32, device=param.device ) - def initialize_state(self, param): + def initialize_state(self, param, store_param_remainders): """Initialize optimizer states. Arguments: param (torch.nn.Parameter): One of parameters in this optimizer. + store_param_remainders (bool): Store trailing remainder bits. """ self._initialize_state(param, "exp_avg", zero_buffer=True) self._initialize_state(param, "exp_avg_sq", zero_buffer=True) if self.master_weights: - self._initialize_state(param, "master_param", zero_buffer=False) - self.set_scaled_state(param, "master_param", param.clone().detach().float()) + self._initialize_state( + param, + "master_param", + zero_buffer=False, + store_param_remainders=store_param_remainders, + ) + if not store_param_remainders: + self.set_scaled_state(param, "master_param", param.clone().detach().float()) def state_dict(self): """Override the state_dict() of pytorch. Before returning the state_dict, cast all @@ -377,7 +452,17 @@ def load_state_dict(self, state_dict): param = id_map[k] self.state[param] = {} for name in v: - self.set_scaled_state(param, name, v[name].float()) + if v[name] is None: + continue + if ( + self.store_param_remainders + and name == "master_param" + and param.dtype == torch.bfloat16 + ): + self.set_scaled_state(param, name, v[name]) + assert v[name].dtype == torch.int16 + else: + self.set_scaled_state(param, name, v[name].float()) def step(self, closure=None, grad_scaler=None): """Performs a single optimization step. @@ -444,9 +529,11 @@ def step(self, closure=None, grad_scaler=None): for p in group["params"]: state = self.state[p] + store_param_remainders = self.store_param_remainders and p.dtype == torch.bfloat16 + # State initialization if len(state) == 0: - self.initialize_state(p) + self.initialize_state(p, store_param_remainders) if self.use_decoupled_grad: p_grad = p.decoupled_grad if hasattr(p, "decoupled_grad") else None @@ -462,8 +549,12 @@ def step(self, closure=None, grad_scaler=None): unscaled_state = {} for name in ["exp_avg", "exp_avg_sq", "master_param"]: if name in state: - unscaled = self.get_unscaled_state(p, name) - unscaled_state[name] = unscaled + if name == "master_param" and store_param_remainders: + unscaled_state[name] = self.state[p][name] + assert unscaled_state[name].dtype == torch.int16 + else: + unscaled = self.get_unscaled_state(p, name) + unscaled_state[name] = unscaled if self.name_to_dtype_map[name] != torch.float32: unscaled_lists[name].append(unscaled) scaled_lists[name].append(state[name]) @@ -506,6 +597,12 @@ def step(self, closure=None, grad_scaler=None): ) if has_fp16 and has_bf16: + if self.store_param_remainders: + raise RuntimeError( + "FusedAdam doesn't support a mix of FP16/BF16 weights + Store param" + " remainder." + ) + # simple to add support for this, but not needed for now raise RuntimeError( "FusedAdam does not support a mix of float16 and bfloat16 model weights." @@ -599,7 +696,14 @@ def apply_multi_tensor_adam(adam_func, tensor_lists, inv_scale=None, out_dtype=N v_of_f16_model, p_main_of_f16_model, ] - apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) + if self.store_param_remainders and has_bf16 and not has_fp16: + # When you have BF16 params and need FP32 master params, you can reconstruct + # the FP32 master params with BF16 params + int16 remainders + apply_multi_tensor_adam( + self.multi_tensor_adam_param_remainder, tensor_lists + ) + else: + apply_multi_tensor_adam(self.multi_tensor_adam, tensor_lists) if len(p_fp8_model) > 0: tensor_lists = [ g_of_fp8_model, diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index ee428d2417..8a76ec5901 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -1,8 +1,13 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Fused SGD optimizer.""" +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional +import warnings + import torch from torch.optim.optimizer import Optimizer, required @@ -37,8 +42,8 @@ class FusedSGD(Optimizer): parameter groups lr (float): learning rate momentum (float, optional): momentum factor (default: 0) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) nesterov (bool, optional): enables Nesterov momentum (default: False) Example: @@ -74,15 +79,16 @@ class FusedSGD(Optimizer): def __init__( self, - params, - lr=required, - momentum=0, - dampening=0, - weight_decay=0, - nesterov=False, + params: Iterable[torch.nn.Parameter | dict], + lr: float | Any = required, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + *, wd_after_momentum=False, materialize_master_grads=True, - set_grad_none=False, + set_grad_none: Optional[bool] = None, # deprecated ): if lr is not required and lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") @@ -98,7 +104,7 @@ def __init__( "weight_decay": weight_decay, "nesterov": nesterov, } - if nesterov and (momentum <= 0 or dampening != 0): + if nesterov and (momentum <= 0.0 or dampening != 0.0): raise ValueError("Nesterov momentum requires a momentum and zero dampening") super().__init__(params, defaults) @@ -106,7 +112,6 @@ def __init__( self.materialize_master_grads = materialize_master_grads self.most_recent_scale = 1.0 self.scale_set_by_backward = False - self.set_grad_none = set_grad_none # Skip buffer self._dummy_overflow_buf = torch.tensor( @@ -114,14 +119,42 @@ def __init__( ) self.multi_tensor_sgd = tex.multi_tensor_sgd + # Deprecated options + self.set_grad_none = set_grad_none + if self.set_grad_none is not None: + warnings.warn( + "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "Use set_to_none kwarg in zero_grad instead.", + DeprecationWarning, + ) + def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault("nesterov", False) - def zero_grad(self): - # pylint: disable=missing-function-docstring - if self.set_grad_none: + def zero_grad(self, set_to_none: Optional[bool] = None) -> None: + """Reset parameter gradients. + + Arguments: + set_to_none (bool, optional): whether to set grads to `None` + instead of zeroing out buffers. (default: True) + + """ + + # Handle deprecated set_grad_none option + if self.set_grad_none is not None: + if set_to_none is not None and set_to_none != self.set_grad_none: + raise ValueError( + f"Called zero_grad with set_to_none={set_to_none}, " + f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + ) + set_to_none = self.set_grad_none + if set_to_none is None: + set_to_none = True + + # Reset grads + if set_to_none: for group in self.param_groups: for p in group["params"]: p.grad = None diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py index 191b57eab9..64ec0a28da 100644 --- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py +++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 540bacbf84..dd2f60deba 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -1,25 +1,27 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Linear API""" +"""MoE Permutaion API""" import warnings from typing import Tuple import torch import transformer_engine_torch as tex -from .constants import TE_DType -from .float8_tensor import Float8Tensor +import transformer_engine.pytorch.triton.permutation as triton_permutation +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.float8_tensor import Float8Tensor __all__ = [ "moe_permute", "moe_unpermute", + "moe_sort_chunks_by_index", ] -class _moe_permute(torch.autograd.Function): - """functional Permute""" +class _moe_permute_index_map(torch.autograd.Function): + """functional Permute with index router map""" workspace = None max_expanded_token_num = 0 @@ -28,7 +30,7 @@ class _moe_permute(torch.autograd.Function): def forward( ctx, inp: torch.Tensor, - indices: torch.Tensor, + index: torch.Tensor, num_out_tokens: int, max_token_num: int, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -39,49 +41,57 @@ def forward( # Device check assert inp.is_cuda, "TransformerEngine needs CUDA." - assert indices.is_cuda, "TransformerEngine needs CUDA." + assert index.is_cuda, "TransformerEngine needs CUDA." # Shape check - assert inp.size(0) == indices.size(0), "Permute not possible" + assert inp.size(0) == index.size(0), "Permute not possible" # Data type check fp8 = isinstance(inp, Float8Tensor) if fp8: + assert ( + inp._quantizer.scale.ndim == 0 + ), "Only one factor scaling per tensor (Delayed Scaling) supported by moe_permute." dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: dtype = TE_DType[inp.dtype] - if indices.dtype != torch.int32: + if index.dtype != torch.int32: warnings.warn( - f"The data type of the input `indices` of Permute is {indices.dtype}! " + f"The data type of the input `index` of Permute is {index.dtype}! " "The recommended type is torch.int32." ) - indices = indices.to(torch.int32) + index = index.to(torch.int32) - topK = indices.size(1) + topK = index.size(1) input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK - if _moe_permute.max_expanded_token_num < input_max_expanded_token_num: - _moe_permute.max_expanded_token_num = input_max_expanded_token_num - _moe_permute.workspace = [] + if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num: + _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num + _moe_permute_index_map.workspace = [] - permuted_act, row_id_map, _moe_permute.workspace = tex.moe_permute_fwd( + permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd( inp, dtype, - indices, + index, num_out_tokens, - _moe_permute.workspace, - _moe_permute.max_expanded_token_num, + _moe_permute_index_map.workspace, + _moe_permute_index_map.max_expanded_token_num, ) if fp8: permuted_act = Float8Tensor( - data=permuted_act, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + data=permuted_act, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=permuted_act.shape, + dtype=fake_dtype, ) ctx.row_id_map = row_id_map - ctx.num_tokens = indices.size(0) - ctx.topK = indices.size(1) + ctx.num_tokens = index.size(0) + ctx.topK = index.size(1) ctx.fp8 = fp8 return permuted_act, row_id_map @@ -105,6 +115,7 @@ def backward( ), "Grad of the output must be in Float8Tensor type for FP8 moe_permute." dtype = permuted_act_grad._fp8_dtype fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype permuted_act_grad = permuted_act_grad._data else: dtype = TE_DType[permuted_act_grad.dtype] @@ -116,14 +127,18 @@ def backward( ) if ctx.fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv * ctx.topK + data=act_grad, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv * ctx.topK, + shape=act_grad.shape, + dtype=fake_dtype, ) return act_grad, None, None, None -class _moe_unpermute(torch.autograd.Function): - """functional Unpermute""" +class _moe_unpermute_index_map(torch.autograd.Function): + """functional Unpermute with index router map""" @staticmethod def forward( @@ -165,6 +180,7 @@ def forward( if fp8: dtype = inp._fp8_dtype fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype inp = inp._data else: dtype = TE_DType[inp.dtype] @@ -179,7 +195,11 @@ def forward( if fp8: unpermuted_output = Float8Tensor( - data=unpermuted_output, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv + data=unpermuted_output, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=unpermuted_output.shape, + dtype=fake_dtype, ) ctx.save_for_backward(inp, row_id_map, probs) @@ -205,6 +225,7 @@ def backward( ), "Grad of the output must be in Float8Tensor type for FP8 moe_unpermute." dtype = unpermuted_act_grad._fp8_dtype fp8_scale_inv = unpermuted_act_grad._scale_inv + fake_dtype = unpermuted_act_grad.dtype unpermuted_act_grad = unpermuted_act_grad._data else: dtype = TE_DType[unpermuted_act_grad.dtype] @@ -218,28 +239,279 @@ def backward( unpermuted_act_grad, inp, dtype, row_id_map, probs ) if ctx.fp8: - act_grad = Float8Tensor(data=act_grad, fp8_dtype=dtype, fp8_scale_inv=fp8_scale_inv) + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) if not ctx.needs_input_grad[2]: prob_grad = None return act_grad, None, prob_grad +class _moe_permute_mask_map(torch.autograd.Function): + """functional Permute with mask router map""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int, + probs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + if not inp.numel(): + ctx.probs = probs + return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) + + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert routing_map.is_cuda, "TransformerEngine needs CUDA." + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." + + assert inp.size(0) == routing_map.size(0), "Permute not possible" + num_tokens, hidden_size = inp.size() + num_experts = routing_map.size(1) + assert ( + num_out_tokens is not None + ), "num_out_tokens must be provided to the fused permute function." + + row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) + + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype + inp = inp._data + output, permuted_probs = triton_permutation.permute_with_mask_map( + inp, + row_id_map, + probs, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + ) + if fp8: + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) + + ctx.save_for_backward(row_id_map) + ctx.num_experts = num_experts + ctx.num_tokens = num_tokens + ctx.hidden_size = hidden_size + return output, row_id_map, permuted_probs + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + _, + permuted_probs_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, ...]: + # pylint: disable=missing-function-docstring + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None, ctx.probs + + act_grad = None + probs_grad = None + if ctx.needs_input_grad[0]: + (row_id_map,) = ctx.saved_tensors + fp8 = isinstance(permuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype + permuted_act_grad = permuted_act_grad._data + else: + fp8_dtype = None + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( + permuted_act_grad, + row_id_map, + None, + permuted_probs_grad, + ctx.num_tokens, + ctx.num_experts, + ctx.hidden_size, + fp8_dtype, + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv * ctx.num_experts, + shape=act_grad.shape, + dtype=fake_dtype, + ) + if not ctx.needs_input_grad[3]: + probs_grad = None + return act_grad, None, None, probs_grad + + +class _moe_unpermute_mask_map(torch.autograd.Function): + """functional Unpermute with mask router map""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + row_id_map: torch.Tensor, + merging_probs: torch.Tensor, + restore_shape: torch.Size, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + if not inp.numel(): + ctx.merging_probs = merging_probs + return inp + + if restore_shape is None: + restore_shape = inp.shape + num_tokens, hidden_size = restore_shape + num_experts = row_id_map.size(0) + + with_probs = merging_probs is not None + if with_probs: + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + if not with_probs: + fp8_scale_inv = inp._scale_inv * num_experts + else: + fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype + inp = inp._data + else: + fp8_dtype = None + unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( + inp, + row_id_map, + merging_probs, + None, + num_tokens, + num_experts, + hidden_size, + fp8_dtype=fp8_dtype, + ) + if fp8: + unpermuted_output = Float8Tensor( + data=unpermuted_output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=unpermuted_output.shape, + dtype=fake_dtype, + ) + + if with_probs: + ctx.save_for_backward(inp, row_id_map, merging_probs) + else: + ctx.save_for_backward(row_id_map) + ctx.num_experts = num_experts + ctx.num_tokens = num_tokens + ctx.num_permuted_tokens = inp.size(0) + ctx.hidden_size = hidden_size + ctx.with_probs = with_probs + return unpermuted_output + + @staticmethod + def backward(ctx, unpermuted_act_grad): + # pylint: disable=missing-function-docstring + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, ctx.merging_probs, None + + act_grad = None + probs_grad = None + if ctx.needs_input_grad[0]: + if ctx.with_probs: + fwd_input, row_id_map, merging_probs = ctx.saved_tensors + else: + (row_id_map,) = ctx.saved_tensors + + fp8 = isinstance(unpermuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = unpermuted_act_grad._fp8_dtype + fp8_scale_inv = unpermuted_act_grad._scale_inv + fake_dtype = unpermuted_act_grad.dtype + unpermuted_act_grad = unpermuted_act_grad._data + else: + fp8_dtype = None + + if ctx.with_probs: + act_grad, probs_grad = ( + triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( + unpermuted_act_grad, + row_id_map, + fwd_input, + merging_probs, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, + fp8_dtype, + ) + ) + else: + act_grad, _ = triton_permutation.permute_with_mask_map( + unpermuted_act_grad, + row_id_map, + None, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, + ) + + if fp8: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) + + if not ctx.needs_input_grad[2]: + probs_grad = None + return act_grad, None, probs_grad, None + + def moe_permute( inp: torch.Tensor, - indices: torch.Tensor, + routing_map: torch.Tensor, num_out_tokens: int = -1, max_token_num: int = -1, + map_type: str = "mask", ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Permute the tokens based on the indices. Token with the same index will be grouped together. + Permute the tokens based on the routing_map. Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. Parameters ---------- inp: torch.Tensor Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. - indices: torch.Tensor - The token to expert indices tensor of shape [num_tokens, topK] and dtype 'int32'. + routing_map: torch.Tensor + The token to expert mapping tensor. + If map_type is 'mask', routing_map is of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + If map_type is 'index', routing_map is of shape [num_tokens, topK] and dtype 'int32'. + The values in it are the routed expert indices. num_out_tokens: int, default = -1 The effective output token count, representing the number of tokens not dropped. By default, set to '-1', meaning no tokens are dropped. @@ -247,13 +519,58 @@ def moe_permute( The maximum number of tokens, used for workspace allocation. By default, set to '-1', meaning the calculation of the size of workspace is automatically taken over by the operator. + map_type: str, default = 'mask' + Type of the routing map tensor. + Options are: 'mask', 'index'. + Refer to `routing_map` for more details. + """ + if map_type == "index": + return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) + if map_type == "mask": + output, row_id_map, _ = _moe_permute_mask_map.apply(inp, routing_map, num_out_tokens, None) + return output, row_id_map + raise ValueError("map_type should be one of 'mask' or 'index'") + + +def moe_permute_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int = -1, +) -> Tuple[torch.Tensor, torch.Tensor]: """ - return _moe_permute.apply(inp, indices, num_out_tokens, max_token_num) + Permute the tokens and probs based on the routing_map. + Token with the same index will be grouped together. + Tokens with the same designated expert will be grouped together. + The routing_map indicates which experts were selected by each token. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens, num_experts]. It will be permuted with the tokens + according to the routing_map. + routing_map: torch.Tensor + The token to expert mapping tensor of shape [num_tokens, num_experts] and dtype 'int32'. + The values in it: 1 means the token is routed to this expert and 0 means not. + num_out_tokens: int, default = -1 + The effective output token count, representing the number of tokens not dropped. + By default, set to '-1', meaning no tokens are dropped. + """ + output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + inp, routing_map, num_out_tokens, probs + ) + return output, permuted_probs, row_id_map def moe_unpermute( inp: torch.Tensor, row_id_map: torch.Tensor, + merging_probs: torch.Tensor = None, + restore_shape: torch.Tensor = None, + map_type: str = "mask", probs: torch.Tensor = None, ) -> torch.Tensor: """ @@ -267,9 +584,172 @@ def moe_unpermute( row_id_map: torch.Tensor The tensor of a mapping table for sorted indices used to unpermute the tokens, which is the second output tensor of `Permute`. - probs: torch.Tensor + merging_probs: torch.Tensor, default = None The tensor of probabilities corresponding to the permuted tokens. If provided, the unpermuted tokens will be merged with their respective probabilities. By default, set to an empty tensor, which means that the tokens are directly merged by accumulation. + restore_shape: torch.Tensor + The output shape after the unpermute operation. + map_type: str, default = 'mask' + Type of the routing map tensor. Should be the same as the value passed to moe_permute. + Options are: 'mask', 'index'. + probs: torch.Tensor, default = None + Renamed to merging_probs. Keep for backward compatibility. + """ + if probs is not None: + if merging_probs is not None: + raise ValueError( + "Both merging_probs and probs kwarg are provided. probs is deprecated." + ) + warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.") + merging_probs = probs + if map_type == "index": + return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) + if map_type == "mask": + return _moe_unpermute_mask_map.apply(inp, row_id_map, merging_probs, restore_shape) + raise ValueError("map_type should be one of 'mask' or 'index'") + + +class _moe_chunk_sort(torch.autograd.Function): + """functional MoE chunk permute""" + + @staticmethod + def forward( + ctx, + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, + probs: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # pylint: disable=missing-function-docstring + if not inp.numel(): + return inp, probs + + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert split_sizes.is_cuda, "TransformerEngine needs CUDA." + assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." + + num_tokens, hidden_size = inp.shape + num_splits = split_sizes.size(0) + assert num_splits == sorted_idxs.size(0) + + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype + inp = inp._data + output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx( + inp, + split_sizes, + sorted_idxs, + probs, + num_tokens, + hidden_size, + num_splits, + ) + if fp8: + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) + + ctx.save_for_backward(row_id_map) + ctx.num_tokens = num_tokens + ctx.hidden_size = hidden_size + return output, permuted_probs + + @staticmethod + def backward( + ctx, + permuted_act_grad: torch.Tensor, + permuted_probs_grad: torch.Tensor, + ) -> Tuple[torch.Tensor, ...]: + # pylint: disable=missing-function-docstring + if not permuted_act_grad.numel(): + return permuted_act_grad, None, None, permuted_probs_grad + + act_grad = None + probs_grad = None + if ctx.needs_input_grad[0]: + (row_id_map,) = ctx.saved_tensors + fp8 = isinstance(permuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype + permuted_act_grad = permuted_act_grad._data + act_grad, probs_grad = triton_permutation.sort_chunks_by_map( + permuted_act_grad, + row_id_map, + permuted_probs_grad, + ctx.num_tokens, + ctx.hidden_size, + ) + if fp8: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) + if not ctx.needs_input_grad[3]: + probs_grad = None + return act_grad, None, None, probs_grad + + +def moe_sort_chunks_by_index( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_index: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split and sort the input tensor based on the split_sizes and sorted indices. + The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted + according to the sorted_indices. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + split_sizes: torch.Tensor + Chunk sizes of the inp tensor along the 0-th dimension. + sorted_indices: torch.Tensor + Chunk indices used to permute the chunks. + """ + output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None) + return output + + +def moe_sort_chunks_by_index_with_probs( + inp: torch.Tensor, + probs: torch.Tensor, + split_sizes: torch.Tensor, + sorted_index: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Split and sort the input tensor and probs based on the split_sizes and sorted indices. + The inp tensor is splitted along dim-0 according to the split_sizes list and then sorted + according to the sorted_indices. + + Parameters + ---------- + inp: torch.Tensor + Input tensor of shape `[num_tokens, hidden_size]`, on which permutation will be applied. + probs: torch.Tensor + The tensor of probabilities corresponding to the permuted tokens and is + of shape [num_tokens]. It will be permuted with the tokens according to + the split_sizes and sorted_indices. + split_sizes: torch.Tensor + Chunk sizes of the inp tensor along the 0-th dimension. + sorted_indices: torch.Tensor + Chunk indices used to permute the chunks. """ - return _moe_unpermute.apply(inp, row_id_map, probs) + output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs) + return output, permuted_probs diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index c527ca83ef..4499c28826 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,10 +12,9 @@ from pathlib import Path import setuptools -from torch.utils.cpp_extension import BuildExtension try: - import torch # pylint: disable=unused-import + from torch.utils.cpp_extension import BuildExtension except ImportError as e: raise RuntimeError("This package needs Torch to build.") from e @@ -36,7 +35,7 @@ os.environ["NVTE_PROJECT_BUILDING"] = "1" -CMakeBuildExtension = get_build_ext(BuildExtension) +CMakeBuildExtension = get_build_ext(BuildExtension, True) if __name__ == "__main__": @@ -57,7 +56,7 @@ ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, install_requires=["torch"], - tests_require=["numpy", "onnxruntime", "torchvision"], + tests_require=["numpy", "torchvision"], ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): shutil.rmtree(common_headers_dir) diff --git a/transformer_engine/pytorch/softmax.py b/transformer_engine/pytorch/softmax.py index a632851a76..25362e1d58 100644 --- a/transformer_engine/pytorch/softmax.py +++ b/transformer_engine/pytorch/softmax.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -7,11 +7,7 @@ from typing import Callable, Tuple, Union, Optional import torch from torch import nn -import torch._C._onnx as _C_onnx -from torch.onnx import _type_utils import transformer_engine_torch as tex -from transformer_engine.pytorch.export import is_in_onnx_export_mode -from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32 THREADS_PER_WARP = 32 @@ -32,35 +28,6 @@ def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor: return _default_causal_mask[matrix_identifiers] -def _get_onnx_export_causal_mask( - seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor -) -> torch.Tensor: - """Return the causal upper triangular mask for softmax input, for ONNX export. - - ONNX does not support dynamic control-flow and requires non-square masks when - using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1). - - Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct - shape for GPT context and generation phases. - In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in - the generation phase the mask is rectangular with shape (1, seq_k). - """ - assert len(onnx_causal_mask.size()) == 2 - assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1) - assert onnx_causal_mask.size(0) >= (seq_k - seq_q) >= 0 - derived_mask = onnx_causal_mask[seq_k - seq_q : seq_k, :seq_k] - return derived_mask - - -def fp32_compute(onnx_symbolic_fn): - """A decorator that wraps an ONNX symoblic function with FP32 compute operators.""" - - def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs): - return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs) - - return wrapper - - class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ Fused operation which performs following three operations in sequence @@ -88,34 +55,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] return input_grads, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledUpperTriangMaskedSoftmax symbolic method""" - - def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 - ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) - k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - mask = g.op("Trilu", ones, k, upper_i=1) - mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) - return mask - - # Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward - mask = triangular_mask() - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function): """ @@ -143,40 +82,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] return input_grads, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledAlignedCausalMaskedSoftmax symbolic method""" - - def triangular_mask(): - dtype = _type_utils.JitScalarType.INT64 - ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype) - k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - - # rectangular causal mask aligned to the bottom right corner of Attention matrix - rows = inputs.size(dim=-2) - cols = inputs.size(dim=-1) - diag_shift = cols - rows + 1 - - mask = g.op("Trilu", ones, k, upper_i=diag_shift) - mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL) - return mask - - # Captures the logic of function scaled_aligned_masked_softmax_warp_forward - mask = triangular_mask() - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledMaskedSoftmax(torch.autograd.Function): """ @@ -203,30 +108,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None - @staticmethod - @fp32_compute - def symbolic( - g: torch.Graph, inputs: torch._C.Value, mask: torch._C.Value, scale: float - ) -> torch._C.Value: - """ScaledMaskedSoftmax symbolic method""" - # Captures the logic of function scaled_masked_softmax_warp_forward. - # output = softmax(mask(input*scale) - # Computed as: - # masked_scaled = (1 - mask)*(input*scale) - # softmax_mask = mask * -10000 - # output = softmax(masked_scaled + softmax_mask) - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)) - inv_mask = g.op("Sub", one, mask) - # Note: type is hard coded because softmax uses FP16 or BF16 - neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16)) - softmax_mask = g.op("Mul", mask, neg_tenK) - masked_scaled = g.op("Mul", inv_mask, scaled) - masked = g.op("Add", masked_scaled, softmax_mask) - out = g.op("Softmax", masked) - return out - class ScaledSoftmax(torch.autograd.Function): """ @@ -252,15 +133,6 @@ def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None] input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0]) return input_grads, None, None - @staticmethod - @fp32_compute - def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value: - """ScaledSoftmax symbolic method""" - scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16)) - scaled = g.op("Mul", inputs, scale_input) - out = g.op("Softmax", scaled) - return out - class FusedScaleMaskSoftmax(nn.Module): """ @@ -281,18 +153,6 @@ def __init__( self.mask_func = mask_func self.softmax_in_fp32 = softmax_in_fp32 - # Users exporting to ONNX can optimize the attention mask for GPT text generation. - self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1")) - if self.kvcache_max_seq > 0: - self.register_buffer( - "onnx_causal_mask", - torch.triu( - torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"), - diagonal=1, - ).bool(), - persistent=False, - ) - def forward( self, inp: torch.Tensor, @@ -310,7 +170,7 @@ def forward( assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled" - if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode(): + if self.is_kernel_available(mask, *inp.size()): return self.forward_fused_softmax(inp, mask, scale) return self.forward_torch_softmax(inp, mask, scale) @@ -363,8 +223,9 @@ def forward_fused_softmax( """ scale = 1.0 if scale is None else scale - if self.attn_mask_type in ["causal", "causal_bottom_right"]: - return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) + # Disable for now until unalignment bug is fixed. + # if self.attn_mask_type in ["causal", "causal_bottom_right"]: + # return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale) # input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk) if mask is not None and self.attn_mask_type != "no_mask": @@ -383,13 +244,7 @@ def forward_torch_softmax( if self.attn_mask_type in ["causal", "causal_bottom_right"]: seq_len_q, seq_len_k = inp.size(2), inp.size(3) - if is_in_onnx_export_mode() and self.kvcache_max_seq > 0: - assert self.kvcache_max_seq >= seq_len_k - causal_mask = _get_onnx_export_causal_mask( - seq_len_q, seq_len_k, self.onnx_causal_mask - ) - else: - causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) + causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k) if mask is None: mask = causal_mask else: diff --git a/transformer_engine/pytorch/te_onnx_extensions.py b/transformer_engine/pytorch/te_onnx_extensions.py deleted file mode 100755 index 9b4b2df145..0000000000 --- a/transformer_engine/pytorch/te_onnx_extensions.py +++ /dev/null @@ -1,519 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -""" -ONNX symbolic functions for Transformer Engine - -Warnings of the type pasted below are a known Pytorch issue -(https://github.com/pytorch/pytorch/issues/81693): - -tests/test_onnx_export.py::test_export_cast_ops[112] - /opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py:649: - UserWarning: The shape inference of trt::TRT_FP8DequantizeLinear type is missing, - so it may result in wrong shape inference for the exported graph. - Please consider adding it in symbolic function. (Triggered internally at - /opt/pytorch/pytorch/torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1880.) - _C._jit_pass_onnx_graph_shape_type_inference( - - -Scale tensors are treated as lists ("fs") instead of tensors ("v") because we need to access -specific entries using the index passes as `fp8_tensor`. If you fail to do this you will get -the following error when accessing a sepcific scale element (e.g. `scale_inv[fp8_tensor]`): - TypeError: 'torch._C.Value' object is not subscriptable -""" - -import torch -from torch.onnx import symbolic_helper, register_custom_op_symbolic, _type_utils -import torch._C._onnx as _C_onnx - -# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics -from torch.onnx._internal import jit_utils - -import transformer_engine_torch as tex - - -# This file registers custom op symbolic ONNX functions and does not export any symbols. -__all__ = [] - - -# Custom ops spec version -VER = 1 -UNSPECIFIED_TYPE = -1 - - -def make_op_name(op_name: str) -> str: - """custom op name""" - return "trt::" + op_name - - -def get_TensorProtoDataType(t): - """Return the _C_onnx.TensorProtoDataType of the input tensor""" - try: - return { - "Float": _C_onnx.TensorProtoDataType.FLOAT, - "Half": _C_onnx.TensorProtoDataType.FLOAT16, - "BFloat16": _C_onnx.TensorProtoDataType.BFLOAT16, - }[t.type().scalarType()] - except KeyError as e: - raise TypeError(f"Onnx export for dtype {t.type().scalarType()} not supported.") from e - - -def is_dtype_fp32(t): - """Check fp32 dtype""" - return t.type().scalarType() == "Float" - - -def is_dtype_fp16(t): - """Check fp16 dtype""" - return t.type().scalarType() == "Half" - - -def is_dtype_bf16(t): - """Check bf16 dtype""" - return t.type().scalarType() == "BFloat16" - - -def quantize(g, inputs, scale, fp8_tensor): - """Helper Function for Quantization""" - output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - - # Q inputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the input if needed. - if not is_dtype_fp32(inputs): - inputs = g.op("Cast", inputs, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - scale = g.op("Constant", value_t=torch.tensor(1 / scale[fp8_tensor])) - q_op = g.op(make_op_name("TRT_FP8QuantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.uint8).with_sizes(output_shape) - ) - return q_op - - -def dequantize(g, inputs, scale_inv, fp8_tensor, otype): - """Helper Function for Dequantization""" - output_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - - scale = g.op("Constant", value_t=torch.tensor(scale_inv[fp8_tensor])) - out = g.op(make_op_name("TRT_FP8DequantizeLinear"), inputs, scale).setType( - inputs.type().with_dtype(torch.float32).with_sizes(output_shape) - ) - - # DQ outputs are currently constrained to FP32 due to a similar limitation in ORT - # custom ops, so cast the output if needed. - if otype == int(tex.DType.kFloat16): - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.FLOAT16) - elif otype == int(tex.DType.kBFloat16): - out = g.op("Cast", out, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return out - - -def compute_in_fp32(g, inp, subgraph, *args, **kwargs): - """Wrap subgraph with casts to/from FP32 so that its precision is FP32. - - If `inp` data type is not FP32, add a cast of `inp` to FP32 and feed that into `subgraph`; - then cast subgraphs's output back to `inp` data type. - """ - inp_dtype = get_TensorProtoDataType(inp) - is_fp32 = inp_dtype == _type_utils.JitScalarType.FLOAT - if not is_fp32: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - sg_out = subgraph(g, inp, *args, **kwargs) - if not is_fp32: - sg_out = g.op("Cast", sg_out, to_i=inp_dtype) - return sg_out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_cast_to_fp8(g, inputs, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for cast_to_fp8""" - # pylint: disable=unused-argument - return quantize(g, inputs, scale, fp8_tensor) - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "v", "i", "i") -def onnx_cast_to_fp8_noalloc(g, inputs, scale, output, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for cast_to_fp8_noalloc""" - # pylint: disable=unused-argument - return quantize(g, inputs, scale, fp8_tensor) - - -@symbolic_helper.parse_args("v", "fs", "i", "i", "i") -def onnx_cast_from_fp8(g, inputs, scale_inv, fp8_tensor, itype, otype): - """ONNX graph for cast_from_fp8""" - # pylint: disable=unused-argument - return dequantize(g, inputs, scale_inv, fp8_tensor, otype) - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_gelu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_gelu""" - # pylint: disable=unused-argument - # TE computes GELU using float32 precision so wrap the GELU subgraph with - # conversion to/from float32. - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = torch.onnx.symbolic_opset9.gelu(g, inp, "tanh") - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_relu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_relu""" - # pylint: disable=unused-argument - out = torch.onnx.symbolic_opset9.relu(g, inp) - if scale: - out = quantize(g, out, scale, fp8_tensor) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_swiglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for swiglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - out = g.op("Mul", g.op("Sigmoid", first), second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_swiglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_swiglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_swiglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_reglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for reglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - out = g.op("Mul", g.op("Relu", first), second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_reglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_reglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_reglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "i") -def onnx_geglu(g: jit_utils.GraphContext, inp, dim): - """ONNX graph for geglu""" - - # Check dimensions - dim_size = symbolic_helper._get_tensor_dim_size(inp, dim) - if dim_size is not None: - assert dim_size % 2 == 0 - - # Perform compute in FP32 - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - first, second = g.op("Split", inp, axis_i=dim, outputs=2) - first = torch.onnx.symbolic_opset9.gelu(g, first, "tanh") - out = g.op("Mul", first, second) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args("v", "fs", "v", "v", "i", "i") -def onnx_fp8_geglu(g, inp, scale, amax, scale_inv, fp8_tensor, otype): - """ONNX graph for fp8_geglu""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_geglu(g, inp, 1) - if scale: - out = quantize(g, out, scale, fp8_tensor) - elif dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -@symbolic_helper.parse_args( - "v", - "fs", - "i", - "i", - "i", - "v", - "fs", - "i", - "i", - "i", - "v", - "fs", - "i", - "v", - "v", - "i", - "v", - "i", - "v", - "i", - "i", - "i", -) -def onnx_te_gemm( - g, - weight, - weight_scale_inverse, - weight_fp8_tensor, - weight_type, - trans_weight, - inputs, - input_scale_inverse, - input_fp8_tensor, - input_type, - trans_input, - out, - out_scale, - out_type, - out_amax, - bias, - bias_type, - pre_gelu_out, - grad, - workspace, - workspaceSize, - accumulate, - use_split_accumulator, -): - """ONNX graph for te_gemm""" - # pylint: disable=unused-argument - is_fp16 = is_dtype_fp16(inputs) - is_bf16 = is_dtype_bf16(inputs) - if input_type == int(tex.DType.kFloat8E4M3): - inputs = dequantize(g, inputs, input_scale_inverse, input_fp8_tensor, out_type) - - if weight_type == int(tex.DType.kFloat8E4M3): - weight = dequantize(g, weight, weight_scale_inverse, weight_fp8_tensor, out_type) - - empty_tensor_size = [0] - bias_empty = torch.onnx.symbolic_helper._get_tensor_sizes(bias) == empty_tensor_size - pre_gelu_out_empty = ( - torch.onnx.symbolic_helper._get_tensor_sizes(pre_gelu_out) == empty_tensor_size - ) - - if not bias_empty: - output = g.op("Gemm", inputs, weight, bias, transA_i=trans_input, transB_i=trans_weight) - else: - output = g.op("Gemm", inputs, weight, transA_i=trans_input, transB_i=trans_weight) - if not bias_empty: - if not pre_gelu_out_empty: - # TE computes GELU using float32 precision so wrap the GELU subgraph with - # conversion to/from float32. - output = compute_in_fp32(g, output, torch.onnx.symbolic_opset9.gelu, "tanh") - else: - if is_fp16: - output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.FLOAT16) - elif is_bf16: - output = g.op("Cast", output, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return output - - -def _ones_like(g, inp, dtype): - """Returns a tensor filled with the scalar value 1, with the same size as input and - with dtype data-type""" - shape = g.op("Shape", inp) - # WAR ONNX spec: ConstantOfShape accepts all data types except for BF16. To WAR - # create a ConstantOfShape with type FP32 and then add a Cast to BF16. - is_bf16 = dtype == torch.bfloat16 - one = g.op( - "ConstantOfShape", - shape, - value_t=torch.tensor([1], dtype=torch.float32 if is_bf16 else dtype), - ) - if is_bf16: - one = g.op("Cast", one, to_i=_C_onnx.TensorProtoDataType.BFLOAT16) - return one - - -@symbolic_helper.parse_args("v", "v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") -def onnx_layernorm_fwd_fp8( - g, - inputs, - weight, - bias, - eps, - scale, - amax, - scale_inv, - fp8_tensor, - otype, - sm_margin, - zero_centered_gamma, -): - """ONNX graph for layernorm_fwd_fp8""" - # pylint: disable=unused-argument - inp_dtype = get_TensorProtoDataType(inputs) - - if inp_dtype != get_TensorProtoDataType(weight): - weight = g.op("Cast", weight, to_i=inp_dtype) - if inp_dtype != get_TensorProtoDataType(bias): - bias = g.op("Cast", bias, to_i=inp_dtype) - - ln = onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma) - fp8_ln = quantize(g, ln, scale, fp8_tensor) - return fp8_ln - - -@symbolic_helper.parse_args("v", "v", "v", "f", "i", "b") -def onnx_layernorm_fwd(g, inputs, weight, bias, eps, sm_margin, zero_centered_gamma): - """ONNX graph for layernorm_fwd""" - # pylint: disable=unused-argument - - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inputs) - if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inputs) - assert ndim is not None - normalized_shape = list(range(0, ndim)) - # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 - normalized_shape = normalized_shape[1:] - - if zero_centered_gamma: - inputs_dtype = inputs.type().dtype() - one = _ones_like(g, weight, inputs_dtype) - weight = g.op("Add", weight, one) - - axis = -len(normalized_shape) - ln = g.op( - "LayerNormalization", - inputs, - weight, - bias, - epsilon_f=eps, - axis_i=axis, - # This sets the LN compute precision - use FP32 always as does TE. - stash_type_i=_C_onnx.TensorProtoDataType.FLOAT, - ) - return ln - - -@symbolic_helper.parse_args("v", "v", "f", "fs", "v", "v", "i", "i", "i", "b") -def onnx_rmsnorm_fwd_fp8( - g, - inp, - weight, - eps, - scale, - amax, - scale_inv, - fp8_tensor, - otype, - sm_margin, - zero_centered_gamma, -): - """ONNX graph for rmsnorm_fwd_fp8""" - # pylint: disable=unused-argument - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - out = onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma) - out = quantize(g, out, scale, fp8_tensor) - return out - - -@symbolic_helper.parse_args("v", "v", "f", "i", "b") -def onnx_rmsnorm_fwd(g, inp, weight, eps, sm_margin, zero_centered_gamma): - """ONNX graph for rmsnorm_fwd""" - # pylint: disable=unused-argument - - # Check dimensions - normalized_shape = torch.onnx.symbolic_helper._get_tensor_sizes(inp) - if normalized_shape is None: - ndim = torch.onnx.symbolic_helper._get_tensor_rank(inp) - assert ndim is not None - normalized_shape = list(range(0, ndim)) - # Normalization axis = 0, so normalized_shape uses all dims except dim = 0 - normalized_shape = normalized_shape[1:] - axis = -len(normalized_shape) - - # Cast input tensors to FP32 if needed - dtype = get_TensorProtoDataType(inp) - if dtype != _type_utils.JitScalarType.FLOAT: - inp = g.op("Cast", inp, to_i=_C_onnx.TensorProtoDataType.FLOAT) - if get_TensorProtoDataType(weight) != _type_utils.JitScalarType.FLOAT: - weight = g.op("Cast", weight, to_i=_C_onnx.TensorProtoDataType.FLOAT) - - # Adjust zero-centered weights - if zero_centered_gamma: - one = _ones_like(g, weight, torch.float32) - weight = g.op("Add", weight, one) - - # Perform compute in FP32 - sum_square = g.op("ReduceSumSquare", inp, axes_i=[axis]) - shape = g.op("Shape", inp, start_i=-1) - shape_f = g.op("Cast", shape, to_i=_C_onnx.TensorProtoDataType.FLOAT) - mean_squared = g.op("Div", sum_square, shape_f) - eps_tensor = g.op("ConstantOfShape", shape, value_t=torch.tensor([eps], dtype=torch.float32)) - rms_squared = g.op("Add", mean_squared, eps_tensor) - rms_eps = g.op("Sqrt", rms_squared) - normalized_input = g.op("Div", inp, rms_eps) - out = g.op("Mul", weight, normalized_input) - if dtype != _type_utils.JitScalarType.FLOAT: - out = g.op("Cast", out, to_i=dtype) - return out - - -register_custom_op_symbolic("tex_ts::cast_to_fp8_ts", onnx_cast_to_fp8, VER) -register_custom_op_symbolic("tex_ts::cast_to_fp8_noalloc_ts", onnx_cast_to_fp8_noalloc, VER) -register_custom_op_symbolic("tex_ts::cast_from_fp8_ts", onnx_cast_from_fp8, VER) -register_custom_op_symbolic("tex_ts::gelu_ts", onnx_fp8_gelu, VER) -register_custom_op_symbolic("tex_ts::relu_ts", onnx_fp8_relu, VER) -register_custom_op_symbolic("tex_ts::reglu_ts", onnx_fp8_reglu, VER) -register_custom_op_symbolic("tex_ts::geglu_ts", onnx_fp8_geglu, VER) -register_custom_op_symbolic("tex_ts::swiglu_ts", onnx_fp8_swiglu, VER) -register_custom_op_symbolic("tex_ts::te_gemm_ts", onnx_te_gemm, VER) -register_custom_op_symbolic("tex_ts::layernorm_fwd_fp8_inf_ts", onnx_layernorm_fwd_fp8, VER) -register_custom_op_symbolic("tex_ts::layernorm_fwd_inf_ts", onnx_layernorm_fwd, VER) -register_custom_op_symbolic("tex_ts::rmsnorm_fwd_fp8_inf_ts", onnx_rmsnorm_fwd_fp8, VER) -register_custom_op_symbolic("tex_ts::rmsnorm_fwd_inf_ts", onnx_rmsnorm_fwd, VER) diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 16b7f8b623..610ec2a777 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,10 +6,12 @@ import torch -from .float8_tensor import Float8Tensor -from .quantized_tensor import QuantizedTensor +from .quantized_tensor import QuantizedTensor, Quantizer -__all__ = ["Float8Tensor", "QuantizedTensor"] +__all__ = [ + "QuantizedTensor", + "Quantizer", +] def _make_module_cast_func(dtype): @@ -22,14 +24,8 @@ def _make_module_cast_func(dtype): def tensor_cast_func(tensor: torch.Tensor) -> torch.Tensor: """Cast tensor dtype""" - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data, - fp8_attrs=tensor._fp8_attrs, - dtype=dtype, - requires_grad=tensor.requires_grad, - ) + if isinstance(tensor, QuantizedTensor): + return tensor.__class__.make_like(tensor, dtype=dtype) if tensor.is_floating_point(): return getattr(tensor, cast_func_name)() return tensor diff --git a/transformer_engine/pytorch/tensor/_internal/__init__.py b/transformer_engine/pytorch/tensor/_internal/__init__.py new file mode 100644 index 0000000000..e13014bf75 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Internal data structures for quantized tensors.""" diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py new file mode 100644 index 0000000000..bf518cae22 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8Tensor""" + +from __future__ import annotations +import math +from typing import Any, Dict, Optional, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromFloat8Func(torch.autograd.Function): + """Cast from FP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: Float8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._data is not None: + # Cast from FP8 + return tex.dequantize(tensor, dtype) + + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class Float8TensorBase: + """Mixin class that holds data attributes of Float8Tensor. + + Float8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _scale_inv: torch.Tensor + + # FP8 transpose cache + _transpose: Optional[torch.Tensor] + _transpose_invalid: bool + + def __new__( + cls, + *args, + data: Optional[torch.Tensor], + fp8_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + data_transpose: Optional[torch.Tensor] = None, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + if cls is Float8TensorBase: + instance = object.__new__(cls) + else: + instance = super().__new__(cls, *args, **kwargs) + instance._data = data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._scale_inv = fp8_scale_inv + instance._transpose = data_transpose + instance._transpose_invalid = instance._transpose is None + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "data": self._data, + "fp8_scale_inv": self._scale_inv, + "fp8_dtype": self._fp8_dtype, + "data_transpose": self._transpose, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward""" + tensors = [self._data, self._transpose] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list""" + self._data = tensors[0] + self._transpose = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._data, self._transpose + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromFloat8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + if self._data is not None: + return self._data.size(*args, **kwargs) + size = self._transpose.size(*args, **kwargs) + return torch.Size([size[-1], math.prod(size[:-1])]) + + def __repr__(self): + return ( + "Float8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"scale_inv={self._scale_inv.item()}, " + f"data={self.dequantize()}" + ")" + ) diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py new file mode 100644 index 0000000000..e6dcf1d48f --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for MXFP8Tensor""" + +from __future__ import annotations +from typing import Optional, Dict, Any, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype + +from ..quantized_tensor import Quantizer + + +class _FromMXFP8Func(torch.autograd.Function): + """Cast from MXFP8 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: MXFP8TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + dtype = torch_to_transformer_engine_dtype[dtype] + + # Make sure FP8 data is in expected format + if tensor._rowwise_data is not None: + return tex.dequantize(tensor, dtype) + raise NotImplementedError("Casting back from the transpose not implemented yet!") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class MXFP8TensorBase: + """Mixin class that holds data attributes of MXFP8Tensor. + + MXFP8Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _fp8_dtype: TE_DType + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + quantizer: Optional[Quantizer] = None, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + """Prepare the tensor base for saving for backward""" + tensors = [self._rowwise_data, self._columnwise_data] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromMXFP8Func.forward(None, self, dtype) + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + if self._rowwise_data is not None: + return self._rowwise_data.size(*args, **kwargs) + return self._columnwise_data.size(*args, **kwargs) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "MXFP8TensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"rowwise_scaled_data={data_rowwise}" + f"rowwise_scale_inv={self._rowwise_scale_inv}, " + ")" + ) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 36136292df..e45010bb00 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1,313 +1,334 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tensor class with FP8 data""" from __future__ import annotations -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple, Iterable import warnings import torch import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..constants import TE_DType as torch_to_transformer_engine_dtype -from ..cpp_extensions import ( - cast_from_fp8, - cast_to_fp8, - fp8_cast_transpose_fused, -) -from ..fp8 import FP8GlobalStateManager -from ..utils import devices_match -from .quantized_tensor import QuantizedTensor +from ..utils import devices_match, non_tn_fp8_gemm_supported +from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..constants import dist_group_type aten = torch.ops.aten -updated_fp8_params = {} - -def _make_fp8_attr_property_funcs(name: str) -> Any: - """Make accessors for an FP8 attribute - - We store FP8 attributes in a dictionary so we can share them - between tensors with the same data, e.g. detached tensors. For - convenience, we also expose them as property attributes. This - function creates the accessors for property attributes. - - Parameters - ---------- - name: str - Key in dictionary of FP8 attributes +_ops_to_preserve_subclass_in_fsdp2 = { + torch.ops.aten.empty_like.default, + torch.ops.aten.new_zeros.default, + torch.ops.aten.slice.Tensor, + torch.ops.aten.copy_.default, + torch.ops.aten.view.default, + torch.ops.aten.as_strided.default, + torch.ops.aten._to_copy.default, + torch.ops.aten._pin_memory.default, + torch.ops.aten.split.Tensor, + torch.ops.aten.clone.default, +} + + +class Float8Quantizer(Quantizer): + """Builder class for FP8 tensors with per-tensor delayed scaling + + High-precision tensors (e.g. in FP32 or BF16) are quantized by + multiplying with a scaling factor and casting to FP8. The max-abs + value ("amax") in the tensor is also computed, which can be used + for updating the scaling factor (handled externally by + DelayedScalingRecipeState and FP8GlobalStateManager). """ - def get_func(self) -> Any: - return self._fp8_attrs[name] - - def set_func(self, value: Any) -> None: - self._fp8_attrs[name] = value - - def del_func(self) -> None: - del self._fp8_attrs[name] - - return {"fget": get_func, "fset": set_func, "fdel": del_func} - - -class _FromFloat8Func(torch.autograd.Function): - """Cast from FP8 to other dtype""" - - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: Float8Tensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - return tensor.dequantize(dtype=dtype) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None - - -def post_optimizer_step_fwd_amax_reduction(param: Float8Tensor) -> None: - """Amax scale and update when there is at least 1 trainable FP8 parameter.""" - param_id = id(param._data) + """Scaling factor to multiply when quantizing to FP8""" + scale: torch.Tensor + """Max-abs value from last FP8 cast""" + amax: torch.Tensor + """FP8 datatype""" + dtype: TE_DType - if param_id not in FP8GlobalStateManager.fp8_param_to_autocast: - return - - autocast_key = FP8GlobalStateManager.fp8_param_to_autocast[param_id] + def __init__( + self, + scale: torch.Tensor, + amax: torch.Tensor, + fp8_dtype: TE_DType, + *, + rowwise: bool = True, + columnwise: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.scale = scale + self.amax = amax + self.dtype = fp8_dtype - if autocast_key not in FP8GlobalStateManager.autocast_to_fp8_params: - return + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + if not isinstance(dst, Float8Tensor): + raise ValueError("Float8Quantizer can only update Float8Tensor") - if autocast_key in updated_fp8_params: - updated_fp8_params[autocast_key].add(param_id) - else: - updated_fp8_params[autocast_key] = {param_id} + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() - current_fp8_params_set = FP8GlobalStateManager.autocast_to_fp8_params[autocast_key] - # All FP8 trainable parameters have been updated. - if updated_fp8_params[autocast_key] == current_fp8_params_set: - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True, fp8_weights=True) - del updated_fp8_params[autocast_key] + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + # Update FP8 dtype + dst._fp8_dtype = self.dtype -class _ToFloat8Func(torch.autograd.Function): - """Cast to FP8 from other dtype""" + return dst - @staticmethod - def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: torch.Tensor, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - data: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, - data_transpose: Optional[torch.Tensor] = None, + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, ) -> Float8Tensor: - # pylint: disable=missing-function-docstring - # Tensor attributes - dtype = tensor.dtype - if dtype not in (torch.float32, torch.bfloat16, torch.float16): - dtype = torch.float32 - device = tensor.device - if device.type != "cuda": + # Canonicalize tensor attributes + if device is None: device = torch.device("cuda") - # FP8 data buffer - if data is None: - data = torch.empty(tensor.size(), dtype=torch.uint8, device=device) - - # Check scale - if scale is None and fp8_meta is None: - scale = torch.full([1], 1, dtype=torch.float32, device=device) - if scale is not None: - scale = scale.to(device=device, dtype=torch.float32) - - # Check scale-inverse - if scale_inv is None: - scale_inv = torch.empty([1], dtype=torch.float32, device=device) - elif not devices_match(scale_inv.device, device) or scale_inv.dtype != dtype: - scale_inv = scale_inv.to(device=device, dtype=torch.float32) + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) - # Transpose cache - if data_transpose is None and with_transpose_cache: + # Allocate FP8 data transpose if needed + data_transpose = None + if self.columnwise_usage: + inner_dim = data.size(-1) data_transpose = torch.empty( - (data.size(-1), data.numel() // data.size(-1)), + inner_dim, + data.numel() // inner_dim, dtype=torch.uint8, - device=tensor.device, + device=device, ) # Construct FP8 tensor - out = Float8Tensor( - data=data, - fp8_meta=fp8_meta, - fp8_meta_forward=fp8_meta_forward, - fp8_meta_index=fp8_meta_index, - fp8_dtype=fp8_dtype, - fp8_scale_inv=scale_inv, + return Float8Tensor( + shape=shape, dtype=dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, data_transpose=data_transpose, + quantizer=self, ) - # Cast to FP8 tensor - out.quantize_(tensor, scale=scale, amax=amax) - - return out - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None, None, None, None, None, None, None, None, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. + def calibrate(self, tensor: torch.Tensor) -> None: + amin, amax = tensor.aminmax() + self.amax.copy_(torch.max(-amin, amax)) - """ - - @staticmethod - def forward( - ctx, - tensor: Float8Tensor, - init_kwargs: Optional[Dict[str, Any]] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring + def create_tensor_from_data( + self, + data: torch.Tensor, + fake_dtype=torch.float32, + requires_grad: bool = False, + internal: bool = False, + ): + """Create Float8Tensor from raw uint8 data""" + assert data.dtype in [ + torch.uint8, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ] + if internal: + return Float8TensorBase( + data=data, + fp8_scale_inv=1 / self.scale, + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) + return Float8Tensor( + shape=data.shape, + dtype=fake_dtype, + data=data, + fp8_scale_inv=1 / self.scale, + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) - # Return input tensor if constructor kwargs are not provided - ctx.input_dtype = tensor.dtype - if init_kwargs is None: - return tensor - - # Construct new tensor if constructor kwargs are provided - default_kwargs = { - "data": tensor._data, - "fp8_meta": tensor._fp8_meta, - "fp8_meta_forward": tensor._fp8_meta_forward, - "fp8_meta_index": tensor._fp8_meta_index, - "fp8_dtype": tensor._fp8_dtype, - "fp8_scale_inv": tensor._scale_inv, - "dtype": tensor.dtype, - } - for key, val in default_kwargs.items(): - if key not in init_kwargs: - init_kwargs[key] = val - return Float8Tensor(**init_kwargs) - @staticmethod - def backward(ctx, grad): - # pylint: disable=missing-function-docstring - return grad.to(ctx.input_dtype), None +class Float8CurrentScalingQuantizer(Quantizer): + """Builder class for FP8 tensors with per-tensor current scaling + High-precision tensors (e.g. in FP32 or BF16) are quantized by + multiplying with a scaling factor and casting to FP8. The max-abs + value ("amax") in the tensor is computed directly by scanning the input + high-precision tensor, without the need of any history window. -class _ViewFunc(torch.autograd.Function): - """View function + Unlike delayed scaling, scale and amax tensors are not needed to initialize the + quantizer, becuse they are simply GPU buffers that will be filled by current + scaling quantization kernels, instead of using values taken from delayed scaling + history window. Therefore, device parameter is needed for tensor allocation. - View the Float8Tensor using the provided shape. + Both Float8CurrentScalingQuantizer and Float8Quantizer produces Float8Tensor, + because they are both per-tensor scaling, ie. one scaling factor per tensor. """ - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor - - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.view(*shape), - ) - return tensor.view(*shape) - - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring + """Scaling factor to multiply when quantizing to FP8""" + scale: torch.Tensor + """Max-abs value from last FP8 cast""" + amax: torch.Tensor + """FP8 datatype""" + dtype: TE_DType + """amax reduction options""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] + amax_reduction_size: Optional[int] + """Options about how to quantize the tensor""" + force_pow_2_scales: bool + amax_epsilon: float + + def __init__( + self, + fp8_dtype: TE_DType, + device: torch.device, + *, + rowwise: bool = True, + columnwise: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[dist_group_type] = None, + amax_reduction_size: Optional[int] = 1, + force_pow_2_scales: bool = False, + amax_epsilon: float = 0.0, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.scale = torch.ones(1, dtype=torch.float32, device=device) + self.amax = torch.empty(1, dtype=torch.float32, device=device) + self.dtype = fp8_dtype + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group + self.amax_reduction_size = amax_reduction_size + self.force_pow_2_scales = force_pow_2_scales + self.amax_epsilon = amax_epsilon + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + if not isinstance(dst, Float8Tensor): + raise ValueError("Float8CurrentScalingQuantizer can only update Float8Tensor") - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.view(ctx.shape), - ) - return dgrad, None - return grad.view(ctx.shape), None + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) -class _ReshapeFunc(torch.autograd.Function): - """Reshape function + # Update FP8 dtype + dst._fp8_dtype = self.dtype - Reshape the Float8Tensor using the provided shape. + return dst - """ + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> Float8Tensor: - @staticmethod - def forward( - ctx, - tensor: torch.Tensor, - shape: Tuple[int] = None, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") - # Return input tensor if shape is not provided - ctx.shape = tensor.shape - if shape is None: - return tensor + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) - # Construct new tensor if shape is provided - if isinstance(tensor, Float8Tensor): - return Float8Tensor.make_like( - tensor, - data=tensor._data.reshape(*shape), + # Allocate FP8 data transpose if needed + data_transpose = None + if self.columnwise_usage: + inner_dim = data.size(-1) + data_transpose = torch.empty( + inner_dim, + data.numel() // inner_dim, + dtype=torch.uint8, + device=device, ) - return tensor.reshape(*shape) - @staticmethod - def backward( - ctx, - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring + # Construct FP8 tensor + return Float8Tensor( + shape=shape, + dtype=dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=data_transpose, + quantizer=self, + ) - if isinstance(grad, Float8Tensor): - dgrad = Float8Tensor.make_like( - grad, - data=grad._data.reshape(ctx.shape), + def calibrate(self, tensor: torch.Tensor) -> None: + # current scaling don't need to calibrate + return + + def create_tensor_from_data( + self, + data: torch.Tensor, + fake_dtype=torch.float32, + requires_grad: bool = False, + internal: bool = False, + ): + """ + Create Float8Tensor from raw uint8 data, unlike delayed scaling, + self.scale doesn't mean anything, so we are simply creating empty scale_inv + """ + assert data.dtype in [ + torch.uint8, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ] + if internal: + return Float8TensorBase( + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, ) - return dgrad, None - return grad.reshape(ctx.shape), None + return Float8Tensor( + shape=data.shape, + dtype=fake_dtype, + data=data, + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), + fp8_dtype=self.dtype, + requires_grad=requires_grad, + data_transpose=None, + quantizer=self, + ) -class Float8Tensor(QuantizedTensor): +class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -317,225 +338,68 @@ class Float8Tensor(QuantizedTensor): Parameters ---------- + shape: int or iterable of int + Tensor dimensions. + dtype: torch.dtype + Nominal tensor datatype. + requires_grad: bool, optional = False + Whether to compute gradients for this tensor. data: torch.Tensor - Raw FP8 data in a uint8 tensor - fp8_attrs: dict, optional - FP8 metadata, primarily managed by Float8Tensor. If - provided, all other FP8 configuration is ignored. - fp8_meta: dict, optional - FP8 metadata object, primarily managed by TE modules. - fp8_meta_forward: bool, default = `True` - Whether to access the FP8 metadata for the - forward pass. Ignored if fp8_meta is not - provided. - fp8_meta_index: int, optional - Index to access in FP8 meta tensors. Required if - fp8_meta is provided and otherwise ignored. - fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 - FP8 format. + Raw FP8 data in a uint8 tensor fp8_scale_inv: torch.Tensor - Reciprocal of the scaling factor applied when - casting to FP8, i.e. the scaling factor that must - be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. - dtype: torch.dtype, default = torch.float32 - Nominal tensor datatype. + Reciprocal of the scaling factor applied when casting to FP8, + i.e. the scaling factor that must be applied when casting from + FP8 to higher precision. + fp8_dtype: transformer_engine_torch.DType + FP8 format. + data_transpose: torch.Tensor, optional + FP8 transpose data in a uint8 tensor + quantizer: Float8Quantizer, Float8CurrentScalingQuantizer, optional + Builder class for FP8 tensors """ - _data: torch.Tensor - _fp8_attrs: Dict[str, Any] - _fp8_meta: Optional[Dict[str, Any]] - _fp8_meta_forward: bool - _fp8_meta_index: Optional[int] - _fp8_dtype: TE_DType - _scale_inv: torch.Tensor - - # FP8 transpose cache - _transpose: Optional[torch.Tensor] - _transpose_invalid: bool - - def __new__( - cls, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - fp8_scale_inv: Optional[torch.Tensor] = None, - dtype: torch.dtype = torch.float32, - requires_grad: bool = False, - data_transpose: Optional[torch.Tensor] = None, - ): - - # Check that data buffer is valid - if data.element_size() != 1: - raise ValueError( - f"Float8Tensor requires data buffer with 8-bit dtype (got dtype={data.dtype})" - ) - if data.requires_grad: - raise ValueError("Float8Tensor requires non-differentiable data buffer") - if not data.is_cuda: - data = data.cuda() - - # Initialize tensor object - self = torch.Tensor._make_wrapper_subclass( - cls, - data.size(), - strides=data.stride(), - storage_offset=data.storage_offset(), - dtype=dtype, - layout=data.layout, - requires_grad=requires_grad, - device=data.device, - ) - self._data = data - - # Initialize dict of class attributes - # Note: We store FP8 attributes in a dictionary so we can - # share them between tensors with the same data, e.g. detached - # tensors. - if fp8_attrs is None: - self._fp8_attrs = {} - else: - self._fp8_attrs = fp8_attrs - return self - - # FP8 meta tensors - if fp8_meta is not None and fp8_meta_index is None: - raise ValueError( - "To initialize Float8Tensor with FP8 meta tensors, " - "the FP8 meta tensor index must also be provided" - ) - self._fp8_meta = fp8_meta - self._fp8_meta_forward = fp8_meta_forward - self._fp8_meta_index = fp8_meta_index - - # FP8 dtype - assert fp8_dtype in ( - TE_DType.kFloat8E4M3, - TE_DType.kFloat8E5M2, - ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype = fp8_dtype - - # FP8 scale-inverse - if fp8_scale_inv is None and self._fp8_meta is not None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - fp8_scale_inv = self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index] - fp8_scale_inv = fp8_scale_inv.detach().view(1).clone() - if fp8_scale_inv is None: - raise ValueError( - "Attempted to initialize Float8Tensor without specifying scale-inverse" - ) - if fp8_scale_inv.numel() != 1: - raise ValueError( - "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" - ) - if fp8_scale_inv.dim() != 1: - fp8_scale_inv = fp8_scale_inv.reshape(1) - if ( - not devices_match(fp8_scale_inv.device, self._data.device) - or fp8_scale_inv.dtype != torch.float32 - ): - fp8_scale_inv = fp8_scale_inv.to( - device=self._data.device, - dtype=torch.float32, - ) - self._scale_inv = fp8_scale_inv - - # FP8 transpose cache - self._transpose = data_transpose - self._transpose_invalid = self._transpose is None - - return self - - @classmethod - def make_like( - cls, - tensor: Float8Tensor, - *, - data: torch.Tensor, - fp8_attrs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Float8Tensor: - """Use attributes of a Float8Tensor to create another Float8Tensor - - See constructor for list of keyword arguments. - - """ - default_kwargs = { - "fp8_meta": tensor._fp8_meta, - "fp8_meta_forward": tensor._fp8_meta_forward, - "fp8_meta_index": tensor._fp8_meta_index, - "fp8_dtype": tensor._fp8_dtype, - "fp8_scale_inv": tensor._scale_inv, - "dtype": tensor.dtype, - } - for key, val in default_kwargs.items(): - if key not in kwargs: - kwargs[key] = val - return Float8Tensor(data=data, fp8_attrs=fp8_attrs, **kwargs) - - def __repr__(self): + def __repr__(self, *, tensor_contents=None): return ( "Float8Tensor(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " - f"data={self.from_float8(dtype=self.dtype)}" + f"data={self.dequantize(dtype=self.dtype)}" ")" ) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8Tensor + By default the resulting tensor's dtype is the + Float8Tensor's nominal dtype. + """ # Convert PyTorch dtype to TE dtype if dtype is None: dtype = self.dtype - dtype = torch_to_transformer_engine_dtype[dtype] - # Make sure FP8 data is in expected format - data = self._data - if data.device.type != "cuda": - data = data.cuda() - if not data.is_contiguous(): - data = data.contiguous() - if data.dim() != 2: - data = data.view(1, -1) - - # Cast from FP8 - out = cast_from_fp8( - data.view(1, -1), - None, # fp8_meta_tensor - None, # fp8_tensor - self._fp8_dtype, - dtype, - scale_inv=self._scale_inv, - ) + if torch.is_grad_enabled(): + return _FromFloat8Func.apply(self, dtype) + return _FromFloat8Func.forward(None, self, dtype) - # Make sure output is in expected format - if out.size() != self.size(): - out = out.view(self.size()) - return out + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor - def from_float8(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct plain PyTorch tensor from Float8Tensor + Quantizer can be used for in-place operations. - By default the resulting tensor's dtype is the - Float8Tensor's nominal dtype. """ - return _FromFloat8Func.apply(self, dtype) + if self._quantizer is not None: + return self._quantizer + # Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling) + raise ValueError( + "Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable" + ) def quantize_( self, tensor: torch.Tensor, *, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, noop_flag: Optional[torch.Tensor] = None, ) -> Float8Tensor: """Update FP8 data @@ -544,184 +408,66 @@ def quantize_( ---------- tensor: torch.Tensor Tensor to copy from - scale: torch.Tensor, optional - Scaling factor to use for FP8 quantization - amax: torch.Tensor, optional - History of maximum absolute values. The first entry will - be updated with the absmax of `tensor`. noop_flag: torch.Tensor, optional float32 flag indicating whether to avoid performing update """ - src = tensor - dst = self - - # In-place operations invalidate transpose cache - self._reset_caches() - - # Special logic if other tensor is Float8Tensor - if isinstance(src, Float8Tensor): - - # Cast to plain tensor if FP8 dtypes don't match - if dst._fp8_dtype != src._fp8_dtype: - return dst.quantize_(src.dequantize()) - - # Directly copy FP8 data - dst._data.copy_(src._data.detach()) - dst._scale_inv.copy_(src._scale_inv.detach()) - if amax is not None or dst._fp8_meta is not None: - src_amax: torch.Tensor - if src._fp8_meta is None: - src_min, src_max = src.dequantize().aminmax() - src_amax = torch.maximum(-src_min, src_max) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=src._fp8_meta_forward, - ) - fp8_meta_index = src._fp8_meta_index - src_amax = src._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] - dst_amax: torch.Tensor - if amax is None: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta_index = dst._fp8_meta_index - dst_amax = dst._fp8_meta[fp8_meta_key].amax_history[0, fp8_meta_index] - else: - dst_amax = amax - if dst_amax.dim() > 0: - dst_amax = dst_amax[tuple([0] * dst_amax.dim())] - torch.maximum(src_amax, dst_amax, out=dst_amax) - if dst._transpose is not None: - if src._transpose is None: - dst.transpose_2d(force_compute=True, fill_cache=True) - else: - dst._transpose.copy_(src._transpose) - dst._transpose_invalid = False - return self - - # Convert QuantizedTensor to plain tensor - if isinstance(src, QuantizedTensor): - return dst.quantize_(src.dequantize()) - - # Make sure input is in expected format - if src.size() != dst.size(): - src = src.expand(dst.size()) - if not devices_match(src.device, dst.device): - src = src.to(device=dst.device) - if src.dtype not in (torch.float32, torch.bfloat16, torch.float16): - src = src.float() - if not src.is_contiguous(): - src = src.contiguous() - - # Make sure FP8 scaling factors are in expected format - if scale is not None: - if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: - scale = scale.to(device=dst.device, dtype=torch.float32) - if amax is not None: - while amax.dim() < 2: - amax = amax.unsqueeze(0) - if not devices_match(amax.device, dst.device): - raise ValueError( - f"Invalid device for amax (expected {dst.device}, found {amax.device})" - ) - if amax.dtype != torch.float32: - raise ValueError(f"Invalid dtype for amax (expected float32, found {amax.type})") - - # Default FP8 scaling factors - fp8_meta = None - if dst._fp8_meta is None: - if scale is None: - scale = dst._scale_inv.reciprocal() - if amax is None: - amax = torch.empty((1, 1), dtype=torch.float32, device=dst.device) - else: - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=dst._fp8_meta_forward, - ) - fp8_meta = dst._fp8_meta[fp8_meta_key] - - # Check local data - if not dst._data.is_contiguous(): - raise RuntimeError("Transformer Engine cast kernels require contiguous data") - - # Perform FP8 cast - if dst._transpose is None: - dst_data = dst._data - if src.dim() != 2: - src = src.view(1, -1) - dst_data = dst_data.view(1, -1) - cast_to_fp8( - src, - fp8_meta, - dst._fp8_meta_index, - dst._fp8_dtype, - out=dst_data, - scale=scale, - amax=amax, - scale_inv=dst._scale_inv, - ) - else: - fp8_cast_transpose_fused( - src.view(-1, src.size(-1)), - fp8_meta, - dst._fp8_meta_index, - dst._fp8_dtype, - cast_out=dst._data, - transpose_out=dst._transpose, - scale=scale, - amax=amax, - scale_inv=dst._scale_inv, - noop_flag=noop_flag, - ) - dst._transpose_invalid = False - - # Callback hook to perform amax reduction after optimizer step - post_optimizer_step_fwd_amax_reduction(self) - + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) return self - @classmethod - def to_float8( - cls, - tensor: torch.Tensor, - *, - fp8_meta: Optional[Dict[str, Any]] = None, - fp8_meta_forward: bool = True, - fp8_meta_index: Optional[int] = None, - fp8_dtype: TE_DType = TE_DType.kFloat8E4M3, - data: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - amax: Optional[torch.Tensor] = None, - scale_inv: Optional[torch.Tensor] = None, - with_transpose_cache: bool = False, - data_transpose: Optional[torch.Tensor] = None, - ): - """Construct Float8Tensor from plain PyTorch tensor""" - return _ToFloat8Func.apply( - tensor, - fp8_meta, - fp8_meta_forward, - fp8_meta_index, - fp8_dtype, - data, - scale, - amax, - scale_inv, - with_transpose_cache, - data_transpose, - ) - def detach(self) -> Float8Tensor: # pylint: disable=missing-function-docstring - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - ) + return Float8Tensor.make_like(self) + + def _create_transpose(self): + data = self._data + if not data.is_contiguous(): + data = data.contiguous() + self._transpose = tex.fp8_transpose(data, self._fp8_dtype, out=self._transpose) + self._transpose_invalid = False + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + # Figure out what data is available and what is required + has_data = self._data is not None + has_data_transpose = self._transpose is not None and not self._transpose_invalid + needs_data = has_data + needs_data_transpose = has_data_transpose + if non_tn_fp8_gemm_supported(): + if rowwise_usage is not None and rowwise_usage: + needs_data = True + if columnwise_usage is not None and columnwise_usage: + needs_data = True + needs_data_transpose = False + else: + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP8 data is required to generate FP8 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self._data = None + if not needs_data_transpose: + self._transpose = None + self._transpose_invalid = True def clone(self) -> Float8Tensor: # pylint: disable=missing-function-docstring + assert self._data is not None data = self._data.detach().clone() data_transpose = None if self._transpose is not None: @@ -744,7 +490,6 @@ def reshape(self, *shape: Tuple[int]) -> Float8Tensor: def contiguous( self, - *, memory_format: torch.memory_format = torch.contiguous_format, ) -> Float8Tensor: """Returns tensor with data in provided memory format @@ -752,161 +497,71 @@ def contiguous( Returns `self` if data is already in correct memory format. """ - if self._data.is_contiguous(memory_format=memory_format): + if self._data is not None and self._data.is_contiguous(memory_format=memory_format): return self - return _IdentityFunc.apply( - self, - {"data": self._data.detach().contiguous(memory_format=memory_format)}, - ) - - def transpose_2d( - self, - *, - force_compute: bool = False, - fill_cache: bool = False, - noop_flag: Optional[torch.Tensor] = None, - cache: Optional[bool] = None, - ) -> torch.Tensor: - """ - 2D transpose with caching support. + if self._transpose is not None and self._transpose.is_contiguous( + memory_format=memory_format + ): + return self + return Float8Tensor.make_like(tensor=self, data=self._data.contiguous()) - Parameters - ---------- - force_compute: bool, default = `False` - Force computation of transpose. Otherwise use - cached values, if possible. - fill_cache: bool, default = `False` - Cache output tensor for future function calls. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - cached values, if possible. - cache: bool, deprecated + # raise ValueError("Float8Tensor does not support different memory formats!") + def _reset_caches(self) -> None: """ - - # Handle deprecated cache kwarg - if cache is not None: - msg = ( - "cache kwarg for Float8Tensor.transpose_2d is deprecated, " - "please use force_compute and fill_cache instead" - ) - warnings.warn(msg, DeprecationWarning) - if cache: - force_compute = False - fill_cache = True - else: - force_compute = True - fill_cache = False - - # Need to compute transpose if cache is invalid - need_compute = ( - force_compute - or (self._transpose is None) - or self._transpose_invalid - or (noop_flag is not None) - ) - - # Return cached transpose if possible - if not need_compute: - assert self._transpose is not None - return self._transpose - - # Allocate output if needed - data = self._data.contiguous().reshape(-1, self.size(-1)) - out: Optional[torch.Tensor] = self._transpose - if out is None: - out = torch.empty( - (data.size(1), data.size(0)), - dtype=torch.uint8, - device=data.device, - ) - noop_flag = None - else: - self._transpose_invalid = False - - # Apply transpose kernel - fp8_dtype = self._fp8_dtype - if noop_flag is None: - tex.fp8_transpose_noalloc(data, out, fp8_dtype) - else: - noop_flag = noop_flag.to(dtype=torch.float32, device=data.device) - tex.fp8_transpose_noalloc_noop(data, out, noop_flag, fp8_dtype) - - # Fill cache if needed - if fill_cache: - self._transpose = out - self._transpose_invalid = False - - return out - - @torch.no_grad() - def cast_transpose_( - self, - tensor: torch.Tensor, - noop_flag: Optional[torch.Tensor] = None, - ) -> None: - """Cast from tensor and populate transpose cache - - Tensor is reshaped as a 2D matrix. - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from. Must have same dimensions as - destination tensor. - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid updating - destination tensor. - + Set transpose cache as invalid. + Should be called after any in-place operation. """ - if self._transpose is None: - self._transpose = torch.empty( - (self.size(-1), self.numel() // self.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self.quantize_(tensor, noop_flag=noop_flag) - - @torch.no_grad() - def reset_fp8_meta_scale_inv(self) -> None: - """Replace FP8 meta tensor scale-inverse with cached value - - The FP8 meta tensor scale_inv entry corresponding to this - tensor is replaced with the scale_inv value used to construct - the tensor. + self._transpose_invalid = True + def remove_caches(self) -> None: """ - assert self._fp8_meta is not None, "FP8 meta tensors not found." - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=self._fp8_meta_forward, - ) - self._fp8_meta[fp8_meta_key].scale_inv[self._fp8_meta_index].copy_(self._scale_inv[0]) - - def to_dtype(self, dtype: torch.dtype) -> Float8Tensor: - """Create `Float8Tensor` with given nominal dtype - - The new tensor has the same underlying FP8 data. - + Remove transpose cache and mark it as invalid. """ - return Float8Tensor.make_like( - self, - data=self._data, - fp8_attrs=self._fp8_attrs, - dtype=dtype, - ) + self._transpose_invalid = True + del self._transpose # explicitly deletes the data for safety + self._transpose = None - def _reset_caches(self) -> None: - """ - Set transpose cache as invalid. - Should be called after any in-place operation. - """ + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._data = torch.Tensor() if self._data is not None else None + self._transpose = torch.Tensor() if self._transpose is not None else None self._transpose_invalid = True @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # Slice op - if func == aten.slice.Tensor: + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if ( + out_transpose_shape[0] != out_shape[-1] + or out_transpose_shape[1:] != out_shape[:-1] + ): + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=False, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) + + if func in [aten.slice.Tensor, aten.select.int]: tensor = args[0] data = tensor._data data_slice = data.__torch_dispatch__( @@ -915,21 +570,64 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=data_slice) + return Float8Tensor.make_like(tensor, data=data_slice, shape=data_slice.shape) - # View op - if func == aten.view.default: + # Related to FSDP2 + if func == aten.split.Tensor: tensor = args[0] data = tensor._data - data_view = data.__torch_dispatch__( + func_out = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=data_view) + return [ + Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) + for split_tensor in func_out + ] + if func == aten.new_zeros.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + if func == torch.ops.aten.as_strided.default: + tensor = args[0] + data = tensor._data + func_out = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + if func == torch.ops.aten.detach.default: + return cls.detach(args[0]) + if func == torch.ops.aten.clone.default: + return cls.clone(args[0]) + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + # Just copy FP8 attrs if copying between Float8Tensors + if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): + dst._data.copy_(src._data.detach()) + dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) + if src._transpose is not None or dst._transpose is not None: + dst._create_transpose() + return dst + elif func in _ops_to_preserve_subclass_in_fsdp2: + # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 + warnings.warn( + f"A function call({func}) in {cls} may not return {cls} tensor as an output. It" + " might cause an error in torch FSDP2!" + ) + else: + pass - # Default case return super().__torch_dispatch__(func, types, args, kwargs) @classmethod @@ -939,6 +637,7 @@ def _make_in_reduce_ex( fp8_dtype: TE_DType, fp8_scale_inv: torch.Tensor, dtype: torch.dtype, + shape: torch.shape, ) -> Float8Tensor: """Build Float8Tensor, for use in __reduce__ @@ -951,13 +650,14 @@ def _make_in_reduce_ex( fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, dtype=dtype, + shape=shape, ) def __reduce_ex__(self, protocol: int) -> tuple: """Custom pickling to remove references to FP8 metadata objects""" return ( Float8Tensor._make_in_reduce_ex, - (self._data, self._fp8_dtype, self._scale_inv, self.dtype), + (self._data, self._fp8_dtype, self._scale_inv, self.dtype, self.shape), ) def _get_data(self) -> Float8Tensor: @@ -975,13 +675,13 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Tensor device new_device = tensor.device if tensor.is_cuda else self.device - - # Check whether grad is required - if self.requires_grad != tensor.requires_grad: - self.requires_grad_(requires_grad=tensor.requires_grad) + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) # Just copy FP8 data if other tensor is Float8Tensor if isinstance(tensor, Float8Tensor): + + # PyTorch tensor attributes if ( # pylint: disable=too-many-boolean-expressions self.size() != tensor.size() or self.stride() != tensor.stride() @@ -1002,57 +702,110 @@ def _set_data(self, tensor: torch.Tensor) -> None: ) # pylint: disable=unnecessary-dunder-call super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) + + # Float8Tensor attributes self._data = tensor._data - self._fp8_attrs = tensor._fp8_attrs + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._scale_inv = tensor._scale_inv + self._transpose = tensor._transpose + self._transpose_invalid = tensor._transpose_invalid return - # Reallocate FP8 data if needed - if ( - self.size() != tensor.size() - or self.stride() != tensor.stride() - or self.dtype != tensor.dtype - or self.layout != tensor.layout - or not devices_match(self.device, new_device) - ): - self._data = torch.empty_like( - tensor, - dtype=torch.uint8, - device=new_device, - ) - dummy_tensor = torch.Tensor._make_wrapper_subclass( - Float8Tensor, - self._data.size(), - strides=self._data.stride(), - storage_offset=self._data.storage_offset(), - dtype=tensor.dtype, - layout=self._data.layout, - requires_grad=tensor.requires_grad, - device=self._data.device, - ) - # pylint: disable=unnecessary-dunder-call - super(Float8Tensor, type(self)).data.__set__(self, dummy_tensor) - if self._transpose is not None: - self._transpose = torch.empty( - (self._data.size(-1), self._data.numel() // self._data.size(-1)), - dtype=torch.uint8, - device=self.device, - ) - self._transpose_invalid = True - - # Copy values from other tensor - self.quantize_(tensor) + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.internal = False + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) # Cast to FP8 when setting Float8Tensor.data data = property(_get_data, _set_data) - # Accessors for objects in self._fp8_attrs - # Note: We store FP8 attributes in a dictionary so we can share - # them between tensors with the same data, e.g. detached tensors. - # For convenience, we also expose them as property attributes. - _fp8_meta = property(**_make_fp8_attr_property_funcs("fp8_meta")) - _fp8_meta_forward = property(**_make_fp8_attr_property_funcs("fp8_meta_forward")) - _fp8_meta_index = property(**_make_fp8_attr_property_funcs("fp8_meta_index")) - _fp8_dtype = property(**_make_fp8_attr_property_funcs("dtype")) - _transpose = property(**_make_fp8_attr_property_funcs("transpose")) - _transpose_invalid = property(**_make_fp8_attr_property_funcs("transpose_invalid")) - _scale_inv = property(**_make_fp8_attr_property_funcs("scale_inv")) + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + shape: Optional[list[int]] = None, + ) -> Float8Tensor: + # pylint: disable=missing-function-docstring + ctx.shape = tensor.shape + if shape is None: + return tensor.detach() + out_data = tensor._data.view(*shape) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + return grad.reshape(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8Tensor, + shape: Tuple[int], + ) -> Float8Tensor: + # pylint: disable=missing-function-docstring + ctx.shape = tensor.shape + if shape is None: + return tensor.detach() + out_data = tensor._data.reshape(*shape) + out_shape = out_data.size() + out_transpose = None if tensor._transpose_invalid else tensor._transpose + if out_transpose is not None: + out_transpose_shape = out_transpose.size() + if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: + out_transpose = None + return Float8Tensor( + shape=out_shape, + dtype=tensor.dtype, + requires_grad=tensor.requires_grad, + data=out_data, + fp8_scale_inv=tensor._scale_inv, + fp8_dtype=tensor._fp8_dtype, + data_transpose=out_transpose, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + return grad.reshape(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py new file mode 100644 index 0000000000..843c7936f2 --- /dev/null +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -0,0 +1,608 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple + +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ..constants import MXFP8_BLOCK_SCALING_SIZE +from ..utils import devices_match, round_up_to_nearest_multiple + +from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc + +aten = torch.ops.aten + + +class MXFP8Quantizer(Quantizer): + """Builder class for FP8 tensors with MX block scaling + + High-precision tensors (e.g. in FP32 or BF16) are quantized by + dividing them into groups of 32 elements, each scaled and cast + separately using current data. + + """ + + dtype: TE_DType + + def __init__( + self, + fp8_dtype: TE_DType, + *, + rowwise: bool = True, + columnwise: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp8_dtype + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, MXFP8Tensor), f"Cannot store quantized MXFP8 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + # Update FP8 dtype + dst._fp8_dtype = self.dtype + + return dst + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % MXFP8_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE != 0: + return False + return True + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> MXFP8Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert ( + shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 + and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0 + ), ( + f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by" + f" {MXFP8_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + ) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty_like(data) + columnwise_scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + dtype=torch.uint8, + device=device, + ) + + # Construct FP8 tensor + return MXFP8Tensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # TODO(ksivamani): No calibration needed for mxfp8? + pass + + +class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): + """Experimental tensor class with FP8 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + data: torch.Tensor + Raw FP8 data in a uint8 tensor + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + fp8_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP8, i.e. the scaling factor that must + be applied when casting from FP8 to higher + precision. Can be inferred from fp8_meta if + provided. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype. + + """ + + def __repr__(self, *, tensor_contents=None): + return f"MXFP8Tensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from MXFP8Tensor + + By default the resulting tensor's dtype is the + MXFP8Tensor's nominal dtype. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromMXFP8Func.apply(self, dtype) + return _FromMXFP8Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return MXFP8Quantizer( + fp8_dtype=self._fp8_dtype, + ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> MXFP8Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return MXFP8Tensor.make_like(self) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """ + For MXFP8, columnwise scaled output is only produced by x2 + scaling kernels, so this function only disables usages. + """ + + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data + if rowwise_usage: + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but MXFP8Tensor is missing column-scaled scale-inverses" + ) + else: + self._columnwise_data = None + self._columnwise_scale_inv = None + + def clone(self) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + assert self._rowwise_data is not None + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> MXFP8Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("MXFP8Tensor does not support different memory formats!") + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + return MXFP8Tensor( + shape=out_shape, + dtype=tensor.dtype, + rowwise_data=out_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + ) -> MXFP8Tensor: + """Build MXFP8Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + MXFP8Tensor._make_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + ), + ) + + def _get_data(self) -> MXFP8Tensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a MXFP8Tensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) + + # Just copy FP8 data if other tensor is MXFP8Tensor + if isinstance(tensor, MXFP8Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + MXFP8Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting MXFP8Tensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the MXFP8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: MXFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != ctx.shape[-1]: + raise RuntimeError( + "MXFP8Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Construct new tensor if shape is provided + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.view(*shape) + if tensor._columnwise_data is not None: + columnwise_shape = [shape[-1]] + list(shape[:-1]) + new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + return MXFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, MXFP8Tensor): + new_data = ( + grad._rowwise_data.view(*ctx.shape) if grad._rowwise_data is not None else None + ) + if grad._columnwise_data is not None: + new_columnwise_data = grad._columnwise_data.view(ctx.shape[-1], -1) + else: + new_columnwise_data = None + dgrad = MXFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the MXFP8Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: MXFP8Tensor, + shape: Optional[list[int]] = None, + ) -> MXFP8Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != ctx.shape[-1]: + raise RuntimeError( + "MXFP8Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Construct new tensor if shape is provided + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + new_rowwise_data = tensor._rowwise_data.reshape(*shape) + if tensor._columnwise_data is not None: + columnwise_shape = [shape[-1]] + list(shape[:-1]) + new_columnwise_data = tensor._columnwise_data.view(columnwise_shape) + + return MXFP8Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + fp8_dtype=tensor._fp8_dtype, + quantizer=tensor._quantizer, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, MXFP8Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + new_rowwise_data = grad._rowwise_data.view(*ctx.shape) + if grad._columnwise_data is not None: + columnwise_shape = [ctx.shape[-1]] + list(ctx.shape[:-1]) + new_columnwise_data = grad._columnwise_data.view(columnwise_shape) + dgrad = MXFP8Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + fp8_dtype=grad._fp8_dtype, + quantizer=grad._quantizer, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 92c95b56ca..019aca9f60 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -1,55 +1,240 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tensor with quantized data""" from __future__ import annotations -from typing import Optional, Tuple +from typing import Optional, Tuple, Iterable, Any, Dict, Union +import abc +import copy import torch from torch.utils._pytree import tree_map +import transformer_engine_torch as tex + + +def prepare_for_saving( + *tensors, +) -> Tuple[list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], Optional[Any]]: + """Prepare tensors for saving. Needed because save_for_backward accepts only + torch.Tensor/torch.nn.Parameter types, while we want to be able to save + the internal TensorBase types too.""" + + tensor_list, tensor_objects_list = [], [] + for tensor in tensors: + if tensor is None or isinstance(tensor, torch.Tensor): + tensor_list.append(tensor) + tensor_objects_list.append(None) + else: + t, t_obj = tensor.prepare_for_saving() + tensor_list.extend(t) + tensor_objects_list.append(t_obj) + return tensor_list, tensor_objects_list + + +def restore_from_saved( + tensors: list[Optional[Any]], + saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], +) -> list[Optional[Any]]: + """Recombine the tensor data and metadata during backward pass.""" + tensor_objects = [] + for tensor in tensors: + if tensor is None or isinstance(tensor, torch.Tensor): + tensor_objects.append(saved_tensors[0]) + saved_tensors = saved_tensors[1:] + else: + saved_tensors = tensor.restore_from_saved(saved_tensors) + tensor_objects.append(tensor) + return tensor_objects + + +class Quantizer(abc.ABC): + """Builder class for quantized tensors. + + This class is typically used to convert a high-precision tensor + (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). -class _DequantizeFunc(torch.autograd.Function): - """Autograd function to convert quantized tensor to standard tensor""" + """ + + """Whether to construct quantized tensors with "row-wise usage" + + Hand-wave explanation: Consider the matrix multiplication C = A * + B^T (used in linear forward). Tensor Cores prefer "TN GEMMs" (in + Fortran-style column-major order), so A and B should be in + row-major order. + + """ + rowwise_usage: bool + + """Whether to construct quantized tensors with "column-wise usage" + + Hand-wave explanation: Consider the matrix multiplication C = A^T + * B (used in linear backward wgrad). Tensor Cores prefer "TN + GEMMs" (in Fortran-style column-major order), so A and B should be + in column-major order. + + """ + columnwise_usage: bool + + """Whether to instantiates tensor for purely internal usage + + Internal tensors are storage classes with minimal logic. They have + less overhead than PyTorch tensor sub-classes, but are not + compatible with PyTorch's autograd infrastructure nor PyTorch + operations. + + """ + internal: bool + + def __init__(self, *, rowwise: bool, columnwise: bool) -> None: + self.rowwise_usage = rowwise + self.columnwise_usage = columnwise + self.internal = False + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"rowwise_usage={self.rowwise_usage}, " + f"columnwise_usage={self.columnwise_usage}, " + f"internal={self.internal}, " + ")" + ) + + @abc.abstractmethod + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Quantize tensor in-place""" + + def quantize( + self, tensor: torch.Tensor, *, out: Optional[QuantizedTensor] = None + ) -> QuantizedTensor: + """Quantize tensor""" + if out is not None: + return self.update_quantized(tensor, out) + if (not self.internal) and torch.is_grad_enabled(): + return _QuantizeFunc.apply(tensor, self) + return _QuantizeFunc.forward(None, tensor, self) + + def multi_quantize(self, list_of_tensors): + """Quantize multiple tensors""" + list_of_output_tensors = [] + for tensor in list_of_tensors: + list_of_output_tensors.append(self.quantize(tensor)) + return list_of_output_tensors + + def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor""" + return self.quantize(tensor) + + @abc.abstractmethod + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensor: + """Construct quantized tensor with uninitialized data""" + + @abc.abstractmethod + def calibrate(self, tensor: torch.Tensor) -> None: + """Calibrate quantizer state + + Updates quantization state as if quantizing a tensor, but + without actually performing the quantization. + + """ + + def set_usage( + self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None + ) -> None: + """Set how the quantized tensor is expected to be used + + See documentation for `rowwise_usage` and `columnwise_usage` + variables. + + """ + if rowwise is not None: + self.rowwise_usage = rowwise + if columnwise is not None: + self.columnwise_usage = columnwise + + def copy(self) -> Quantizer: + """Create shallow copy""" + return copy.copy(self) + + +class _QuantizeFunc(torch.autograd.Function): + """Cast to FP8 from other dtype""" @staticmethod def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: QuantizedTensor, - dtype: Optional[torch.dtype] = None, - ) -> torch.Tensor: + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: torch.Tensor, + quantizer: Quantizer, + ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tensor.dequantize(dtype=dtype) + return tex.quantize(tensor, quantizer) @staticmethod def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, + _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision return grad, None class _IdentityFunc(torch.autograd.Function): - """Autograd function to create quantized tensor with same data""" + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ @staticmethod def forward( - _ctx: torch.autograd.function.FunctionCtx, # unused - tensor: QuantizedTensor, + ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tensor.detach() + + # Return input tensor if constructor kwargs are not provided + if init_kwargs is None: + return tensor.detach() + + # Construct new tensor if constructor kwargs are provided + ctx.input_dtype = tensor.dtype + kwargs = tensor.get_metadata() + for key, val in init_kwargs.items(): + kwargs[key] = val + return type(tensor)(tensor.shape, tensor.dtype, **kwargs) @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> torch.Tensor: + def backward(ctx, grad_output): # pylint: disable=missing-function-docstring - return grad + grad_input = grad_output + if grad_input.dtype == ctx.input_dtype: + grad_input = grad_input.detach() + else: + grad_input = grad_input.to(ctx.input_dtype) + return grad_input, None + + +def _stride_from_shape(shape: list[int]): + if len(shape) == 0: + return [] + rstride = [1] + for d in reversed(shape[1:]): + rstride.append(rstride[-1] * d) + return list(reversed(rstride)) class QuantizedTensor(torch.Tensor): @@ -62,6 +247,22 @@ class QuantizedTensor(torch.Tensor): """ + def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): + # We are assuming only contiguous tensors + stride = _stride_from_shape(shape) + instance = torch.Tensor._make_wrapper_subclass( + cls, + shape, + strides=stride, + storage_offset=0, + dtype=dtype, + layout=torch.strided, + requires_grad=requires_grad, + device=torch.cuda.current_device(), + ) + + return instance + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( @@ -85,24 +286,42 @@ def detach(self) -> QuantizedTensor: f"{self.__class__.__name__} class does not implement detach function" ) - def __repr__(self) -> str: + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """Indicate to the tensor how it is going to be used + + This enables optimizations to memory usage in some cases + where forward and backward passes use the tensor in + different directions. + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_usage function" + ) + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully""" + + def __repr__(self, *, tensor_contents=None) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" def float(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.float32) + return self.dequantize(dtype=torch.float32) def bfloat16(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.bfloat16) + return self.dequantize(dtype=torch.bfloat16) def half(self) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self, torch.float16) + return self.dequantize(dtype=torch.float16) - def cpu(self) -> torch.Tensor: + def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor: # pylint: disable=missing-function-docstring - return _DequantizeFunc.apply(self).cpu() + return self.dequantize().cpu(memory_format=memory_format) def expand_as(self, other: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -179,3 +398,54 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): kwargs = {} # Do not force the QuantizedTensor type on the returned tensor return torch._C._disabled_torch_function_impl(func, types, args, kwargs) + + def contiguous( + self, memory_format: torch.memory_format = torch.contiguous_format + ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement contiguous function" + ) + + def get_metadata(self) -> Dict[str, Any]: + """Get keyword arguments for quantized tensor constructor + + Contains metadata so that the new quantized tensor has the + same underlying quantized data. + + """ + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement get_metadata function" + ) + + @classmethod + def make_like( + cls, + tensor: QuantizedTensor, + *, + shape: Optional[Iterable[int]] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, + data: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Create new quantized tensor + + By default, new tensor has the same attributes and underlying + data. + + """ + if shape is None: + shape = data.shape if data is not None else tensor.shape + dtype = dtype if dtype is not None else tensor.dtype + kwargs = tensor.get_metadata() + if data is not None: + kwargs["data"] = data + return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) + + def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: + """Create `QuantizedTensor` with given nominal dtype + + The new tensor has the same underlying data. + + """ + return self.__class__.make_like(self, dtype=dtype) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ad5476450b..d829275777 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,10 +12,10 @@ from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.attention import ( - InferenceParams, MultiheadAttention, - check_set_window_size, ) +from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams +from transformer_engine.pytorch.dot_product_attention.utils import check_set_window_size from transformer_engine.pytorch.jit import ( set_jit_fusion_options, warmup_jit_bias_dropout_add_all_dtypes, @@ -267,11 +267,11 @@ def __init__( zero_centered_gamma: bool = False, qkv_weight_interleaved: bool = True, ub_tp_comm_overlap: bool = False, - ub_bulk_wgrad: bool = True, - ub_bulk_dgrad: bool = True, ub_overlap_ag: bool = True, ub_overlap_rs: bool = True, ub_overlap_rs_dgrad: bool = False, + ub_bulk_dgrad: bool = True, + ub_bulk_wgrad: bool = True, bias: bool = True, activation: str = "gelu", normalization: str = "LayerNorm", @@ -546,6 +546,7 @@ def forward( max_seqlen_q: Optional[int] = None, max_seqlen_kv: Optional[int] = None, fast_zero_fill: bool = True, + pad_between_seqs: Optional[bool] = None, ) -> torch.Tensor: """ Transformer Layer: attention block and a feedforward network (MLP) @@ -637,6 +638,9 @@ def forward( inference_params: InferenceParams, default = None Inference parameters that are passed to the main model in order to efficiently calculate and store the context during inference. + pad_between_seqs: Optional[bool], default = `None` + If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. + If true, there are padding tokens between individual sequences in a packed batch. """ if self_attn_mask_type is None: @@ -697,6 +701,7 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, + pad_between_seqs=pad_between_seqs, ) if self.apply_residual_connection_post_layernorm and not self.output_layernorm: diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py new file mode 100644 index 0000000000..76c9b98d0e --- /dev/null +++ b/transformer_engine/pytorch/triton/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Kernels written with OpenAI Triton.""" diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py new file mode 100644 index 0000000000..43a3100926 --- /dev/null +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -0,0 +1,341 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient Cross Entropy kernels written with OpenAI Triton.""" + +from typing import Union +from functools import reduce +from operator import mul + +import torch +import torch.distributed as dist + +import triton +import triton.language as tl + + +@triton.jit +def online_softmax_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes the m/d components on this TP rank for the online softmax. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride (int): The stride of the m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) + else: + X_y = float("-inf") + else: + X_y = float("-inf") + + m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( + tl.float32 + ) + block_max = tl.max(X_block) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + tl.store(m_d_X_y_ptr, m) + tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) + tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) + + +@triton.jit +def cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + loss_ptr, + loss_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + world_size, + n_cols, + n_non_ignore, + label_smoothing: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + loss_stride (int): The stride of the loss tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride: The stride of m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + world_size (int): The size of world involved in this distributed loss calculation. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + loss_ptr += program_id * loss_stride + m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride + + # Need to reduce the m/d/X_y values from other TP ranks + m = tl.load(m_d_X_y_ptr) + d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) + ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) + + for i in range(1, world_size): + offset = i * 3 * n_non_ignore * m_d_X_y_stride + access_ptr = m_d_X_y_ptr + offset + m_new = tl.load(access_ptr) + d_new = tl.load(access_ptr + m_d_X_y_stride) + X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) + + d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) + m = tl.maximum(m, m_new) + ori_X_y = tl.maximum(ori_X_y, X_y_new) + + # Label smoothing is a general case of normal cross entropy + scaled_x_sum = 0.0 + eps = label_smoothing / (n_cols * world_size) + + # 4. [Online softmax] second pass: calculate the gradients + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # N is the number of non ignored elements in the batch + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + grad_dtype = X_block.dtype + X_block = X_block.to(tl.float32) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + loss = -(ori_X_y - m - tl.log(d)) + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx) + X_y += -(1 - label_smoothing) / (n_non_ignore) + tl.store(X_ptr + y - vocab_start_idx, X_y) + + tl.store(loss_ptr, loss) + + +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) + + +def cross_entropy_forward( + _input: torch.Tensor, + target: torch.Tensor, + label_smoothing: float, + reduce_loss: bool, + dist_process_group: Union[dist.ProcessGroup, None], +): + """Forward implementation of Cross Entropy kernel""" + + B, SQ, V = _input.shape + n_rows = B * SQ + + assert reduce(mul, list(target.size())) == (B * SQ), "Each token needs a target token ID." + + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # unreduced loss + loss_1d = torch.zeros(n_rows, dtype=torch.float32, device=_input.device) + + # tensor to hold this rank's m/d/X_y values + m_d_X_y = torch.zeros(n_rows * 3, dtype=torch.float32, device=_input.device) + + # ensure _input and target are contiguous in the last dimension + if _input.stride(-1) != 1: + _input = _input.contiguous() + if target.stride(-1) != 1: + target = target.contiguous() + + rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + + online_softmax_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), # always 1 + m_d_X_y_ptr=m_d_X_y, + m_d_X_y_stride=m_d_X_y.stride(-1), + rank=rank, + n_cols=V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if world_size > 1: + m_d_X_y_gathered = torch.zeros( + n_rows * 3 * world_size, dtype=torch.float32, device=_input.device + ) + dist.all_gather_into_tensor(m_d_X_y_gathered, m_d_X_y, group=dist_process_group) + else: + m_d_X_y_gathered = m_d_X_y + + cross_entropy_kernel[(n_rows,)]( + X_ptr=_input, + X_stride=_input.stride(-2), + Y_ptr=target, + Y_stride=target.stride(-1), + loss_ptr=loss_1d, + loss_stride=loss_1d.stride(-1), + m_d_X_y_ptr=m_d_X_y_gathered, + m_d_X_y_stride=m_d_X_y_gathered.stride(-1), + rank=rank, + world_size=world_size, + n_cols=V, + n_non_ignore=n_rows, + label_smoothing=label_smoothing, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + loss = torch.reshape(loss_1d, (B, SQ)) if not reduce_loss else (torch.sum(loss_1d) / n_rows) + + return loss, _input + + +def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): + """Backward implementation of cross entropy loss kernel""" + + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + pass + + else: + B, SQ, V = _input.shape + n_rows = B * SQ + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + element_mul_kernel[(n_rows,)]( + _input, + _input.stride(-2), + grad_output, + V, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + + return _input diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py new file mode 100644 index 0000000000..1c5fd73581 --- /dev/null +++ b/transformer_engine/pytorch/triton/permutation.py @@ -0,0 +1,734 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Permutation kernels written with OpenAI Triton.""" + +from typing import Union + +import torch +import triton +import triton.language as tl + +from transformer_engine_torch import DType as TE_DType + + +@triton.jit +def _row_id_map_pass_1_kernel( + # pointers + routing_map_ptr, + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # strides + stride_routing_map_token, + stride_routing_map_expert, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + expert_token_mask = tl.load( + routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, + mask=(offset < num_tokens), + other=0, + ).to(tl.int64) + row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask + tl.store( + row_id_map_ptr + pid_m * num_tokens + offset, + row_id_within_token_block, + mask=offset < num_tokens, + ) + n_tokens_per_block = tl.sum(expert_token_mask) + tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block) + + +@triton.jit +def _row_id_map_pass_2_kernel( + # pointers + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # metas + WORKSPACE_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + row_id_within_token_block = tl.load( + row_id_map_ptr + pid_m * num_tokens + offset, mask=(offset < num_tokens), other=0 + ) + + workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) + n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) + row_id = tl.where( + row_id_within_token_block == 0, + -1, + row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, + ) + tl.store( + row_id_map_ptr + pid_m * num_tokens + offset, + row_id, + mask=(offset < num_tokens), + ) + + +def make_row_id_map( + routing_map: torch.Tensor, + num_tokens: int, + num_experts: int, +): + # pylint: disable=missing-function-docstring + row_id_map = torch.empty((num_experts, num_tokens), dtype=torch.int64, device="cuda") + block_size = 256 + grid = (num_experts, triton.cdiv(num_tokens, block_size)) + workspace_tensor = torch.empty(grid, dtype=torch.int64, device="cuda") + # block cumsum + _row_id_map_pass_1_kernel[grid]( + routing_map, + row_id_map, + workspace_tensor, + num_tokens, + routing_map.stride(0), + routing_map.stride(1), + block_size, + ) + # cumsum all and process the mask + _row_id_map_pass_2_kernel[grid]( + row_id_map, + workspace_tensor, + num_tokens, + triton.next_power_of_2(num_experts * triton.cdiv(num_tokens, block_size)), + block_size, + ) + return row_id_map + + +@triton.jit +def _permute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, + # sizes + num_tokens, + num_experts, + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_probs_expert, + stride_permuted_probs_token, + # metas + PERMUTE_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + cur_pos = 0 + while cur_pos < hidden_size: + cur_off = cur_pos + tl.arange(0, BLOCK_SIZE) + mask = cur_off < hidden_size + input_off = pid * stride_input_token + cur_off * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + for expert_idx in range(num_experts): + dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) + if dst_row != -1: + output_off = dst_row * stride_output_token + cur_off * stride_output_hidden + tl.store(output_ptr + output_off, inp, mask=mask) + if PERMUTE_PROBS: + if cur_pos == 0: + prob_off = pid * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + cur_pos += BLOCK_SIZE + + +try: + _permute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_permute_kernel) +except RuntimeError: + pass + + +def permute_with_mask_map( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, +): + # pylint: disable=missing-function-docstring + output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None + grid = (num_tokens,) + _permute_kernel[grid]( + inp, + output, + row_id_map, + probs, + permuted_probs, + num_tokens, + num_experts, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + probs.stride(0) if probs is not None else None, + probs.stride(1) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, + ) + return output, permuted_probs + + +@triton.jit +def _unpermute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + merging_probs_ptr, + permuted_probs_ptr, + unpermuted_probs_ptr, + # sizes + num_tokens, + num_experts, + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_permuted_probs_token, + stride_unpermuted_probs_token, + stride_unpermuted_probs_expert, + # metas + WITH_MERGING_PROBS: tl.constexpr, + PERMUTE_PROBS: tl.constexpr, + FP8_DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + if FP8_DTYPE == "e5m2": + data_type = tl.float8e5 + pytorch_tensor_dtype = tl.uint8 + elif FP8_DTYPE == "e4m3": + data_type = tl.float8e4nv + pytorch_tensor_dtype = tl.uint8 + else: + data_type = input_ptr.dtype.element_ty + assert FP8_DTYPE is None + compute_type = tl.float32 + + pid = tl.program_id(0) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + for expert_idx in range(num_experts): + src_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) + if src_row != -1: + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + if FP8_DTYPE is not None: + inp = inp.to(data_type, bitcast=True) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + inp *= merging_prob + accumulator += inp + if PERMUTE_PROBS: + if current_start == 0: + unpermuted_prob_off = ( + pid * stride_unpermuted_probs_token + + expert_idx * stride_unpermuted_probs_expert + ) + if src_row != -1: + permuted_prob_off = src_row * stride_permuted_probs_token + prob = tl.load(permuted_probs_ptr + permuted_prob_off) + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) + else: + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, 0.0) + if FP8_DTYPE is not None: + if not WITH_MERGING_PROBS: + # Directly adding these value may cause overflow for fp8, we scale it here. + # The outside fp8_scale_inv is also scaled in the meantime. + accumulator /= num_experts + accumulator = accumulator.to(data_type).to(pytorch_tensor_dtype, bitcast=True) + else: + accumulator = accumulator.to(data_type) + output_off = pid * stride_output_token + current_offset * stride_output_hidden + tl.store(output_ptr + output_off, accumulator, mask=mask) + current_start += BLOCK_SIZE + + +try: + _unpermute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_unpermute_kernel) +except RuntimeError: + pass + + +def unpermute_with_mask_map( + inp: torch.Tensor, + row_id_map: torch.Tensor, + merging_probs: Union[torch.Tensor, None], + permuted_probs: Union[torch.Tensor, None], + num_tokens: int, + num_experts: int, + hidden_size: int, + fp8_dtype: TE_DType, +): + # pylint: disable=missing-function-docstring + if fp8_dtype == TE_DType.kFloat8E5M2: + fp8_dtype = "e5m2" + elif fp8_dtype == TE_DType.kFloat8E4M3: + fp8_dtype = "e4m3" + else: + fp8_dtype = None + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if permuted_probs is not None: + unpermuted_probs = torch.empty( + (num_tokens, num_experts), dtype=permuted_probs.dtype, device="cuda" + ) + else: + unpermuted_probs = None + grid = (num_tokens,) + _unpermute_kernel[grid]( + inp, + output, + row_id_map, + merging_probs, + permuted_probs, + unpermuted_probs, + num_tokens, + num_experts, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + merging_probs.stride(0) if merging_probs is not None else None, + merging_probs.stride(1) if merging_probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + unpermuted_probs.stride(0) if unpermuted_probs is not None else None, + unpermuted_probs.stride(1) if unpermuted_probs is not None else None, + WITH_MERGING_PROBS=merging_probs is not None, + PERMUTE_PROBS=permuted_probs is not None, + FP8_DTYPE=fp8_dtype, + ) + return output, unpermuted_probs + + +@triton.jit +def _unpermute_bwd_with_merging_probs_kernel( + # pointers + fwd_output_grad_ptr, + fwd_input_grad_ptr, + fwd_input_ptr, + merging_probs_ptr, + merging_probs_grad_ptr, + row_id_map_ptr, + # sizes + num_tokens, + num_experts, + hidden_size, + # strides + stride_fwd_output_grad_token, + stride_fwd_output_grad_hidden, + stride_fwd_input_grad_token, + stride_fwd_input_grad_hidden, + stride_fwd_input_token, + stride_fwd_input_hidden, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_merging_probs_grad_token, + stride_merging_probs_grad_expert, + # metas + FP8_DTYPE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + if FP8_DTYPE == "e5m2": + data_type = tl.float8e5 + pytorch_tensor_dtype = tl.uint8 + elif FP8_DTYPE == "e4m3": + data_type = tl.float8e4nv + pytorch_tensor_dtype = tl.uint8 + else: + data_type = fwd_output_grad_ptr.dtype.element_ty + assert FP8_DTYPE is None + compute_type = tl.float32 + + pid = tl.program_id(0) + for expert_idx in range(num_experts): + dst_row = tl.load(row_id_map_ptr + expert_idx * num_tokens + pid) + if dst_row != -1: + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_off = ( + pid * stride_fwd_output_grad_token + + current_offset * stride_fwd_output_grad_hidden + ) + inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) + if FP8_DTYPE is not None: + inp = inp.to(data_type, bitcast=True) + inp = inp.to(compute_type) + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + output = inp * merging_prob + output = output.to(data_type) + if FP8_DTYPE is not None: + output = output.to(pytorch_tensor_dtype, bitcast=True) + output_off = ( + dst_row * stride_fwd_input_grad_token + + current_offset * stride_fwd_input_grad_hidden + ) + tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) + + fwd_input_off = ( + dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden + ) + fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) + if FP8_DTYPE is not None: + fwd_input = fwd_input.to(data_type, bitcast=True) + prob_grad_accum += fwd_input.to(compute_type) * inp + current_start += BLOCK_SIZE + probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) + else: + probs_grad_off = ( + pid * stride_merging_probs_grad_token + + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, 0.0) + + +try: + _unpermute_bwd_with_merging_probs_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_unpermute_bwd_with_merging_probs_kernel) +except RuntimeError: + pass + + +def unpermute_with_mask_map_bwd_with_merging_probs( + fwd_output_grad: torch.Tensor, + row_id_map: torch.Tensor, + fwd_input: torch.Tensor, + merging_probs: torch.Tensor, + num_tokens: int, + num_experts: int, + num_out_tokens: int, + hidden_size: int, + fp8_dtype: TE_DType, +): + # pylint: disable=missing-function-docstring + if fp8_dtype == TE_DType.kFloat8E5M2: + fp8_dtype = "e5m2" + elif fp8_dtype == TE_DType.kFloat8E4M3: + fp8_dtype = "e4m3" + else: + fp8_dtype = None + act_grad = torch.empty( + (num_out_tokens, hidden_size), dtype=fwd_output_grad.dtype, device="cuda" + ) + merging_probs_grad = torch.empty( + (num_tokens, num_experts), dtype=merging_probs.dtype, device="cuda" + ) + grid = (num_tokens,) + _unpermute_bwd_with_merging_probs_kernel[grid]( + fwd_output_grad, + act_grad, + fwd_input, + merging_probs, + merging_probs_grad, + row_id_map, + num_tokens, + num_experts, + hidden_size, + fwd_output_grad.stride(0), + fwd_output_grad.stride(1), + act_grad.stride(0), + act_grad.stride(1), + fwd_input.stride(0), + fwd_input.stride(1), + merging_probs.stride(0), + merging_probs.stride(1), + merging_probs_grad.stride(0), + merging_probs_grad.stride(1), + fp8_dtype, + ) + return act_grad, merging_probs_grad + + +@triton.jit +def _sort_chunks_by_idxs_kernel( + # pointers + input_ptr, + split_sizes_ptr, + sorted_indices_ptr, + output_ptr, + dst_rows_ptr, + probs_ptr, + permuted_probs_ptr, + # sizes + num_splits, + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, + # metas + PERMUTE_PROBS: tl.constexpr, + IDX_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + load_split_offset = tl.arange(0, IDX_LOAD_WIDTH) + sorted_indices = tl.load( + sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits + ) + + # get chunk idx of the current token in the input tensor + input_chunk_idx = -1 + in_chunk_offset = tl.zeros([], dtype=tl.int64) + acc_chunk_sizes = tl.zeros([], dtype=tl.int64) + cursor = 0 + while cursor < num_splits: + cur_chunk_size = tl.load(split_sizes_ptr + cursor).to(tl.int64) + acc_chunk_sizes += cur_chunk_size + if input_chunk_idx == -1 and acc_chunk_sizes > pid: + input_chunk_idx = cursor + in_chunk_offset = pid - (acc_chunk_sizes - cur_chunk_size) + cursor += 1 + + # get chunk idx of the current token in the output tensor + output_chunk_idx = 0 + cursor = 0 + while cursor < num_splits: + cur_input_idx = tl.load(sorted_indices_ptr + cursor) + if cur_input_idx == input_chunk_idx: + output_chunk_idx = cursor + cursor += 1 + + # make row_id_map + output_split_sizes = tl.load( + split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits + ).to(tl.int64) + output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) + dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset + tl.store(dst_rows_ptr + pid, dst_row) + + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = pid * stride_input_token + current_offset * stride_input_hidden + output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) + current_start += BLOCK_SIZE + + if PERMUTE_PROBS: + prob_off = pid * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + + +try: + _sort_chunks_by_idxs_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_sort_chunks_by_idxs_kernel) +except RuntimeError: + pass + + +def sort_chunks_by_idx( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_indices: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + hidden_size: int, + num_splits: int, +): + # pylint: disable=missing-function-docstring + row_id_map = torch.empty((num_tokens,), dtype=torch.int64, device="cuda") + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None + grid = (num_tokens,) + _sort_chunks_by_idxs_kernel[grid]( + inp, + split_sizes, + sorted_indices, + output, + row_id_map, + probs, + permuted_probs, + num_splits, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + probs.stride(0) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, + IDX_LOAD_WIDTH=triton.next_power_of_2(num_splits), + ) + return output, row_id_map, permuted_probs + + +@triton.jit +def _sort_chunks_by_map_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, + # sizes + hidden_size, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, + # metas + PERMUTE_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + dst_row = tl.load(row_id_map_ptr + pid) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = dst_row * stride_input_token + current_offset * stride_input_hidden + output_offsets = pid * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) + current_start += BLOCK_SIZE + if PERMUTE_PROBS: + prob_off = dst_row * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = pid * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + + +try: + _sort_chunks_by_map_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + ], + key=["hidden_size"], + )(_sort_chunks_by_map_kernel) +except RuntimeError: + pass + + +def sort_chunks_by_map( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + hidden_size: int, +): + # pylint: disable=missing-function-docstring + output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device="cuda") + if probs is not None: + permuted_probs = torch.empty((num_tokens,), dtype=probs.dtype, device="cuda") + else: + permuted_probs = None + grid = (num_tokens,) + _sort_chunks_by_map_kernel[grid]( + inp, + output, + row_id_map, + probs, + permuted_probs, + hidden_size, + inp.stride(0), + inp.stride(1), + output.stride(0), + output.stride(1), + probs.stride(0) if probs is not None else None, + permuted_probs.stride(0) if permuted_probs is not None else None, + PERMUTE_PROBS=probs is not None, + ) + return output, permuted_probs diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 947c642c2c..1922a7e867 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -6,11 +6,14 @@ from __future__ import annotations import functools import math -from typing import Any, Callable, Optional, Tuple +import os +from typing import Any, Callable, List, Optional, Tuple import torch import transformer_engine.pytorch.cpp_extensions as ext +from .tensor.quantized_tensor import QuantizedTensor + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" @@ -27,12 +30,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: Must be used carefully. """ - from .float8_tensor import Float8Tensor - for t in tensors: if t is not None: - if isinstance(t, Float8Tensor): - t._data.data = torch.Tensor() + if isinstance(t, QuantizedTensor): + t.clear() else: t.data = torch.Tensor() del t @@ -231,14 +232,15 @@ def check_dim_for_fp8_exec(tensor: torch.Tensor) -> bool: return tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0 -def assert_dim_for_fp8_exec(tensor: torch.Tensor) -> None: - """Assert that tensor dimensions are supported for FP8 TN GEMM""" - # single tensor check so it's clear which tensor is triggering the assertion - assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( - "FP8 execution requires 2D input matrices with " - "height divisible by 8 and width divisible by 16, " - f"but got tensor with dims={list(tensor.size())}" - ) +def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: + """Assert that tensor or tensors dimensions are supported for FP8 TN GEMM.""" + + for tensor in tensors: + assert tensor.dim() == 2 and tensor.size(0) % 8 == 0 and tensor.size(1) % 16 == 0, ( + "FP8 execution requires 2D input matrices with " + "height divisible by 8 and width divisible by 16, " + f"but got tensor with dims={list(tensor.size())}" + ) def is_bf16_compatible() -> None: @@ -248,6 +250,13 @@ def is_bf16_compatible() -> None: return torch.cuda.get_device_capability()[0] >= 8 +def non_tn_fp8_gemm_supported() -> bool: + """Checks whether the device supports + non-TN layouts for FP8 GEMMs. + """ + return torch.cuda.get_device_capability() >= (10, 0) + + @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" @@ -305,3 +314,75 @@ def devices_match(device1: torch.device, device2: torch.device) -> bool: index2 = torch.cuda.current_device() return index1 == index2 return device1 == device2 + + +@functools.lru_cache +def get_sm_count() -> int: + """Returns the number of streaming multiprocessors in the current device.""" + return torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count + + +def round_up_to_nearest_multiple(value, multiple): + """Round up `value` to the next mutiple of `multiple`""" + if multiple == 0: + raise ValueError("multiple cannot be zero.") + return ((value + multiple - 1) // multiple) * multiple + + +@functools.lru_cache(maxsize=None) +def _nvtx_enabled() -> bool: + """Check if NVTX range profiling is enabled""" + return bool(int(os.getenv("NVTE_NVTX_ENABLED", "0"))) + + +# Messages associated with active NVTX ranges +_nvtx_range_messages: list[str] = [] + + +def nvtx_range_push(msg: str) -> None: + """Push NVTX range onto stack, if NVTX range profiling is enabled + + Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range + profiling. + + Parameters + ---------- + msg: str + Message to associate with range + + """ + if not _nvtx_enabled(): + return + _nvtx_range_messages.append(msg) + torch.cuda.nvtx.range_push(msg) + + +def nvtx_range_pop(msg: Optional[str] = None) -> None: + """Pop NVTX range from stack, if NVTX range profiling is enabled + + Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range + profiling. + + Parameters + ---------- + msg: str, optional + Message associated with range + + """ + + # Return immediately if NVTX range profiling is not enabled + if not _nvtx_enabled(): + return + + # Update list of NVTX range messages and check for consistency + if not _nvtx_range_messages: + raise RuntimeError("Attempted to pop NVTX range from empty stack") + last_msg = _nvtx_range_messages.pop() + if msg is not None and msg != last_msg: + raise ValueError( + f"Attempted to pop NVTX range from stack with msg={msg}, " + f"but last range has msg={last_msg}" + ) + + # Pop NVTX range + torch.cuda.nvtx.range_pop()
    cuDNN 8.9.6+: sm90
    JAX, PaddlePaddle: `no_bias`, `post_scale_bias`JAX: `no_bias`, `post_scale_bias`ALiBi slopes: FP32cuDNN 9.0+: sm80+