diff --git a/.codex b/.codex new file mode 100644 index 0000000000..e69de29bb2 diff --git a/docs/attention/attn_qat/index.md b/docs/attention/attn_qat/index.md new file mode 100644 index 0000000000..2fffe67491 --- /dev/null +++ b/docs/attention/attn_qat/index.md @@ -0,0 +1,318 @@ +# Attention QAT + +Attention QAT in FastVideo covers two related, but different, backends: + +- `ATTN_QAT_INFER`: the inference-oriented CUDA kernel path +- `ATTN_QAT_TRAIN`: the training-oriented Triton attention path + +Both are selected with `FASTVIDEO_ATTENTION_BACKEND`, but they are not +interchangeable. The main practical split is: + +- use `ATTN_QAT_INFER` for standalone inference with the dedicated inference + kernel +- use `ATTN_QAT_TRAIN` for finetuning, validation during training, or when you + specifically want to reproduce the training-side attention path + +## Quick Start + +If your goal is "run Wan 2.1 14B with Attention QAT inference weights", this is +the shortest path: + +1. Build the in-repo kernel package so FastVideo can import `attn_qat_infer`. +2. Download the Wan 2.1 14B QAT checkpoint. +3. Edit the provided inference example to point at the 14B base model and the + downloaded QAT safetensors. +4. Run the example with `ATTN_QAT_INFER`. + +### Step 1. Build the kernel package + +Before using either Attention QAT backend, build the in-repo +`fastvideo-kernel` package from source: + +```bash +git submodule update --init --recursive +cd fastvideo-kernel +./build.sh +``` + +After a successful build: + +- `ATTN_QAT_TRAIN` should be able to import `fastvideo_kernel` +- `ATTN_QAT_INFER` should be able to import `attn_qat_infer` + +`ATTN_QAT_INFER` currently targets the Blackwell CUDA path under +`fastvideo-kernel/attn_qat_infer/` and requires CUDA 12.8+. + +### Step 2. Download the Wan 2.1 14B QAT checkpoint + +FastVideo includes a helper script: + +- `examples/inference/optimizations/download_14B_qat.sh` + +By default it downloads: + +- Hugging Face repo: `FastVideo/14B_qat_400` +- local directory: `checkpoints/14B_qat_400` + +Prerequisites: + +- `huggingface_hub` installed, for example: + `uv pip install huggingface_hub` +- access to the model repo if it is private or gated: + `huggingface-cli login` + +Run the downloader: + +```bash +bash examples/inference/optimizations/download_14B_qat.sh +``` + +To download into a custom directory: + +```bash +bash examples/inference/optimizations/download_14B_qat.sh /path/to/14B_qat_400 +``` + +The script prints a ready-to-copy `init_weights_from_safetensors=...` value at +the end. + +### Step 3. Edit the provided inference example + +The example to start from is: + +- `examples/inference/optimizations/attn_qat_inference_example.py` + +Open that file and update these two values: + +1. Change the base model from `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` to + `Wan-AI/Wan2.1-T2V-14B-Diffusers` +2. Replace + `init_weights_from_safetensors="safetensors_path"` with the directory that + contains the downloaded `.safetensors` files + +Example: + +```python +import os + +from fastvideo import VideoGenerator + +os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "ATTN_QAT_INFER" + +generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-T2V-14B-Diffusers", + num_gpus=1, + use_fsdp_inference=True, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pin_cpu_memory=False, + init_weights_from_safetensors="checkpoints/14B_qat_400", +) +``` + +Important: + +- the checked-in example currently uses the `1.3B` base model until you edit it +- do not load the 14B QAT weights on top of the `1.3B` base model; the weights + and model config will not match + +### Step 4. Run the inference example + +```bash +python examples/inference/optimizations/attn_qat_inference_example.py +``` + +Generated videos are written to `video_samples/` by default. + +## Backend Overview + +| Backend | Best for | Package requirement | Primary kernel location | +|---------|----------|---------------------|-------------------------| +| `ATTN_QAT_TRAIN` | finetuning, training-time validation, reproducing the training path | `fastvideo_kernel` | `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/attn_qat_train.py` | +| `ATTN_QAT_INFER` | standalone inference with the dedicated CUDA kernel | `attn_qat_infer` from the in-repo `fastvideo-kernel` checkout | `fastvideo-kernel/attn_qat_infer/` | + +FastVideo routes backend selection through: + +- `fastvideo/envs.py` +- `fastvideo/platforms/cuda.py` +- `fastvideo/attention/backends/attn_qat_train.py` +- `fastvideo/attention/backends/attn_qat_infer.py` + +The legacy training pipeline also contains explicit Attention QAT integration: + +- `fastvideo/training/training_pipeline.py` + +That pipeline forces generator loading through `ATTN_QAT_TRAIN` when +`FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_TRAIN` or `--generator_4bit_attn` is +enabled. + +## Inference Workflows + +For standalone inference, prefer `ATTN_QAT_INFER` when the CUDA kernel is +available. Use `ATTN_QAT_TRAIN` for inference only if you intentionally want to +exercise the training-side attention path for debugging or parity checks. + +### Minimal Python example + +```python +import os + +from fastvideo import VideoGenerator + +os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "ATTN_QAT_INFER" + +generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + num_gpus=1, +) + +generator.generate_video( + "A cinematic close-up of rain on a neon street at night.", + output_path="video_samples", + save_video=True, +) +``` + +### Loading custom safetensors during inference + +FastVideo supports loading custom transformer weights through +`init_weights_from_safetensors`. + +This value can point to either: + +- a directory containing one or more `.safetensors` files +- a single `.safetensors` file + +For Wan 2.1 14B QAT inference, the common pattern is: + +```python +generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-T2V-14B-Diffusers", + num_gpus=1, + use_fsdp_inference=True, + init_weights_from_safetensors="checkpoints/14B_qat_400", +) +``` + +### CLI example + +You can also force the backend from the command line: + +```bash +FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_INFER \ +fastvideo generate \ + --model-path Wan-AI/Wan2.1-T2V-14B-Diffusers \ + --num-gpus 1 \ + --sp-size 1 \ + --tp-size 1 \ + --height 480 \ + --width 832 \ + --num-frames 77 \ + --num-inference-steps 50 \ + --guidance-scale 6.0 \ + --prompt "A cinematic close-up of rain on a neon street at night." \ + --output-path outputs_video/ +``` + +If you want to use custom QAT transformer weights from the CLI, pass the same +custom weight override that the Python API uses: + +```bash +FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_INFER \ +fastvideo generate \ + --model-path Wan-AI/Wan2.1-T2V-14B-Diffusers \ + --init-weights-from-safetensors checkpoints/14B_qat_400 \ + --num-gpus 1 \ + --output-path outputs_video/ \ + --prompt "A cinematic close-up of rain on a neon street at night." +``` + +## Training Workflows + +Today the checked-in Attention QAT training launchers use the legacy training +pipeline in `fastvideo/training/wan_training_pipeline.py`. + +### Ready-made launchers + +Use the provided SLURM scripts directly: + +```bash +sbatch examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v_qat_attn.sh +sbatch examples/training/finetune/wan_t2v_14B/finetune_t2v_qat_attn.sh +``` + +Both scripts already set: + +```bash +export FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_TRAIN +``` + +Before launching, update the script-local values that depend on your +environment: + +- `WANDB_API_KEY` +- `MODEL_PATH` +- `DATA_DIR` +- `VALIDATION_DATASET_FILE` +- output directory and SLURM resource requests + +### What the launchers run + +The training scripts eventually invoke: + +```bash +torchrun fastvideo/training/wan_training_pipeline.py ... +``` + +If you are adapting the workflow to your own cluster or running outside SLURM, +the main Attention QAT requirement is still: + +```bash +export FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_TRAIN +``` + +Then launch the normal Wan training pipeline with your preferred `torchrun` +arguments and training flags. + +## Where The Code Lives + +Use these paths when you want to trace or modify the Attention QAT flow: + +| Location | Purpose | +|----------|---------| +| `fastvideo/attention/backends/attn_qat_train.py` | FastVideo wrapper that imports and calls the Triton training kernel | +| `fastvideo/attention/backends/attn_qat_infer.py` | FastVideo wrapper that imports and calls the inference kernel | +| `fastvideo-kernel/CMakeLists.txt` | Kernel build definition that compiles the `attn_qat_infer` inference extensions | +| `fastvideo/platforms/cuda.py` | Chooses the concrete attention backend at runtime | +| `fastvideo/envs.py` | Documents supported `FASTVIDEO_ATTENTION_BACKEND` values | +| `fastvideo/training/training_pipeline.py` | Training-time forcing logic for the generator attention backend | +| `fastvideo-kernel/python/fastvideo_kernel/triton_kernels/attn_qat_train.py` | Triton implementation for `ATTN_QAT_TRAIN` | +| `fastvideo-kernel/attn_qat_infer/api.py` | Python API entrypoint for the inference kernel | +| `fastvideo-kernel/benchmarks/benchmark_*.py` | Kernel-side benchmark scripts for FlashAttn2, SageAttention3, FP4, and comparison plots | +| `fastvideo-kernel/attn_qat_infer/blackwell/api.cu` | CUDA implementation behind `ATTN_QAT_INFER` | +| `fastvideo-kernel/tests/test_attn_qat_train.py` | Kernel-level test coverage for the training path | +| `examples/inference/optimizations/attn_qat_inference_example.py` | Ready-to-edit inference example for custom Attention QAT weights | +| `examples/inference/optimizations/download_14B_qat.sh` | Helper script for downloading the Wan 2.1 14B QAT checkpoint | +| `examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v_qat_attn.sh` | Ready-to-run Wan 1.3B Attention QAT finetune launcher | +| `examples/training/finetune/wan_t2v_14B/finetune_t2v_qat_attn.sh` | Ready-to-run Wan 14B Attention QAT finetune launcher | + +## Troubleshooting + +- If `ATTN_QAT_TRAIN` fails to import, verify that `fastvideo-kernel` built + successfully and exposes `fastvideo_kernel`. +- If `ATTN_QAT_INFER` fails to import, verify that the local build exposes the + `attn_qat_infer` package. +- If the Wan 2.1 14B example fails after you changed only the checkpoint path, + make sure you also changed the base model to + `Wan-AI/Wan2.1-T2V-14B-Diffusers`. +- If you hit issues with CPU memory pressure or obscure CUDA argument errors in + the example script, try setting `pin_cpu_memory=False`. +- If you want a known-safe fallback for debugging, use + `FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA`. + +## Related Pages + +- [Attention Overview](../index.md) +- [Inference Optimizations](../../inference/optimizations.md) +- [Debugging](../../utilities/debugging.md) diff --git a/docs/attention/index.md b/docs/attention/index.md index f11744901a..e274962faa 100644 --- a/docs/attention/index.md +++ b/docs/attention/index.md @@ -5,6 +5,8 @@ FastVideo provides highly optimized custom attention kernels to accelerate video ## Supported Kernels * **[Video Sparse Attention (VSA)](vsa/index.md)**: Sparse attention mechanism selecting top-k blocks. +* **[Attention QAT](attn_qat/index.md)**: Dedicated guide for Attention QAT + inference, training, checkpoint loading, and troubleshooting. * **[Sliding Tile Attention (STA)](sta/index.md)**: STA kernel support is kept in `fastvideo-kernel`; full FastVideo STA pipeline workflow is archived in `sta_do_not_delete`. diff --git a/docs/design/inference_schema_parity_inventory.yaml b/docs/design/inference_schema_parity_inventory.yaml index f4a6a66811..95c46537dd 100644 --- a/docs/design/inference_schema_parity_inventory.yaml +++ b/docs/design/inference_schema_parity_inventory.yaml @@ -35,6 +35,7 @@ surfaces: prompt_txt: request.inputs.prompt_path override_text_encoder_safetensors: generator.pipeline.components.text_encoder_weights override_text_encoder_quant: generator.engine.quantization.text_encoder_quant + transformer_quant: generator.engine.quantization.transformer_quant override_transformer_cls_name: generator.pipeline.components.override_transformer_cls_name init_weights_from_safetensors: generator.pipeline.components.transformer_weights init_weights_from_safetensors_2: generator.pipeline.components.transformer_2_weights @@ -345,6 +346,7 @@ surfaces: num_inference_steps: request.sampling.num_inference_steps num_inference_steps_sr: request.sampling.num_inference_steps_sr guidance_scale: request.sampling.guidance_scale + guidance_scale_2: request.sampling.guidance_scale_2 guidance_rescale: request.sampling.guidance_rescale boundary_ratio: request.sampling.boundary_ratio sigmas: request.sampling.sigmas @@ -364,15 +366,7 @@ surfaces: data_type: "Derived from the request shape and not a public input." sampling_param_extensions: - moved: - guidance_scale_2: - target: request.sampling.guidance_scale_2 - sources: - - fastvideo.configs.sample.lingbotworld.LingBotWorld_SamplingParam - - fastvideo.configs.sample.lingbotworld.Wan2_2_I2V_A14B_SamplingParam - - fastvideo.configs.sample.wan.SelfForcingWan2_2_T2V_A14B_480P_SamplingParam - - fastvideo.configs.sample.wan.Wan2_2_I2V_A14B_SamplingParam - - fastvideo.configs.sample.wan.Wan2_2_T2V_A14B_SamplingParam + moved: {} profile_owned: action_list: target: request.extensions.hunyuangamecraft.action_list @@ -597,6 +591,7 @@ cli: - text_encoder_cpu_offload - text_encoder_precisions - torch_compile_kwargs + - transformer_quant - tp_size - trust_remote_code - use_fsdp_inference @@ -702,6 +697,7 @@ cli: - text_encoder_cpu_offload - text_encoder_precisions - torch_compile_kwargs + - transformer_quant - tp_size - trust_remote_code - use_fsdp_inference diff --git a/docs/design/overview.md b/docs/design/overview.md index 8365bc3a0e..e26e19749f 100644 --- a/docs/design/overview.md +++ b/docs/design/overview.md @@ -167,6 +167,11 @@ How this maps to FastVideo: - Attention backends live in `fastvideo/attention/` and can be selected via `FASTVIDEO_ATTENTION_BACKEND`. +- SageAttention3 is split into two selectable backends: + `SAGE_ATTN_THREE` for the regular upstream package and + `ATTN_QAT_INFER` for the FastVideoKernel-backed inference variant. +- `ATTN_QAT_TRAIN` is a separate FastVideoKernel Triton backend for the QAT attention + path. - `LocalAttention` is used for cross-attention and most attention layers. - `DistributedAttention` is used for full-sequence self-attention in the DiT. - Tensor-parallel layers live in `fastvideo/layers/`. diff --git a/docs/inference/inference_quick_start.md b/docs/inference/inference_quick_start.md index a5e0db2f95..4a647ab482 100644 --- a/docs/inference/inference_quick_start.md +++ b/docs/inference/inference_quick_start.md @@ -107,6 +107,8 @@ If you encounter CUDA out of memory errors: (single GPU) or `use_fsdp_inference=True` (multi-GPU) - Try a smaller model or use distilled versions - Use `num_gpus` > 1 if multiple GPUs are available +- Try enabling FSDP inference with `use_fsdp_inference=True` (may slow down generation) +- Try enabling DiT layerwise offload with `dit_layerwise_offload=True` (now only a few models support this, but may introduce less overhead than FSDP) ### Slow Generation diff --git a/docs/inference/optimizations.md b/docs/inference/optimizations.md index 27f4439f6e..22d6450ef4 100644 --- a/docs/inference/optimizations.md +++ b/docs/inference/optimizations.md @@ -21,6 +21,8 @@ This page describes the various options for speeding up generation times in Fast - Video Sparse Attention: `FASTVIDEO_ATTENTION_BACKEND=VIDEO_SPARSE_ATTN` - Sage Attention: `FASTVIDEO_ATTENTION_BACKEND=SAGE_ATTN` - Sage Attention 3: `FASTVIDEO_ATTENTION_BACKEND=SAGE_ATTN_THREE` +- Attn QAT Infer: `FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_INFER` +- Attn QAT Train: `FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_TRAIN` - Video MoBA Attention: `FASTVIDEO_ATTENTION_BACKEND=VMOBA_ATTN` - Sparse Linear Attention: `FASTVIDEO_ATTENTION_BACKEND=SLA_ATTN` - SageSLA Attention: `FASTVIDEO_ATTENTION_BACKEND=SAGE_SLA_ATTN` @@ -103,6 +105,14 @@ python setup.py install # or pip install -e . ### Sage Attention 3 +FastVideo now exposes two SageAttention3-compatible backends with distinct +environment variable values: + +- `SAGE_ATTN_THREE`: the regular upstream SageAttention3 backend imported from + the `sageattn3` package. +- `ATTN_QAT_INFER`: the inference CUDA-kernel backend imported from the + in-repo `attn_qat_infer` package. + **`SAGE_ATTN_THREE`** [SageAttention 3](https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell) is an advanced attention mechanism that leverages FP4 quantization and Blackwell GPU Tensor Cores for significant performance improvements. @@ -117,6 +127,53 @@ Note that Sage Attention 3 requires `python>=3.13`, `torch>=2.8.0`, `CUDA >=12.8 To use Sage Attention 3 in FastVideo, follow the `README.md` in the linked repository to install the package from source. +### Attn QAT Infer + +**`ATTN_QAT_INFER`** + +This backend uses the `attn_qat_infer` implementation that lives in the +`fastvideo-kernel` repository alongside the `fastvideo_kernel` Triton kernels. +Use this backend when you want to run the dedicated FP4 inference CUDA kernel +directly during inference. + +For the full Attention QAT guide, including Wan 2.1 14B checkpoint download, +example editing steps, training launchers, and troubleshooting, see +[Attention QAT](../attention/attn_qat/index.md). + +This backend currently assumes access to the in-repo `fastvideo-kernel` +checkout or an equivalent editable/source install that exposes: + +- `attn_qat_infer` + +Example: + +```python +os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "ATTN_QAT_INFER" +``` + +### QAT Attention + +**`ATTN_QAT_TRAIN`** + +This backend uses the FastVideoKernel Triton attention implementation from +`fastvideo_kernel.triton_kernels.attn_qat_train`. Use it when you specifically +want the training-oriented Triton attention path rather than the +`attn_qat_infer` CUDA kernel path. + +The dedicated [Attention QAT](../attention/attn_qat/index.md) page covers when +to use `ATTN_QAT_TRAIN` versus `ATTN_QAT_INFER`, the ready-made training +launchers, and the end-to-end Wan 2.1 14B inference workflow. + +This backend currently assumes access to an install that exposes: + +- `fastvideo_kernel` + +Example: + +```python +os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "ATTN_QAT_TRAIN" +``` + ### V-MoBA / SLA / SageSLA These backends are model-specific and require the corresponding kernels and diff --git a/docs/utilities/debugging.md b/docs/utilities/debugging.md index 137b127c08..aeeaddd6d2 100644 --- a/docs/utilities/debugging.md +++ b/docs/utilities/debugging.md @@ -27,7 +27,8 @@ Useful variables: - `FASTVIDEO_LOGGING_LEVEL`: `DEBUG`, `INFO`, `WARNING`, `ERROR` - `FASTVIDEO_STAGE_LOGGING`: print per-stage timings during pipeline execution - `FASTVIDEO_ATTENTION_BACKEND`: force an attention backend (for example - `TORCH_SDPA` or `FLASH_ATTN`) + `TORCH_SDPA`, `FLASH_ATTN`, `SAGE_ATTN_THREE`, or + `ATTN_QAT_INFER`, or `ATTN_QAT_TRAIN`) ## Common Failure Modes @@ -52,7 +53,11 @@ If forcing a backend fails, verify optional dependencies are installed: - `VIDEO_SPARSE_ATTN`: `fastvideo-kernel` - `SLIDING_TILE_ATTN`: STA legacy workflow in `sta_do_not_delete` + `fastvideo-kernel` -- `SAGE_ATTN` / `SAGE_ATTN_THREE`: SageAttention packages +- `SAGE_ATTN`: SageAttention package +- `SAGE_ATTN_THREE`: upstream `sageattn3` package +- `ATTN_QAT_INFER`: `fastvideo-kernel` checkout/source install that exposes + `attn_qat_infer` +- `ATTN_QAT_TRAIN`: `fastvideo-kernel` install exposing `fastvideo_kernel` As a fallback, use: diff --git a/examples/distill/SFWan2.2-A14B/distill_dmd.sh b/examples/distill/SFWan2.2-A14B/distill_dmd.sh index c9bbfd882e..44c0ff4885 100644 --- a/examples/distill/SFWan2.2-A14B/distill_dmd.sh +++ b/examples/distill/SFWan2.2-A14B/distill_dmd.sh @@ -22,7 +22,7 @@ export NODE_RANK=$SLURM_PROCID nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) export MASTER_ADDR=${nodes[0]} export TOKENIZERS_PARALLELISM=false -export WANDB_API_KEY="2f25ad37933894dbf0966c838c0b8494987f9f2f" +export WANDB_API_KEY=YOUR_WANDB_API_KEY # export WANDB_API_KEY='your_wandb_api_key_here' export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online diff --git a/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/README.md b/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/README.md index cce5a097f8..b43e87fdf9 100644 --- a/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/README.md +++ b/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/README.md @@ -9,7 +9,7 @@ pip install vsa ### 1. Download dataset: ```bash -bash examples/distill/Wan-Syn-480P/download_dataset.sh +bash examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/download_dataset.sh ``` ### 2. Configure and run distillation: diff --git a/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/download_dataset.sh b/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/download_dataset.sh index 50e072a4a4..390027143e 100644 --- a/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/download_dataset.sh +++ b/examples/distill/Wan2.1-T2V/Wan-Syn-Data-480P/download_dataset.sh @@ -1,3 +1,3 @@ #!/bin/bash - -python scripts/huggingface/download_hf.py --repo_id "FastVideo/Wan-Syn_77x448x832_600k" --local_dir "FastVideo/Wan-Syn_77x448x832_600k" --repo_type "dataset" +mkdir -p data +python scripts/huggingface/download_hf.py --repo_id "FastVideo/Wan-Syn_77x448x832_600k" --local_dir "data/Wan-Syn_77x448x832_600k" --repo_type "dataset" diff --git a/examples/inference/optimizations/README.md b/examples/inference/optimizations/README.md index ee644a90b2..9147183fbc 100644 --- a/examples/inference/optimizations/README.md +++ b/examples/inference/optimizations/README.md @@ -1,5 +1,79 @@ # Optimization Examples +## Wan 2.1 QAT Attention 14B Inference + +Use these files for Wan 2.1 14B inference with the `ATTN_QAT_INFER` backend: + +- `examples/inference/optimizations/download_14B_qat.sh` +- `examples/inference/optimizations/attn_qat_inference_example.py` + +### 1. Download the 14B QAT checkpoint + +The helper script downloads the QAT safetensors from +`FastVideo/14B_qat_400` into `checkpoints/14B_qat_400` by default. + +Prerequisites: + +- `huggingface_hub` installed, for example: `uv pip install huggingface_hub` +- access to the model repo if it is private or gated: `huggingface-cli login` + +Run: + +```bash +bash examples/inference/optimizations/download_14B_qat.sh +``` + +To download into a custom directory, pass it as the first argument: + +```bash +bash examples/inference/optimizations/download_14B_qat.sh /path/to/14B_qat_400 +``` + +### 2. Edit the inference example for Wan 2.1 14B + +Open `examples/inference/optimizations/attn_qat_inference_example.py` and +update these two values: + +1. Change the base model from `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` to + `Wan-AI/Wan2.1-T2V-14B-Diffusers`. +2. Replace the placeholder + `init_weights_from_safetensors="safetensors_path"` with the directory that + contains the downloaded `.safetensors` files. + +Example: + +```python +generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-T2V-14B-Diffusers", + num_gpus=1, + use_fsdp_inference=True, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pin_cpu_memory=False, + init_weights_from_safetensors="checkpoints/14B_qat_400", +) +``` + +The script already sets: + +```python +os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "ATTN_QAT_INFER" +``` + +### 3. Run the example + ```bash -python examples/inference/optimizations/attention_example.py +python examples/inference/optimizations/attn_qat_inference_example.py ``` + +The generated videos are written to `video_samples/` by default. + +### Notes + +- `ATTN_QAT_INFER` requires the in-repo `fastvideo-kernel` build to expose the + `attn_qat_infer` package. +- If you have not built the kernel yet, run `cd fastvideo-kernel && ./build.sh` + first. +- If you keep the example on the `1.3B` base model while loading the 14B QAT + weights, the model/config will not match. diff --git a/examples/inference/optimizations/attn_qat_inference_example.py b/examples/inference/optimizations/attn_qat_inference_example.py new file mode 100644 index 0000000000..b612959fdd --- /dev/null +++ b/examples/inference/optimizations/attn_qat_inference_example.py @@ -0,0 +1,54 @@ +from fastvideo import VideoGenerator +import os +from pathlib import Path +# from fastvideo.configs.sample import SamplingParam + +OUTPUT_PATH = "video_samples" +os.environ["FASTVIDEO_ATTENTION_BACKEND"] = "ATTN_QAT_INFER" + +CHECKPOINT_PATH = Path(__file__).parent.parent.parent + +def main(): + # FastVideo will automatically use the optimal default arguments for the + # model. + # If a local path is provided, FastVideo will make a best effort + # attempt to identify the optimal arguments. + generator = VideoGenerator.from_pretrained( + "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + # FastVideo will automatically handle distributed setup + num_gpus=1, + use_fsdp_inference=True, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pin_cpu_memory=False, # set to false if low CPU RAM or hit obscure "CUDA error: Invalid argument" + # image_encoder_cpu_offload=False, + # Load custom weights from checkpoint + init_weights_from_safetensors="safetensors_path" + ) + + # sampling_param = SamplingParam.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + # sampling_param.num_frames = 45 + # sampling_param.image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + # Generate videos with the same simple API, regardless of GPU count + prompt = ( + "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes " + "wide with interest. The playful yet serene atmosphere is complemented by soft " + "natural light filtering through the petals. Mid-shot, warm and cheerful tones." + ) + video = generator.generate_video(prompt, output_path=OUTPUT_PATH, save_video=True) + # video = generator.generate_video(prompt, sampling_param=sampling_param, output_path="wan_t2v_videos/") + + # Generate another video with a different prompt, without reloading the + # model! + prompt2 = ( + "A majestic lion strides across the golden savanna, its powerful frame " + "glistening under the warm afternoon sun. The tall grass ripples gently in " + "the breeze, enhancing the lion's commanding presence. The tone is vibrant, " + "embodying the raw energy of the wild. Low angle, steady tracking shot, " + "cinematic.") + video2 = generator.generate_video(prompt2, output_path=OUTPUT_PATH, save_video=True) + + +if __name__ == "__main__": + main() diff --git a/examples/inference/optimizations/download_14B_qat.sh b/examples/inference/optimizations/download_14B_qat.sh new file mode 100755 index 0000000000..41b0dc39b3 --- /dev/null +++ b/examples/inference/optimizations/download_14B_qat.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd -- "${SCRIPT_DIR}/../../.." && pwd)" + +HF_REPO_ID="${HF_REPO_ID:-FastVideo/14B_qat_400}" +HF_REVISION="${HF_REVISION:-main}" +LOCAL_DIR="${1:-${REPO_ROOT}/checkpoints/14B_qat_400}" +PYTHON_BIN="${PYTHON:-python}" + +if ! command -v "${PYTHON_BIN}" >/dev/null 2>&1; then + echo "Python executable not found: ${PYTHON_BIN}" >&2 + exit 1 +fi + +if ! "${PYTHON_BIN}" -c "import huggingface_hub" >/dev/null 2>&1; then + echo "Missing dependency: huggingface_hub" >&2 + echo "Install it with: uv pip install huggingface_hub" >&2 + exit 1 +fi + +mkdir -p "${LOCAL_DIR}" + +echo "Downloading ${HF_REPO_ID}@${HF_REVISION}" +echo "Local directory: ${LOCAL_DIR}" + +"${PYTHON_BIN}" -c ' +import argparse +from huggingface_hub import snapshot_download + +parser = argparse.ArgumentParser() +parser.add_argument("--repo-id", required=True) +parser.add_argument("--revision", required=True) +parser.add_argument("--local-dir", required=True) +args = parser.parse_args() + +snapshot_download( + repo_id=args.repo_id, + revision=args.revision, + repo_type="model", + local_dir=args.local_dir, + local_dir_use_symlinks=False, + resume_download=True, +) +' \ + --repo-id "${HF_REPO_ID}" \ + --revision "${HF_REVISION}" \ + --local-dir "${LOCAL_DIR}" + +echo +echo "Download complete." +echo "Use this in your inference script:" +echo "init_weights_from_safetensors=\"${LOCAL_DIR}\"" +echo +echo "If the repo is private or gated, make sure you are logged in with:" +echo "huggingface-cli login" diff --git a/examples/inference/optimizations/quantization_example.py b/examples/inference/optimizations/quantization_example.py new file mode 100644 index 0000000000..d5518a489d --- /dev/null +++ b/examples/inference/optimizations/quantization_example.py @@ -0,0 +1,88 @@ +import torch +from fastvideo import VideoGenerator +from fastvideo.configs.pipelines.base import PipelineConfig + +OUTPUT_PATH = "video_samples" + + +def main(): + print("=== FP4 Quantization Video Generation Example ===") + + if not torch.cuda.is_available(): + print("Warning: CUDA not available. FP4 quantization requires GPU.") + return + + gpu_capability = torch.cuda.get_device_capability() + if gpu_capability[0] < 9: # H100 and newer + print(f"Warning: GPU capability {gpu_capability} may not support FP4. Recommended: 9.0+") + + print(f"GPU: {torch.cuda.get_device_name()}") + print(f"GPU Capability: {gpu_capability}") + + model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + # model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + pipeline_config = PipelineConfig.from_pretrained(model_id) + pipeline_config.dit_precision = "bf16" + + print("\nLoading model with FP4 quantization...") + generator = VideoGenerator.from_pretrained( + model_id, + pipeline_config=pipeline_config, + num_gpus=1, + use_fsdp_inference=True, + transformer_quant="fp4", + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=True, + pin_cpu_memory=False, + ) + + print("FP4 configuration applied. Generating videos...") + + print("\n=== Generating Video with FP4 Quantization ===") + + prompt1 = ( + "A curious raccoon peers through a vibrant field of yellow sunflowers, its eyes " + "wide with interest. The playful yet serene atmosphere is complemented by soft " + "natural light filtering through the petals. Mid-shot, warm and cheerful tones." + ) + + print(f"Prompt: {prompt1}") + print("Generating video...") + + try: + video1 = generator.generate_video( + prompt1, + output_path=OUTPUT_PATH, + save_video=True, + ) + print("✓ First video generated successfully with FP4 quantization!") + + # # Generate a second video to show the model can be reused + prompt2 = ( + "A majestic lion strides across the golden savanna, its powerful frame " + "glistening under the warm afternoon sun. The tall grass ripples gently in " + "the breeze, enhancing the lion's commanding presence. The tone is vibrant, " + "embodying the raw energy of the wild. Low angle, steady tracking shot, " + "cinematic." + ) + + print(f"\nGenerating second video...") + print(f"Prompt: {prompt2}") + + video2 = generator.generate_video( + prompt2, + output_path=OUTPUT_PATH, + save_video=True, + ) + print("✓ Second video generated successfully with FP4 quantization!") + + except Exception as e: + print(f"Error during video generation: {e}") + return + + print(f"Videos saved to: {OUTPUT_PATH}") + + +if __name__ == "__main__": + main() diff --git a/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh b/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh index 3dd2f77ec3..440edf8baf 100644 --- a/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh +++ b/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v.sh @@ -1,27 +1,47 @@ #!/bin/bash +#SBATCH --job-name=wan_t2v_1.3B_finetune +#SBATCH --partition=all +#SBATCH --nodes=1 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --output=logs/wan_t2v_1.3B_finetune.out +#SBATCH --error=logs/wan_t2v_1.3B_finetune.err + +source .venv/bin/activate + export WANDB_BASE_URL="https://api.wandb.ai" export WANDB_MODE=online export TOKENIZERS_PARALLELISM=false # export FASTVIDEO_ATTENTION_BACKEND=TORCH_SDPA +# export TRITON_PRINT_AUTOTUNING=1 # to print the best config +export WANDB_API_KEY=YOUR_WANDB_API_KEY MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers" -DATA_DIR="data/crush-smol_processed_t2v/combined_parquet_dataset/" -VALIDATION_DATASET_FILE="$(dirname "$0")/validation.json" -NUM_GPUS=4 +DATA_DIR=data/Wan-Syn_77x448x832_600k +VALIDATION_DATASET_FILE="examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json" +NUM_GPUS=1 # export CUDA_VISIBLE_DEVICES=4,5 +set -euo pipefail + +# ---- torchrun rendezvous (multi-node) ---- +# Launch ONE torchrun per node (via srun) and let torchrun spawn 4 workers per node. +MASTER_ADDR="$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)" +MASTER_PORT="${MASTER_PORT:-29500}" +export MASTER_ADDR MASTER_PORT + # Training arguments training_args=( - --tracker_project_name "wan_t2v_finetune" - --output_dir "checkpoints/wan_t2v_finetune" - --max_train_steps 5000 + --tracker_project_name "wan_t2v_finetune_qat" + --output_dir "checkpoints/wan_t2v_finetune_1.3B_77" + --max_train_steps 4000 --train_batch_size 1 --train_sp_batch_size 1 - --gradient_accumulation_steps 8 + --gradient_accumulation_steps 1 --num_latent_t 20 - --num_height 480 + --num_height 448 --num_width 832 --num_frames 77 --enable_gradient_checkpointing_type "full" @@ -30,7 +50,7 @@ training_args=( # Parallel arguments parallel_args=( --num_gpus $NUM_GPUS - --sp_size $NUM_GPUS + --sp_size 1 --tp_size 1 --hsdp_replicate_dim 1 --hsdp_shard_dim $NUM_GPUS @@ -45,7 +65,7 @@ model_args=( # Dataset arguments dataset_args=( --data_path $DATA_DIR - --dataloader_num_workers 1 + --dataloader_num_workers 4 ) # Validation arguments @@ -54,16 +74,16 @@ validation_args=( --validation_dataset_file $VALIDATION_DATASET_FILE --validation_steps 200 --validation_sampling_steps "50" - --validation_guidance_scale "3.0" + --validation_guidance_scale "5.0" ) # Optimizer arguments optimizer_args=( - --learning_rate 5e-5 + --learning_rate 1e-6 --mixed_precision "bf16" --weight_only_checkpointing_steps 1000 --training_state_checkpointing_steps 1000 - --weight_decay 1e-4 + --weight_decay 0.01 --max_grad_norm 1.0 ) @@ -72,23 +92,24 @@ miscellaneous_args=( --inference_mode False --checkpoints_total_limit 3 --training_cfg_rate 0.1 - --multi_phased_distill_schedule "4000-1" - --not_apply_cfg_solver --dit_precision "fp32" - --num_euler_timesteps 50 --ema_start_step 0 - --enable_gradient_checkpointing_type "full" - # --resume_from_checkpoint "checkpoints/wan_t2v_finetune/checkpoint-2500" + --flow_shift 5 + --seed 1000 ) -torchrun \ - --nnodes 1 \ - --nproc_per_node $NUM_GPUS \ - fastvideo/training/wan_training_pipeline.py \ - "${parallel_args[@]}" \ - "${model_args[@]}" \ - "${dataset_args[@]}" \ - "${training_args[@]}" \ - "${optimizer_args[@]}" \ - "${validation_args[@]}" \ - "${miscellaneous_args[@]}" +srun --nodes="$SLURM_NNODES" --ntasks="$SLURM_NNODES" --ntasks-per-node=1 \ + torchrun \ + --nnodes "$SLURM_NNODES" \ + --nproc_per_node 4 \ + --rdzv_backend c10d \ + --rdzv_endpoint "${MASTER_ADDR}:${MASTER_PORT}" \ + --rdzv_id "$SLURM_JOB_ID" \ + fastvideo/training/wan_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" diff --git a/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v_qat_attn.sh b/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v_qat_attn.sh new file mode 100644 index 0000000000..33f124c1ba --- /dev/null +++ b/examples/training/finetune/wan_t2v_1.3B/crush_smol/finetune_t2v_qat_attn.sh @@ -0,0 +1,124 @@ +#!/bin/bash +#SBATCH --job-name=wan_t2v_1.3B_finetune_qat_16 +#SBATCH --partition=all +#SBATCH --nodes=4 +#SBATCH --gres=gpu:4 +#SBATCH --ntasks-per-node=1 +#SBATCH --output=logs/wan_t2v_1.3B_finetune_qat_16.out +#SBATCH --error=logs/wan_t2v_1.3B_finetune_qat_16.err + +source .venv/bin/activate + + +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export TOKENIZERS_PARALLELISM=false +export FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_TRAIN + +# export TRITON_PRINT_AUTOTUNING=1 # to print the best config +export WANDB_API_KEY=YOUR_WANDB_API_KEY +# Use node-local Triton cache to avoid stale file handle errors on shared filesystems +export TRITON_CACHE_DIR="/tmp/triton_cache_${SLURM_JOB_ID}_${SLURM_NODEID}" +MODEL_PATH="Wan-AI/Wan2.1-T2V-1.3B-Diffusers" +DATA_DIR=YOUR_DATA_DIR +VALIDATION_DATASET_FILE="examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json" +NUM_GPUS=16 +# export CUDA_VISIBLE_DEVICES=4,5 + +set -euo pipefail + +# ---- torchrun rendezvous (multi-node) ---- +# 1. Get the hostname of the first node (Master) +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) +MASTER_PORT=29500 + +# 2. Get the node count automatically +NNODES=$SLURM_NNODES +GPUS_PER_NODE=$SLURM_GPUS_ON_NODE +NUM_GPUS=$((NNODES * GPUS_PER_NODE)) + +echo "MASTER_ADDR=$MASTER_ADDR MASTER_PORT=$MASTER_PORT NNODES=$NNODES" + +# Training arguments +training_args=( + --tracker_project_name "wan_t2v_finetune_qat" + --output_dir "checkpoints/wan_1.3B_t2v_finetune_qat" + --max_train_steps 4000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 20 + --num_height 448 + --num_width 832 + --num_frames 77 + --enable_gradient_checkpointing_type "full" # if OOM enable this +) + +# Parallel arguments +parallel_args=( + --num_gpus $NUM_GPUS + --sp_size 1 + --tp_size 1 + --hsdp_replicate_dim $NUM_GPUS + --hsdp_shard_dim 1 +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 4 +) + +# Validation arguments +validation_args=( + --log_validation + --validation_dataset_file $VALIDATION_DATASET_FILE + --validation_steps 200 + --validation_sampling_steps "50" + --validation_guidance_scale "5.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 1e-6 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 200 + --training_state_checkpointing_steps 200 + --weight_decay 0.01 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --dit_precision "fp32" + --ema_start_step 0 + --flow_shift 1 + --seed 1000 +) + +srun torchrun \ + --nnodes $NNODES \ + --nproc_per_node $GPUS_PER_NODE \ + --node_rank $SLURM_PROCID \ + --rdzv_backend c10d \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + fastvideo/training/wan_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" diff --git a/examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json b/examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json index 674c3cb621..d324f5ee65 100644 --- a/examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json +++ b/examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json @@ -1,31 +1,131 @@ { "data": [ { - "caption": "A large metal cylinder is seen pressing down on a pile of Oreo cookies, flattening them as if they were under a hydraulic press.", - "image_path": null, - "video_path": null, - "num_inference_steps": 50, - "height": 480, + "caption": "In the video, a woman is elegantly showcasing her earrings, bringing attention to their intricate design with a gentle touch of her fingers. She is bathed in ambient purple and pink lighting, which casts a soft glow on her delicate features and enhances the vivid tones of her lipstick and eye makeup. Her hair is styled to frame her face smoothly, emphasizing the contours of her jawline and cheekbones. The background features a blurred neon light, adding an artistic and modern touch to the overall aesthetic.", + "video_path": "Fashion/mixkit-face-of-an-elegant-and-captivating-woman-41914_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, "width": 832, "num_frames": 77 }, { - "caption": "A large metal cylinder is seen compressing colorful clay into a compact shape, demonstrating the power of a hydraulic press.", - "image_path": null, - "video_path": null, - "num_inference_steps": 50, - "height": 480, + "caption": "In the video, a lone rider guides a majestic horse across an expansive, open field as the sun sets in the background. The rider, dressed in a classic blue shirt and wide-brimmed hat, sits confidently in the saddle, silhouetted against the warm glow of the evening sky. The horse moves gracefully, its mane and tail flowing with each step, creating a sense of harmony between horse and rider. Surrounding the pair, towering trees form a natural border, their leaves gently rustling in the breeze. The shadows lengthen on the ground, accentuating the serene and timeless feel of the scene. The distant hills and wooden fences frame the horizon, adding depth to the tranquil landscape. A few horses graze peacefully in the background, blending into the pastoral setting. The overall ambiance evokes a sense of calmness and quietude, capturing a perfect moment in the golden light of dusk.", + "video_path": "Man/mixkit-a-rancher-riding-a-horse-at-sunset-1143_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, "width": 832, "num_frames": 77 }, { - "caption": "A large metal cylinder is seen pressing down on a pile of colorful candies, flattening them as if they were under a hydraulic press. The candies are crushed and broken into small pieces, creating a mess on the table.", - "image_path": null, - "video_path": null, - "num_inference_steps": 50, - "height": 480, + "caption": "In a dimly lit, eerie setting, a mysterious pink bottle labeled \"Authentic 100% organic POISON\" sits prominently in the foreground, casting a menacing aura. The bottle is accentuated by green fog, which swirls lightly around it, enhancing its sinister allure. Behind it, a shadowy golden bottle adorned with a spider emblem subtly emerges, adding an extra layer of mystery to the scene. Dim candles provide faint, flickering light, which complements the dark atmosphere, making the setting ideal for an illusion of hidden dangers.", + "video_path": "smoke/mixkit-poison-in-halloween-ritual-33879_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, "width": 832, "num_frames": 77 - } - ] + }, + { + "caption": "The video opens with a tranquil scene in the heart of a dense forest, emphasizing two large, textured tree trunks in the foreground framing the view. Sunlight filters through the canopy above, casting intricate patterns of light and shadow on the trees and the ground. Between the tree trunks, a clear view of a calm, muddy river unfolds, its surface shimmering under the gentle sunlight. The riverbank is decorated with a variety of small bushes and vibrant foliage, subtly transitioning into the deep greens of tall, leafy plants. In the background, the dense forest looms, filled with dark, towering trees, their branches intertwining to form an intricate canopy. The scene is bathed in the soft glow of the sun, creating a serene and picturesque setting. Occasional sunbeams pierce through the foliage, adding a magical aura to the landscape. The vibrant reds and oranges of the smaller plants add contrast, bringing warmth to the earthy tones of the scenery. Overall, this harmonious blend of natural elements creates a peaceful and idyllic forest setting.", + "video_path": "forest/mixkit-view-of-a-river-between-two-old-trees-560_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "In the video, a martial artist dressed in a traditional white uniform with a black belt demonstrates a series of precise movements against a stark black background. The individual gracefully transitions between stances, embodying a sense of focused discipline and control. Each motion is executed with a deliberate pace, showcasing the fluidity of martial arts techniques. The soft lighting creates subtle highlights on the uniform, adding depth to the figure as it moves. The practitioner begins with an open-hand pose, feet firmly grounded, gradually shifting to a powerful forward punch. The fluidity of the sequence displays a mastery of balance and poise. Every trajectory of the limbs is precise and deliberate, capturing the elegance and strength of martial arts. The serene, isolated setting enhances the intensity and concentration of the practitioner. This visual presentation is an elegant interplay of motion and stillness, displaying the art form's discipline and grace.", + "video_path": "Man/mixkit-a-young-man-practicing-his-karate-moves-49635_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "A tranquil coastal scene unfolds with a drone's aerial view capturing a serene beach landscape. The camera glides over a quiet stretch of sandy shoreline, where gentle waves kiss the shore under a clear blue sky. Nestled amidst lush palm trees are a series of traditional thatched-roof huts, their earthy tones blending harmoniously with the natural surroundings. The sandy beach stretches endlessly, bordered by the rhythmic dance of ocean waves on one side and verdant greenery on the other. A pair of white umbrellas is set up on the sand, suggesting a place to relax and enjoy the sun. In the distance, two small human figures can be seen walking leisurely along the water's edge, leaving faint footprints behind them. The scene exudes a calm and inviting atmosphere, with the soft rustle of palm leaves and the whisper of the ocean breeze almost audible. The overall composition is a captivating blend of nature's tranquility and architectural simplicity. This picturesque setting invites viewers to imagine themselves steps away from this idyllic coastal escape.", + "video_path": "beach/mixkit-sunny-beach-in-a-dynamic-shot-from-a-drone-44383_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "A lone figure stands on a large, moss-covered rock, surrounded by the soft rush of a nearby stream. The figure is wearing white sneakers and shorts, with a plaid shirt that hangs loosely in the breeze. The lighting creates dramatic shadows, enhancing the textures of the rock and the subtle movement of the water below. In the background, a waterfall cascades into the stream, completing this tranquil and serene nature scene.", + "video_path": "forest/mixkit-woman-standing-in-front-of-waterfall-559_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "In an industrial setting, a person leans casually against a railing, exuding a sense of confidence and composure. They are wearing a striking outfit, consisting of a vibrant, patterned jacket over a simple white crop top, creating a bold contrast. The atmosphere is infused with warm, ambient lighting that casts soft shadows on the concrete walls and metallic surfaces. Intricate wiring and pipes form an intricate backdrop, enhancing the urban aesthetic. Their relaxed posture and direct, engaging gaze suggest a sense of ease in this industrial environment. This scene encapsulates a blend of modern fashion and gritty, urban architecture, creating a visually compelling narrative.", + "video_path": "Fashion/mixkit-portrait-of-a-hipster-woman-walking-down-a-stairs-1297_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "A man is energetically stretching in an open-air setting, surrounded by rows of vibrant red seats that suggest an amphitheater or outdoor venue. He wears a sleeveless black shirt layered with a hooded vest, emphasizing his athletic build as he engages in a warm-up routine. Behind him, the striking modern architecture of the building features geometric panels, with large sections of glass and overlapping metallic beams creating a dynamic backdrop. The scene captures the contrast between his focused movements and the static, bold design of the structure, while the surrounding greenery adds a touch of nature to the environment. The overall atmosphere is one of preparation and anticipation, with the man appearing determined and ready for an upcoming event or performance.", + "video_path": "Sport/mixkit-man-doing-arm-stretches-595_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "A young woman is seated on the floor in front of a plush, beige tufted couch, fully engrossed in sorting through a stack of papers. Her dark hair falls loosely past her shoulders, and she wears a green plaid shirt, contributing to the casual yet focused atmosphere. She gently places the papers onto a small round white table, occasionally lifting individual sheets to examine them more closely. Her expression shifts subtly, reflecting concentration and contemplation as she processes the information on the pages. Two small, round nested tables hold her documents, along with a small plant in a gray pot, adding a touch of greenery to the scene. The background features a dark paneled wall, creating a contrasting backdrop for the light-colored furniture. The setting is tranquil and organized, the couch and tables arranged symmetrically, conveying a sense of harmony. A calculator rests on the smaller table, hinting at a task involving calculations or budgeting.", + "video_path": "Woman/mixkit-frustrated-woman-throws-paperwork-on-the-floor-4526_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "A heavily rusted metal gate stands firmly locked, with two vertical bars joined by a thick, old chain that loops elegantly around them. The chain's texture is coarse and rugged, its surface reflecting varying shades of orange and brown, indicative of years exposed to the elements. At the heart of the chain, a black iron padlock, slightly worn yet imposing, secures the gate, its curves and edges smooth against the aged links. The gate's metalwork is outlined by a backdrop of soft, blurred greenery, suggesting a serene and isolated location beyond the barrier. Tall trees rise in the distance, their trunks and leaves creating a lush, forest-like setting that contrasts with the gate's severe rust. A pathway leads away from the gate, its surface uneven with patches of moss and weathered stone visible in the soft focus, inviting yet inaccessible. The ambiance is quiet and mysterious, with a sense of abandonment hanging subtly in the air, evoking curiosity about what lies beyond. Shadows play across the gate, cast by branches swaying gently in the breeze, adding to the dynamic interaction of light and texture. This scene, rich in detail and atmosphere, captures the viewer's imagination, evoking both the allure of the forbidden and the beauty of decay.", + "video_path": "forest/mixkit-rusty-fence-with-a-chain-of-a-property-in-nature-5294_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "In a serene and softly lit yoga studio, three individuals engage in a yoga session, each performing an upward-facing stretch. The central figure is a woman with shoulder-length brown hair, dressed in a light cropped top and green leggings, her posture reflecting grace and concentration. To her right, another participant, a woman in a purple outfit, mirrors the pose with equal poise. On her left, a person with a bun focuses intently, supported slightly by yoga blocks beneath their hands. The warm-colored wooden floor contrasts soothingly with the soft pastel mural on the back wall, featuring an abstract design and partial visage of a serene face. Natural light floods the space from a large window on the right, where lush greens peek through, adding an element of tranquility. In the corner of the room, a collection of meditation instruments, including a gong and a Buddha statue, subtly frame the peaceful setting. The mood is calm yet focused, as all three participants are deeply engaged in their practice. The scene combines elements of balance, harmony, and a shared journey towards mindfulness. This depiction captures the essence of a yoga session that blends personal growth with collective experience.", + "video_path": "People/mixkit-small-group-of-people-doing-yoga-together-43730_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "In the deep blue expanse of the ocean, two dolphins glide effortlessly, their sleek bodies reflecting the sunlight filtering through the water. The prominent shadows and caustics create a shimmering effect on their skin, capturing the beauty of their natural habitat. Each dolphin moves with a fluid grace, occasionally interacting with gentle nudges, showcasing their playful and social nature. The scene is vibrant and dynamic, with the clear blue background accentuating the dolphins' movements, making it an ideal subject for AI recreation.", + "video_path": "sea/mixkit-dolphins-underwater-4133_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "In the video, a young woman stands against a vibrant graffiti-covered wall, deeply engrossed in her smartphone. Her expression reflects a mix of focus and subtle satisfaction as she interacts with the screen. She wears a black floral-patterned top, which contrasts with the bright, abstract shapes and bold colors of the mural behind her. As she continues to engage with her phone, a series of like count notifications appear on the screen, indicating a growing online appreciation. The wall behind her features a striking mix of geometric and organic shapes, including swirls of teal, orange, and black, with large humanoid figures in a pop-art style. Her long, light-brown hair frames her face, adding a calm, composed aura amidst the lively backdrop. The video captures a blend of contemporary digital interaction and expressive urban art, creating a dynamic yet harmonious scene.", + "video_path": "Girl/mixkit-girl-looking-at-the-likes-in-her-post-4914_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "A young mother and her baby sit comfortably on a bed, surrounded by an inviting, cozy atmosphere. The woman, wearing a sleeveless top and jeans, is gently engaging with the baby, who is dressed in an adorable animal-print onesie. The child is seated on the bed with colorful toys scattered around, including a plush toy and a board book. The warm glow from a hanging lamp casts a soft light on them, enhancing the serene environment. Pillows are propped up against the headboard, providing a cushioned backdrop as the mother leans slightly over to interact with the baby. A small bottle is visible beside her, suggesting a nurturing setting. Her hand gestures animatedly as she holds up a soft, white cushion with red and blue accents, likely stimulating the baby\u2019s curiosity. Their shared moment is filled with affection and joy, a perfect snapshot of familial bonding.", + "video_path": "Baby/mixkit-loving-mother-and-her-baby-playing-with-soft-toys-49966_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + }, + { + "caption": "A young girl with long brown hair sits at a round wooden table, engrossed in working on her laptop. The laptop screen is a vivid green, suggesting a green screen effect is in use. To her left, a doll dressed in a yellow and white outfit is casually laid on top of some books, adding a playful and innocent touch to the scene. The setting is cozy, with sheer curtains in the background allowing soft natural light to spill into the room. The girl's posture and focused attention on the laptop suggest she is either playing a game or learning something new. This serene and domestic atmosphere is complemented by the slight blur of a dark couch in the foreground, framing the focused activity of the child.", + "video_path": "Girl/mixkit-little-girl-doing-homework-on-a-laptop-4757_clip_1.mp4", + "num_inference_steps": 50, + "height": 448, + "width": 832, + "num_frames": 77 + } ] } diff --git a/examples/training/finetune/wan_t2v_14B/finetune_t2v_qat_attn.sh b/examples/training/finetune/wan_t2v_14B/finetune_t2v_qat_attn.sh new file mode 100644 index 0000000000..a4db0ca624 --- /dev/null +++ b/examples/training/finetune/wan_t2v_14B/finetune_t2v_qat_attn.sh @@ -0,0 +1,123 @@ +#!/bin/bash +#SBATCH --job-name=wan_t2v_1.3B_finetune_qat_16 +#SBATCH --partition=main +#SBATCH --nodes=4 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:8 +#SBATCH --cpus-per-task=128 +#SBATCH --mem=1440G +#SBATCH --output=logs/wan_t2v_1.3B_finetune_qat_16.out +#SBATCH --error=logs/wan_t2v_1.3B_finetune_qat_16.err +#SBATCH --exclusive + +source ~/conda/miniconda/bin/activate +conda activate matthew-fv + +# Basic Info +export WANDB_MODE="online" +export NCCL_P2P_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +# different cache dir for different processes +export TRITON_CACHE_DIR=/tmp/triton_cache_${SLURM_PROCID} +export MASTER_PORT=29500 +export NODE_RANK=$SLURM_PROCID +nodes=( $(scontrol show hostnames $SLURM_JOB_NODELIST) ) +export MASTER_ADDR=${nodes[0]} +export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID +export TOKENIZERS_PARALLELISM=false +export WANDB_BASE_URL="https://api.wandb.ai" +export WANDB_MODE=online +export FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_TRAIN + +echo "MASTER_ADDR: $MASTER_ADDR" +echo "NODE_RANK: $NODE_RANK" + +# export TRITON_PRINT_AUTOTUNING=1 # to print the best config +export WANDB_API_KEY=YOUR_WANDB_API_KEY +MODEL_PATH="Wan-AI/Wan2.1-T2V-14B-Diffusers" +DATA_DIR=YOUR_DATA_DIR +VALIDATION_DATASET_FILE="examples/training/finetune/wan_t2v_1.3B/crush_smol/validation.json" +NUM_GPUS_PER_NODE=8 +TOTAL_GPUS=$((NUM_GPUS_PER_NODE * SLURM_JOB_NUM_NODES)) +# export CUDA_VISIBLE_DEVICES=4,5 + +# Training arguments +training_args=( + --tracker_project_name "wan_t2v_finetune_qat" + --output_dir "checkpoints/wan_14B_t2v_finetune_qat" + --max_train_steps 4000 + --train_batch_size 1 + --train_sp_batch_size 1 + --gradient_accumulation_steps 1 + --num_latent_t 20 + --num_height 768 + --num_width 1280 + --num_frames 77 + --enable_gradient_checkpointing_type "full" # if OOM enable this +) + +# Parallel arguments +parallel_args=( + --num_gpus $TOTAL_GPUS + --sp_size 4 + --tp_size 1 + --hsdp_replicate_dim 4 + --hsdp_shard_dim 8 +) + +# Model arguments +model_args=( + --model_path $MODEL_PATH + --pretrained_model_name_or_path $MODEL_PATH +) + +# Dataset arguments +dataset_args=( + --data_path "$DATA_DIR" + --dataloader_num_workers 4 +) + +# Validation arguments +validation_args=( + # --log_validation + --validation_dataset_file $VALIDATION_DATASET_FILE + --validation_steps 200 + --validation_sampling_steps "50" + --validation_guidance_scale "5.0" +) + +# Optimizer arguments +optimizer_args=( + --learning_rate 1e-6 + --mixed_precision "bf16" + --weight_only_checkpointing_steps 200 + --training_state_checkpointing_steps 200 + --weight_decay 0.01 + --max_grad_norm 1.0 +) + +# Miscellaneous arguments +miscellaneous_args=( + --inference_mode False + --checkpoints_total_limit 3 + --training_cfg_rate 0.1 + --dit_precision "fp32" + --ema_start_step 0 + --flow_shift 5 + --seed 1000 +) + +srun torchrun \ +--nnodes $SLURM_JOB_NUM_NODES \ +--nproc_per_node $NUM_GPUS_PER_NODE \ +--node_rank $SLURM_PROCID \ +--rdzv_backend=c10d \ +--rdzv_endpoint="$MASTER_ADDR:$MASTER_PORT" \ + fastvideo/training/wan_training_pipeline.py \ + "${parallel_args[@]}" \ + "${model_args[@]}" \ + "${dataset_args[@]}" \ + "${training_args[@]}" \ + "${optimizer_args[@]}" \ + "${validation_args[@]}" \ + "${miscellaneous_args[@]}" diff --git a/examples/training/finetune/wan_t2v_14B/validation.json b/examples/training/finetune/wan_t2v_14B/validation.json new file mode 100644 index 0000000000..1757fb7be5 --- /dev/null +++ b/examples/training/finetune/wan_t2v_14B/validation.json @@ -0,0 +1,131 @@ +{ + "data": [ + { + "caption": "In the video, a woman is elegantly showcasing her earrings, bringing attention to their intricate design with a gentle touch of her fingers. She is bathed in ambient purple and pink lighting, which casts a soft glow on her delicate features and enhances the vivid tones of her lipstick and eye makeup. Her hair is styled to frame her face smoothly, emphasizing the contours of her jawline and cheekbones. The background features a blurred neon light, adding an artistic and modern touch to the overall aesthetic.", + "video_path": "Fashion/mixkit-face-of-an-elegant-and-captivating-woman-41914_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "In the video, a lone rider guides a majestic horse across an expansive, open field as the sun sets in the background. The rider, dressed in a classic blue shirt and wide-brimmed hat, sits confidently in the saddle, silhouetted against the warm glow of the evening sky. The horse moves gracefully, its mane and tail flowing with each step, creating a sense of harmony between horse and rider. Surrounding the pair, towering trees form a natural border, their leaves gently rustling in the breeze. The shadows lengthen on the ground, accentuating the serene and timeless feel of the scene. The distant hills and wooden fences frame the horizon, adding depth to the tranquil landscape. A few horses graze peacefully in the background, blending into the pastoral setting. The overall ambiance evokes a sense of calmness and quietude, capturing a perfect moment in the golden light of dusk.", + "video_path": "Man/mixkit-a-rancher-riding-a-horse-at-sunset-1143_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "In a dimly lit, eerie setting, a mysterious pink bottle labeled \"Authentic 100% organic POISON\" sits prominently in the foreground, casting a menacing aura. The bottle is accentuated by green fog, which swirls lightly around it, enhancing its sinister allure. Behind it, a shadowy golden bottle adorned with a spider emblem subtly emerges, adding an extra layer of mystery to the scene. Dim candles provide faint, flickering light, which complements the dark atmosphere, making the setting ideal for an illusion of hidden dangers.", + "video_path": "smoke/mixkit-poison-in-halloween-ritual-33879_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "The video opens with a tranquil scene in the heart of a dense forest, emphasizing two large, textured tree trunks in the foreground framing the view. Sunlight filters through the canopy above, casting intricate patterns of light and shadow on the trees and the ground. Between the tree trunks, a clear view of a calm, muddy river unfolds, its surface shimmering under the gentle sunlight. The riverbank is decorated with a variety of small bushes and vibrant foliage, subtly transitioning into the deep greens of tall, leafy plants. In the background, the dense forest looms, filled with dark, towering trees, their branches intertwining to form an intricate canopy. The scene is bathed in the soft glow of the sun, creating a serene and picturesque setting. Occasional sunbeams pierce through the foliage, adding a magical aura to the landscape. The vibrant reds and oranges of the smaller plants add contrast, bringing warmth to the earthy tones of the scenery. Overall, this harmonious blend of natural elements creates a peaceful and idyllic forest setting.", + "video_path": "forest/mixkit-view-of-a-river-between-two-old-trees-560_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "In the video, a martial artist dressed in a traditional white uniform with a black belt demonstrates a series of precise movements against a stark black background. The individual gracefully transitions between stances, embodying a sense of focused discipline and control. Each motion is executed with a deliberate pace, showcasing the fluidity of martial arts techniques. The soft lighting creates subtle highlights on the uniform, adding depth to the figure as it moves. The practitioner begins with an open-hand pose, feet firmly grounded, gradually shifting to a powerful forward punch. The fluidity of the sequence displays a mastery of balance and poise. Every trajectory of the limbs is precise and deliberate, capturing the elegance and strength of martial arts. The serene, isolated setting enhances the intensity and concentration of the practitioner. This visual presentation is an elegant interplay of motion and stillness, displaying the art form's discipline and grace.", + "video_path": "Man/mixkit-a-young-man-practicing-his-karate-moves-49635_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "A tranquil coastal scene unfolds with a drone's aerial view capturing a serene beach landscape. The camera glides over a quiet stretch of sandy shoreline, where gentle waves kiss the shore under a clear blue sky. Nestled amidst lush palm trees are a series of traditional thatched-roof huts, their earthy tones blending harmoniously with the natural surroundings. The sandy beach stretches endlessly, bordered by the rhythmic dance of ocean waves on one side and verdant greenery on the other. A pair of white umbrellas is set up on the sand, suggesting a place to relax and enjoy the sun. In the distance, two small human figures can be seen walking leisurely along the water's edge, leaving faint footprints behind them. The scene exudes a calm and inviting atmosphere, with the soft rustle of palm leaves and the whisper of the ocean breeze almost audible. The overall composition is a captivating blend of nature's tranquility and architectural simplicity. This picturesque setting invites viewers to imagine themselves steps away from this idyllic coastal escape.", + "video_path": "beach/mixkit-sunny-beach-in-a-dynamic-shot-from-a-drone-44383_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "A lone figure stands on a large, moss-covered rock, surrounded by the soft rush of a nearby stream. The figure is wearing white sneakers and shorts, with a plaid shirt that hangs loosely in the breeze. The lighting creates dramatic shadows, enhancing the textures of the rock and the subtle movement of the water below. In the background, a waterfall cascades into the stream, completing this tranquil and serene nature scene.", + "video_path": "forest/mixkit-woman-standing-in-front-of-waterfall-559_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "In an industrial setting, a person leans casually against a railing, exuding a sense of confidence and composure. They are wearing a striking outfit, consisting of a vibrant, patterned jacket over a simple white crop top, creating a bold contrast. The atmosphere is infused with warm, ambient lighting that casts soft shadows on the concrete walls and metallic surfaces. Intricate wiring and pipes form an intricate backdrop, enhancing the urban aesthetic. Their relaxed posture and direct, engaging gaze suggest a sense of ease in this industrial environment. This scene encapsulates a blend of modern fashion and gritty, urban architecture, creating a visually compelling narrative.", + "video_path": "Fashion/mixkit-portrait-of-a-hipster-woman-walking-down-a-stairs-1297_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "A man is energetically stretching in an open-air setting, surrounded by rows of vibrant red seats that suggest an amphitheater or outdoor venue. He wears a sleeveless black shirt layered with a hooded vest, emphasizing his athletic build as he engages in a warm-up routine. Behind him, the striking modern architecture of the building features geometric panels, with large sections of glass and overlapping metallic beams creating a dynamic backdrop. The scene captures the contrast between his focused movements and the static, bold design of the structure, while the surrounding greenery adds a touch of nature to the environment. The overall atmosphere is one of preparation and anticipation, with the man appearing determined and ready for an upcoming event or performance.", + "video_path": "Sport/mixkit-man-doing-arm-stretches-595_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "A young woman is seated on the floor in front of a plush, beige tufted couch, fully engrossed in sorting through a stack of papers. Her dark hair falls loosely past her shoulders, and she wears a green plaid shirt, contributing to the casual yet focused atmosphere. She gently places the papers onto a small round white table, occasionally lifting individual sheets to examine them more closely. Her expression shifts subtly, reflecting concentration and contemplation as she processes the information on the pages. Two small, round nested tables hold her documents, along with a small plant in a gray pot, adding a touch of greenery to the scene. The background features a dark paneled wall, creating a contrasting backdrop for the light-colored furniture. The setting is tranquil and organized, the couch and tables arranged symmetrically, conveying a sense of harmony. A calculator rests on the smaller table, hinting at a task involving calculations or budgeting.", + "video_path": "Woman/mixkit-frustrated-woman-throws-paperwork-on-the-floor-4526_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "A heavily rusted metal gate stands firmly locked, with two vertical bars joined by a thick, old chain that loops elegantly around them. The chain's texture is coarse and rugged, its surface reflecting varying shades of orange and brown, indicative of years exposed to the elements. At the heart of the chain, a black iron padlock, slightly worn yet imposing, secures the gate, its curves and edges smooth against the aged links. The gate's metalwork is outlined by a backdrop of soft, blurred greenery, suggesting a serene and isolated location beyond the barrier. Tall trees rise in the distance, their trunks and leaves creating a lush, forest-like setting that contrasts with the gate's severe rust. A pathway leads away from the gate, its surface uneven with patches of moss and weathered stone visible in the soft focus, inviting yet inaccessible. The ambiance is quiet and mysterious, with a sense of abandonment hanging subtly in the air, evoking curiosity about what lies beyond. Shadows play across the gate, cast by branches swaying gently in the breeze, adding to the dynamic interaction of light and texture. This scene, rich in detail and atmosphere, captures the viewer's imagination, evoking both the allure of the forbidden and the beauty of decay.", + "video_path": "forest/mixkit-rusty-fence-with-a-chain-of-a-property-in-nature-5294_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "In a serene and softly lit yoga studio, three individuals engage in a yoga session, each performing an upward-facing stretch. The central figure is a woman with shoulder-length brown hair, dressed in a light cropped top and green leggings, her posture reflecting grace and concentration. To her right, another participant, a woman in a purple outfit, mirrors the pose with equal poise. On her left, a person with a bun focuses intently, supported slightly by yoga blocks beneath their hands. The warm-colored wooden floor contrasts soothingly with the soft pastel mural on the back wall, featuring an abstract design and partial visage of a serene face. Natural light floods the space from a large window on the right, where lush greens peek through, adding an element of tranquility. In the corner of the room, a collection of meditation instruments, including a gong and a Buddha statue, subtly frame the peaceful setting. The mood is calm yet focused, as all three participants are deeply engaged in their practice. The scene combines elements of balance, harmony, and a shared journey towards mindfulness. This depiction captures the essence of a yoga session that blends personal growth with collective experience.", + "video_path": "People/mixkit-small-group-of-people-doing-yoga-together-43730_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "In the deep blue expanse of the ocean, two dolphins glide effortlessly, their sleek bodies reflecting the sunlight filtering through the water. The prominent shadows and caustics create a shimmering effect on their skin, capturing the beauty of their natural habitat. Each dolphin moves with a fluid grace, occasionally interacting with gentle nudges, showcasing their playful and social nature. The scene is vibrant and dynamic, with the clear blue background accentuating the dolphins' movements, making it an ideal subject for AI recreation.", + "video_path": "sea/mixkit-dolphins-underwater-4133_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "In the video, a young woman stands against a vibrant graffiti-covered wall, deeply engrossed in her smartphone. Her expression reflects a mix of focus and subtle satisfaction as she interacts with the screen. She wears a black floral-patterned top, which contrasts with the bright, abstract shapes and bold colors of the mural behind her. As she continues to engage with her phone, a series of like count notifications appear on the screen, indicating a growing online appreciation. The wall behind her features a striking mix of geometric and organic shapes, including swirls of teal, orange, and black, with large humanoid figures in a pop-art style. Her long, light-brown hair frames her face, adding a calm, composed aura amidst the lively backdrop. The video captures a blend of contemporary digital interaction and expressive urban art, creating a dynamic yet harmonious scene.", + "video_path": "Girl/mixkit-girl-looking-at-the-likes-in-her-post-4914_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "A young mother and her baby sit comfortably on a bed, surrounded by an inviting, cozy atmosphere. The woman, wearing a sleeveless top and jeans, is gently engaging with the baby, who is dressed in an adorable animal-print onesie. The child is seated on the bed with colorful toys scattered around, including a plush toy and a board book. The warm glow from a hanging lamp casts a soft light on them, enhancing the serene environment. Pillows are propped up against the headboard, providing a cushioned backdrop as the mother leans slightly over to interact with the baby. A small bottle is visible beside her, suggesting a nurturing setting. Her hand gestures animatedly as she holds up a soft, white cushion with red and blue accents, likely stimulating the baby\u2019s curiosity. Their shared moment is filled with affection and joy, a perfect snapshot of familial bonding.", + "video_path": "Baby/mixkit-loving-mother-and-her-baby-playing-with-soft-toys-49966_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + }, + { + "caption": "A young girl with long brown hair sits at a round wooden table, engrossed in working on her laptop. The laptop screen is a vivid green, suggesting a green screen effect is in use. To her left, a doll dressed in a yellow and white outfit is casually laid on top of some books, adding a playful and innocent touch to the scene. The setting is cozy, with sheer curtains in the background allowing soft natural light to spill into the room. The girl's posture and focused attention on the laptop suggest she is either playing a game or learning something new. This serene and domestic atmosphere is complemented by the slight blur of a dark couch in the foreground, framing the focused activity of the child.", + "video_path": "Girl/mixkit-little-girl-doing-homework-on-a-laptop-4757_clip_1.mp4", + "num_inference_steps": 50, + "height": 768, + "width": 1280, + "num_frames": 77 + } ] +} diff --git a/fastvideo-kernel/CMakeLists.txt b/fastvideo-kernel/CMakeLists.txt index ca8b67c5aa..5fdce43204 100644 --- a/fastvideo-kernel/CMakeLists.txt +++ b/fastvideo-kernel/CMakeLists.txt @@ -12,6 +12,18 @@ else() enable_language(CUDA) # Ensure CUDA toolkit targets (CUDA::cudart, CUDA::cuda_driver, etc.) are available. find_package(CUDAToolkit REQUIRED) + if(NOT DEFINED CUDA_TOOLKIT_ROOT_DIR) + if(DEFINED CUDAToolkit_ROOT) + set(CUDA_TOOLKIT_ROOT_DIR "${CUDAToolkit_ROOT}" CACHE PATH + "CUDA toolkit root directory" FORCE) + elseif(DEFINED ENV{CUDAToolkit_ROOT}) + set(CUDA_TOOLKIT_ROOT_DIR "$ENV{CUDAToolkit_ROOT}" CACHE PATH + "CUDA toolkit root directory" FORCE) + elseif(DEFINED ENV{CUDA_HOME}) + set(CUDA_TOOLKIT_ROOT_DIR "$ENV{CUDA_HOME}" CACHE PATH + "CUDA toolkit root directory" FORCE) + endif() + endif() endif() # Import common utils if needed, but we keep it simple for now @@ -19,13 +31,46 @@ endif() # Find Python and Torch find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) -# Robustly find Torch include paths using Python +# Locate the installed torch package without importing it. This keeps CMake +# configure working even on nodes where CUDA runtime libraries are not yet on +# the dynamic loader path. execute_process( - COMMAND "${Python_EXECUTABLE}" -c "import torch; from torch.utils.cpp_extension import include_paths; print(';'.join(include_paths()))" - OUTPUT_VARIABLE TORCH_INCLUDE_PATHS + COMMAND "${Python_EXECUTABLE}" -c "import sysconfig; print(sysconfig.get_path('platlib'))" + OUTPUT_VARIABLE PYTHON_PLATLIB + OUTPUT_STRIP_TRAILING_WHITESPACE +) +execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import sysconfig; print(sysconfig.get_path('purelib'))" + OUTPUT_VARIABLE PYTHON_PURELIB OUTPUT_STRIP_TRAILING_WHITESPACE ) -list(APPEND TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_PATHS}) + +set(TORCH_PYTHON_PACKAGE_DIR "") +foreach(_candidate + "${PYTHON_PLATLIB}/torch" + "${PYTHON_PURELIB}/torch" +) + if(EXISTS "${_candidate}") + set(TORCH_PYTHON_PACKAGE_DIR "${_candidate}") + break() + endif() +endforeach() + +if(NOT TORCH_PYTHON_PACKAGE_DIR) + message(FATAL_ERROR "Could not locate the installed torch Python package.") +endif() + +list(APPEND TORCH_INCLUDE_DIRS + "${TORCH_PYTHON_PACKAGE_DIR}/include" + "${TORCH_PYTHON_PACKAGE_DIR}/include/torch/csrc/api/include" +) + +if(NOT Torch_DIR) + set(_TORCH_CONFIG_DIR "${TORCH_PYTHON_PACKAGE_DIR}/share/cmake/Torch") + if(EXISTS "${_TORCH_CONFIG_DIR}/TorchConfig.cmake") + set(Torch_DIR "${_TORCH_CONFIG_DIR}" CACHE PATH "Path to Torch CMake config" FORCE) + endif() +endif() # Find Torch package (still useful for libraries) find_package(Torch REQUIRED) @@ -50,6 +95,21 @@ include_directories( set(FASTVIDEO_KERNEL_BUILD_TK "AUTO" CACHE STRING "Build ThunderKittens kernels: AUTO/ON/OFF") set_property(CACHE FASTVIDEO_KERNEL_BUILD_TK PROPERTY STRINGS AUTO ON OFF) +set(_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT "AUTO") +if(DEFINED FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3 AND NOT DEFINED CACHE{FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER}) + set(_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT "${FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3}") +endif() + +set(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER "${_FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER_DEFAULT}" CACHE STRING + "Build attn_qat_infer Blackwell inference kernels: AUTO/ON/OFF") +set_property(CACHE FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER PROPERTY STRINGS AUTO ON OFF) + +if(DEFINED FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3) + message(DEPRECATION + "FASTVIDEO_KERNEL_BUILD_MODIFIED_SAGE3 is deprecated. " + "Use FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER instead.") +endif() + # Prefer environment variable (used by CI) if CMake var is not explicitly set. if(NOT DEFINED TORCH_CUDA_ARCH_LIST AND DEFINED ENV{TORCH_CUDA_ARCH_LIST}) set(TORCH_CUDA_ARCH_LIST "$ENV{TORCH_CUDA_ARCH_LIST}") @@ -57,6 +117,7 @@ endif() message(STATUS "TORCH_CUDA_ARCH_LIST (cmake/env): ${TORCH_CUDA_ARCH_LIST}") message(STATUS "FASTVIDEO_KERNEL_BUILD_TK: ${FASTVIDEO_KERNEL_BUILD_TK}") +message(STATUS "FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER: ${FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER}") set(ENABLE_TK_KERNELS OFF) if(FASTVIDEO_KERNEL_BUILD_TK STREQUAL "ON") @@ -91,6 +152,54 @@ else() message(STATUS "ThunderKittens kernels: DISABLED (will use Triton fallbacks at runtime)") endif() +set(ENABLE_ATTN_QAT_INFER OFF) +if(GPU_BACKEND STREQUAL "ROCM") + message(STATUS "attn_qat_infer kernels: DISABLED (ROCm build)") +else() + set(_WANTS_ATTN_QAT_INFER OFF) + if(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER STREQUAL "ON") + set(_WANTS_ATTN_QAT_INFER ON) + elseif(FASTVIDEO_KERNEL_BUILD_ATTN_QAT_INFER STREQUAL "AUTO") + if(TORCH_CUDA_ARCH_LIST) + string(REGEX MATCH + "(^|[; ,])((12\\.0a)|(120a)|(sm_120a))([; ,]|$)" + _HAS_120A "${TORCH_CUDA_ARCH_LIST}") + if(_HAS_120A) + set(_WANTS_ATTN_QAT_INFER ON) + endif() + else() + execute_process( + COMMAND "${Python_EXECUTABLE}" -c + "import torch; print('1' if (torch.cuda.is_available() and torch.version.cuda and torch.cuda.get_device_capability()[0] >= 12) else '0')" + OUTPUT_VARIABLE _LOCAL_HAS_BLACKWELL + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + ) + if(_LOCAL_HAS_BLACKWELL STREQUAL "1") + set(_WANTS_ATTN_QAT_INFER ON) + endif() + endif() + endif() + + if(_WANTS_ATTN_QAT_INFER) + if(CUDAToolkit_VERSION VERSION_LESS 12.8) + message(WARNING + "attn_qat_infer kernels require CUDA Toolkit 12.8+. " + "Skipping because CUDAToolkit_VERSION=${CUDAToolkit_VERSION}.") + else() + set(ENABLE_ATTN_QAT_INFER ON) + endif() + endif() + + if(ENABLE_ATTN_QAT_INFER) + message(STATUS "attn_qat_infer kernels: ENABLED") + else() + message(STATUS + "attn_qat_infer kernels: DISABLED " + "(requires CUDA 12.8+ and Blackwell sm_120a)") + endif() +endif() + # Always try to build the extension if CUDA is available, but conditionally add sources/flags set(BUILD_CXX_KERNELS ON) @@ -161,12 +270,15 @@ if(BUILD_CXX_KERNELS) # Also link against libtorch_python to satisfy Python-binding symbols # (e.g., torch::PyWarningHandler) required by torch/extension.h. - execute_process( - COMMAND "${Python_EXECUTABLE}" -c "import torch; from pathlib import Path; p=Path(torch.__file__).parent/'lib'; m=sorted(p.glob('libtorch_python*')); print(str(m[0]) if m else '')" - OUTPUT_VARIABLE TORCH_PYTHON_LIBRARY_PATH - OUTPUT_STRIP_TRAILING_WHITESPACE - ERROR_QUIET + file(GLOB TORCH_PYTHON_LIBRARY_CANDIDATES + "${TORCH_PYTHON_PACKAGE_DIR}/lib/libtorch_python*" ) + list(LENGTH TORCH_PYTHON_LIBRARY_CANDIDATES _TORCH_PYTHON_LIBRARY_COUNT) + if(_TORCH_PYTHON_LIBRARY_COUNT GREATER 0) + list(GET TORCH_PYTHON_LIBRARY_CANDIDATES 0 TORCH_PYTHON_LIBRARY_PATH) + else() + set(TORCH_PYTHON_LIBRARY_PATH "") + endif() if(TORCH_PYTHON_LIBRARY_PATH) message(STATUS "TORCH_PYTHON_LIBRARY_PATH: ${TORCH_PYTHON_LIBRARY_PATH}") target_link_libraries(fastvideo_kernel_ops PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}") @@ -183,3 +295,73 @@ if(BUILD_CXX_KERNELS) install(TARGETS fastvideo_kernel_ops LIBRARY DESTINATION fastvideo_kernel/_C) endif() +if(ENABLE_ATTN_QAT_INFER) + set(ATTN_QAT_INFER_DIR ${CMAKE_SOURCE_DIR}/attn_qat_infer) + set(ATTN_QAT_INFER_INCLUDE_DIRS + ${ATTN_QAT_INFER_DIR} + ${CMAKE_SOURCE_DIR}/include/cutlass/include + ${CMAKE_SOURCE_DIR}/include/cutlass/tools/util/include + ${TORCH_INCLUDE_DIRS} + ) + set(ATTN_QAT_INFER_CUDA_FLAGS + "-O3" + "-std=c++17" + "-U__CUDA_NO_HALF_OPERATORS__" + "-U__CUDA_NO_HALF_CONVERSIONS__" + "-U__CUDA_NO_BFLOAT16_OPERATORS__" + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__" + "-U__CUDA_NO_BFLOAT162_OPERATORS__" + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__" + "--expt-relaxed-constexpr" + "--expt-extended-lambda" + "--use_fast_math" + "--ptxas-options=--verbose,--warn-on-local-memory-usage" + "-lineinfo" + "-DCUTLASS_DEBUG_TRACE_LEVEL=0" + "-DNDEBUG" + "-DQBLKSIZE=128" + "-DKBLKSIZE=128" + "-DCTA256" + "-DDQINRMEM" + ) + + Python_add_library(fp4attn_cuda MODULE WITH_SOABI + attn_qat_infer/blackwell/api.cu + ) + target_include_directories(fp4attn_cuda PRIVATE ${ATTN_QAT_INFER_INCLUDE_DIRS}) + target_compile_definitions(fp4attn_cuda PRIVATE TORCH_EXTENSION_NAME=fp4attn_cuda) + target_compile_options(fp4attn_cuda PRIVATE + $<$:-O3 -std=c++17> + $<$:${ATTN_QAT_INFER_CUDA_FLAGS}> + ) + set_target_properties(fp4attn_cuda PROPERTIES + CUDA_ARCHITECTURES "120a" + CXX_STANDARD 17 + CUDA_STANDARD 17 + ) + target_link_libraries(fp4attn_cuda PRIVATE ${TORCH_LIBRARIES} CUDA::cudart CUDA::cuda_driver) + + Python_add_library(fp4quant_cuda MODULE WITH_SOABI + attn_qat_infer/quantization/fp4_quantization_4d.cu + ) + target_include_directories(fp4quant_cuda PRIVATE ${ATTN_QAT_INFER_INCLUDE_DIRS}) + target_compile_definitions(fp4quant_cuda PRIVATE TORCH_EXTENSION_NAME=fp4quant_cuda) + target_compile_options(fp4quant_cuda PRIVATE + $<$:-O3 -std=c++17> + $<$:${ATTN_QAT_INFER_CUDA_FLAGS}> + ) + set_target_properties(fp4quant_cuda PROPERTIES + CUDA_ARCHITECTURES "120a" + CXX_STANDARD 17 + CUDA_STANDARD 17 + ) + target_link_libraries(fp4quant_cuda PRIVATE ${TORCH_LIBRARIES} CUDA::cudart CUDA::cuda_driver) + + if(TORCH_PYTHON_LIBRARY_PATH) + target_link_libraries(fp4attn_cuda PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}") + target_link_libraries(fp4quant_cuda PRIVATE "${TORCH_PYTHON_LIBRARY_PATH}") + endif() + + install(TARGETS fp4attn_cuda LIBRARY DESTINATION .) + install(TARGETS fp4quant_cuda LIBRARY DESTINATION .) +endif() diff --git a/fastvideo-kernel/MANIFEST.in b/fastvideo-kernel/MANIFEST.in index 6952427846..aa69fbbd7f 100644 --- a/fastvideo-kernel/MANIFEST.in +++ b/fastvideo-kernel/MANIFEST.in @@ -2,5 +2,6 @@ include LICENSE include README.md include pyproject.toml recursive-include python/fastvideo_kernel *.py +recursive-include attn_qat_infer *.py *.cu *.cuh *.cpp *.h recursive-include csrc *.cu *.cuh *.cpp *.h recursive-include include/tk *.cu *.cuh *.cpp *.h *.src diff --git a/fastvideo-kernel/README.md b/fastvideo-kernel/README.md index 82d3fe81d9..f47f4c5161 100644 --- a/fastvideo-kernel/README.md +++ b/fastvideo-kernel/README.md @@ -20,6 +20,11 @@ cd fastvideo-kernel ./build.sh ``` +On supported Blackwell environments, the same install also packages +`attn_qat_infer` and builds its `fp4attn_cuda` / `fp4quant_cuda` +extensions directly from `fastvideo-kernel/attn_qat_infer/`. This path +requires CUDA Toolkit 12.8+ and targets `sm_120a`. + ### Rocm Build If you are in a rocm environment without the compilation toolchaine of CUDA. @@ -58,6 +63,18 @@ cd fastvideo-kernel python benchmarks/bench_vsa.py --batch_size 1 --num_heads 16 --head_dim 128 --q_seq_lens 49152 --topk 64 ``` +### Attn QAT Attention Benchmarks + +The Attn QAT microbenchmarks now live alongside the kernel package: + +```bash +cd fastvideo-kernel +python benchmarks/benchmark_flashattn2.py --batch-size 1 --num-heads 16 --seq-len 4096 --head-dim 128 +python benchmarks/benchmark_sageattn3.py --batch-size 1 --num-heads 16 --seq-len 4096 --head-dim 128 +python benchmarks/benchmark_blockscaled_fp4_attn.py --batch-size 1 --num-heads 16 --seq-len 4096 --head-dim 128 +python benchmarks/benchmark_combined.py --output benchmark_attention.png +``` + ### TurboDiffusion Kernels This package also includes kernels from [TurboDiffusion](https://github.com/thu-ml/TurboDiffusion), including INT8 GEMM, Quantization, RMSNorm and LayerNorm. diff --git a/fastvideo-kernel/attn_qat_infer/__init__.py b/fastvideo-kernel/attn_qat_infer/__init__.py new file mode 100644 index 0000000000..9b2ee4e37d --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/__init__.py @@ -0,0 +1,16 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from .api import sageattn_blackwell \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/api.py b/fastvideo-kernel/attn_qat_infer/api.py new file mode 100644 index 0000000000..05d1d80567 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/api.py @@ -0,0 +1,185 @@ +# Modified from the original SageATtention3 code +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import torch +import triton +import triton.language as tl +import torch.nn.functional as F +from typing import Tuple +from torch.nn.functional import scaled_dot_product_attention as sdpa +import fp4attn_cuda +import fp4quant_cuda + +# Centralized block size configuration for sageattn_blackwell kernels +# These should match the values in fastvideo/attention/backends/sageattn/blackwell/block_config.h +BLOCK_M = 128 # Block size for M dimension (query sequence length) +BLOCK_N = 128 # Block size for N dimension (key/value sequence length) + + +@triton.jit +def group_mean_kernel( + q_ptr, + q_out_ptr, + qm_out_ptr, + B, H, L, D: tl.constexpr, + stride_qb, stride_qh, stride_ql, stride_qd, + stride_qmb, stride_qmh, stride_qml, stride_qmd, + GROUP_SIZE: tl.constexpr +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_group = tl.program_id(2) + + group_start = pid_group * GROUP_SIZE + offsets = group_start + tl.arange(0, GROUP_SIZE) + + q_offsets = pid_b * stride_qb + pid_h * stride_qh + offsets[:, None] * stride_ql + tl.arange(0, D)[None, :] * stride_qd + q_group = tl.load(q_ptr + q_offsets) + + qm_group = tl.sum(q_group, axis=0) / GROUP_SIZE + + q_group = q_group - qm_group + tl.store(q_out_ptr + q_offsets, q_group) + + qm_offset = pid_b * stride_qmb + pid_h * stride_qmh + pid_group * stride_qml + tl.arange(0, D) * stride_qmd + tl.store(qm_out_ptr + qm_offset, qm_group) + + +def triton_group_mean(q: torch.Tensor): + B, H, L, D = q.shape + GROUP_SIZE = BLOCK_M + num_groups = L // GROUP_SIZE + + q_out = torch.empty_like(q) # [B, H, L, D] + qm = torch.empty(B, H, num_groups, D, device=q.device, dtype=q.dtype) + + grid = (B, H, num_groups) + + group_mean_kernel[grid]( + q, q_out, qm, + B, H, L, D, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + qm.stride(0), qm.stride(1), qm.stride(2), qm.stride(3), + GROUP_SIZE=GROUP_SIZE + ) + return q_out, qm + + +def preprocess_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, per_block_mean: bool = True, enable_smoothing_q: bool = False, enable_smoothing_k: bool = False): + + def pad_to_block_size(x): + L = x.size(2) + pad_len = (BLOCK_M - L % BLOCK_M) % BLOCK_M + if pad_len == 0: + return x.contiguous() + return F.pad(x, (0, 0, 0, pad_len), value=0).contiguous() + + if enable_smoothing_k: + k -= k.mean(dim=-2, keepdim=True) + q, k, v = map(lambda x: pad_to_block_size(x), [q, k, v]) + if per_block_mean and enable_smoothing_q: + q, qm = triton_group_mean(q) + elif enable_smoothing_q: + qm = q.mean(dim=-2, keepdim=True) + q = q - qm + if enable_smoothing_q: + delta_s = torch.matmul(qm, k.transpose(-2, -1)).to(torch.float32).contiguous() + else: # used to disable q smoothing + B, H, L, D = q.shape + delta_s = torch.zeros((B, H, L // BLOCK_M, k.shape[2]), device=q.device, dtype=torch.float32) + + return q, k, v, delta_s + +def scale_and_quant_fp4(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty((B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn) + fp4quant_cuda.scaled_fp4_quant(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + +def scale_and_quant_fp4_permute(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, N, D // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty((B, H, N, D // 16), device=x.device, dtype=torch.float8_e4m3fn) + fp4quant_cuda.scaled_fp4_quant_permute(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + +def scale_and_quant_fp4_transpose(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.ndim == 4 + B, H, N, D = x.shape + packed_fp4 = torch.empty((B, H, D, N // 2), device=x.device, dtype=torch.uint8) + fp8_scale = torch.empty((B, H, D, N // 16), device=x.device, dtype=torch.float8_e4m3fn) + fp4quant_cuda.scaled_fp4_quant_trans(x, packed_fp4, fp8_scale, 1) + return packed_fp4, fp8_scale + +def blockscaled_fp4_attn(qlist: Tuple, + klist: Tuple, + vlist: Tuple, + delta_s: torch.Tensor, + KL: int, + is_causal: bool = False, + per_block_mean: bool = True, + is_bf16: bool = True, + single_level_p_quant: bool = False + ): + softmax_scale = (qlist[0].shape[-1] * 2) ** (-0.5) + return fp4attn_cuda.fwd(qlist[0], klist[0], vlist[0], qlist[1], klist[1], vlist[1], delta_s, KL, None, softmax_scale, is_causal, per_block_mean, is_bf16, single_level_p_quant) + + +def sageattn_blackwell(q, k, v, attn_mask = None, is_causal = False, per_block_mean = True, single_level_p_quant = True, **kwargs): + """ + SageAttention3 Blackwell kernel for FP4 attention. + + Args: + q: Query tensor [B, H, L, D] + k: Key tensor [B, H, L, D] + v: Value tensor [B, H, L, D] + attn_mask: Attention mask (not used) + is_causal: Whether to use causal masking + per_block_mean: Whether to use per-block mean for Q smoothing + single_level_p_quant: If True, use single-level quantization: s_P2, P̂_2 = φ(P̃) directly + (standard per-block FP4 quantization like V, no s_P1). + If False (default), use two-level quantization: + s_P1 = rowmax(P̃)/(448×6), then s_P2, P̂_2 = φ(P̃/s_P1). + **kwargs: Additional arguments (ignored) + + Returns: + Output tensor [B, H, L, D] + """ + if q.size(-1) >= 256: + print(f"Unsupported Headdim {q.size(-1)}") + return sdpa(q, k, v, is_causal = is_causal) + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist_from_cuda = scale_and_quant_fp4(q) + klist_from_cuda = scale_and_quant_fp4_permute(k) + vlist_from_cuda = scale_and_quant_fp4_transpose(v) + o_fp4 = blockscaled_fp4_attn( + qlist_from_cuda, + klist_from_cuda, + vlist_from_cuda, + delta_s, + KL, + is_causal, + per_block_mean, + is_bf16, + single_level_p_quant + )[0][:, :, :QL, :].contiguous() + return o_fp4 \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/__init__.py b/fastvideo-kernel/attn_qat_infer/blackwell/__init__.py new file mode 100644 index 0000000000..2e33087c53 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/__init__.py @@ -0,0 +1 @@ +__version__ = "3.0.0.b1" diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/api.cu b/fastvideo-kernel/attn_qat_infer/blackwell/api.cu new file mode 100644 index 0000000000..b454f6b13c --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/api.cu @@ -0,0 +1,346 @@ +// Modified from the original SageAttention3 code +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include +#include +#include + +#include + +#include "params.h" +#include "launch.h" +#include "static_switch.h" +#include "block_config.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t unpadded_seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor delta_s, + at::Tensor out, + const at::Tensor sfq, + const at::Tensor sfk, + const at::Tensor sfv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + bool per_block_mean, + bool is_bf16, + bool single_level_p_quant=false, + bool seqlenq_ngroups_swapped=false) { + + // Reset the parameters + params = {}; + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + params.delta_s_ptr = delta_s.data_ptr(); + params.sfq_ptr = sfq.data_ptr(); + params.sfk_ptr = sfk.data_ptr(); + params.sfv_ptr = sfv.data_ptr(); + + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-2) * 2; + params.k_row_stride = k.stride(-2) * 2; + params.v_row_stride = v.stride(-2) * 2;; + params.q_head_stride = q.stride(-3) * 2; + params.k_head_stride = k.stride(-3) * 2; + params.v_head_stride = v.stride(-3) * 2; // for packed q k v + + params.ds_row_stride = delta_s.stride(-2); + params.ds_head_stride = delta_s.stride(-3); + + params.sfq_row_stride = sfq.stride(-2); + params.sfk_row_stride = sfk.stride(-2); + params.sfv_row_stride = sfv.stride(-2); + params.sfq_head_stride = sfq.stride(-3); + params.sfk_head_stride = sfk.stride(-3); + params.sfv_head_stride = sfv.stride(-3); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-2); + params.o_head_stride = out.stride(-3); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0) * 2; + params.k_batch_stride = k.stride(0) * 2; + params.v_batch_stride = v.stride(0) * 2; + params.ds_batch_stride = delta_s.stride(0); + params.sfq_batch_stride = sfq.stride(0); + params.sfk_batch_stride = sfk.stride(0); + params.sfv_batch_stride = sfv.stride(0); + params.o_batch_stride = out.stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.unpadded_seqlen_k = unpadded_seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + params.head_divmod = cutlass::FastDivmod(int(h)); + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2); + __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half); + params.scale_softmax_log2_half2 = reinterpret_cast(scale_softmax_log2_half2); + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.per_block_mean = per_block_mean; + if (per_block_mean) { + params.seqlen_s = seqlen_q; + } else { + params.seqlen_s = flash::BLOCK_M; // size of BLOCK_M + } + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); + #endif + + params.is_seqlens_k_cumulative = true; + params.is_bf16 = is_bf16; + params.single_level_p_quant = single_level_p_quant; + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif +} + +template +void run_mha_fwd_dispatch_dtype(Flash_fwd_params ¶ms, cudaStream_t stream) { + using OType = std::conditional_t; + if (params.d == 64) { + run_mha_fwd_, 64, OType>(params, stream); + } else if (params.d == 128) { + run_mha_fwd_, 128, OType>(params, stream); + } +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel = false) { + BOOL_SWITCH(params.is_bf16, IsBF16, ([&] { + run_mha_fwd_dispatch_dtype(params, stream); + })); +} + +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x (head_size // 2) + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x (head_size // 2) + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x (head_size // 2) + const at::Tensor &sfq, + const at::Tensor &sfk, + const at::Tensor &sfv, + const at::Tensor &delta_s, + int unpadded_k, + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + bool is_causal, + bool per_block_mean, + bool is_bf16, + bool single_level_p_quant=false // If true, use only per-row scale s_P2 (no per-block s_P1) + ) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm120 = dprops->major == 12 && dprops->minor == 0; + TORCH_CHECK(is_sm120, "only supports Blackwell GPUs or newer."); + + auto q_dtype = q.dtype(); + auto sfq_dtype = sfq.dtype(); + TORCH_CHECK(q_dtype == torch::kUInt8, "q dtype must be uint8"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(sfq_dtype == torch::kFloat8_e4m3fn, "q dtype must be uint8"); + TORCH_CHECK(sfk.dtype() == sfq_dtype, "query and key must have the same dtype"); + TORCH_CHECK(sfv.dtype() == sfq_dtype, "query and value must have the same dtype"); + CHECK_DEVICE(sfq); CHECK_DEVICE(sfk); CHECK_DEVICE(sfv); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(delta_s.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous"); + + const auto sizes = q.sizes(); + auto opts = q.options(); + const int batch_size = sizes[0]; + int seqlen_q = sizes[2]; + int num_heads = sizes[1]; + const int head_size_og = sizes[3]; + const int unpacked_head_size = head_size_og * 2; + const int seqlen_k = k.size(2); + const int num_heads_k = k.size(1); + + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(unpacked_head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads == num_heads_k, "We do not support MQA/GQA yet"); + + TORCH_CHECK(unpacked_head_size == 64 || unpacked_head_size == 128 || unpacked_head_size == 256, "Only support head size 64, 128, and 256 for now"); + + CHECK_SHAPE(q, batch_size, num_heads, seqlen_q, head_size_og); + CHECK_SHAPE(k, batch_size, num_heads_k, seqlen_k, head_size_og); + CHECK_SHAPE(v, batch_size, num_heads_k, unpacked_head_size, seqlen_k/2); + // CHECK_SHAPE(delta_s, batch_size, num_heads, seqlen_q / 128, seqlen_k); + // CHECK_SHAPE(sfq, batch_size, seqlen_q, num_heads, unpacked_head_size); + // CHECK_SHAPE(sfk, batch_size, seqlen_k, num_heads_k, unpacked_head_size); + // CHECK_SHAPE(sfv, batch_size, unpacked_head_size, num_heads_k, seqlen_k); + TORCH_CHECK(unpacked_head_size % 8 == 0, "head_size must be a multiple of 8"); + + auto dtype = is_bf16 ? at::ScalarType::BFloat16 : at::ScalarType::Half; + at::Tensor out = torch::empty({batch_size, num_heads, seqlen_q, unpacked_head_size}, opts.dtype(dtype)); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + // const int head_size = round_multiple(head_size_og, 8); + // const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, flash::BLOCK_M); + const int seqlen_k_rounded = round_multiple(seqlen_k, flash::BLOCK_N); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, unpadded_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + unpacked_head_size, unpacked_head_size, + q, k, v, delta_s, out, + sfq, sfk, sfv, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + /*window_size_left=*/-1, + /*window_size_right=*/is_causal ? 0 : -1, + per_block_mean, + is_bf16, + single_level_p_quant + ); + // TODO: 132 sm count? + auto tile_count_semaphore = is_causal ? torch::full({1}, 132, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + // at::Tensor out_padded = out; + // if (head_size_og % 8 != 0) { + // out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + // if (out_.has_value()) { out_.value().copy_(out); } + // } + + // return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; + // cudaDeviceSynchronize(); + // auto err = cudaGetLastError(); + // printf("%s\n", cudaGetErrorString(err)); + return {out, softmax_lse}; +} + + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashAttention"; + m.def("fwd", &mha_fwd, "Forward pass"); +} \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/block_config.h b/fastvideo-kernel/attn_qat_infer/blackwell/block_config.h new file mode 100644 index 0000000000..e4ecfda77f --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/block_config.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// Centralized block size configuration for sageattn_blackwell kernels +// Block sizes for M and N dimensions +namespace flash { + // Block size for M dimension (query sequence length) + static constexpr int BLOCK_M = 128; + + // Block size for N dimension (key/value sequence length) + static constexpr int BLOCK_N = 128; +} + diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/block_info.h b/fastvideo-kernel/attn_qat_infer/blackwell/block_info.h new file mode 100644 index 0000000000..bd0d454886 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/block_info.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * This code is based on code from FlashAttention3, https://github.com/Dao-AILab/flash-attention + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/blockscaled_layout.h b/fastvideo-kernel/attn_qat_infer/blackwell/blockscaled_layout.h new file mode 100644 index 0000000000..6c36f3ccec --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/blockscaled_layout.h @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Blocked Scale configs specific for SM100 BlockScaled MMA +*/ + +#pragma once + +#include "cutlass/layout/matrix.h" + +#include "cute/int_tuple.hpp" +#include "cute/atom/mma_traits_sm100.hpp" + +namespace flash { + +///////////////////////////////////////////////////////////////////////////////////////////////// +using namespace cute; + +template +struct BlockScaledBasicChunk { + + using Blk_MN = _64; + using Blk_SF = _4; + + using SfAtom = Layout< Shape< Shape<_16,_4>, Shape, _4>>, + Stride, Stride< _0, _1>>>; +}; + +template +struct BlockScaledConfig { + // We are creating the SFA and SFB tensors' layouts in the collective since they always have the same layout. + // k-major order + static constexpr int SFVecSize = SFVecSize_; + static constexpr int MMA_NSF = 4; // SFVecSize, MMA_NSF + using BlkScaledChunk = BlockScaledBasicChunk; + using Blk_MN = _64; + using Blk_SF = _4; + using mnBasicBlockShape = Shape<_16,_4>; + using mnBasicBlockStride = Stride<_16,_4>; + using kBasicBlockShape = Shape, Int>; // SFVecSize, MMA_NSF + using kBasicBlockStride = Stride<_0, _1>; + using SfAtom = Layout< Shape< mnBasicBlockShape, kBasicBlockShape>, + Stride>; + + using LayoutSF = decltype(blocked_product(SfAtom{}, + make_layout( + make_shape(int32_t(0), int32_t(0), int32_t(0), int32_t(0)), + make_stride(int32_t(0), _1{}, int32_t(0), int32_t(0))))); + // A single indivisible block will hold 4 scale factors of 64 rows/columns (A/B matrix). + // 4 is chosen to make consecutive 32bits of data to have scale factors for only a single row (col). 32bits corresponds to the TMEM word size + using Blk_Elems = decltype(Blk_MN{} * Blk_SF{}); + using sSF_strideMN = decltype(prepend(Blk_Elems{}, mnBasicBlockStride{})); + + + // The following function is provided for user fill dynamic problem size to the layout_SFA. + template < class ProblemShape> + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFQKV(ProblemShape problem_shape) { + auto [Seqlen, Dim, HeadNum, Batch] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(Seqlen, Dim, HeadNum, Batch), Step<_2,_1,_3,_4>{}); + } + + // The following function is provided for user fill dynamic problem size to the layout_SFB. + template + CUTE_HOST_DEVICE + static constexpr auto + tile_atom_to_shape_SFVt(ProblemShape problem_shape) { + auto [Dim, Seqlen, HeadNum, Batch] = problem_shape; + return tile_to_shape(SfAtom{}, make_shape(Dim, Seqlen, HeadNum, Batch), Step<_2,_1,_3,_4>{}); + } + + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_smem_layoutSFQ(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { + + using sSFQ_shapeK = decltype(prepend(make_shape(Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int{} / Blk_SF{}), kBasicBlockShape{})); + using sSFQ_shapeM = decltype(prepend(size<0>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + using sSFQ_strideM = sSF_strideMN; + using sSFQ_strideK = decltype(prepend(make_stride(Int{}, size<0>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFQ_shape = decltype(make_shape(sSFQ_shapeM{}, sSFQ_shapeK{})); + using sSFQ_stride = decltype(make_stride(sSFQ_strideM{}, sSFQ_strideK{})); + using SmemLayoutAtomSFQ = decltype(make_layout(sSFQ_shape{}, sSFQ_stride{})); + return SmemLayoutAtomSFQ{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_smem_layoutSFKV(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { + + using sSFK_shapeK = decltype(prepend(make_shape(Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int{} / Blk_SF{}), kBasicBlockShape{})); + using sSFK_shapeN = decltype(prepend(size<1>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + using sSFK_strideN = sSF_strideMN; + using sSFK_strideK = decltype(prepend(make_stride(Int{}, size<1>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFK_shape = decltype(make_shape(sSFK_shapeN{}, sSFK_shapeK{})); + using sSFK_stride = decltype(make_stride(sSFK_strideN{}, sSFK_strideK{})); + using SmemLayoutAtomSFK = decltype(make_layout(sSFK_shape{}, sSFK_stride{})); + return SmemLayoutAtomSFK{}; + } + + template + CUTE_HOST_DEVICE + static constexpr auto + deduce_smem_layoutSFVt(TiledMma tiled_mma, TileShape_MNK tileshape_mnk) { + + using sSFVt_shapeK = decltype(prepend(make_shape(Blk_SF{}/Int{}, size<2>(TileShape_MNK{}) / Int{} / Blk_SF{}), kBasicBlockShape{})); + using sSFVt_shapeN = decltype(prepend(size<1>(TileShape_MNK{}) / Blk_MN{}, mnBasicBlockShape{})); + using sSFVt_strideN = sSF_strideMN; + using sSFVt_strideK = decltype(prepend(make_stride(Int{}, size<1>(TileShape_MNK{}) / Blk_MN{} * Blk_Elems{}), kBasicBlockStride{})); + using sSFVt_shape = decltype(make_shape(sSFVt_shapeN{}, sSFVt_shapeK{})); + using sSFVt_stride = decltype(make_stride(sSFVt_strideN{}, sSFVt_strideK{})); + using SmemLayoutAtomSFVt = decltype(make_layout(sSFVt_shape{}, sSFVt_stride{})); + return SmemLayoutAtomSFVt{}; + } +}; + + +} // namespace flash diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/cute_extension.h b/fastvideo-kernel/attn_qat_infer/blackwell/cute_extension.h new file mode 100644 index 0000000000..bcd7cf3b5a --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/cute_extension.h @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "cute/arch/mma_sm120.hpp" +#include "cute/atom/mma_traits_sm120.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/float8.h" +#include "cutlass/float_subbyte.h" + +namespace cute::SM120::BLOCKSCALED { + +using cutlass::float_e2m1_t; +using cutlass::float_ue4m3_t; + +// MMA.SF 16x32x64 TN E2M1 x E2M1 with SF E4M3 +struct SM120_16x32x64_TN_VS_NVFP4 { + using DRegisters = float[16]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[8]; + using CRegisters = float[16]; + + static constexpr int SFBits = 32; + using RegTypeSF = cute::uint_bit_t; + + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0 , float & d1 , float & d2 , float & d3 , + float & d4 , float & d5 , float & d6 , float & d7 , + float & d8 , float & d9 , float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& a0 , uint32_t const& a1 , uint32_t const& a2 , uint32_t const& a3 , + uint32_t const& b0 , uint32_t const& b1 , uint32_t const& b2 , uint32_t const& b3 , + uint32_t const& b4 , uint32_t const& b5 , uint32_t const& b6 , uint32_t const& b7 , + float const & c0 , float const & c1 , float const & c2 , float const & c3 , + float const & c4 , float const & c5 , float const & c6 , float const & c7 , + float const & c8 , float const & c9 , float const & c10 , float const & c11, + float const & c12, float const & c13, float const & c14, float const & c15, + RegTypeSF const& sfa0, + RegTypeSF const& sfb0) + { + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t bidB = 0; + static constexpr uint16_t tidB0 = 0; + static constexpr uint16_t tidB1 = 1; + static constexpr uint16_t tidB2 = 2; + static constexpr uint16_t tidB3 = 3; + +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d8), "=f"(d9) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c8), "f"(c9), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB0)); + + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d2), "=f"(d3), "=f"(d10), "=f"(d11) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b2), "r"(b3), + "f"(c2), "f"(c3), "f"(c10), "f"(c11), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB1)); + + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d4), "=f"(d5), "=f"(d12), "=f"(d13) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b4), "r"(b5), + "f"(c4), "f"(c5), "f"(c12), "f"(c13), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB2)); + + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d6), "=f"(d7), "=f"(d14), "=f"(d15) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b6), "r"(b7), + "f"(c6), "f"(c7), "f"(c14), "f"(c15), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED"); +#endif + + } +}; + +} // namespace cute::SM120::BLOCKSCALED + +namespace cute { + +// MMA NVFP4 16x32x64 TN +template <> +struct MMA_Traits +{ + // The MMA accepts 4-bit inputs regardless of the types for A and B + using ValTypeA = uint4_t; + using ValTypeB = uint4_t; + + using ValTypeD = float; + using ValTypeC = float; + + using ValTypeSF = cutlass::float_ue4m3_t; + constexpr static int SFVecSize = 16; + + using Shape_MNK = Shape<_16,_32,_64>; + using ThrID = Layout<_32>; + + // (T32,V32) -> (M16,K64) + using ALayout = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_16,_8,_512>>>; + // (T32,V64) -> (N32,K64) + using BLayout = Layout,Shape <_8, _2, _4>>, + Stride,Stride<_32,_1024, _8>>>; + // (T32,V64) -> (M16,K64) + using SFALayout = Layout,_64>, + Stride,_16>>; + // (T32,V64) -> (N32,K64) + using SFBLayout = Layout,_64>, + Stride, _32>>; + // (T32,V16) -> (M16,N32) + using CLayout = Layout,Shape < Shape<_2, _4>,_2>>, + Stride,Stride,_8>>>; +}; + + +template +CUTE_HOST_DEVICE constexpr +auto +thrfrg_SFA(SFATensor&& sfatensor, TiledMMA& mma) +{ + CUTE_STATIC_ASSERT_V(rank(sfatensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFA_TV = typename Atom::Traits::SFALayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfatensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; +} + +template +CUTE_HOST_DEVICE constexpr +auto +thrfrg_SFB(SFBTensor&& sfbtensor, TiledMMA& mma) +{ + CUTE_STATIC_ASSERT_V(rank(sfbtensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFB_TV = typename Atom::Traits::SFBLayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfbtensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + return thr_tensor; +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_SFA(SFATensor&& sfatensor, ThrMma& thread_mma) { + auto thr_tensor = make_tensor(static_cast(sfatensor).data(), thrfrg_SFA(sfatensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_SFA(SFATensor&& sfatensor, ThrMma& thread_mma) { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + return make_fragment_like(partition_SFA(sfatensor, thread_mma)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_SFB(SFBTensor&& sfbtensor, ThrMma& thread_mma) { + auto thr_tensor = make_tensor(static_cast(sfbtensor).data(), thrfrg_SFB(sfbtensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vnk = make_coord(get<0>(thr_vmnk), make_coord(get<2>(thr_vmnk), get<3>(thr_vmnk))); + return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_SFB(SFBTensor&& sfbtensor, ThrMma& thread_mma) { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + return make_fragment_like(partition_SFB(sfbtensor, thread_mma)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +get_layoutSFA_TV(TiledMma& mma) +{ + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_A = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFA(ref_A, mma).compose(atile, _).compose(thridx_2_thrid, _); +} + +template +CUTE_HOST_DEVICE constexpr +auto +get_layoutSFB_TV(TiledMma& mma) +{ + // (N,K) -> (N,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_B = make_layout(make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFB(ref_B, mma).compose(btile, _).compose(thridx_2_thrid, _); +} + +} // namespace cute \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/epilogue_tma_ws.h b/fastvideo-kernel/attn_qat_infer/blackwell/epilogue_tma_ws.h new file mode 100644 index 0000000000..85e82a0537 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/epilogue_tma_ws.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "named_barrier.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveEpilogueFwd{ + + using Element = typename Ktraits::ElementOut; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int kHeadDim = Ktraits::kHeadDim; + using TileShape_MNK = Shape, Int, Int>; + static constexpr int kNWarps = Ktraits::kNWarps; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr int NumMmaThreads = kNThreads - cutlass::NumThreadsPerWarpGroup; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + + // These are for storing the output tensor without TMA (e.g., for setting output to zero) + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using SmemLayoutO = typename Ktraits::SmemLayoutO; + + using SmemCopyAtomO = Copy_Atom; + using SharedStorage = cute::array_aligned>; + + using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) + using StrideO = cute::Stride; + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch) + + using TMA_O = decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}), + SmemLayoutO{}, + select<0, 2>(TileShape_MNK{}), + _1{})); // no mcast for O + + // Host side kernel arguments + struct Arguments { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + float* ptr_LSE; + StrideLSE const stride_LSE; + }; + + // Device side kernel params + struct Params { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + float* ptr_LSE; + StrideLSE const stride_LSE; + TMA_O tma_store_O; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); + TMA_O tma_store_O = make_tma_copy( + GmemTiledCopyOTMA{}, + mO, + SmemLayoutO{}, + select<0, 2>(TileShape_MNK{}), + _1{}); // no mcast for O + return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + mma_store( + SharedStorage& shared_storage, + TiledMma tiled_mma, + FrgTensorO const& tOrO, + int thread_idx + ){ + Tensor sO = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_o.begin()), SmemLayoutO{})); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + constexpr int numel = decltype(size(tOrO))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tOrO.data())); + auto tOrO_out = make_tensor(make_rmem_ptr(&frag), tOrO.layout()); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + } + + template + CUTLASS_DEVICE void + tma_store( + SharedStorage& shared_storage, + Params const& epilogue_params, + WorkTileInfo work_tile_info, + SchedulerParams const& scheduler_params, + int thread_idx + ) { + auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params); + Tensor sO = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.smem_o.begin()), SmemLayoutO{})); + Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O); + Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) + + // auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O); + // Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE); + // Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape>{}, make_coord(m_block)); + + // Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); + // auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + // Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + // static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + // static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + // // // // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. + // Tensor taccOcO_row = taccOcO(make_coord(_0{}, _), _, _0{}); + // CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + // if (get<1>(taccOcO_row(_0{})) == 0) { + // #pragma unroll + // for (int mi = 0; mi < size(lse); ++mi) { + // const int row = get<0>(taccOcO_row(mi)); + // if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); } + // } + // } + + // if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) { + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, + // static_cast(FP4NamedBarriers::EpilogueBarrier)); + // int const lane_predicate = cute::elect_one_sync(); + // if (lane_predicate) { + // cute::copy(epilogue_params.tma_store_O, tOsO, tOgO); + // tma_store_arrive(); + // } + // } + cute::copy(epilogue_params.tma_store_O, tOsO, tOgO); + tma_store_arrive(); + } + + CUTLASS_DEVICE void + store_tail() { + tma_store_wait<0>(); + } + + // Write 0 to output and -inf to LSE + CUTLASS_DEVICE void + store_zero( + Params const& epilogue_params, + int thread_idx, + cute::tuple const& block_coord + ) { + auto [m_block, bidh, bidb] = block_coord; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O); + Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O); + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE); + Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape>{}, make_coord(m_block)); + + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM + ); + static_assert(kBlockM <= NumMmaThreads); + if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; } + } + +}; + +} // namespace flash diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/kernel_traits.h b/fastvideo-kernel/attn_qat_infer/blackwell/kernel_traits.h new file mode 100644 index 0000000000..5035f4b58d --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/kernel_traits.h @@ -0,0 +1,202 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/algorithm/copy.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" + +#include "blockscaled_layout.h" +#include "cute_extension.h" +#include "named_barrier.h" +using namespace cute; + +template < + int kStages, + int EpiStages, + typename Element, + typename ElementSF, + typename OutputType, + typename SmemLayoutQ, + typename SmemLayoutK, + typename SmemLayoutV, + typename SmemLayoutDS, + typename SmemLayoutO, + typename SmemLayoutSFQ, + typename SmemLayoutSFK, + typename SmemLayoutSFV +> +struct SharedStorageQKVOwithSF : cute::aligned_struct<128, _0>{ + + alignas(1024) cute::ArrayEngine> smem_q; + alignas(1024) cute::ArrayEngine> smem_k; + cute::ArrayEngine> smem_SFQ; + cute::ArrayEngine> smem_SFK; + cute::ArrayEngine> smem_SFV; + alignas(1024) cute::ArrayEngine> smem_ds; + alignas(1024) cute::ArrayEngine> smem_v; + alignas(1024) cute::ArrayEngine> smem_o; + + struct { + alignas(16) typename cutlass::PipelineTmaAsync<1>::SharedStorage pipeline_q; + alignas(16) typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + alignas(16) typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + alignas(16) typename flash::OrderedSequenceBarrierVarGroupSize::SharedStorage barrier_o; + int tile_count_semaphore; + }; + }; + +template < + int kHeadDim_, + int kBlockM_, + int kBlockN_, + int kStages_, + int kClusterM_, + bool BlockMean_, + typename ElementPairType_ = cutlass::nv_float4_t, + typename ElementOut_ = cutlass::bfloat16_t +> +struct Flash_fwd_kernel_traits { + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static constexpr bool BlockMean = BlockMean_; + static constexpr bool SmoothQ = true; + static_assert(kHeadDim % 32 == 0); + static_assert(kBlockM == 64 || kBlockM == 128); + static constexpr int kNWarps = kBlockM == 128 ? 12 : 8; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr int kClusterM = kClusterM_; + static constexpr int kStages = kStages_; + static constexpr int EpiStages = 1; + static constexpr int NumSFQK = kHeadDim / 16; + static constexpr int NumSFPV = kBlockN / 16; + using ElementSF = cutlass::float_ue4m3_t; + using Element = cutlass::float_e2m1_t; + using ElementAccum = float; + using ElementOut = ElementOut_; + using index_t = int64_t; + static constexpr auto SFVectorSize = 16; + using TileShape_MNK = Shape, Int, Int>; + using ClusterShape_MNK = Shape<_1, _1, _1>; + using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{})); + using PermTileN = _32; + using PermTileK = Int; + + using ElementQMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + using ElementKMma = decltype(cutlass::gemm::collective::detail::sm1xx_kernel_input_element_to_mma_input_element()); + + using AtomLayoutMNK = std::conditional_t>, + Layout> + >; + using TiledMmaQK = decltype(cute::make_tiled_mma( + cute::SM120::BLOCKSCALED::SM120_16x32x64_TN_VS_NVFP4{}, + AtomLayoutMNK{}, + Tile{} + )); + + using TiledMmaPV = decltype(cute::make_tiled_mma( + cute::SM120::BLOCKSCALED::SM120_16x32x64_TN_VS_NVFP4{}, + AtomLayoutMNK{}, + Tile{} + )); + + static constexpr int MMA_NSF = size<2>(typename TiledMmaQK::AtomShape_MNK{}) / SFVectorSize; + + using GmemTiledCopy = SM90_TMA_LOAD; + using GmemTiledCopySF = SM90_TMA_LOAD; + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::sm120_rr_smem_selector(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutVt = + decltype(tile_to_shape(SmemLayoutAtomVt{}, + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); + using SmemLayoutAtomDS = Layout, Int>, Stride<_0, _1>>; + using SmemLayoutDS = + decltype(tile_to_shape(SmemLayoutAtomDS{}, + make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); + + using SmemCopyAtomQ = Copy_Atom; + using SmemCopyAtomKV = Copy_Atom; + using SmemCopyAtomSF = Copy_Atom, ElementSF>; + using SmemCopyAtomDS = Copy_Atom, float>; + + using BlkScaledConfig = flash::BlockScaledConfig; + using LayoutSF = typename BlkScaledConfig::LayoutSF; + using SfAtom = typename BlkScaledConfig::SfAtom; + using SmemLayoutAtomSFQ = decltype(BlkScaledConfig::deduce_smem_layoutSFQ(TiledMmaQK{}, TileShape_MNK{})); + using SmemLayoutAtomSFK = decltype(BlkScaledConfig::deduce_smem_layoutSFKV(TiledMmaQK{}, TileShape_MNK{})); + using SmemLayoutAtomSFV = decltype(BlkScaledConfig::deduce_smem_layoutSFKV(TiledMmaPV{}, TileShape_MNK{})); + using SmemLayoutAtomSFVt = decltype(BlkScaledConfig::deduce_smem_layoutSFVt(TiledMmaPV{}, Shape, Int, Int>{})); + using LayoutSFP = decltype( + make_layout( + make_shape(make_shape(_16{}, _4{}), _1{}, Int{}), + make_stride(make_stride(_0{}, _1{}), _0{}, _4{}) + ) + ); + using LayoutP = decltype( + make_layout( + make_shape(make_shape(_8{}, _2{}, _2{}), _1{}, Int{}), + make_stride(make_stride(_1{}, _8{}, _16{}), _0{}, _32{}) + ) + ); + using SmemLayoutSFQ = decltype(make_layout( + shape(SmemLayoutAtomSFQ{}), + stride(SmemLayoutAtomSFQ{}) + )); + using SmemLayoutSFK = decltype(make_layout( + append(shape(SmemLayoutAtomSFK{}), Int{}), + append(stride(SmemLayoutAtomSFK{}), size(filter_zeros(SmemLayoutAtomSFK{}))) + )); + using SmemLayoutSFV = decltype(make_layout( + append(shape(SmemLayoutAtomSFV{}), Int{}), + append(stride(SmemLayoutAtomSFV{}), size(filter_zeros(SmemLayoutAtomSFV{}))) + )); + using SmemLayoutSFVt = decltype(make_layout( + append(shape(SmemLayoutAtomSFVt{}), Int{}), + append(stride(SmemLayoutAtomSFVt{}), size(filter_zeros(SmemLayoutAtomSFVt{}))) + )); + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}), Step<_1, _2>{})); + using SharedStorage = SharedStorageQKVOwithSF; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename cutlass::PipelineState; + using MainloopPipelineQ = cutlass::PipelineTmaAsync<1>; + using PipelineParamsQ = typename MainloopPipelineQ::Params; + using PipelineStateQ = typename cutlass::PipelineState<1>; + using EpilogueBarrier = typename flash::OrderedSequenceBarrierVarGroupSize; +}; + diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/kernel_ws.h b/fastvideo-kernel/attn_qat_infer/blackwell/kernel_ws.h new file mode 100644 index 0000000000..675608726a --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/kernel_ws.h @@ -0,0 +1,204 @@ +// Modified from the original SageAttention3 code +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "params.h" +#include "utils.h" +#include "tile_scheduler.h" +#include "mainloop_tma_ws.h" +#include "epilogue_tma_ws.h" +#include "named_barrier.h" +#include "softmax_fused.h" + +namespace flash { + +using namespace cute; + +template +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) + compute_attn_ws(CUTE_GRID_CONSTANT Flash_fwd_params const params, + CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, + CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params + ) { + + using Element = typename Ktraits::Element; + using ElementAccum = typename Ktraits::ElementAccum; + using SoftType = ElementAccum; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaQK{}); + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = Ktraits::kBlockM; + + using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveEpilogue = CollectiveEpilogueFwd; + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ; + using PipelineParamsQ = typename Ktraits::PipelineParamsQ; + using PipelineStateQ = typename Ktraits::PipelineStateQ; + using EpilogueBarrier = typename Ktraits::EpilogueBarrier; + + + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole { + Mainloop = 0, + Epilogue = 1, + Warp2 = 2, + Warp3 = 3 + }; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_group_idx = cutlass::canonical_warp_group_idx(); + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup; + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); + } + + // Obtain warp index + + PipelineParams pipeline_params_v; + pipeline_params_v.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV; + pipeline_params_v.role = warp_group_role == WarpGroupRole::Producer + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params_v.is_leader = warp_group_thread_idx == 0; + pipeline_params_v.num_consumers = NumMmaThreads; + + PipelineParams pipeline_params_k; + pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + pipeline_params_k.role = warp_group_role == WarpGroupRole::Producer + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params_k.is_leader = warp_group_thread_idx == 0; + pipeline_params_k.num_consumers = NumMmaThreads; + + PipelineParamsQ pipeline_params_q; + pipeline_params_q.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ; + pipeline_params_q.role = warp_group_role == WarpGroupRole::Producer + ? MainloopPipelineQ::ThreadCategory::Producer + : MainloopPipelineQ::ThreadCategory::Consumer; + pipeline_params_q.is_leader = warp_group_thread_idx == 0; + pipeline_params_q.num_consumers = NumMmaThreads; + + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + MainloopPipelineQ pipeline_q(shared_storage.pipeline_q, pipeline_params_q, ClusterShape{}); + MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params_k, ClusterShape{}); + MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params_v, ClusterShape{}); + + uint32_t epilogue_barrier_group_size_list[2] = {cutlass::NumThreadsPerWarp, NumMmaThreads}; + typename EpilogueBarrier::Params params_epilogue_barrier; + params_epilogue_barrier.group_id = (warp_group_role == WarpGroupRole::Producer); + params_epilogue_barrier.group_size_list = epilogue_barrier_group_size_list; + EpilogueBarrier barrier_o(shared_storage.barrier_o, params_epilogue_barrier); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + __syncthreads(); + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc<24>(); + TileScheduler scheduler; + + if (producer_warp_role == ProducerWarpRole::Mainloop) { // Load Q, K, V + PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); + + int work_idx = 0; + for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) { + int tile_count_semaphore = 0; + collective_mainloop.load(mainloop_params, scheduler_params, + pipeline_q, pipeline_k, pipeline_v, + smem_pipe_write_q, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, work_tile_info, work_idx, tile_count_semaphore); + } + collective_mainloop.load_tail(pipeline_q, pipeline_k, pipeline_v, + smem_pipe_write_q, smem_pipe_write_k, smem_pipe_write_v); + } else if (producer_warp_role == ProducerWarpRole::Epilogue) { + for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) { + barrier_o.wait(); + collective_epilogue.tma_store(shared_storage, epilogue_params, work_tile_info, scheduler_params, threadIdx.x); + collective_epilogue.store_tail(); + barrier_o.arrive(); + } + + } + } else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { + cutlass::arch::warpgroup_reg_alloc<232>(); + typename Ktraits::TiledMmaPV tiled_mma_pv; + TileScheduler scheduler{}; + PipelineState smem_pipe_read_k, smem_pipe_read_v; + PipelineStateQ smem_pipe_read_q; + + int work_idx = 0; + + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.get_initial_work(); work_tile_info.is_valid(scheduler_params); work_tile_info = scheduler.get_next_work(scheduler_params, work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 2>(TileShape_MNK{})); + // flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax; + // Pass single_level_p_quant flag to control P quantization mode + flash::SoftmaxFused<2 * (2 * kBlockM / NumMmaThreads)> softmax_fused(params.single_level_p_quant); + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, bidh, bidb] = block_coord; + + int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); + if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord); + continue; + } + + collective_mainloop.mma(mainloop_params, pipeline_q, pipeline_k, pipeline_v, smem_pipe_read_q, smem_pipe_read_k, smem_pipe_read_v, + tOrO, softmax_fused, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage); + barrier_o.wait(); + collective_epilogue.mma_store(shared_storage, tiled_mma_pv, tOrO, threadIdx.x - NumCopyThreads); + barrier_o.arrive(); + ++work_idx; + } + } +} + +} // namespace flash diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/launch.h b/fastvideo-kernel/attn_qat_infer/blackwell/launch.h new file mode 100644 index 0000000000..91398c5f44 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/launch.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cluster_launch.hpp" + +#include "static_switch.h" +#include "params.h" +#include "tile_scheduler.h" +#include "kernel_ws.h" +#include "kernel_traits.h" +#include "block_config.h" + + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + using Element = typename Kernel_traits::Element; + using ElementSF = typename Kernel_traits::ElementSF; + using ElementOut = typename Kernel_traits::ElementOut; + using TileShape_MNK = typename Kernel_traits::TileShape_MNK; + using ClusterShape = typename Kernel_traits::ClusterShape_MNK; + using CollectiveMainloop = flash::CollectiveMainloopFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + // using Scheduler = flash::SingleTileScheduler; + using Scheduler = flash::StaticPersistentTileScheduler; + typename CollectiveMainloop::Params mainloop_params = + CollectiveMainloop::to_underlying_arguments({ + static_cast(params.q_ptr), + {params.seqlen_q, params.d, params.h, params.b}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride}, // stride_Q + static_cast(params.k_ptr), + {params.seqlen_k, params.d, params.h_k, params.b}, // shape_K + {params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride}, // stride_K + {params.unpadded_seqlen_k, params.d, params.h_k, params.b}, // shape_K + static_cast(params.v_ptr), + {params.d, params.seqlen_k, params.h_k, params.b}, // shape_Vt + {params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride}, // stride_Vt + static_cast(params.sfq_ptr), + {params.seqlen_q, params.d, params.h, params.b}, // shape_SFQ + static_cast(params.sfk_ptr), + {params.seqlen_k, params.d, params.h_k, params.b}, // shape_SFK + static_cast(params.sfv_ptr), + {params.d, params.seqlen_k, params.h_k, params.b}, // shape_SFVt + static_cast(params.delta_s_ptr), + {params.seqlen_s, params.seqlen_k, params.h_k, params.b}, + {params.ds_row_stride, _1{}, params.ds_head_stride, params.ds_batch_stride}, + params.scale_softmax_log2 + }); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.o_ptr), + {params.seqlen_q, params.d, params.h, params.b}, // shape_O + {params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride}, // stride_O + static_cast(params.softmax_lse_ptr), + {_1{}, params.seqlen_q, params.h * params.seqlen_q}, // stride_LSE + }); + + int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); + typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + // Get the ptr to kernel function. + void *kernel; + kernel = (void *)flash::compute_attn_ws; + int smem_size = sizeof(typename Kernel_traits::SharedStorage); + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + static constexpr int ctaSize = Kernel_traits::kNWarps * 32; + params.m_block_divmod = cutlass::FastDivmod(num_blocks_m); + params.total_blocks = num_blocks_m * params.h * params.b; + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, 170); + dim3 block_dims(ctaSize); + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster(launch_params, kernel, params, mainloop_params, epilogue_params, scheduler_params); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + + +template +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + BOOL_SWITCH(params.per_block_mean, per_block, [&] { + if constexpr (Headdim == 64 || Headdim == 128) { + run_flash_fwd< + Flash_fwd_kernel_traits, + Is_causal + >(params, stream); + } else { + static_assert(Headdim == 64 || Headdim == 128, "Unsupported Headdim"); + } + }); + }); +} diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/mainloop_tma_ws.h b/fastvideo-kernel/attn_qat_infer/blackwell/mainloop_tma_ws.h new file mode 100644 index 0000000000..d6ef18f4ad --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/mainloop_tma_ws.h @@ -0,0 +1,908 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "utils.h" +#include "named_barrier.h" +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopFwd { + + using Element = typename Ktraits::Element; + using ElementSF = typename Ktraits::ElementSF; + // using TMAElement = Element; + // using TMAElementSF = typename Ktraits::ElementSF; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static constexpr int kStages = Ktraits::kStages; + static constexpr int kHeadDim = Ktraits::kHeadDim; + static constexpr int BlockMean = Ktraits::BlockMean; + using GmemTiledCopy = typename Ktraits::GmemTiledCopy; + using SmemLayoutQ = typename Ktraits::SmemLayoutQ; + using SmemLayoutK = typename Ktraits::SmemLayoutK; + using SmemLayoutV = typename Ktraits::SmemLayoutV; + using SmemLayoutVt = typename Ktraits::SmemLayoutVt; + using SmemLayoutDS = typename Ktraits::SmemLayoutDS; + using SmemLayoutAtomDS = typename Ktraits::SmemLayoutAtomDS; + using LayoutDS = decltype( + blocked_product( + SmemLayoutAtomDS{}, + make_layout( + make_shape(int32_t(0), int32_t(0), int32_t(0), int32_t(0)), + make_stride(int32_t(0), _1{}, int32_t(0), int32_t(0))) + ) + ); + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQKV = cute::Stride; + using ShapeSF = cute::Shape; // (seqlen, d // 16, head, batch) + using LayoutSF = typename Ktraits::LayoutSF; + using LayoutP = typename Ktraits::LayoutP; + using LayoutSFP = typename Ktraits::LayoutSFP; + using SfAtom = typename Ktraits::SfAtom; + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + SmemLayoutQ{}, + select<0, 2>(TileShape_MNK{}), + _1{})); + + using TMA_KV = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + take<0, 2>(SmemLayoutK{}), + select<1, 2>(TileShape_MNK{}), + _1{})); + + using TMA_Vt = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + take<0, 2>(SmemLayoutVt{}), + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), + _1{})); + + using TMA_DS = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), LayoutDS{}), + take<0, 2>(SmemLayoutDS{}), + make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), + _1{})); + + using BlkScaledConfig = typename Ktraits::BlkScaledConfig; + using GmemTiledCopySF = typename Ktraits::GmemTiledCopySF; + using SmemLayoutSFQ = typename Ktraits::SmemLayoutSFQ; + using SmemLayoutSFK = typename Ktraits::SmemLayoutSFK; + using SmemLayoutSFV = typename Ktraits::SmemLayoutSFV; + using SmemLayoutSFVt = typename Ktraits::SmemLayoutSFVt; + + using TMA_SFQ = decltype(make_tma_copy( + GmemTiledCopySF{}, + make_tensor(static_cast(nullptr), LayoutSF{}), + SmemLayoutSFQ{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{})), + _1{})); // No programmatic multicast + + + using TMA_SFKV = decltype(make_tma_copy( + GmemTiledCopySF{}, + make_tensor(static_cast(nullptr), LayoutSF{}), + SmemLayoutSFK{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{})), + _1{})); + + using TMA_SFVt = decltype(make_tma_copy( + GmemTiledCopySF{}, + make_tensor(static_cast(nullptr), LayoutSF{}), + SmemLayoutSFVt{}(_,_,cute::Int<0>{}), + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), + _1{})); + + using SmemCopyAtomQ = typename Ktraits::SmemCopyAtomQ; + using SmemCopyAtomKV = typename Ktraits::SmemCopyAtomKV; + using SmemCopyAtomSF = typename Ktraits::SmemCopyAtomSF; + using TiledMmaQK = typename Ktraits::TiledMmaQK; + using TiledMmaPV = typename Ktraits::TiledMmaPV; + static constexpr int NumMmaThreads = size(TiledMmaQK{}); + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + using MainloopPipelineQ = typename Ktraits::MainloopPipelineQ; + using PipelineParamsQ = typename Ktraits::PipelineParamsQ; + using PipelineStateQ = typename Ktraits::PipelineStateQ; + using EpilogueBarrier = typename Ktraits::EpilogueBarrier; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast( + cutlass::bits_to_bytes(cosize((SmemLayoutSFQ{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size((SmemLayoutQ{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytesK = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFK{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutDS{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutK{})) * sizeof_bits::value)); + + static constexpr uint32_t TmaTransactionBytesV = static_cast( + cutlass::bits_to_bytes(cosize(take<0,2>(SmemLayoutSFVt{})) * cute::sizeof_bits_v) + + cutlass::bits_to_bytes(size(take<0,2>(SmemLayoutVt{})) * sizeof_bits::value)); + + // Host side kernel arguments + struct Arguments { + Element const* ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + ShapeQKV const unpadded_shape_K; + Element const* ptr_Vt; + ShapeQKV const shape_Vt; + StrideQKV const stride_Vt; + ElementSF const* ptr_SFQ{nullptr}; + ShapeSF const shape_SFQ{}; + ElementSF const* ptr_SFK{nullptr}; + ShapeSF const shape_SFK{}; + ElementSF const* ptr_SFVt{nullptr}; + ShapeSF const shape_SFVt{}; + float const* ptr_ds; + ShapeQKV const shape_ds; + StrideQKV const stride_ds; + float const softmax_scale_log2; + }; + + // Device side kernel params + struct Params { + ShapeQKV const shape_Q; + LayoutSF const layout_SFQ; + ShapeQKV const shape_K; + ShapeQKV const unpadded_shape_K; + LayoutSF const layout_SFK; + ShapeQKV const shape_Vt; + LayoutSF const layout_SFVt; + LayoutDS const layout_DS; + TMA_Q tma_load_Q; + TMA_SFQ tma_load_SFQ; + TMA_KV tma_load_K; + TMA_SFKV tma_load_SFK; + TMA_Vt tma_load_Vt; + TMA_SFVt tma_load_SFVt; + TMA_DS tma_load_DS; + float const softmax_scale_log2; + }; + + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + TMA_Q tma_load_Q = make_tma_copy( + GmemTiledCopy{}, + mQ, + SmemLayoutQ{}, + select<0, 2>(TileShape_MNK{}), + _1{}); // no mcast for Q + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + TMA_KV tma_load_K = make_tma_copy( + GmemTiledCopy{}, + mK, + SmemLayoutK{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + _1{}); // mcast along M mode for this N load, if any + Tensor mVt = make_tensor(make_gmem_ptr(args.ptr_Vt), args.shape_Vt, args.stride_Vt); + TMA_Vt tma_load_Vt = make_tma_copy( + GmemTiledCopy{}, + mVt, + SmemLayoutVt{}(_, _, _0{}), + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), + _1{}); // mcast along M mode for this N load, if any + auto [Seqlen_Q, Seqlen_K, HeadNum, Batch] = args.shape_ds; + LayoutDS layout_ds = tile_to_shape(SmemLayoutAtomDS{}, make_shape(Seqlen_Q, Seqlen_K, HeadNum, Batch), Step<_2,_1,_3,_4>{}); + Tensor mDS = make_tensor(make_gmem_ptr(args.ptr_ds), layout_ds); + TMA_DS tma_load_ds = make_tma_copy ( + GmemTiledCopy{}, + mDS, + SmemLayoutDS{}(_, _, _0{}), + make_shape(shape<0>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), + _1{}); + LayoutSF layout_sfq = BlkScaledConfig::tile_atom_to_shape_SFQKV(args.shape_SFQ); + Tensor mSFQ = make_tensor(make_gmem_ptr(args.ptr_SFQ), layout_sfq); + TMA_SFQ tma_load_sfq = make_tma_copy( + GmemTiledCopySF{}, + mSFQ, + SmemLayoutSFQ{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{})), + _1{}); + LayoutSF layout_sfk = BlkScaledConfig::tile_atom_to_shape_SFQKV(args.shape_SFK); + Tensor mSFK = make_tensor(make_gmem_ptr(args.ptr_SFK), layout_sfk); + TMA_SFKV tma_load_sfk = make_tma_copy( + GmemTiledCopySF{}, + mSFK, + SmemLayoutSFK{}(_, _, _0{}), + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{})), + _1{}); + LayoutSF layout_sfvt = BlkScaledConfig::tile_atom_to_shape_SFVt(args.shape_SFVt); + Tensor mSFVt = make_tensor(make_gmem_ptr(args.ptr_SFVt), layout_sfvt); + TMA_SFVt tma_load_sfvt = make_tma_copy( + GmemTiledCopySF{}, + mSFVt, + SmemLayoutSFVt{}(_, _, _0{}), + make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), + _1{}); + return {args.shape_Q, layout_sfq, + args.shape_K, args.unpadded_shape_K, layout_sfk, + args.shape_Vt, layout_sfvt, + layout_ds, + tma_load_Q, tma_load_sfq, + tma_load_K, tma_load_sfk, + tma_load_Vt, tma_load_sfvt, + tma_load_ds, + args.softmax_scale_log2}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Vt.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_SFQ.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_SFK.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_SFVt.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_DS.get_tma_descriptor()); + } + + CUTLASS_DEVICE + int get_n_block_max(Params const& mainloop_params, int m_block) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const seqlen_q = get<0>(mainloop_params.shape_Q); + int const seqlen_k = get<0>(mainloop_params.shape_K); + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN)); + } + return n_block_max; + } + + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFA(SFATensor&& sfatensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfatensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFA_TV = typename Atom::Traits::SFALayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<0>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfatensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_SFB(SFBTensor&& sfbtensor, TiledMMA& mma) + { + CUTE_STATIC_ASSERT_V(rank(sfbtensor) >= Int<2>{}); + + using AtomShape_MNK = typename Atom::Shape_MNK; + using AtomLayoutSFB_TV = typename Atom::Traits::SFBLayout; + + auto permutation_mnk = TiledPerm{}; + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(get<1>(permutation_mnk), + get<2>(permutation_mnk)); + auto t_tensor = logical_divide(sfbtensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutSFB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk)), + make_layout(size<3>(thr_layout_vmnk)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + return thr_tensor; + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFA(SFATensor&& sfatensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfatensor).data(), thrfrg_SFA(sfatensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vmk = make_coord(get<0>(thr_vmnk), make_coord(get<1>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFA = thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFA); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_SFB(SFBTensor&& sfbtensor, ThrMma& thread_mma) + { + using ValTypeSF = typename ThrMma::Atom::Traits::ValTypeSF; + auto thr_tensor = make_tensor(static_cast(sfbtensor).data(), thrfrg_SFB(sfbtensor.layout(),thread_mma)); + auto thr_vmnk = thread_mma.thr_vmnk_; + auto thr_vnk = make_coord(get<0>(thr_vmnk), make_coord(get<2>(thr_vmnk), get<3>(thr_vmnk))); + auto partition_SFB = thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + return make_fragment_like(partition_SFB); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFA_TV(TiledMma& mma) + { + // (M,K) -> (M,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_A = make_layout(make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFA(ref_A, mma).compose(atile, _).compose(thridx_2_thrid, _); + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_layoutSFB_TV(TiledMma& mma) + { + // (N,K) -> (N,K) + auto tile_shape_mnk = tile_shape(mma); + auto ref_B = make_layout(make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk))); + auto thr_layout_vmnk = mma.get_thr_layout_vmnk(); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk), size<2>(thr_layout_vmnk)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk); + // (thr_idx,val) -> (M,K) + return thrfrg_SFB(ref_B, mma).compose(btile, _).compose(thridx_2_thrid, _); + } + + template + CUTLASS_DEVICE void + load(Params const& mainloop_params, + SchedulerParams const& scheduler_params, + MainloopPipelineQ pipeline_q, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineStateQ& smem_pipe_write_q, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, + SharedStorage &shared_storage, + WorkTileInfo work_tile_info, + int& work_idx, + int& tile_count_semaphore + ) { + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + auto [m_block, bidh, bidb] = work_tile_info.get_block_coord(scheduler_params); + + int n_block_max = get_n_block_max(mainloop_params, m_block); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.begin()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.begin()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.begin()), SmemLayoutVt{}); + Tensor sSFQ = make_tensor(make_smem_ptr(shared_storage.smem_SFQ.begin()), SmemLayoutSFQ{}); + Tensor sSFK = make_tensor(make_smem_ptr(shared_storage.smem_SFK.begin()), SmemLayoutSFK{}); + Tensor sSFVt = make_tensor(make_smem_ptr(shared_storage.smem_SFV.begin()), SmemLayoutSFVt{}); + Tensor sDS = make_tensor(make_smem_ptr(shared_storage.smem_ds.begin()), SmemLayoutDS{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.shape_Q); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K); + Tensor mVt = mainloop_params.tma_load_Vt.get_tma_tensor(mainloop_params.shape_Vt); + Tensor mDS = mainloop_params.tma_load_DS.get_tma_tensor(shape(mainloop_params.layout_DS)); + Tensor mSFQ = mainloop_params.tma_load_SFQ.get_tma_tensor(shape(mainloop_params.layout_SFQ)); + Tensor mSFK = mainloop_params.tma_load_SFK.get_tma_tensor(shape(mainloop_params.layout_SFK)); + Tensor mSFVt = mainloop_params.tma_load_SFVt.get_tma_tensor(shape(mainloop_params.layout_SFVt)); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gVt = local_tile(mVt(_, _, bidh, bidb), make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), make_coord(_0{}, _)); // (N, K, _) + Tensor gDS = [&] { + if constexpr (BlockMean) { + return local_tile(mDS(_, _, bidh, bidb), select<0, 1>(TileShape_MNK{}), make_coord(m_block, _)); + } else { + return local_tile(mDS(_, _, bidh, bidb), select<0, 1>(TileShape_MNK{}), make_coord(_0{}, _)); + } + }(); + Tensor gSFQ = local_tile(mSFQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); + Tensor gSFK = local_tile(mSFK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); + Tensor gSFVt = local_tile(mSFVt(_, _, bidh, bidb), make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{})), make_coord(_0{}, _)); + auto block_tma_q = mainloop_params.tma_load_Q.get_slice(_0{}); + Tensor tQgQ = block_tma_q.partition_S(gQ); + Tensor tQsQ = block_tma_q.partition_D(sQ); + auto block_tma_sfq = mainloop_params.tma_load_SFQ.get_slice(_0{}); + Tensor tQgSFQ = block_tma_sfq.partition_S(gSFQ); + Tensor tQsSFQ = block_tma_sfq.partition_D(sSFQ); + auto block_tma_k = mainloop_params.tma_load_K.get_slice(cluster_local_block_id.x); + Tensor tKgK = group_modes<0, 3>(block_tma_k.partition_S(gK)); + Tensor tKsK = group_modes<0, 3>(block_tma_k.partition_D(sK)); + auto block_tma_sfk = mainloop_params.tma_load_SFK.get_slice(cluster_local_block_id.x); + Tensor tKgSFK = group_modes<0, 3>(block_tma_sfk.partition_S(gSFK)); + Tensor tKsSFK = group_modes<0, 3>(block_tma_sfk.partition_D(sSFK)); + auto block_tma_vt = mainloop_params.tma_load_Vt.get_slice(cluster_local_block_id.x); + Tensor tVgVt = group_modes<0, 3>(block_tma_vt.partition_S(gVt)); + Tensor tVsVt = group_modes<0, 3>(block_tma_vt.partition_D(sVt)); + auto block_tma_sfvt = mainloop_params.tma_load_SFVt.get_slice(cluster_local_block_id.x); + Tensor tVgSFVt = group_modes<0, 3>(block_tma_sfvt.partition_S(gSFVt)); + Tensor tVsSFVt = group_modes<0, 3>(block_tma_sfvt.partition_D(sSFVt)); + auto block_tma_ds = mainloop_params.tma_load_DS.get_slice(cluster_local_block_id.x); + Tensor tDSgDS = group_modes<0, 3>(block_tma_ds.partition_S(gDS)); + Tensor tDSsDS = group_modes<0, 3>(block_tma_ds.partition_D(sDS)); + uint16_t mcast_mask_kv = 0; + + int n_block = n_block_max - 1; + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + pipeline_q.producer_acquire(smem_pipe_write_q); + copy(mainloop_params.tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), 0), tQgQ, tQsQ); + copy(mainloop_params.tma_load_SFQ.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), 0), tQgSFQ, tQsSFQ); + ++smem_pipe_write_q; + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); + copy(mainloop_params.tma_load_SFK.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tKgSFK(_, n_block), tKsSFK(_, smem_pipe_write_k.index())); + copy(mainloop_params.tma_load_DS.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tDSgDS(_, n_block), tDSsDS(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_Vt.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), + tVgVt(_, n_block), tVsVt(_, smem_pipe_write_v.index())); + copy(mainloop_params.tma_load_SFVt.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), + tVgSFVt(_, n_block), tVsSFVt(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + + n_block--; + if (lane_predicate) { + // CUTLASS_PRAGMA_NO_UNROLL + #pragma unroll 2 + for (; n_block >= 0; --n_block) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); + copy(mainloop_params.tma_load_SFK.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tKgSFK(_, n_block), tKsSFK(_, smem_pipe_write_k.index())); + copy(mainloop_params.tma_load_DS.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tDSgDS(_, n_block), tDSsDS(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_Vt.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), + tVgVt(_, n_block), tVsVt(_, smem_pipe_write_v.index())); + copy(mainloop_params.tma_load_SFVt.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), + tVgSFVt(_, n_block), tVsSFVt(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + ++work_idx; + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipelineQ pipeline_q, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineStateQ& smem_pipe_write_q, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v) { + int lane_predicate = cute::elect_one_sync(); + // Issue the epilogue waits + if (lane_predicate) { + pipeline_q.producer_tail(smem_pipe_write_q); + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } + } + + template + CUTLASS_DEVICE void + mma(Params const& mainloop_params, + MainloopPipelineQ pipeline_q, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineStateQ& smem_pipe_read_q, + PipelineState& smem_pipe_read_k, + PipelineState& smem_pipe_read_v, + FrgTensorO& tOrO_store, + SoftmaxFused& softmax_fused, + int n_block_count, + int thread_idx, + int work_idx, + int m_block, + SharedStorage& shared_storage + ) { + + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + static constexpr int kBlockK = get<2>(TileShape_MNK{}); + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.begin()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.begin()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.begin()), SmemLayoutVt{}); + Tensor sDS = make_tensor(make_smem_ptr(shared_storage.smem_ds.begin()), SmemLayoutDS{}); + Tensor sSFQ = make_tensor(make_smem_ptr(shared_storage.smem_SFQ.begin()), SmemLayoutSFQ{}); + Tensor sSFK = make_tensor(make_smem_ptr(shared_storage.smem_SFK.begin()), SmemLayoutSFK{}); + Tensor sSFVt = make_tensor(make_smem_ptr(shared_storage.smem_SFV.begin()), SmemLayoutSFVt{}); + + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + auto thread_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + auto thread_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + + Tensor tSrQ = thread_mma_qk.partition_fragment_A(sQ); + Tensor tSrK = thread_mma_qk.partition_fragment_B(sK(_,_,Int<0>{})); + Tensor tOrVt = thread_mma_pv.partition_fragment_B(sVt(_,_,Int<0>{})); + Tensor tOrP = make_tensor_like(LayoutP{}); + Tensor tSrSFQ = partition_fragment_SFA(sSFQ, thread_mma_qk); + Tensor tSrSFK = partition_fragment_SFB(sSFK(_,_,Int<0>{}), thread_mma_qk); + Tensor tOrSFVt = partition_fragment_SFB(sSFVt(_,_,Int<0>{}), thread_mma_pv); + Tensor tOrSFP = make_tensor(LayoutSFP{}); + Tensor tOrSFP_flt = filter_zeros(tOrSFP); + Tensor tSrDS = make_tensor(make_shape(_8{}, _4{}), make_stride(_1{}, _8{})); + // copy qk and sf from smem to rmem + auto smem_tiled_copy_Q = make_tiled_copy_A(SmemCopyAtomQ{}, tiled_mma_qk); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(thread_idx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(as_position_independent_swizzle_tensor(sQ)); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(SmemCopyAtomKV{}, tiled_mma_qk); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(thread_idx); + Tensor tSsK = smem_thr_copy_K.partition_S(as_position_independent_swizzle_tensor(sK)); + Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK); + + auto smem_tiled_copy_V = make_tiled_copy_B(SmemCopyAtomKV{}, tiled_mma_pv); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(thread_idx); + Tensor tOsVt = smem_thr_copy_V.partition_S(as_position_independent_swizzle_tensor(sVt)); + Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt); + + auto tile_shape_mnk = tile_shape(tiled_mma_qk); + auto smem_tiled_copy_SFQ = make_tiled_copy_impl(SmemCopyAtomSF{}, + get_layoutSFA_TV(tiled_mma_qk), + make_shape(size<0>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFQ = smem_tiled_copy_SFQ.get_thread_slice(thread_idx); + Tensor tSsSFQ = smem_thr_copy_SFQ.partition_S(as_position_independent_swizzle_tensor(sSFQ)); + Tensor tSrSFQ_copy_view = smem_thr_copy_SFQ.retile_D(tSrSFQ); + + auto smem_tiled_copy_SFK = make_tiled_copy_impl(SmemCopyAtomSF{}, + get_layoutSFB_TV(tiled_mma_qk), + make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFK = smem_tiled_copy_SFK.get_thread_slice(thread_idx); + Tensor tSsSFK = smem_thr_copy_SFK.partition_S(as_position_independent_swizzle_tensor(sSFK)); + Tensor tSrSFK_copy_view = smem_thr_copy_SFK.retile_D(tSrSFK); + + auto smem_tiled_copy_SFV = make_tiled_copy_impl(SmemCopyAtomSF{}, + get_layoutSFB_TV(tiled_mma_pv), + make_shape(size<1>(tile_shape_mnk), size<2>(tile_shape_mnk)) + ); + auto smem_thr_copy_SFV = smem_tiled_copy_SFV.get_thread_slice(thread_idx); + Tensor tOsSFVt = smem_thr_copy_SFV.partition_S(as_position_independent_swizzle_tensor(sSFVt)); + Tensor tOrSFVt_copy_view = smem_thr_copy_SFV.retile_D(tOrSFVt); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + int const seqlen_q = get<0>(mainloop_params.shape_Q); + int const seqlen_k = get<0>(mainloop_params.shape_K); + int const unpadded_seqlen_k = get<0>(mainloop_params.unpadded_shape_K); + int n_block = n_block_count - 1; + + auto copy_k_block = [&](auto block_id) { + auto tSsK_stage = tSsK(_, _, _, smem_pipe_read_k.index()); + auto tSsSFK_stage = tSsSFK(_, _, _, smem_pipe_read_k.index()); + copy(smem_tiled_copy_K, tSsK_stage(_, _, block_id), tSrK_copy_view(_, _, block_id)); + copy(smem_tiled_copy_SFK, tSsSFK_stage(_, _, block_id), tSrSFK_copy_view(_, _, block_id)); + }; + + auto copy_v_block = [&](auto block_id) { + auto tOsVt_stage = tOsVt(_, _, _, smem_pipe_read_v.index()); + auto tOsSFVt_stage = tOsSFVt(_, _, _, smem_pipe_read_v.index()); + copy(smem_tiled_copy_V, tOsVt_stage(_, _, block_id), tOrVt_copy_view(_, _, block_id)); + copy(smem_tiled_copy_SFV, tOsSFVt_stage(_, _, block_id), tOrSFVt_copy_view(_, _, block_id)); + }; + // auto gemm_qk = [&](auto block_id) { + // cute::gemm(tiled_mma_qk, make_zip_tensor(tSrQ(_, _, block_id), tSrSFQ(_, _, block_id)), make_zip_tensor(tSrK(_, _, block_id), tSrSFK(_, _, block_id)), tSrS); + // }; + // auto gemm_pv = [&](auto block_id) { + // cute::gemm(tiled_mma_pv, make_zip_tensor(tOrP(_, _, block_id), tOrSFP(_, _, block_id)), make_zip_tensor(tOrVt(_, _, block_id), tOrSFVt(_, _, block_id)), tOrO); + // }; + auto add_delta_s = [&](auto& acc) { + auto tSsDS_stage = recast(sDS(_, _, smem_pipe_read_k.index())); + auto acc_float4 = recast(acc); + int quad_id = (threadIdx.x % 4) * 2; + for (int i = 0; i < 4; i++) { + auto num = quad_id + i * 8; + float4 delta_s_0 = tSsDS_stage(make_coord(_0{}, _0{}), make_coord(num, _0{})); + float4 delta_s_1 = tSsDS_stage(make_coord(_0{}, _0{}), make_coord(num + 1, _0{})); + acc_float4(make_coord(make_coord(_0{}, _0{}), _0{}), _0{}, i) = delta_s_0; + acc_float4(make_coord(make_coord(_0{}, _0{}), _1{}), _0{}, i) = delta_s_0; + acc_float4(make_coord(make_coord(_0{}, _1{}), _0{}), _0{}, i) = delta_s_1; + acc_float4(make_coord(make_coord(_0{}, _1{}), _1{}), _0{}, i) = delta_s_1; + } + }; + consumer_wait(pipeline_q, smem_pipe_read_q); + copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + copy(smem_tiled_copy_SFQ, tSsSFQ, tSrSFQ_copy_view); + pipeline_q.consumer_release(smem_pipe_read_q); + ++smem_pipe_read_q; + + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + Tensor tSrS_converion_view = make_tensor(tSrS.data(), flash::convert_to_conversion_layout(tSrS.layout())); + Tensor AbsMaxP = make_tensor_like( + make_layout(shape(group<1, 4>(flatten(tSrS_converion_view.layout()(make_coord(_0{}, _), _, _))))) + ); + consumer_wait(pipeline_k, smem_pipe_read_k); + copy_k_block(_0{}); + add_delta_s(tSrS); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, make_zip_tensor(tSrQ(_, _, k_block), tSrSFQ(_, _, k_block)), + make_zip_tensor(tSrK(_, _, k_block), tSrSFK(_, _, k_block)), tSrS); + if (k_block < size<2>(tSrQ) - 1) { + copy_k_block(k_block + 1); + } else { + pipeline_k.consumer_release(smem_pipe_read_k); + ++smem_pipe_read_k; + } + } + + + auto col_limit_causal = [&](int row, int n_block) { + return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; + }; + { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = thread_mma_qk.partition_C(cS); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(unpadded_seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + } else { + if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, + col_limit_causal(int(get<0>(tScS(i))), n_block))) { + tSrS(i) = -INFINITY; + } + } + } + } + auto quantize = [&](auto mma_k, auto acc_conversion_view) { + Tensor AbsMaxP_stagek = AbsMaxP(_, make_coord(_, _, mma_k)); + Tensor acc_conversion_stagek = acc_conversion_view(_, _, mma_k); + Tensor SFP = make_tensor_like(AbsMaxP_stagek.layout()); + Tensor SFP_uint32_view = recast(SFP); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(AbsMaxP_stagek); i += 4) { + uint32_t& tmp = SFP_uint32_view(i / 4); + flash::packed_float_to_ue4m3( + AbsMaxP_stagek(i), + AbsMaxP_stagek(i + 1), + AbsMaxP_stagek(i + 2), + AbsMaxP_stagek(i + 3), + tmp + ); + } + int const quad_id = threadIdx.x & 3; + uint32_t MASK = (0xFF00FF) << ((quad_id & 1) * 8); + Tensor tOrSFP_uint32_view = recast(tOrSFP(_, _, mma_k)); + Tensor tOrP_uint32_view = recast(tOrP(_, _, mma_k)); + + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < size<1>(tOrP); ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + flash::packed_float_to_e2m1( + acc_conversion_stagek(make_coord(_0{}, i), mma_m), + acc_conversion_stagek(make_coord(_1{}, i), mma_m), + acc_conversion_stagek(make_coord(_2{}, i), mma_m), + acc_conversion_stagek(make_coord(_3{}, i), mma_m), + acc_conversion_stagek(make_coord(_4{}, i), mma_m), + acc_conversion_stagek(make_coord(_5{}, i), mma_m), + acc_conversion_stagek(make_coord(_6{}, i), mma_m), + acc_conversion_stagek(make_coord(_7{}, i), mma_m), + tOrP_uint32_view(i, mma_m) + ); + } + uint32_t local_sfp = SFP_uint32_view(_0{}, _0{}, mma_m); + uint32_t peer_sfp = __shfl_xor_sync(int32_t(-1), local_sfp, 2); + if ((quad_id & 1) == 0) { + uint32_t sfp = (local_sfp & MASK) | ((peer_sfp & MASK) << 8); + tOrSFP_uint32_view(_0{}, mma_m) = sfp; + } else { + uint32_t sfp = (peer_sfp & MASK) | ((local_sfp & MASK) >> 8); + tOrSFP_uint32_view(_0{}, mma_m) = sfp; + } + } + }; + + softmax_fused.template online_softmax_with_quant(tSrS, AbsMaxP, mainloop_params.softmax_scale_log2); + + consumer_wait(pipeline_v, smem_pipe_read_v); + copy_v_block(_0{}); + quantize(_0{}, tSrS_converion_view); + CUTLASS_PRAGMA_UNROLL + for (int v_block = 0; v_block < size<2>(tOrP); ++v_block) { + cute::gemm(tiled_mma_pv, make_zip_tensor(tOrP(_, _, v_block), tOrSFP(_, _, v_block)), + make_zip_tensor(tOrVt(_, _, v_block), tOrSFVt(_, _, v_block)), tOrO_store); + if (v_block < size<2>(tOrP) - 1) { + copy_v_block(v_block + 1); + quantize(v_block + 1, tSrS_converion_view); + } else { + pipeline_v.consumer_release(smem_pipe_read_v); + ++smem_pipe_read_v; + } + } + + n_block--; + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; + // // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal + CUTLASS_PRAGMA_UNROLL + for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block >= 0; ++masking_step, --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + Tensor tSrS_converion_view = make_tensor(tSrS.data(), flash::convert_to_conversion_layout(tSrS.layout())); + consumer_wait(pipeline_k, smem_pipe_read_k); + copy_k_block(_0{}); + add_delta_s(tSrS); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, make_zip_tensor(tSrQ(_, _, k_block), tSrSFQ(_, _, k_block)), + make_zip_tensor(tSrK(_, _, k_block), tSrSFK(_, _, k_block)), tSrS); + if (k_block < size<2>(tSrQ) - 1) { + copy_k_block(k_block + 1); + } + } + pipeline_k.consumer_release(smem_pipe_read_k); // release K + ++smem_pipe_read_k; + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = thread_mma_qk.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) { + tSrS(i) = -INFINITY; + } + } + softmax_fused.template online_softmax_with_quant(tSrS, AbsMaxP, mainloop_params.softmax_scale_log2); + Tensor tOrO = make_fragment_like(tOrO_store); + consumer_wait(pipeline_v, smem_pipe_read_v); + copy_v_block(_0{}); + quantize(_0{}, tSrS_converion_view); + CUTLASS_PRAGMA_UNROLL + for (int v_block = 0; v_block < size<2>(tOrP); ++v_block) { + cute::gemm(tiled_mma_pv, make_zip_tensor(tOrP(_, _, v_block), tOrSFP(_, _, v_block)), + make_zip_tensor(tOrVt(_, _, v_block), tOrSFVt(_, _, v_block)), tOrO); + if (v_block < size<2>(tOrP) - 1) { + copy_v_block(v_block + 1); + quantize(v_block + 1, tSrS_converion_view); + } + } + pipeline_v.consumer_release(smem_pipe_read_v); + ++smem_pipe_read_v; + if (masking_step > 0) { softmax_fused.rescale_o(tOrO_store, tOrO); } + } + + #pragma unroll 1 + for (; n_block >= 0; --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{})); + Tensor tSrS_converion_view = make_tensor(tSrS.data(), flash::convert_to_conversion_layout(tSrS.layout())); + consumer_wait(pipeline_k, smem_pipe_read_k); + copy_k_block(_0{}); + add_delta_s(tSrS); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, make_zip_tensor(tSrQ(_, _, k_block), tSrSFQ(_, _, k_block)), + make_zip_tensor(tSrK(_, _, k_block), tSrSFK(_, _, k_block)), tSrS); + if (k_block < size<2>(tSrQ) - 1) { + copy_k_block(k_block + 1); + } else { + pipeline_k.consumer_release(smem_pipe_read_k); + ++smem_pipe_read_k; + } + } + softmax_fused.template online_softmax_with_quant(tSrS, AbsMaxP, mainloop_params.softmax_scale_log2); + Tensor tOrO = make_fragment_like(tOrO_store); + consumer_wait(pipeline_v, smem_pipe_read_v); + copy_v_block(_0{}); + quantize(_0{}, tSrS_converion_view); + CUTLASS_PRAGMA_UNROLL + for (int v_block = 0; v_block < size<2>(tOrP); ++v_block) { + cute::gemm(tiled_mma_pv, make_zip_tensor(tOrP(_, _, v_block), tOrSFP(_, _, v_block)), + make_zip_tensor(tOrVt(_, _, v_block), tOrSFVt(_, _, v_block)), tOrO); + if (v_block < size<2>(tOrP) - 1) { + copy_v_block(v_block + 1); + quantize(v_block + 1, tSrS_converion_view); + } else { + pipeline_v.consumer_release(smem_pipe_read_v); + ++smem_pipe_read_v; + } + } + softmax_fused.rescale_o(tOrO_store, tOrO); + } + softmax_fused.finalize(tOrO_store); + return; + } + +}; + +} // namespace flash + diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/named_barrier.h b/fastvideo-kernel/attn_qat_infer/blackwell/named_barrier.h new file mode 100644 index 0000000000..bd8aefb6a0 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/named_barrier.h @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/barrier.h" +#include "cutlass/pipeline/sm90_pipeline.hpp" + +namespace flash { + +enum class FP4NamedBarriers { + QueryEmpty = 1, + WarpSpecializedConsumer = 2, + WarpSpecializedPingPongConsumer1 = 3, + WarpSpecializedPingPongConsumer2 = 4, + ProducerEnd = 5, + ConsumerEnd = 6, + EpilogueBarrier = 7 +}; + +template +struct OrderedSequenceBarrierVarGroupSizeSharedStorage { + using Barrier = cutlass::arch::ClusterBarrier; + Barrier barrier_[SequenceDepth][SequenceLength]; +}; + + template +class OrderedSequenceBarrierVarGroupSize { +public: + static constexpr int SequenceDepth = SequenceDepth_; + static constexpr int SequenceLength = SequenceLength_; + using Barrier = cutlass::arch::ClusterBarrier; + using SharedStorage = flash::OrderedSequenceBarrierVarGroupSizeSharedStorage; + + + struct Params { + uint32_t group_id; + uint32_t* group_size_list; + }; + +private : + // In future this Params object can be replaced easily with a CG object + Params params_; + Barrier *barrier_ptr_; + cutlass::PipelineState stage_; + + static constexpr int Depth = SequenceDepth; + static constexpr int Length = SequenceLength; + +public: + OrderedSequenceBarrierVarGroupSize() = delete; + OrderedSequenceBarrierVarGroupSize(const OrderedSequenceBarrierVarGroupSize&) = delete; + OrderedSequenceBarrierVarGroupSize(OrderedSequenceBarrierVarGroupSize&&) = delete; + OrderedSequenceBarrierVarGroupSize& operator=(const OrderedSequenceBarrierVarGroupSize&) = delete; + OrderedSequenceBarrierVarGroupSize& operator=(OrderedSequenceBarrierVarGroupSize&&) = delete; + ~OrderedSequenceBarrierVarGroupSize() = default; + + CUTLASS_DEVICE + OrderedSequenceBarrierVarGroupSize(SharedStorage& storage, Params const& params) : + params_(params), + barrier_ptr_(&storage.barrier_[0][0]), + // Group 0 - starts with an opposite phase + stage_({0, params.group_id == 0, 0}) { + int warp_idx = cutlass::canonical_warp_idx_sync(); + int lane_predicate = cute::elect_one_sync(); + + // Barrier FULL, EMPTY init + // Init is done only by the one elected thread of the block + if (warp_idx == 0 && lane_predicate) { + for (int d = 0; d < Depth; ++d) { + for (int l = 0; l < Length; ++l) { + barrier_ptr_[d * Length + l].init(*(params.group_size_list + l)); + } + } + } + cutlass::arch::fence_barrier_init(); + } + + // Wait on a stage to be unlocked + CUTLASS_DEVICE + void wait() { + get_barrier_for_current_stage(params_.group_id).wait(stage_.phase()); + } + + // Signal completion of Stage and move to the next stage + // (group_id) signals to (group_id+1) + CUTLASS_DEVICE + void arrive() { + int signalling_id = (params_.group_id + 1) % Length; + get_barrier_for_current_stage(signalling_id).arrive(); + ++stage_; + } + + CUTLASS_DEVICE + void advance() { + ++stage_; + } + +private: + + CUTLASS_DEVICE + Barrier& get_barrier_for_current_stage(int group_id) { + return barrier_ptr_[stage_.index() * Length + group_id]; + } +}; + + } // flash \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/params.h b/fastvideo-kernel/attn_qat_infer/blackwell/params.h new file mode 100644 index 0000000000..fcf9bf8981 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/params.h @@ -0,0 +1,180 @@ +// Modified from the original SageAttention3 code +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + #pragma once + + #include + #include + + #ifdef OLD_GENERATOR_PATH + #include + #else + #include + #endif + + #include // For at::cuda::philox::unpack + + #include "cutlass/fast_math.h" // For cutlass::FastDivmod + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + void *__restrict__ delta_s_ptr; + // The QKV scale factor matrices. + void *__restrict__ sfq_ptr; + void *__restrict__ sfk_ptr; + void *__restrict__ sfv_ptr; + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t ds_batch_stride; + index_t ds_row_stride; + index_t ds_head_stride; + // The stride of the Q, K and V scale factor matrices. + index_t sfq_batch_stride; + index_t sfk_batch_stride; + index_t sfv_batch_stride; + index_t sfq_row_stride; + index_t sfk_row_stride; + index_t sfv_row_stride; + index_t sfq_head_stride; + index_t sfk_head_stride; + index_t sfv_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// + + struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + void * __restrict__ s_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, unpadded_seqlen_k; + cutlass::FastDivmod head_divmod, m_block_divmod; + int total_blocks; + int seqlen_s; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + uint32_t scale_softmax_log2_half2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Random state. + at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_e4m3; + bool is_causal; + bool per_block_mean; + bool single_level_p_quant; // If true, use single-level 1x16 block scale quantization for P (like V), instead of two-level quantization + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + int * __restrict__ tile_count_semaphore; + }; + + //////////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/softmax_fused.h b/fastvideo-kernel/attn_qat_infer/blackwell/softmax_fused.h new file mode 100644 index 0000000000..54a803e874 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/softmax_fused.h @@ -0,0 +1,190 @@ +// Modified from the original SageAttention3 code +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #pragma once + + #include + #include "cute/tensor.hpp" + #include "cutlass/numeric_types.h" + #include "utils.h" + + namespace flash { + + using namespace cute; + + template + struct SoftmaxFused{ + + using TensorT = decltype(make_fragment_like(Shape>{})); + TensorT row_sum, row_max, scores_scale; + static constexpr float fp8_scalexfp4_scale = 1.f / (448 * 6); + static constexpr float fp8_scalexfp4_scale_log2 = -11.392317422778762f; //log2f(fp8_scalexfp4_scale) + static constexpr float fp4_scale_log2 = -2.584962500721156f; // log2f(fp4_scale) + static constexpr int RowReductionThr = 4; + + // If true, use single-level quantization: s_P2, P̂_2 = φ(P̃) directly (standard per-block FP4 quantization like V) + // If false (default), use two-level quantization: s_P1 = rowmax(P̃)/(448×6), then s_P2, P̂_2 = φ(P̃/s_P1) + bool single_level_p_quant; + + CUTLASS_DEVICE SoftmaxFused(bool single_level = false) : single_level_p_quant(single_level) {}; + + template + CUTLASS_DEVICE auto online_softmax_with_quant( + TensorAcc& acc, + TensorMax& AbsMaxP, + const float softmax_scale_log2 + ) { + Tensor acc_reduction_view = make_tensor(acc.data(), flash::convert_to_reduction_layout(acc.layout())); + Tensor acc_conversion_view = make_tensor(acc.data(), flash::convert_to_conversion_layout(acc.layout())); + Tensor acc_conversion_flatten = group_modes<1, 5>(group_modes<0, 2>(flatten(acc_conversion_view))); + + if constexpr (FirstTile) { + fill(row_max, -INFINITY); + clear(row_sum); + fill(scores_scale, 1.f); + + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(acc_reduction_view); mi++) { + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1, 1>(acc_reduction_view); ni++) { + CUTLASS_PRAGMA_UNROLL + for (int ei = 0; ei < size<1, 0>(acc_reduction_view); ei++) { + AbsMaxP(mi, ni) = fmaxf(AbsMaxP(mi, ni), acc_reduction_view(mi, make_coord(ei, ni))); + } + float max_recv = __shfl_xor_sync(int32_t(-1), AbsMaxP(mi, ni), 1); // exchange max with neighbour thread of 8 elements + AbsMaxP(mi, ni) = fmaxf(AbsMaxP(mi, ni), max_recv); + row_max(mi) = fmaxf(row_max(mi), AbsMaxP(mi, ni)); + } + + float max_recv = __shfl_xor_sync(int32_t(-1), row_max(mi), 2); // exchange max in a quad in a row + row_max(mi) = fmaxf(row_max(mi), max_recv); + + // Two-level P quantization (default): s_P1 = rowmax(P̃)/(448×6), then s_P2,P̂_2 = φ(P̃/s_P1) + // - Pre-scales P to [0, 448×6] range before φ, output scaled by s_P1 + // Single-level P quantization: s_P2, P̂_2 = φ(P̃) directly (like V quantization) + // - No s_P1, just standard per-block FP4 quantization φ + const float s_P1_offset = single_level_p_quant ? 0.f : fp8_scalexfp4_scale_log2; + const float max_scaled = InfCheck + ? (row_max(mi) == -INFINITY ? 0.f : (row_max(mi) * softmax_scale_log2 + s_P1_offset)) + : (row_max(mi) * softmax_scale_log2 + s_P1_offset); + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(acc_reduction_view); ni++) { + acc_reduction_view(mi, ni) = flash::ptx_exp2(acc_reduction_view(mi, ni) * softmax_scale_log2 - max_scaled); + } + // s_P2 = max(P_block)/6 — per-block scale factor from φ function (same formula for both modes) + // The difference is in max_scaled: two-level includes 448×6 pre-scaling, single-level doesn't + CUTLASS_PRAGMA_UNROLL + for (int sfi = 0; sfi < size<1>(AbsMaxP); sfi++) { + AbsMaxP(mi, sfi) = flash::ptx_exp2(AbsMaxP(mi, sfi) * softmax_scale_log2 - max_scaled + fp4_scale_log2); + } + } + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(acc_reduction_view); mi++) { + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(acc_reduction_view); ni++) { + row_sum(mi) += acc_reduction_view(mi, ni); + } + } + } + else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size<0>(acc_reduction_view); mi++) { + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1, 1>(acc_reduction_view); ni++) { + float local_max = -INFINITY; + CUTLASS_PRAGMA_UNROLL + for (int ei = 0; ei < size<1, 0>(acc_reduction_view); ei++) { + local_max = fmaxf(local_max, acc_reduction_view(mi, make_coord(ei, ni))); + } + float max_recv = __shfl_xor_sync(int32_t(-1), local_max, 1); // exchange max with neighbour thread of 8 elements + AbsMaxP(mi, ni) = fmaxf(local_max, max_recv); + row_max(mi) = fmaxf(row_max(mi), AbsMaxP(mi, ni)); + } + + float max_recv = __shfl_xor_sync(int32_t(-1), row_max(mi), 2); // exchange max in a quad in a row + row_max(mi) = fmaxf(row_max(mi), max_recv); + + float scores_max_cur = !InfCheck + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + scores_scale(mi) = flash::ptx_exp2((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + + // Two-level P quantization (default): s_P1 = rowmax(P̃)/(448×6), then s_P2,P̂_2 = φ(P̃/s_P1) + // Single-level P quantization: s_P2, P̂_2 = φ(P̃) directly (like V quantization) + const float s_P1_offset = single_level_p_quant ? 0.f : fp8_scalexfp4_scale_log2; + const float max_scaled = InfCheck + ? (row_max(mi) == -INFINITY ? 0.f : (row_max(mi) * softmax_scale_log2 + s_P1_offset)) + : (row_max(mi) * softmax_scale_log2 + s_P1_offset); + row_sum(mi) = row_sum(mi) * scores_scale(mi); + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(acc_reduction_view); ni++) { + acc_reduction_view(mi, ni) = flash::ptx_exp2(acc_reduction_view(mi, ni) * softmax_scale_log2 - max_scaled); + row_sum(mi) += acc_reduction_view(mi, ni); + } + // s_P2 = max(P_block)/6 — per-block scale factor from φ function + CUTLASS_PRAGMA_UNROLL + for (int sfi = 0; sfi < size<1>(AbsMaxP); sfi++) { + AbsMaxP(mi, sfi) = flash::ptx_exp2(AbsMaxP(mi, sfi) * softmax_scale_log2 - max_scaled + fp4_scale_log2); + } + // scores_scale(mi) = max_scaled; + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(AbsMaxP); ++i) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(acc_conversion_flatten); ++j) + acc_conversion_flatten(j, i) /= AbsMaxP(i); + } + } + + template + CUTLASS_DEVICE void finalize(TensorAcc& o_store) { + Tensor o_store_reduction_view = make_tensor(o_store.data(), flash::convert_to_reduction_layout(o_store.layout())); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size(row_max); ++mi) { + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < RowReductionThr; i <<= 1) { + float sum_recv = __shfl_xor_sync(int32_t(-1), row_sum(mi), i); + row_sum(mi) += sum_recv; + } + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1 / sum; + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(o_store_reduction_view); ++ni) { + o_store_reduction_view(mi, ni) *= inv_sum; + } + } + } + + template + CUTLASS_DEVICE void rescale_o(TensorAcc& o_store, TensorAcc const& o_tmp) { + Tensor o_store_reduction_view = make_tensor(o_store.data(), flash::convert_to_reduction_layout(o_store.layout())); + Tensor o_tmp_reduction_view = make_tensor(o_tmp.data(), flash::convert_to_reduction_layout(o_tmp.layout())); + CUTLASS_PRAGMA_UNROLL + for (int mi = 0; mi < size(row_max); ++mi) { + CUTLASS_PRAGMA_UNROLL + for (int ni = 0; ni < size<1>(o_store_reduction_view); ++ni) { + o_store_reduction_view(mi, ni) = o_store_reduction_view(mi, ni) * scores_scale(mi) + o_tmp_reduction_view(mi, ni); + } + } + + } + + + }; + } // namespace flash \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/static_switch.h b/fastvideo-kernel/attn_qat_infer/blackwell/static_switch.h new file mode 100644 index 0000000000..e870643e7d --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/static_switch.h @@ -0,0 +1,83 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +// + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define PREC_SWITCH(PRECTYPE, ...) \ + [&] { \ + if (PRECTYPE == 1) { \ + using kPrecType = cutlass::half_t; \ + constexpr static bool kSoftFp16 = false; \ + constexpr static bool kHybrid = false; \ + return __VA_ARGS__(); \ + } else if (PRECTYPE == 2) { \ + using kPrecType = cutlass::float_e4m3_t; \ + constexpr static bool kSoftFp16 = false; \ + constexpr static bool kHybrid = false; \ + return __VA_ARGS__(); \ + } else if (PRECTYPE == 3) { \ + using kPrecType = cutlass::float_e4m3_t; \ + constexpr static bool kSoftFp16 = false; \ + constexpr static bool kHybrid = true; \ + return __VA_ARGS__(); \ + } else if (PRECTYPE == 4) { \ + using kPrecType = cutlass::float_e4m3_t; \ + constexpr static bool kSoftFp16 = true; \ + constexpr static bool kHybrid = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM == 64) { \ + constexpr static int kHeadSize = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 128) { \ + constexpr static int kHeadSize = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 256) { \ + constexpr static int kHeadSize = 256; \ + return __VA_ARGS__(); \ + } \ + }() + +#define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, SEQ_LEN_OUT_OF_BOUND_CHECK, ...) \ + [&] { \ + if (!USE_VAR_SEQ_LEN) { \ + if (SEQ_LEN_OUT_OF_BOUND_CHECK) { \ + using kSeqLenTraitsType = FixedSeqLenTraits; \ + return __VA_ARGS__(); \ + } else { \ + using kSeqLenTraitsType = FixedSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + } else { \ + using kSeqLenTraitsType = VarSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/tile_scheduler.h b/fastvideo-kernel/attn_qat_infer/blackwell/tile_scheduler.h new file mode 100644 index 0000000000..a559bef074 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/tile_scheduler.h @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * This code is based on code from FlashAttention3, https://github.com/Dao-AILab/flash-attention + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cutlass/fast_math.h" + +namespace flash { + +/////////////////////////////////////////////////////////////////////////////// + +class StaticPersistentTileSchedulerOld { + // + // Data members + // + +private: + int current_work_linear_idx_; + cutlass::FastDivmod const &m_block_divmod, &head_divmod; + int const total_blocks; + +public: + struct WorkTileInfo { + int M_idx = 0; + int H_idx = 0; + int B_idx = 0; + bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool + is_valid() const { + return is_valid_tile; + } + + CUTLASS_HOST_DEVICE + static WorkTileInfo + invalid_work_tile() { + return {-1, -1, -1, false}; + } + + }; + +public: + + CUTLASS_DEVICE explicit StaticPersistentTileSchedulerOld(cutlass::FastDivmod const &m_block_divmod_, + cutlass::FastDivmod const &head_divmod_, + int const total_blocks_) : + m_block_divmod(m_block_divmod_), head_divmod(head_divmod_), total_blocks(total_blocks_) { + + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) + // current_work_linear_idx_ = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + current_work_linear_idx_ = blockIdx.x; +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work() const { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work_for_linear_idx(int linear_idx) const { + if (linear_idx >= total_blocks) { + return WorkTileInfo::invalid_work_tile(); + } + + // Map worker's linear index into the CTA tiled problem shape to the corresponding MHB indices + int M_idx, H_idx, B_idx; + int quotient = m_block_divmod.divmod(M_idx, linear_idx); + B_idx = head_divmod.divmod(H_idx, quotient); + return {M_idx, H_idx, B_idx, true}; + } + + CUTLASS_DEVICE + void + // advance_to_next_work(int advance_count = 1) { + advance_to_next_work() { + // current_work_linear_idx_ += int(gridDim.x * gridDim.y * gridDim.z); + current_work_linear_idx_ += int(gridDim.x); + } + + CUTLASS_DEVICE + WorkTileInfo + fetch_next_work() { + WorkTileInfo new_work_tile_info; + advance_to_next_work(); + new_work_tile_info = get_current_work(); + return new_work_tile_info; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +class SingleTileScheduler { + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_head, num_batch; + int const* tile_count_semaphore = nullptr; + }; + + // Device side kernel params + struct Params {}; + + static Params + to_underlying_arguments(Arguments const& args) { + return {}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)}; + } + + struct WorkTileInfo { + int M_idx = 0; + int H_idx = 0; + int B_idx = 0; + bool is_valid_tile = false; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return is_valid_tile; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + return {M_idx, H_idx, B_idx}; + } + + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params) const { + return {-1, -1, -1, false}; + } + + }; + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; + } + + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {-1, -1, -1, false}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +class StaticPersistentTileScheduler { + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_head, num_batch; + int const* tile_count_semaphore = nullptr; + }; + + // Device side kernel params + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + return {args.num_blocks_m * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int m_block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, bidh, bidb}; + } + + }; + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {current_work.tile_idx + int(gridDim.x)}; + } + +}; + +class DynamicPersistentTileScheduler { + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_head, num_batch; + int const* tile_count_semaphore; + }; + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, head_divmod; + int const* tile_count_semaphore; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + return {args.num_blocks_m * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head), + args.tile_count_semaphore}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(num_sm)}; + } + + using WorkTileInfo = StaticPersistentTileScheduler::WorkTileInfo; + // struct WorkTileInfo { + // int tile_idx; + + // CUTLASS_DEVICE + // bool + // is_valid(Params const& params) const { + // return tile_idx < params.total_blocks; + // } + + // CUTLASS_DEVICE + // cute::tuple + // get_block_coord(Params const& params) const { + // int m_block, bidh, bidb; + // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); + // return {m_block, bidh, bidb}; + // } + + // }; + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {current_work.tile_idx + int(gridDim.x)}; + } + +}; + +} // flash diff --git a/fastvideo-kernel/attn_qat_infer/blackwell/utils.h b/fastvideo-kernel/attn_qat_infer/blackwell/utils.h new file mode 100644 index 0000000000..793270196e --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/blackwell/utils.h @@ -0,0 +1,408 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } +} + +__forceinline__ __device__ __half2 half_exp(__half2 x) { + uint32_t tmp_out, tmp_in; + tmp_in = reinterpret_cast(x); + asm ("ex2.approx.f16x2 %0, %1;\n" + : "=r"(tmp_out) + : "r"(tmp_in)); + __half2 out = reinterpret_cast<__half2&>(tmp_out); + return out; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = Check_inf + ? (max(mi) == -INFINITY ? 0.f : (max(mi) * (Scale_max ? scale : float(M_LOG2E)))) + : (max(mi) * (Scale_max ? scale : float(M_LOG2E))); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float ptx_exp2(float x) { + float y; + asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x)); + return y; +} + +CUTLASS_DEVICE void +packed_float_to_ue4m3( + float const &f0, float const &f1, float const &f2, float const &f3, + uint32_t &out +) { + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(out) : "f"(f0), "f"(f1), "f"(f2), "f"(f3)); +} + +CUTLASS_DEVICE void +packed_float_to_e2m1( + float const &f0, float const &f1, float const &f2, float const& f3, + float const &f4, float const &f5, float const &f6, float const& f7, + uint32_t &out +) { + + asm volatile( \ + "{\n" \ + ".reg .b8 byte0;\n" \ + ".reg .b8 byte1;\n" \ + ".reg .b8 byte2;\n" \ + ".reg .b8 byte3;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" \ + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" \ + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" \ + "}" \ + : "=r"(out) : "f"(f0), "f"(f1), "f"(f2), "f"(f3), + "f"(f4), "f"(f5), "f"(f6), "f"(f7)); + +} + +CUTLASS_DEVICE void +add(float2 & c, + float2 const& a, + float2 const& b) +{ +asm volatile("add.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +} + +CUTLASS_DEVICE void +add_inplace(float2 &a, + float2 const& b) +{ + asm volatile("add.f32x2 %0, %0, %1;\n" + : "+l"(reinterpret_cast(a)) // a: input/output + : "l"(reinterpret_cast(b)) // b: input + ); +} + + +CUTLASS_DEVICE void +sub(float2 & c, + float2 const& a, + float2 const& b) +{ +asm volatile("sub.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +} + +CUTLASS_DEVICE void +sub_inplace(float2 &a, + float2 const& b) +{ + asm volatile("sub.f32x2 %0, %0, %1;\n" + : "+l"(reinterpret_cast(a)) // a: input/output + : "l"(reinterpret_cast(b)) // b: input + ); +} + + +CUTLASS_DEVICE void +mul(float2 & c, + float2 const& a, + float2 const& b) +{ + asm volatile("mul.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +} + +CUTLASS_DEVICE void +fma(float2 & d, + float2 const& a, + float2 const& b, + float2 const& c) +{ + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +CUTLASS_DEVICE void +fma_inplace(float2 &a, + float2 const& b, + float2 const& c) +{ + asm volatile("fma.rn.f32x2 %0, %0, %1, %2;\n" + : "+l"(reinterpret_cast(a)) + : "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Layout +> +CUTLASS_DEVICE constexpr +auto convert_to_reduction_layout(Layout mma_layout) { + static_assert(rank(mma_layout) == 3, "Mma Layout should be (MmaAtom, MmaM, MmaN)"); + static_assert(rank(get<0>(shape(mma_layout))) == 2, "MmaAtom should be (AtomN, AtomM)"); + + return make_layout( + make_layout(get<0,1>(mma_layout), get<1>(mma_layout)), + make_layout(get<0,0>(mma_layout), get<2>(mma_layout)) + ); +} + +template < + class Tensor +> +CUTLASS_DEVICE constexpr +auto convert_to_reduction_tensor(Tensor mma_tensor) { + return make_tensor(mma_tensor.data(), convert_to_reduction_layout(mma_tensor.layout())); +} + + +template < + class Layout +> +CUTLASS_DEVICE constexpr +auto convert_to_conversion_layout(Layout mma_layout) { + static_assert(rank(mma_layout) == 3, "Mma Layout should be (MmaAtom, MmaM, MmaN)"); + static_assert(rank(get<0>(shape(mma_layout))) == 2, "MmaAtom should be (AtomN, AtomM)"); + + constexpr int MmaAtomN = size<0, 0>(mma_layout); + constexpr int MmaAtomM = size<0, 1>(mma_layout); + constexpr int MmaM = size<1>(mma_layout); + constexpr int MmaN = size<2>(mma_layout); + + static_assert(MmaAtomN == 8, "MmaAtomN should be 8."); + static_assert(MmaAtomM == 2, "MmaAtomM should be 2."); + static_assert(MmaN % 2 == 0, "MmaN should be multiple of 2."); + + auto mma_n_division = zipped_divide( + layout<2>(mma_layout), make_tile(_2{}) + ); + return make_layout( + make_layout(layout<0,0>(mma_layout), make_layout(layout<0,1>(mma_layout), layout<0>(mma_n_division))), + layout<1>(mma_layout), layout<1>(mma_n_division) + ); +} + +template < + class Tensor +> +CUTLASS_DEVICE constexpr +auto convert_to_conversion_tensor(Tensor mma_tensor) { + return make_tensor(mma_tensor.data(), convert_to_conversion_layout(mma_tensor.layout())); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +} // namespace flash diff --git a/fastvideo-kernel/attn_qat_infer/quantization/__init__.py b/fastvideo-kernel/attn_qat_infer/quantization/__init__.py new file mode 100644 index 0000000000..5d46677a1b --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/quantization/__init__.py @@ -0,0 +1 @@ +__version__ = "3.0.0.b1" \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_k.py b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_k.py new file mode 100644 index 0000000000..94c10681c1 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_k.py @@ -0,0 +1,90 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import torch +import fp4quant +from triton.tools.mxfp import MXFP4Tensor + +from bench_utils import bench_kineto +b = 1 +h = 32 +n = 16384 +d = 128 + +def test(): + q = torch.randn((b, h, n, d), device="cuda", dtype=torch.float16) + o = torch.empty((b, h, n, d // 2), device="cuda", dtype=torch.uint8) + o_s = torch.empty((b, h, n, d // 16), device="cuda", dtype=torch.float8_e4m3fn) + fp4quant.scaled_fp4_quant_permute(q, o, o_s, 1) + +test() + +t = bench_kineto(test, "scaled_fp4_quant_kernel", suppress_kineto_output=True) + +IO = b * h * n * d * 2 + b * h * n * d * 0.5 + b * h * n * d // 16 * 1 +throughput = IO / t * 1e-9 + +print(f"Throughput: {throughput:.2f} GB/s") + +def scale_and_fp4_tensor(x: torch.Tensor, packed_dim: int = 3, all_ones: bool = False, permuted: bool = False): + assert x.is_contiguous() and x.ndim == 4 and x.shape[-1] % 16 == 0 + B, H, M, N = x.shape + x = x.view(B, H, M, N // 16, 16) + scales = (x.abs().amax(dim=-1, keepdim=True) / 6).to(torch.float32) + if all_ones: + scales = torch.ones_like(scales) + x_scaled = x / scales + packed_fp4 = MXFP4Tensor(x_scaled.flatten(start_dim=-2)).to_packed_tensor(dim=packed_dim) + dequant_x = (MXFP4Tensor(x_scaled).to(torch.float32) * scales.to(torch.float8_e4m3fn).to(torch.float32)).flatten(start_dim=-2) + fp8_scale = scales.flatten(start_dim=-2).to(torch.float8_e4m3fn) + permuted_fp8_scale = None + if permuted: + scales = scales.view(B, H // 64, 4, 16, M, N // 16).permute(0, 1, 3, 2, 4, 5).reshape(B, H, M, N // 16) + permuted_fp8_scale = scales.view(B, H // 64, 64, M, N // 64, 4).permute(0, 1, 4, 3, 2, 5).reshape(B, H, M, N // 16).to(torch.float8_e4m3fn) + return fp8_scale, packed_fp4, dequant_x, permuted_fp8_scale + +b = 2 +h = 4 +n = 251 +n_padded = (n + 127) // 128 * 128 +d = 128 + +q = torch.randn(b, h, n, d, dtype=torch.float16, device='cuda') +o = torch.empty((b, h, n, d // 2), dtype=torch.uint8, device='cuda') +o_s = torch.empty((b, h, n, d // 16), dtype=torch.float8_e4m3fn, device='cuda') + +fp4quant.scaled_fp4_quant(q, o, o_s, 1) + +k_permute = [0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31] +o_permuted = torch.empty((b, h, n_padded, d // 2), dtype=torch.uint8, device='cuda') +o_s_permuted = torch.empty((b, h, n_padded, d // 16), dtype=torch.float8_e4m3fn, device='cuda') +fp4quant.scaled_fp4_quant_permute(q, o_permuted, o_s_permuted, 1) + +# padding +if n % 128 != 0: + o_permuted_gt = torch.cat([o, torch.zeros((b, h, n_padded - n, d // 2), dtype=torch.uint8, device='cuda')], dim=2) + o_s_permuted_gt = torch.cat([o_s, torch.zeros((b, h, n_padded - n, d // 16), dtype=torch.float8_e4m3fn, device='cuda')], dim=2) +else: + o_permuted_gt = o + o_s_permuted_gt = o_s + +# use scale_and_fp4_tensor + torch permutation to get the ground truth +o_permuted_gt = o_permuted_gt.reshape(b, h, n_padded // 32, 32, d // 2)[:, :, :, k_permute, :].reshape(b, h, n_padded, d // 2) +o_s_permuted_gt = o_s_permuted_gt.reshape(b, h, n_padded // 32, 32, d // 16)[:, :, :, k_permute, :].reshape(b, h, n_padded, d // 16) + +assert((o_permuted - o_permuted_gt).abs().max() == 0) +assert((o_s_permuted.float() - o_s_permuted_gt.float()).abs().max() == 0) + +print("All tests passed!") \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_q.py b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_q.py new file mode 100644 index 0000000000..ffbab0a7bf --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_q.py @@ -0,0 +1,86 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import torch +import fp4quant +from triton.tools.mxfp import MXFP4Tensor + +from bench_utils import bench_kineto +b = 1 +h = 32 +n = 16384 +d = 128 + +def test(): + q = torch.randn((b, h, n, d), device="cuda", dtype=torch.float16) + o = torch.empty((b, h, n, d // 2), device="cuda", dtype=torch.uint8) + o_s = torch.empty((b, h, n, d // 16), device="cuda", dtype=torch.float8_e4m3fn) + fp4quant.scaled_fp4_quant(q, o, o_s, 1) + +test() + +t = bench_kineto(test, "scaled_fp4_quant_kernel", suppress_kineto_output=True) + +IO = b * h * n * d * 2 + b * h * n * d * 0.5 + b * h * n * d // 16 * 1 +throughput = IO / t * 1e-9 + +print(f"Throughput: {throughput:.2f} GB/s") + +def scale_and_fp4_tensor(x: torch.Tensor, packed_dim: int = 3, all_ones: bool = False, permuted: bool = False): + assert x.is_contiguous() and x.ndim == 4 and x.shape[-1] % 16 == 0 + B, H, M, N = x.shape + x = x.view(B, H, M, N // 16, 16) + scales = (x.abs().amax(dim=-1, keepdim=True) / 6).to(torch.float32) + if all_ones: + scales = torch.ones_like(scales) + x_scaled = x / scales + packed_fp4 = MXFP4Tensor(x_scaled.flatten(start_dim=-2)).to_packed_tensor(dim=packed_dim) + dequant_x = (MXFP4Tensor(x_scaled).to(torch.float32) * scales.to(torch.float8_e4m3fn).to(torch.float32)).flatten(start_dim=-2) + fp8_scale = scales.flatten(start_dim=-2).to(torch.float8_e4m3fn) + permuted_fp8_scale = None + if permuted: + scales = scales.view(B, H // 64, 4, 16, M, N // 16).permute(0, 1, 3, 2, 4, 5).reshape(B, H, M, N // 16) + permuted_fp8_scale = scales.view(B, H // 64, 64, M, N // 64, 4).permute(0, 1, 4, 3, 2, 5).reshape(B, H, M, N // 16).to(torch.float8_e4m3fn) + return fp8_scale, packed_fp4, dequant_x, permuted_fp8_scale + +b = 2 +h = 4 +n = 251 +d = 128 + +q = torch.randn(b, h, n, d, dtype=torch.float16, device='cuda') +o = torch.empty((b, h, n, d // 2), dtype=torch.uint8, device='cuda') +o_s = torch.empty((b, h, n, d // 16), dtype=torch.float8_e4m3fn, device='cuda') + +fp4quant.scaled_fp4_quant(q, o, o_s, 1) + +fp8_scale, packed_fp4, dequant_x, permuted_fp8_scale = scale_and_fp4_tensor(q, packed_dim=3) + +assert((fp8_scale.float() - o_s.float()).abs().max() == 0) + +o_binary = [ + (int(bin_str[:4], 2), int(bin_str[4:], 2)) + for bin_str in [format(x.item(), '08b') for x in o.view(-1)] +] +o_binary_gt = [ + (int(bin_str[:4], 2), int(bin_str[4:], 2)) + for bin_str in [format(x.item(), '08b') for x in packed_fp4.view(-1)] +] +for i in range(len(o_binary)): + # check contiguous 4 bits. Difference should be at most one + assert(abs(o_binary[i][0] - o_binary_gt[i][0]) <= 1) + assert(abs(o_binary[i][1] - o_binary_gt[i][1]) <= 1) + +print("All tests passed!") \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_v.py b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_v.py new file mode 100644 index 0000000000..24b87b7005 --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_quant_v.py @@ -0,0 +1,86 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import torch +import fp4quant +from triton.tools.mxfp import MXFP4Tensor + +from bench_utils import bench_kineto +b = 1 +h = 32 +n = 16384 +d = 128 + +def test(): + q = torch.randn((b, h, n, d), device="cuda", dtype=torch.float16) + o = torch.empty((b, h, d, n // 2), device="cuda", dtype=torch.uint8) + o_s = torch.empty((b, h, d, n // 16), device="cuda", dtype=torch.float8_e4m3fn) + fp4quant.scaled_fp4_quant_trans(q, o, o_s, 1) + +test() + +t = bench_kineto(test, "scaled_fp4_quant_trans_kernel", suppress_kineto_output=True) + +IO = b * h * n * d * 2 + b * h * n * d * 0.5 + b * h * n * d // 16 * 1 +throughput = IO / t * 1e-9 + +print(f"Throughput: {throughput:.2f} GB/s") + +def scale_and_fp4_tensor(x: torch.Tensor, packed_dim: int = 3, all_ones: bool = False, permuted: bool = False): + assert x.is_contiguous() and x.ndim == 4 and x.shape[-1] % 16 == 0 + B, H, M, N = x.shape + x = x.view(B, H, M, N // 16, 16) + scales = (x.abs().amax(dim=-1, keepdim=True) / 6).to(torch.float32) + if all_ones: + scales = torch.ones_like(scales) + x_scaled = x / scales + packed_fp4 = MXFP4Tensor(x_scaled.flatten(start_dim=-2)).to_packed_tensor(dim=packed_dim) + dequant_x = (MXFP4Tensor(x_scaled).to(torch.float32) * scales.to(torch.float8_e4m3fn).to(torch.float32)).flatten(start_dim=-2) + fp8_scale = scales.flatten(start_dim=-2).to(torch.float8_e4m3fn) + permuted_fp8_scale = None + if permuted: + scales = scales.view(B, H // 64, 4, 16, M, N // 16).permute(0, 1, 3, 2, 4, 5).reshape(B, H, M, N // 16) + permuted_fp8_scale = scales.view(B, H // 64, 64, M, N // 64, 4).permute(0, 1, 4, 3, 2, 5).reshape(B, H, M, N // 16).to(torch.float8_e4m3fn) + return fp8_scale, packed_fp4, dequant_x, permuted_fp8_scale + +b = 2 +h = 4 +n = 491 +n_padded = (n + 127) // 128 * 128 +d = 128 + +q = torch.randn(b, h, n, d, dtype=torch.float16, device='cuda') +o = torch.empty((b, h, d, n_padded // 2), dtype=torch.uint8, device='cuda') +o_s = torch.empty((b, h, d, n_padded // 16), dtype=torch.float8_e4m3fn, device='cuda') + +fp4quant.scaled_fp4_quant_trans(q, o, o_s, 1) + +if n % 128 != 0: + q_padded = torch.cat([q, torch.zeros((b, h, n_padded - n, d), dtype=torch.float16, device='cuda')], dim=2) +else: + q_padded = q + +# use torch transpose + scaled_fp4_quant to get the ground truth +q_padded = q_padded.transpose(2, 3).reshape(b, h, n_padded, d).contiguous() +o_gt = torch.empty((b, h, n_padded, d // 2), dtype=torch.uint8, device='cuda') +o_s_gt = torch.empty((b, h, n_padded, d // 16), dtype=torch.float8_e4m3fn, device='cuda') +fp4quant.scaled_fp4_quant(q_padded, o_gt, o_s_gt, 1) +o_gt = o_gt.reshape(b, h, d, n_padded // 2).contiguous() +o_s_gt = o_s_gt.reshape(b, h, d, n_padded // 16).contiguous() + +assert((o_s_gt.float() - o_s.float()).abs().max() == 0) +assert((o_gt - o).abs().max() == 0) + +print("All tests passed!") \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_utils.py b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_utils.py new file mode 100644 index 0000000000..a7046de42f --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/quantization/bench/bench_utils.py @@ -0,0 +1,169 @@ +""" +Copyright (c) 2025 by SageAttention team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import sys +import torch +import torch.distributed as dist + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, + trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = False): + # Conflict with Nsight Systems + using_nsys = os.environ.get('DG_NSYS_PROFILING', False) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress() + with profiler: + for i in range(2): + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + if barrier_comm_profiling: + lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + lhs @ rhs + dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) + for _ in range(num_tests): + if flush_l2: + torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda').zero_() + fn() + + if not using_nsys: + profiler.step() + + # Return 1 if using Nsight Systems + if using_nsys: + return 1 + + # Parse the profiling table + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tupled = isinstance(kernel_names, tuple) + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + assert all([isinstance(name, str) for name in kernel_names]) + for name in kernel_names: + assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + for unit, scale in units.items(): + if unit in time_str: + kernel_times.append(float(time_str.replace(unit, '')) / scale) + break + break + return tuple(kernel_times) if is_tupled else kernel_times[0] + + +def calc_diff(x, y): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(tensors): + total = 0 + for t in tensors: + if isinstance(t, tuple): + total += count_bytes(t) + else: + total += t.numel() * t.element_size() + return total \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/quantization/cuda_utils.h b/fastvideo-kernel/attn_qat_infer/quantization/cuda_utils.h new file mode 100644 index 0000000000..89013d69da --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/quantization/cuda_utils.h @@ -0,0 +1,52 @@ +#pragma once + +#include + +#if defined(__HIPCC__) + #define HOST_DEVICE_INLINE __host__ __device__ + #define DEVICE_INLINE __device__ + #define HOST_INLINE __host__ +#elif defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ + #define DEVICE_INLINE __device__ __forceinline__ + #define HOST_INLINE __host__ __forceinline__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +int64_t get_device_attribute(int64_t attribute, int64_t device_id) { + static int value = [=]() { + int device = static_cast(device_id); + if (device < 0) { + CUDA_CHECK(cudaGetDevice(&device)); + } + int value; + CUDA_CHECK(cudaDeviceGetAttribute( + &value, static_cast(attribute), device)); + return static_cast(value); + }(); + + return value; +} + +namespace cuda_utils { + +template +HOST_DEVICE_INLINE constexpr std::enable_if_t, T> +ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +}; // namespace cuda_utils \ No newline at end of file diff --git a/fastvideo-kernel/attn_qat_infer/quantization/fp4_quantization_4d.cu b/fastvideo-kernel/attn_qat_infer/quantization/fp4_quantization_4d.cu new file mode 100644 index 0000000000..536a362edb --- /dev/null +++ b/fastvideo-kernel/attn_qat_infer/quantization/fp4_quantization_4d.cu @@ -0,0 +1,629 @@ +/* + * Copyright (c) 2025 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include + #include + #include + #include + #include + #include + #include + + #include + #include + + #include + + #include "cuda_utils.h" + #include "../blackwell/block_config.h" + + #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + if (pytorch_dtype == at::ScalarType::Half) { \ + using c_type = half; \ + __VA_ARGS__ \ + } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ + using c_type = nv_bfloat16; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + } + + #define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + if (head_dim == 64) { \ + constexpr int HEAD_DIM = 64; \ + __VA_ARGS__ \ + } else if (head_dim == 128) { \ + constexpr int HEAD_DIM = 128; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head dim: " << int(head_dim); \ + throw std::invalid_argument(err_msg.str()); \ + } + + #define CHECK_CUDA(x) \ + TORCH_CHECK(x.is_cuda(), "Tensor " #x " must be on CUDA") + #define CHECK_DTYPE(x, true_dtype) \ + TORCH_CHECK(x.dtype() == true_dtype, \ + "Tensor " #x " must have dtype (" #true_dtype ")") + #define CHECK_DIMS(x, true_dim) \ + TORCH_CHECK(x.dim() == true_dim, \ + "Tensor " #x " must have dimension number (" #true_dim ")") + #define CHECK_SHAPE(x, ...) \ + TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ + "Tensor " #x " must have shape (" #__VA_ARGS__ ")") + #define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), "Tensor " #x " must be contiguous") + #define CHECK_LASTDIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.stride(-1) == 1, \ + "Tensor " #x " must be contiguous at the last dimension") + + constexpr int CVT_FP4_ELTS_PER_THREAD = 16; + + // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). + inline __device__ uint32_t fp32_vec_to_e2m1(float2 *array) { + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; + #else + return 0; + #endif + } + + // Get type2 from type or vice versa (applied to half and bfloat16) + template + struct TypeConverter { + using Type = half2; + }; // keep for generality + + template <> + struct TypeConverter { + using Type = half; + }; + + template <> + struct TypeConverter { + using Type = half2; + }; + + template <> + struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; + }; + + template <> + struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; + }; + + // Define a 32 bytes packed data type. + template + struct PackedVec { + typename TypeConverter::Type elts[8]; + }; + + template + __global__ void scaled_fp4_quant_kernel( + const T* input, uint8_t* output, uint8_t* output_sf, + int batch_size, int num_heads, int num_tokens, + int stride_bz_input, int stride_h_input, int stride_seq_input, + int stride_bz_output, int stride_h_output, int stride_seq_output, + int stride_bz_output_sf, int stride_h_output_sf, int stride_seq_output_sf) { + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 input are supported"); + using PackedVec = PackedVec; + + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + const int token_block_id = blockIdx.x; + + static_assert(CVT_FP4_ELTS_PER_THREAD == 8 || CVT_FP4_ELTS_PER_THREAD == 16, + "CVT_FP4_ELTS_PER_THREAD must be 8 or 16"); + static_assert(sizeof(PackedVec) == sizeof(T) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + constexpr uint32_t NUM_THREADS_PER_TOKEN = head_dim / CVT_FP4_ELTS_PER_THREAD; + + // load input + const int token_id = token_block_id * BLOCK_SIZE + threadIdx.x / NUM_THREADS_PER_TOKEN; + + int load_token_id; + if constexpr (!permute) { + load_token_id = token_id; + } else { + int local_token_id = threadIdx.x / NUM_THREADS_PER_TOKEN; + int local_token_id_residue = local_token_id % 32; + // [0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, 22, 23, 30, 31] + load_token_id = token_block_id * BLOCK_SIZE + (local_token_id / 32) * 32 + + (local_token_id_residue / 8) * 2 + + ((local_token_id_residue % 8) / 2) * 8 + + (local_token_id_residue % 8) % 2; + } + + PackedVec in_vec; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + reinterpret_cast(in_vec.elts[i]) = 0; + } + + if (load_token_id < num_tokens) { + in_vec = reinterpret_cast(input + + batch_id * stride_bz_input + // batch dim + head_id * stride_h_input + // head dim + load_token_id * stride_seq_input + // seq dim + (threadIdx.x % NUM_THREADS_PER_TOKEN) * CVT_FP4_ELTS_PER_THREAD)[0]; // feature dim + } + + // calculate max of every consecutive 16 elements + auto localMax = __habs2(in_vec.elts[0]); + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { // local max + localMax = __hmax2(localMax, __habs2(in_vec.elts[i])); + } + + if constexpr (CVT_FP4_ELTS_PER_THREAD == 8) { // shuffle across two threads + localMax = __hmax2(__shfl_xor_sync(0xffffffff, localMax, 1, 32), localMax); + } + + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // scaling factor + float SFValue = vecMax / 6.0f; + uint8_t SFValueFP8; + reinterpret_cast<__nv_fp8_e4m3&>(SFValueFP8) = __nv_fp8_e4m3(SFValue); + SFValue = float(reinterpret_cast<__nv_fp8_e4m3&>(SFValueFP8)); + + float SFValueInv = (SFValue == 0.0f) ? 0.0f : 1.0f / SFValue; + + // convert input to float2 and apply scale + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same::value) { + fp2Vals[i] = __half22float2(in_vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(in_vec.elts[i]); + } + fp2Vals[i].x = fp2Vals[i].x * SFValueInv; + fp2Vals[i].y = fp2Vals[i].y * SFValueInv; + } + + // convert to e2m1 + uint32_t e2m1Vals[CVT_FP4_ELTS_PER_THREAD / 8]; + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 8; i++) { + e2m1Vals[i] = fp32_vec_to_e2m1(fp2Vals + i * 4); + } + + // save, do not check range + if constexpr (CVT_FP4_ELTS_PER_THREAD == 8) { + reinterpret_cast(output + + batch_id * stride_bz_output + + head_id * stride_h_output + + token_id * stride_seq_output + + (threadIdx.x % NUM_THREADS_PER_TOKEN) * CVT_FP4_ELTS_PER_THREAD / 2)[0] = e2m1Vals[0]; + } else { + reinterpret_cast(output + + batch_id * stride_bz_output + + head_id * stride_h_output + + token_id * stride_seq_output + + (threadIdx.x % NUM_THREADS_PER_TOKEN) * CVT_FP4_ELTS_PER_THREAD / 2)[0] = reinterpret_cast(e2m1Vals)[0]; + } + + uint8_t* output_sf_save_base = output_sf + batch_id * stride_bz_output_sf + head_id * stride_h_output_sf + (token_id / 64) * 64 * stride_seq_output_sf; + uint32_t token_id_local = token_id % 64; + + if constexpr (CVT_FP4_ELTS_PER_THREAD == 16) { + uint32_t col_id_local = threadIdx.x % NUM_THREADS_PER_TOKEN; + uint32_t offset_local = (col_id_local / 4) * 256 + (col_id_local % 4) + + (token_id_local / 16) * 4 + (token_id_local % 16) * 16; + reinterpret_cast(output_sf_save_base + offset_local)[0] = SFValueFP8; + } else { + if (threadIdx.x % 2 == 0) { + uint32_t col_id_local = (threadIdx.x % NUM_THREADS_PER_TOKEN) / 2; + uint32_t offset_local = (col_id_local / 4) * 256 + (col_id_local % 4) + + (token_id_local / 16) * 4 + (token_id_local % 16) * 16; + reinterpret_cast(output_sf_save_base + offset_local)[0] = SFValueFP8; + } + } + } + + template + __global__ void scaled_fp4_quant_trans_kernel( + const T* input, uint8_t* output, uint8_t* output_sf, + int batch_size, int num_heads, int num_tokens, + int stride_bz_input, int stride_h_input, int stride_seq_input, + int stride_bz_output, int stride_h_output, int stride_d_output, + int stride_bz_output_sf, int stride_h_output_sf, int stride_d_output_sf) { + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 input are supported"); + using PackedVec = PackedVec; + + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + const int token_block_id = blockIdx.x; + + static_assert(CVT_FP4_ELTS_PER_THREAD == 8 || CVT_FP4_ELTS_PER_THREAD == 16, + "CVT_FP4_ELTS_PER_THREAD must be 8 or 16"); + static_assert(sizeof(PackedVec) == sizeof(T) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + constexpr uint32_t NUM_THREADS_PER_TOKEN = head_dim / CVT_FP4_ELTS_PER_THREAD; + constexpr uint32_t NUM_THREADS_PER_SEQ = BLOCK_SIZE / CVT_FP4_ELTS_PER_THREAD; + + // load input + const int token_id = token_block_id * BLOCK_SIZE + threadIdx.x / NUM_THREADS_PER_TOKEN; + + PackedVec in_vec; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + reinterpret_cast(in_vec.elts[i]) = 0; + } + + if (token_id < num_tokens) { + in_vec = reinterpret_cast(input + + batch_id * stride_bz_input + // batch dim + head_id * stride_h_input + // head dim + token_id * stride_seq_input + // seq dim + (threadIdx.x % NUM_THREADS_PER_TOKEN) * CVT_FP4_ELTS_PER_THREAD)[0]; // feature dim + } + + // transpose + __shared__ T shared_input[BLOCK_SIZE * head_dim]; + reinterpret_cast(shared_input)[threadIdx.x] = in_vec; + __syncthreads(); + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + in_vec.elts[i].x = shared_input[(threadIdx.x / NUM_THREADS_PER_SEQ) + ((threadIdx.x % NUM_THREADS_PER_SEQ) * CVT_FP4_ELTS_PER_THREAD + 2 * i) * head_dim]; + in_vec.elts[i].y = shared_input[(threadIdx.x / NUM_THREADS_PER_SEQ) + ((threadIdx.x % NUM_THREADS_PER_SEQ) * CVT_FP4_ELTS_PER_THREAD + 2 * i + 1) * head_dim]; + } + + // calculate max of every consecutive 16 elements + auto localMax = __habs2(in_vec.elts[0]); + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { // local max + localMax = __hmax2(localMax, __habs2(in_vec.elts[i])); + } + + if constexpr (CVT_FP4_ELTS_PER_THREAD == 8) { // shuffle across two threads + localMax = __hmax2(__shfl_xor_sync(0xffffffff, localMax, 1, 32), localMax); + } + + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // scaling factor + float SFValue = vecMax / 6.0f; + uint8_t SFValueFP8; + reinterpret_cast<__nv_fp8_e4m3&>(SFValueFP8) = __nv_fp8_e4m3(SFValue); + SFValue = float(reinterpret_cast<__nv_fp8_e4m3&>(SFValueFP8)); + + float SFValueInv = (SFValue == 0.0f) ? 0.0f : 1.0f / SFValue; + + // convert input to float2 and apply scale + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same::value) { + fp2Vals[i] = __half22float2(in_vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(in_vec.elts[i]); + } + fp2Vals[i].x = fp2Vals[i].x * SFValueInv; + fp2Vals[i].y = fp2Vals[i].y * SFValueInv; + } + + // convert to e2m1 + uint32_t e2m1Vals[CVT_FP4_ELTS_PER_THREAD / 8]; + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 8; i++) { + e2m1Vals[i] = fp32_vec_to_e2m1(fp2Vals + i * 4); + } + + // save + if constexpr (CVT_FP4_ELTS_PER_THREAD == 8) { + reinterpret_cast(output + + batch_id * stride_bz_output + + head_id * stride_h_output + + (threadIdx.x / NUM_THREADS_PER_SEQ) * stride_d_output + + (token_block_id * BLOCK_SIZE + (threadIdx.x % NUM_THREADS_PER_SEQ) * CVT_FP4_ELTS_PER_THREAD) / 2)[0] = e2m1Vals[0]; + } else { + reinterpret_cast(output + + batch_id * stride_bz_output + + head_id * stride_h_output + + (threadIdx.x / NUM_THREADS_PER_SEQ) * stride_d_output + + (token_block_id * BLOCK_SIZE + (threadIdx.x % NUM_THREADS_PER_SEQ) * CVT_FP4_ELTS_PER_THREAD) / 2)[0] = reinterpret_cast(e2m1Vals)[0]; + } + + uint8_t *output_sf_save_base = output_sf + + batch_id * stride_bz_output_sf + + head_id * stride_h_output_sf + + (threadIdx.x / NUM_THREADS_PER_SEQ / 64) * 64 * stride_d_output_sf; + uint32_t row_id_local = (threadIdx.x / NUM_THREADS_PER_SEQ) % 64; + + if constexpr (CVT_FP4_ELTS_PER_THREAD == 16) { + uint32_t col_id_local = token_block_id * BLOCK_SIZE / CVT_FP4_ELTS_PER_THREAD + threadIdx.x % NUM_THREADS_PER_SEQ; + uint32_t offset_local = (col_id_local / 4) * 256 + (col_id_local % 4) + + (row_id_local / 16) * 4 + (row_id_local % 16) * 16; + reinterpret_cast(output_sf_save_base + offset_local)[0] = SFValueFP8; + } else { + if (threadIdx.x % 2 == 0) { + uint32_t col_id_local = token_block_id * BLOCK_SIZE / CVT_FP4_ELTS_PER_THREAD + (threadIdx.x % NUM_THREADS_PER_SEQ) / 2; + uint32_t offset_local = (col_id_local / 4) * 256 + (col_id_local % 4) + + (row_id_local / 16) * 4 + (row_id_local % 16) * 16; + reinterpret_cast(output_sf_save_base + offset_local)[0] = SFValueFP8; + } + } + } + + void scaled_fp4_quant(torch::Tensor const& input, + torch::Tensor const& output, + torch::Tensor const& output_sf, + int tensor_layout) { + constexpr int BLOCK_SIZE = flash::BLOCK_M; + + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(output_sf); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_LASTDIM_CONTIGUOUS(output_sf); + + CHECK_DTYPE(output, at::ScalarType::Byte); + CHECK_DTYPE(output_sf, at::ScalarType::Float8_e4m3fn); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(output_sf, 4); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + const int stride_bz_input = input.stride(0); + const int stride_bz_output = output.stride(0); + const int stride_bz_output_sf = output_sf.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_seq_output, stride_seq_output_sf; + int stride_h_input, stride_h_output, stride_h_output_sf; + if (tensor_layout == 0) { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_seq_output = output.stride(1); + stride_seq_output_sf = output_sf.stride(1); + stride_h_input = input.stride(2); + stride_h_output = output.stride(2); + stride_h_output_sf = output_sf.stride(2); + + CHECK_SHAPE(output, batch_size, num_tokens, num_heads, head_dim / 2); + CHECK_SHAPE(output_sf, batch_size, num_tokens, num_heads, head_dim / 16); + } else { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_seq_output = output.stride(2); + stride_seq_output_sf = output_sf.stride(2); + stride_h_input = input.stride(1); + stride_h_output = output.stride(1); + stride_h_output_sf = output_sf.stride(1); + + CHECK_SHAPE(output, batch_size, num_heads, num_tokens, head_dim / 2); + CHECK_SHAPE(output_sf, batch_size, num_heads, num_tokens, head_dim / 16); + } + + auto input_dtype = input.scalar_type(); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 block(BLOCK_SIZE * HEAD_DIM / CVT_FP4_ELTS_PER_THREAD, 1, 1); + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, batch_size, num_heads); + + scaled_fp4_quant_kernel + <<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_sf.data_ptr()), + batch_size, num_heads, num_tokens, + stride_bz_input, stride_h_input, stride_seq_input, + stride_bz_output, stride_h_output, stride_seq_output, + stride_bz_output_sf, stride_h_output_sf, stride_seq_output_sf); + }); + }); + } + + void scaled_fp4_quant_permute(torch::Tensor const& input, + torch::Tensor const& output, + torch::Tensor const& output_sf, + int tensor_layout) { + constexpr int BLOCK_SIZE = flash::BLOCK_M; + + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(output_sf); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_LASTDIM_CONTIGUOUS(output_sf); + + CHECK_DTYPE(output, at::ScalarType::Byte); + CHECK_DTYPE(output_sf, at::ScalarType::Float8_e4m3fn); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(output_sf, 4); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + const int stride_bz_input = input.stride(0); + const int stride_bz_output = output.stride(0); + const int stride_bz_output_sf = output_sf.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_seq_output, stride_seq_output_sf; + int stride_h_input, stride_h_output, stride_h_output_sf; + if (tensor_layout == 0) { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_seq_output = output.stride(1); + stride_seq_output_sf = output_sf.stride(1); + stride_h_input = input.stride(2); + stride_h_output = output.stride(2); + stride_h_output_sf = output_sf.stride(2); + + CHECK_SHAPE(output, batch_size, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE, num_heads, head_dim / 2); + CHECK_SHAPE(output_sf, batch_size, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE, num_heads, head_dim / 16); + } else { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_seq_output = output.stride(2); + stride_seq_output_sf = output_sf.stride(2); + stride_h_input = input.stride(1); + stride_h_output = output.stride(1); + stride_h_output_sf = output_sf.stride(1); + + CHECK_SHAPE(output, batch_size, num_heads, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE, head_dim / 2); + CHECK_SHAPE(output_sf, batch_size, num_heads, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE, head_dim / 16); + } + + auto input_dtype = input.scalar_type(); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr int BLOCK_SIZE = flash::BLOCK_M; + dim3 block(BLOCK_SIZE * HEAD_DIM / CVT_FP4_ELTS_PER_THREAD, 1, 1); + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, batch_size, num_heads); + + scaled_fp4_quant_kernel + <<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_sf.data_ptr()), + batch_size, num_heads, num_tokens, + stride_bz_input, stride_h_input, stride_seq_input, + stride_bz_output, stride_h_output, stride_seq_output, + stride_bz_output_sf, stride_h_output_sf, stride_seq_output_sf); + }); + }); + } + + void scaled_fp4_quant_trans(torch::Tensor const& input, + torch::Tensor const& output, + torch::Tensor const& output_sf, + int tensor_layout) { + constexpr int BLOCK_SIZE = flash::BLOCK_M; + + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(output_sf); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_LASTDIM_CONTIGUOUS(output_sf); + + CHECK_DTYPE(output, at::ScalarType::Byte); + CHECK_DTYPE(output_sf, at::ScalarType::Float8_e4m3fn); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(output_sf, 4); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + const int stride_bz_input = input.stride(0); + const int stride_bz_output = output.stride(0); + const int stride_bz_output_sf = output_sf.stride(0); + + int num_tokens, num_heads; + int stride_seq_input; + int stride_d_output, stride_d_output_sf; + int stride_h_input, stride_h_output, stride_h_output_sf; + if (tensor_layout == 0) { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_d_output = output.stride(1); + stride_d_output_sf = output_sf.stride(1); + stride_h_input = input.stride(2); + stride_h_output = output.stride(2); + stride_h_output_sf = output_sf.stride(2); + + CHECK_SHAPE(output, batch_size, head_dim, num_heads, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE / 2); + CHECK_SHAPE(output_sf, batch_size, head_dim, num_heads, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE / 16); + } else { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_d_output = output.stride(2); + stride_d_output_sf = output_sf.stride(2); + stride_h_input = input.stride(1); + stride_h_output = output.stride(1); + stride_h_output_sf = output_sf.stride(1); + + CHECK_SHAPE(output, batch_size, num_heads, head_dim, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE / 2); + CHECK_SHAPE(output_sf, batch_size, num_heads, head_dim, ((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE / 16); + } + + auto input_dtype = input.scalar_type(); + auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 block(BLOCK_SIZE * HEAD_DIM / CVT_FP4_ELTS_PER_THREAD, 1, 1); + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, batch_size, num_heads); + + scaled_fp4_quant_trans_kernel + <<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(output_sf.data_ptr()), + batch_size, num_heads, num_tokens, + stride_bz_input, stride_h_input, stride_seq_input, + stride_bz_output, stride_h_output, stride_d_output, + stride_bz_output_sf, stride_h_output_sf, stride_d_output_sf); + }); + }); + } + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("scaled_fp4_quant", &scaled_fp4_quant); + m.def("scaled_fp4_quant_permute", &scaled_fp4_quant_permute); + m.def("scaled_fp4_quant_trans", &scaled_fp4_quant_trans); + } \ No newline at end of file diff --git a/fastvideo-kernel/benchmarks/_bootstrap.py b/fastvideo-kernel/benchmarks/_bootstrap.py new file mode 100644 index 0000000000..50a5d146f9 --- /dev/null +++ b/fastvideo-kernel/benchmarks/_bootstrap.py @@ -0,0 +1,14 @@ +"""Make local benchmark scripts runnable from common repo entrypoints.""" + +from pathlib import Path +import sys + + +BENCHMARKS_DIR = Path(__file__).resolve().parent +KERNEL_ROOT = BENCHMARKS_DIR.parent +REPO_ROOT = KERNEL_ROOT.parent + +for path in (KERNEL_ROOT, REPO_ROOT): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) diff --git a/fastvideo-kernel/benchmarks/benchmark_blockscaled_fp4_attn.py b/fastvideo-kernel/benchmarks/benchmark_blockscaled_fp4_attn.py new file mode 100644 index 0000000000..9655a5301d --- /dev/null +++ b/fastvideo-kernel/benchmarks/benchmark_blockscaled_fp4_attn.py @@ -0,0 +1,287 @@ +import sys +import traceback + +import _bootstrap # noqa: F401 +import torch + +from attn_qat_infer.api import ( + blockscaled_fp4_attn, + preprocess_qkv, + scale_and_quant_fp4, + scale_and_quant_fp4_permute, + scale_and_quant_fp4_transpose, +) +from attn_qat_infer.quantization.bench.bench_utils import bench + + +def calculate_attention_flops(batch_size, num_heads, seq_len_q, seq_len_k, head_dim, is_causal=False): + """Calculate FLOPs for attention (FlashAttention standard - matmuls only).""" + f = 4 * batch_size * num_heads * seq_len_q * seq_len_k * head_dim + if is_causal: + f = f // 2 + return f + + +def benchmark_blockscaled_fp4_attn(batch_size, num_heads, seq_len, head_dim, + is_causal=False, dtype=torch.bfloat16, + per_block_mean=True, single_level_p_quant=False, + num_warmups=100, num_tests=1000): + """ + Benchmark blockscaled_fp4_attn function (excluding quantization overhead). + + This benchmarks ONLY the core FP4 attention kernel, with pre-quantized inputs. + The quantization step is performed once before benchmarking. + + Args: + batch_size: Batch size + num_heads: Number of attention heads + seq_len: Sequence length (same for Q, K, V) + head_dim: Head dimension + is_causal: Whether to use causal masking + dtype: Data type (torch.bfloat16 or torch.float16) + per_block_mean: Whether to use per-block mean for Q smoothing + single_level_p_quant: If True, use single-level quantization for P matrix + num_warmups: Number of warmup iterations + num_tests: Number of test iterations + + Returns: + dict with performance metrics + """ + device = 'cuda' + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. This benchmark requires a CUDA device.") + + # Create input tensors + q = torch.randn(batch_size, num_heads, seq_len, head_dim, + device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, seq_len, head_dim, + device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, seq_len, head_dim, + device=device, dtype=dtype) + + # Pre-process and quantize inputs (done once, not included in benchmark) + is_bf16 = dtype == torch.bfloat16 + KL = k.size(2) + q_processed, k_processed, v_processed, delta_s = preprocess_qkv(q, k, v, per_block_mean) + qlist = scale_and_quant_fp4(q_processed) + klist = scale_and_quant_fp4_permute(k_processed) + vlist = scale_and_quant_fp4_transpose(v_processed) + + # Synchronize to ensure quantization is complete + torch.cuda.synchronize() + + # Create closure for benchmarking (only the attention kernel) + def run_attention(): + return blockscaled_fp4_attn( + qlist, klist, vlist, + delta_s, + KL, + is_causal=is_causal, + per_block_mean=per_block_mean, + is_bf16=is_bf16, + single_level_p_quant=single_level_p_quant + ) + + # Benchmark using the bench utility (handles warmup and timing) + avg_time_ms = bench(run_attention, num_warmups=num_warmups, num_tests=num_tests) + avg_time_s = avg_time_ms / 1000.0 + + # Calculate FLOPs (use processed sequence length after padding) + processed_seq_len = q_processed.size(2) + total_flops = calculate_attention_flops( + batch_size, num_heads, processed_seq_len, processed_seq_len, head_dim, is_causal + ) + + # Calculate TFLOPs + tflops = total_flops / (avg_time_s * 1e12) + + # Calculate throughput (tokens/sec) - use original seq_len for meaningful metric + tokens_per_second = (batch_size * seq_len) / avg_time_s + + return { + 'batch_size': batch_size, + 'num_heads': num_heads, + 'seq_len': seq_len, + 'processed_seq_len': processed_seq_len, + 'head_dim': head_dim, + 'is_causal': is_causal, + 'dtype': str(dtype), + 'per_block_mean': per_block_mean, + 'single_level_p_quant': single_level_p_quant, + 'avg_time_ms': avg_time_ms, + 'avg_time_s': avg_time_s, + 'total_flops': total_flops, + 'tflops': tflops, + 'tokens_per_second': tokens_per_second, + } + + +def print_results(results): + """Print benchmark results in a formatted table.""" + print("\n" + "="*100) + print("blockscaled_fp4_attn Benchmark Results (Kernel Only, No Quantization)") + print("="*100) + print(f"Configuration:") + print(f" Batch Size: {results['batch_size']}") + print(f" Num Heads: {results['num_heads']}") + print(f" Sequence Length: {results['seq_len']}") + print(f" Processed Seq Length: {results['processed_seq_len']} (after padding)") + print(f" Head Dimension: {results['head_dim']}") + print(f" Causal: {results['is_causal']}") + print(f" Data Type: {results['dtype']}") + print(f" Per Block Mean: {results['per_block_mean']}") + print(f" Single Level P Quant: {results['single_level_p_quant']}") + print(f"\nPerformance:") + print(f" Average Time: {results['avg_time_ms']:.3f} ms") + print(f" Total FLOPs: {results['total_flops']/1e12:.4f} TFLOPs (theoretical)") + print(f" Throughput: {results['tflops']:.4f} TFLOPs/s") + print(f" Tokens/sec: {results['tokens_per_second']:,.0f}") + print("="*100 + "\n") + sys.stdout.flush() + + +def run_benchmark_suite(): + """Run a comprehensive benchmark suite with various configurations.""" + print("Starting blockscaled_fp4_attn Benchmark Suite (Kernel Only)...") + print(f"CUDA Device: {torch.cuda.get_device_name(0)}") + print(f"CUDA Version: {torch.version.cuda}") + print(f"PyTorch Version: {torch.__version__}\n") + print("Note: This benchmark measures only the FP4 attention kernel,") + print(" excluding quantization overhead.\n") + sys.stdout.flush() + + # Default configurations to test + # (batch_size, num_heads, seq_len, head_dim, is_causal, dtype) + configs = [ + (1, 16, 512, 64, False, torch.bfloat16), + (1, 16, 1024, 64, False, torch.bfloat16), + (1, 16, 2048, 64, False, torch.bfloat16), + (1, 16, 4096, 64, False, torch.bfloat16), + (1, 16, 8192, 64, False, torch.bfloat16), + (1, 16, 16384, 64, False, torch.bfloat16), + + (1, 16, 512, 128, False, torch.bfloat16), + (1, 16, 1024, 128, False, torch.bfloat16), + (1, 16, 2048, 128, False, torch.bfloat16), + (1, 16, 4096, 128, False, torch.bfloat16), + (1, 16, 8192, 128, False, torch.bfloat16), + (1, 16, 16384, 128, False, torch.bfloat16), + + (1, 32, 512, 64, False, torch.bfloat16), + (1, 32, 1024, 64, False, torch.bfloat16), + (1, 32, 2048, 64, False, torch.bfloat16), + (1, 32, 4096, 64, False, torch.bfloat16), + (1, 32, 8192, 64, False, torch.bfloat16), + (1, 32, 16384, 64, False, torch.bfloat16), + + (1, 32, 512, 128, False, torch.bfloat16), + (1, 32, 1024, 128, False, torch.bfloat16), + (1, 32, 2048, 128, False, torch.bfloat16), + (1, 32, 4096, 128, False, torch.bfloat16), + (1, 32, 8192, 128, False, torch.bfloat16), + (1, 32, 16384, 128, False, torch.bfloat16), + ] + + all_results = [] + + for config in configs: + batch_size, num_heads, seq_len, head_dim, is_causal, dtype = config + + print(f"\nBenchmarking: B={batch_size}, H={num_heads}, L={seq_len}, D={head_dim}, " + f"Causal={is_causal}, dtype={dtype}...") + sys.stdout.flush() + + try: + results = benchmark_blockscaled_fp4_attn( + batch_size=batch_size, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + is_causal=is_causal, + dtype=dtype, + num_warmups=10, + num_tests=50 + ) + + print_results(results) + all_results.append(results) + + except Exception as e: + print(f"Error benchmarking configuration {config}:") + print(f" Exception: {e}") + traceback.print_exc() + sys.stdout.flush() + continue + + # Print summary table + print("\n" + "="*120) + print("Summary Table (blockscaled_fp4_attn Kernel Only)") + print("="*120) + print(f"{'B':<4} {'H':<4} {'L':<6} {'D':<4} {'Causal':<7} {'Time (ms)':<12} {'TFLOPs/s':<12} {'Tokens/s':<15}") + print("-"*120) + + for r in all_results: + print(f"{r['batch_size']:<4} {r['num_heads']:<4} {r['seq_len']:<6} {r['head_dim']:<4} " + f"{str(r['is_causal']):<7} {r['avg_time_ms']:<12.3f} {r['tflops']:<12.4f} " + f"{r['tokens_per_second']:<15,.0f}") + + print("="*120 + "\n") + sys.stdout.flush() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Benchmark blockscaled_fp4_attn kernel (excluding quantization)') + parser.add_argument('--batch-size', type=int, default=None, help='Batch size') + parser.add_argument('--num-heads', type=int, default=None, help='Number of attention heads') + parser.add_argument('--seq-len', type=int, default=None, help='Sequence length') + parser.add_argument('--head-dim', type=int, default=None, help='Head dimension') + parser.add_argument('--causal', action='store_true', help='Use causal attention') + parser.add_argument('--dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], + help='Data type') + parser.add_argument('--per-block-mean', action='store_true', default=True, + help='Use per-block mean for Q smoothing (default: True)') + parser.add_argument('--no-per-block-mean', action='store_false', dest='per_block_mean', + help='Disable per-block mean for Q smoothing') + parser.add_argument('--single-level-p-quant', action='store_true', default=False, + help='Use single-level P quantization (default: False)') + parser.add_argument('--two-level-p-quant', action='store_false', dest='single_level_p_quant', + help='Use two-level P quantization') + parser.add_argument('--num-warmups', type=int, default=10, help='Number of warmup iterations') + parser.add_argument('--num-tests', type=int, default=50, help='Number of test iterations') + parser.add_argument('--suite', action='store_true', help='Run full benchmark suite') + + args = parser.parse_args() + + dtype_map = { + 'bfloat16': torch.bfloat16, + 'float16': torch.float16 + } + + if args.suite: + run_benchmark_suite() + elif args.batch_size and args.num_heads and args.seq_len and args.head_dim: + results = benchmark_blockscaled_fp4_attn( + batch_size=args.batch_size, + num_heads=args.num_heads, + seq_len=args.seq_len, + head_dim=args.head_dim, + is_causal=args.causal, + dtype=dtype_map[args.dtype], + per_block_mean=args.per_block_mean, + single_level_p_quant=args.single_level_p_quant, + num_warmups=args.num_warmups, + num_tests=args.num_tests + ) + print_results(results) + else: + print("Running default benchmark suite. Use --suite for full suite or provide all parameters.") + print( + "Example: python benchmarks/benchmark_blockscaled_fp4_attn.py " + "--batch-size 1 --num-heads 16 --seq-len 4096 --head-dim 128" + ) + print("\nNote: This benchmark measures only the FP4 attention kernel,") + print(" excluding quantization overhead.") + sys.stdout.flush() + run_benchmark_suite() diff --git a/fastvideo-kernel/benchmarks/benchmark_combined.py b/fastvideo-kernel/benchmarks/benchmark_combined.py new file mode 100644 index 0000000000..f83b54243b --- /dev/null +++ b/fastvideo-kernel/benchmarks/benchmark_combined.py @@ -0,0 +1,380 @@ +import argparse +import sys +from typing import Dict, List, Optional + +import _bootstrap # noqa: F401 +import matplotlib +matplotlib.use('Agg') # Use non-interactive backend for server environments +import matplotlib.pyplot as plt +import numpy as np +import torch + +from flash_attn import flash_attn_func +from attn_qat_infer.quantization.bench.bench_utils import bench + +# Import SageAttn components for direct control +from attn_qat_infer.api import ( + preprocess_qkv, + scale_and_quant_fp4, + scale_and_quant_fp4_permute, + scale_and_quant_fp4_transpose, + blockscaled_fp4_attn +) + + +def calculate_attention_flops(batch_size, num_heads, seq_len_q, seq_len_k, head_dim, is_causal=False): + """Calculate FLOPs for attention (FlashAttention standard - matmuls only).""" + f = 4 * batch_size * num_heads * seq_len_q * seq_len_k * head_dim + if is_causal: + f = f // 2 + return f + + +def sageattn_blackwell_configurable(q, k, v, is_causal=False, per_block_mean=True, + single_level_p_quant=True, + enable_smoothing_q=False, enable_smoothing_k=False): + """ + Configurable SageAttention3 Blackwell kernel with explicit smoothing control. + + Args: + q: Query tensor [B, H, L, D] + k: Key tensor [B, H, L, D] + v: Value tensor [B, H, L, D] + is_causal: Whether to use causal masking + per_block_mean: Whether to use per-block mean for Q smoothing + single_level_p_quant: If True, use single-level quantization for P matrix + enable_smoothing_q: Enable Q smoothing + enable_smoothing_k: Enable K smoothing + + Returns: + Output tensor [B, H, L, D] + """ + QL = q.size(2) + KL = k.size(2) + is_bf16 = q.dtype == torch.bfloat16 + + # Preprocess with explicit smoothing control + q, k, v, delta_s = preprocess_qkv(q, k, v, per_block_mean, enable_smoothing_q, enable_smoothing_k) + + qlist_from_cuda = scale_and_quant_fp4(q) + klist_from_cuda = scale_and_quant_fp4_permute(k) + vlist_from_cuda = scale_and_quant_fp4_transpose(v) + + o_fp4 = blockscaled_fp4_attn( + qlist_from_cuda, + klist_from_cuda, + vlist_from_cuda, + delta_s, + KL, + is_causal, + per_block_mean, + is_bf16, + single_level_p_quant + )[0][:, :, :QL, :].contiguous() + + return o_fp4 + + +def benchmark_flashattn2(batch_size, num_heads, seq_len, head_dim, + is_causal=False, dtype=torch.bfloat16, + num_warmups=10, num_tests=50): + """Benchmark FlashAttention2.""" + device = 'cuda' + + # FlashAttention2 expects (batch, seq_len, num_heads, head_dim) + q = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) + + def run_attention(): + return flash_attn_func(q, k, v, causal=is_causal) + + avg_time_ms = bench(run_attention, num_warmups=num_warmups, num_tests=num_tests) + return avg_time_ms + + +def benchmark_sageattn3(batch_size, num_heads, seq_len, head_dim, + is_causal=False, dtype=torch.bfloat16, + per_block_mean=True, single_level_p_quant=False, + enable_smoothing_q=True, enable_smoothing_k=True, + num_warmups=10, num_tests=50): + """Benchmark SageAttention3 with configurable smoothing.""" + device = 'cuda' + + # SageAttn expects (batch, num_heads, seq_len, head_dim) + q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype) + + def run_attention(): + return sageattn_blackwell_configurable( + q, k, v, + is_causal=is_causal, + per_block_mean=per_block_mean, + single_level_p_quant=single_level_p_quant, + enable_smoothing_q=enable_smoothing_q, + enable_smoothing_k=enable_smoothing_k + ) + + avg_time_ms = bench(run_attention, num_warmups=num_warmups, num_tests=num_tests) + return avg_time_ms + + +def time_to_tflops(time_ms, batch_size, num_heads, seq_len, head_dim, is_causal=False): + """Convert time to TFLOPs (Tera FLOPs per Second).""" + total_flops = calculate_attention_flops(batch_size, num_heads, seq_len, seq_len, head_dim, is_causal) + time_s = time_ms / 1000.0 + tflops = total_flops / (time_s * 1e12) + return tflops + + +def run_benchmark_suite(head_dim=64, is_causal=False, num_heads=12, batch_size=1, + num_warmups=10, num_tests=50, + seq_lens=None, output_file="benchmark_attention.png"): + """ + Run comprehensive benchmark suite and generate plot. + + Args: + head_dim: Head dimension (64 or 128) + is_causal: Whether to use causal attention + num_heads: Number of attention heads + batch_size: Batch size + num_warmups: Number of warmup iterations + num_tests: Number of test iterations + seq_lens: List of sequence lengths to test + output_file: Output plot filename + """ + if seq_lens is None: + seq_lens = [1024, 2048, 4096, 8192, 16384, 32768] + + device_name = torch.cuda.get_device_name(0) + # Extract short name (e.g., "RTX5090" from full name) + short_name = device_name.split()[-1] if 'RTX' in device_name or 'A100' in device_name else device_name[:20] + + print(f"Starting Combined Attention Benchmark Suite...") + print(f"CUDA Device: {device_name}") + print(f"CUDA Version: {torch.version.cuda}") + print(f"PyTorch Version: {torch.__version__}") + print(f"Head Dim: {head_dim}, Causal: {is_causal}, Num Heads: {num_heads}, Batch Size: {batch_size}") + print("="*80) + sys.stdout.flush() + + # Results storage: {method_name: {seq_len: tflops}} + results: Dict[str, Dict[int, Optional[float]]] = { + 'FlashAttn': {}, + 'SageAttn3': {}, + 'FP4': {}, + } + + dtype = torch.bfloat16 + + for seq_len in seq_lens: + print(f"\n--- Sequence Length: {seq_len} ---") + sys.stdout.flush() + + # FlashAttention2 + print(f" Benchmarking FlashAttn2...", end=" ") + sys.stdout.flush() + try: + time_ms = benchmark_flashattn2( + batch_size, num_heads, seq_len, head_dim, + is_causal=is_causal, dtype=dtype, + num_warmups=num_warmups, num_tests=num_tests + ) + tflops = time_to_tflops(time_ms, batch_size, num_heads, seq_len, head_dim, is_causal) + results['FlashAttn'][seq_len] = tflops + print(f"{tflops:.0f} TFLOPs ({time_ms:.3f} ms)") + except Exception as e: + print(f"OOM or Error: {e}") + results['FlashAttn'][seq_len] = None + sys.stdout.flush() + + # SageAttn3 (with smoothing: single_level_p_quant=False, enable_smoothing_q=True, enable_smoothing_k=True) + print(f" Benchmarking SageAttn3 (smoothing ON)...", end=" ") + sys.stdout.flush() + try: + time_ms = benchmark_sageattn3( + batch_size, num_heads, seq_len, head_dim, + is_causal=is_causal, dtype=dtype, + per_block_mean=True, + single_level_p_quant=False, + enable_smoothing_q=True, + enable_smoothing_k=True, + num_warmups=num_warmups, num_tests=num_tests + ) + tflops = time_to_tflops(time_ms, batch_size, num_heads, seq_len, head_dim, is_causal) + results['SageAttn3'][seq_len] = tflops + print(f"{tflops:.0f} TFLOPs ({time_ms:.3f} ms)") + except Exception as e: + print(f"OOM or Error: {e}") + results['SageAttn3'][seq_len] = None + sys.stdout.flush() + + # FP4 (no smoothing: single_level_p_quant=True, enable_smoothing_q=False, enable_smoothing_k=False) + print(f" Benchmarking FP4 (smoothing OFF)...", end=" ") + sys.stdout.flush() + try: + time_ms = benchmark_sageattn3( + batch_size, num_heads, seq_len, head_dim, + is_causal=is_causal, dtype=dtype, + per_block_mean=True, + single_level_p_quant=True, + enable_smoothing_q=False, + enable_smoothing_k=False, + num_warmups=num_warmups, num_tests=num_tests + ) + tflops = time_to_tflops(time_ms, batch_size, num_heads, seq_len, head_dim, is_causal) + results['FP4'][seq_len] = tflops + print(f"{tflops:.0f} TFLOPs ({time_ms:.3f} ms)") + except Exception as e: + print(f"OOM or Error: {e}") + results['FP4'][seq_len] = None + sys.stdout.flush() + + # Print summary table + print("\n" + "="*100) + print("Summary Table (TFLOPs)") + print("="*100) + header = f"{'SeqLen':<10}" + for method in results.keys(): + header += f"{method:<15}" + print(header) + print("-"*100) + + for seq_len in seq_lens: + row = f"{seq_len:<10}" + for method in results.keys(): + val = results[method].get(seq_len) + if val is not None: + row += f"{val:<15.0f}" + else: + row += f"{'OOM':<15}" + print(row) + print("="*100) + sys.stdout.flush() + + # Generate plot + generate_plot(results, seq_lens, head_dim, is_causal, short_name, output_file) + + return results + + +def generate_plot(results: Dict[str, Dict[int, Optional[float]]], + seq_lens: List[int], + head_dim: int, + is_causal: bool, + device_name: str, + output_file: str): + """Generate bar plot comparing attention implementations.""" + + # Prepare data + methods = list(results.keys()) + x_labels = [f"{sl//1024}K" for sl in seq_lens] + + # Colors for each method (red, blue, green scheme) + colors = { + 'FlashAttn': '#1E90FF', # Blue (Dodger Blue) + 'SageAttn3': '#228B22', # Green (Forest Green) + 'FP4': '#DC143C', # Red (Crimson) + } + + # Number of methods and positions + n_methods = len(methods) + n_positions = len(seq_lens) + + # Bar width and positions + bar_width = 0.25 + x = np.arange(n_positions) + + # Create figure + fig, ax = plt.subplots(figsize=(12, 6)) + + # Plot bars for each method + for i, method in enumerate(methods): + values = [] + for seq_len in seq_lens: + val = results[method].get(seq_len) + values.append(val if val is not None else 0) + + offset = (i - n_methods/2 + 0.5) * bar_width + bars = ax.bar(x + offset, values, bar_width, + label=method, color=colors.get(method, f'C{i}'), + edgecolor='black', linewidth=0.5) + + # Add value labels on top of bars + for bar, val, seq_len in zip(bars, values, seq_lens): + if results[method].get(seq_len) is None: + label = 'OOM' + else: + label = f'{int(val)}' + + height = bar.get_height() + ax.annotate(label, + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha='center', va='bottom', + fontsize=8, rotation=0) + + # Customize plot + ax.set_xlabel('Sequence Length', fontsize=12, fontweight='bold') + ax.set_ylabel('Speed (TFLOPs)', fontsize=12, fontweight='bold') + ax.set_title(f'{device_name}, (Head dim = {head_dim}, causal = {is_causal})', fontsize=14, fontweight='bold') + ax.set_xticks(x) + ax.set_xticklabels(x_labels) + legend = ax.legend(loc='upper left', ncol=len(methods), fontsize=10) + # Make legend text bold + for text in legend.get_texts(): + text.set_fontweight('bold') + + # Set y-axis to start from 0 + ax.set_ylim(bottom=0) + + # Add grid for readability + ax.yaxis.grid(True, linestyle='--', alpha=0.7) + ax.set_axisbelow(True) + + # Tight layout + plt.tight_layout() + + # Save plot + plt.savefig(output_file, dpi=150, bbox_inches='tight') + print(f"\nPlot saved to: {output_file}") + + # Also save as PDF for high quality + pdf_file = output_file.rsplit('.', 1)[0] + '.pdf' + plt.savefig(pdf_file, bbox_inches='tight') + print(f"PDF saved to: {pdf_file}") + + plt.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Combined Attention Benchmark (FlashAttn2 vs SageAttn3 vs FP4)') + parser.add_argument('--batch-size', type=int, default=1, help='Batch size') + parser.add_argument('--num-heads', type=int, default=16, help='Number of attention heads') + parser.add_argument('--head-dim', type=int, default=64, choices=[64, 128], help='Head dimension') + parser.add_argument('--causal', action='store_true', help='Use causal attention') + parser.add_argument('--num-warmups', type=int, default=10, help='Number of warmup iterations') + parser.add_argument('--num-tests', type=int, default=50, help='Number of test iterations') + parser.add_argument('--seq-lens', type=int, nargs='+', + default=[1024, 2048, 4096, 8192, 16384, 32768], + help='Sequence lengths to benchmark') + parser.add_argument('--output', type=str, default='benchmark_attention.png', + help='Output plot filename') + + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. This benchmark requires a CUDA device.") + + run_benchmark_suite( + head_dim=args.head_dim, + is_causal=args.causal, + num_heads=args.num_heads, + batch_size=args.batch_size, + num_warmups=args.num_warmups, + num_tests=args.num_tests, + seq_lens=args.seq_lens, + output_file=args.output + ) diff --git a/fastvideo-kernel/benchmarks/benchmark_flashattn2.py b/fastvideo-kernel/benchmarks/benchmark_flashattn2.py new file mode 100644 index 0000000000..16c1f94441 --- /dev/null +++ b/fastvideo-kernel/benchmarks/benchmark_flashattn2.py @@ -0,0 +1,234 @@ +import sys +import traceback + +import _bootstrap # noqa: F401 +import torch + +from flash_attn import flash_attn_func +from attn_qat_infer.quantization.bench.bench_utils import bench + + +def calculate_attention_flops(batch_size, num_heads, seq_len_q, seq_len_k, head_dim, is_causal=False): + """Calculate FLOPs for attention (FlashAttention standard - matmuls only).""" + f = 4 * batch_size * num_heads * seq_len_q * seq_len_k * head_dim + if is_causal: + f = f // 2 + return f + + +def benchmark_flashattn2(batch_size, num_heads, seq_len, head_dim, + is_causal=False, dtype=torch.bfloat16, + num_warmups=100, num_tests=1000): + """ + Benchmark FlashAttention2 and return performance metrics. + + Args: + batch_size: Batch size + num_heads: Number of attention heads + seq_len: Sequence length (same for Q, K, V) + head_dim: Head dimension + is_causal: Whether to use causal masking + dtype: Data type (torch.bfloat16 or torch.float16) + num_warmups: Number of warmup iterations + num_tests: Number of test iterations + + Returns: + dict with performance metrics + """ + device = 'cuda' + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. This benchmark requires a CUDA device.") + + # Create input tensors - FlashAttention2 expects (batch, seq_len, num_heads, head_dim) + q = torch.randn(batch_size, seq_len, num_heads, head_dim, + device=device, dtype=dtype) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, + device=device, dtype=dtype) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, + device=device, dtype=dtype) + + # Create closure for benchmarking + def run_attention(): + return flash_attn_func( + q, k, v, + causal=is_causal, + ) + + # Benchmark using the bench utility (handles warmup and timing) + avg_time_ms = bench(run_attention, num_warmups=num_warmups, num_tests=num_tests) + avg_time_s = avg_time_ms / 1000.0 + + # Calculate FLOPs + total_flops = calculate_attention_flops( + batch_size, num_heads, seq_len, seq_len, head_dim, is_causal + ) + + # Calculate TFLOPs + tflops = total_flops / (avg_time_s * 1e12) + + # Calculate throughput (tokens/sec) + tokens_per_second = (batch_size * seq_len) / avg_time_s + + return { + 'batch_size': batch_size, + 'num_heads': num_heads, + 'seq_len': seq_len, + 'head_dim': head_dim, + 'is_causal': is_causal, + 'dtype': str(dtype), + 'avg_time_ms': avg_time_ms, + 'avg_time_s': avg_time_s, + 'total_flops': total_flops, + 'tflops': tflops, + 'tokens_per_second': tokens_per_second, + } + + +def print_results(results): + """Print benchmark results in a formatted table.""" + print("\n" + "="*100) + print("FlashAttention2 Benchmark Results") + print("="*100) + print(f"Configuration:") + print(f" Batch Size: {results['batch_size']}") + print(f" Num Heads: {results['num_heads']}") + print(f" Sequence Length: {results['seq_len']}") + print(f" Head Dimension: {results['head_dim']}") + print(f" Causal: {results['is_causal']}") + print(f" Data Type: {results['dtype']}") + print(f"\nPerformance:") + print(f" Average Time: {results['avg_time_ms']:.3f} ms") + print(f" Total FLOPs: {results['total_flops']/1e12:.4f} TFLOPs (theoretical)") + print(f" Throughput: {results['tflops']:.4f} TFLOPs/s") + print(f" Tokens/sec: {results['tokens_per_second']:,.0f}") + print("="*100 + "\n") + sys.stdout.flush() + + +def run_benchmark_suite(): + """Run a comprehensive benchmark suite with various configurations.""" + print("Starting FlashAttention2 Benchmark Suite...") + print(f"CUDA Device: {torch.cuda.get_device_name(0)}") + print(f"CUDA Version: {torch.version.cuda}") + print(f"PyTorch Version: {torch.__version__}\n") + sys.stdout.flush() + + # Default configurations to test + # (batch_size, num_heads, seq_len, head_dim, is_causal, dtype) + configs = [ + (1, 16, 1024, 64, False, torch.bfloat16), + (1, 16, 2048, 64, False, torch.bfloat16), + (1, 16, 4096, 64, False, torch.bfloat16), + (1, 16, 8192, 64, False, torch.bfloat16), + (1, 16, 16384, 64, False, torch.bfloat16), + + (1, 16, 1024, 128, False, torch.bfloat16), + (1, 16, 2048, 128, False, torch.bfloat16), + (1, 16, 4096, 128, False, torch.bfloat16), + (1, 16, 8192, 128, False, torch.bfloat16), + (1, 16, 16384, 128, False, torch.bfloat16), + + (1, 32, 1024, 64, False, torch.bfloat16), + (1, 32, 2048, 64, False, torch.bfloat16), + (1, 32, 4096, 64, False, torch.bfloat16), + (1, 32, 8192, 64, False, torch.bfloat16), + (1, 32, 16384, 64, False, torch.bfloat16), + + (1, 32, 1024, 128, False, torch.bfloat16), + (1, 32, 2048, 128, False, torch.bfloat16), + (1, 32, 4096, 128, False, torch.bfloat16), + (1, 32, 8192, 128, False, torch.bfloat16), + (1, 32, 16384, 128, False, torch.bfloat16), + ] + + all_results = [] + + for config in configs: + batch_size, num_heads, seq_len, head_dim, is_causal, dtype = config + + print(f"\nBenchmarking: B={batch_size}, H={num_heads}, L={seq_len}, D={head_dim}, " + f"Causal={is_causal}, dtype={dtype}...") + sys.stdout.flush() + + try: + results = benchmark_flashattn2( + batch_size=batch_size, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + is_causal=is_causal, + dtype=dtype, + num_warmups=10, + num_tests=50 + ) + + print_results(results) + all_results.append(results) + + except Exception as e: + print(f"Error benchmarking configuration {config}:") + print(f" Exception: {e}") + traceback.print_exc() + sys.stdout.flush() + continue + + # Print summary table + print("\n" + "="*120) + print("Summary Table") + print("="*120) + print(f"{'B':<4} {'H':<4} {'L':<6} {'D':<4} {'Causal':<7} {'Time (ms)':<12} {'TFLOPs/s':<12} {'Tokens/s':<15}") + print("-"*120) + + for r in all_results: + print(f"{r['batch_size']:<4} {r['num_heads']:<4} {r['seq_len']:<6} {r['head_dim']:<4} " + f"{str(r['is_causal']):<7} {r['avg_time_ms']:<12.3f} {r['tflops']:<12.4f} " + f"{r['tokens_per_second']:<15,.0f}") + + print("="*120 + "\n") + sys.stdout.flush() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Benchmark FlashAttention2 in TFLOPs') + parser.add_argument('--batch-size', type=int, default=None, help='Batch size') + parser.add_argument('--num-heads', type=int, default=None, help='Number of attention heads') + parser.add_argument('--seq-len', type=int, default=None, help='Sequence length') + parser.add_argument('--head-dim', type=int, default=None, help='Head dimension') + parser.add_argument('--causal', action='store_true', help='Use causal attention') + parser.add_argument('--dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], + help='Data type') + parser.add_argument('--num-warmups', type=int, default=10, help='Number of warmup iterations') + parser.add_argument('--num-tests', type=int, default=50, help='Number of test iterations') + parser.add_argument('--suite', action='store_true', help='Run full benchmark suite') + + args = parser.parse_args() + + dtype_map = { + 'bfloat16': torch.bfloat16, + 'float16': torch.float16 + } + + if args.suite: + run_benchmark_suite() + elif args.batch_size and args.num_heads and args.seq_len and args.head_dim: + results = benchmark_flashattn2( + batch_size=args.batch_size, + num_heads=args.num_heads, + seq_len=args.seq_len, + head_dim=args.head_dim, + is_causal=args.causal, + dtype=dtype_map[args.dtype], + num_warmups=args.num_warmups, + num_tests=args.num_tests + ) + print_results(results) + else: + print("Running default benchmark suite. Use --suite for full suite or provide all parameters.") + print( + "Example: python benchmarks/benchmark_flashattn2.py " + "--batch-size 1 --num-heads 16 --seq-len 4096 --head-dim 128" + ) + sys.stdout.flush() + run_benchmark_suite() diff --git a/fastvideo-kernel/benchmarks/benchmark_sageattn3.py b/fastvideo-kernel/benchmarks/benchmark_sageattn3.py new file mode 100644 index 0000000000..70e0b9543f --- /dev/null +++ b/fastvideo-kernel/benchmarks/benchmark_sageattn3.py @@ -0,0 +1,253 @@ +import sys +import traceback + +import _bootstrap # noqa: F401 +import torch + +from attn_qat_infer.api import sageattn_blackwell +from attn_qat_infer.quantization.bench.bench_utils import bench + + +def calculate_attention_flops(batch_size, num_heads, seq_len_q, seq_len_k, head_dim, is_causal=False): + """Calculate FLOPs for attention (FlashAttention standard - matmuls only).""" + f = 4 * batch_size * num_heads * seq_len_q * seq_len_k * head_dim + if is_causal: + f = f // 2 + return f + + +def benchmark_sageattn3(batch_size, num_heads, seq_len, head_dim, + is_causal=False, dtype=torch.bfloat16, + per_block_mean=True, single_level_p_quant=False, + num_warmups=100, num_tests=1000): + """ + Benchmark SageAttention3 and return performance metrics. + + Args: + batch_size: Batch size + num_heads: Number of attention heads + seq_len: Sequence length (same for Q, K, V) + head_dim: Head dimension + is_causal: Whether to use causal masking + dtype: Data type (torch.bfloat16 or torch.float16) + per_block_mean: Whether to use per-block mean for Q smoothing + single_level_p_quant: If True, use single-level quantization for P matrix + num_warmups: Number of warmup iterations + num_tests: Number of test iterations + + Returns: + dict with performance metrics + """ + device = 'cuda' + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. This benchmark requires a CUDA device.") + + # Create input tensors + q = torch.randn(batch_size, num_heads, seq_len, head_dim, + device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, seq_len, head_dim, + device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, seq_len, head_dim, + device=device, dtype=dtype) + + # Create closure for benchmarking (no extra stream needed - bench handles synchronization) + def run_attention(): + return sageattn_blackwell( + q, k, v, + is_causal=is_causal, + per_block_mean=per_block_mean, + single_level_p_quant=single_level_p_quant + ) + + # Benchmark using the bench utility (handles warmup and timing) + avg_time_ms = bench(run_attention, num_warmups=num_warmups, num_tests=num_tests) + avg_time_s = avg_time_ms / 1000.0 + + # Calculate FLOPs + total_flops = calculate_attention_flops( + batch_size, num_heads, seq_len, seq_len, head_dim, is_causal + ) + + # Calculate TFLOPs + tflops = total_flops / (avg_time_s * 1e12) + + # Calculate throughput (tokens/sec) + tokens_per_second = (batch_size * seq_len) / avg_time_s + + return { + 'batch_size': batch_size, + 'num_heads': num_heads, + 'seq_len': seq_len, + 'head_dim': head_dim, + 'is_causal': is_causal, + 'dtype': str(dtype), + 'per_block_mean': per_block_mean, + 'single_level_p_quant': single_level_p_quant, + 'avg_time_ms': avg_time_ms, + 'avg_time_s': avg_time_s, + 'total_flops': total_flops, + 'tflops': tflops, + 'tokens_per_second': tokens_per_second, + } + + +def print_results(results): + """Print benchmark results in a formatted table.""" + print("\n" + "="*100) + print("SageAttention3 Benchmark Results") + print("="*100) + print(f"Configuration:") + print(f" Batch Size: {results['batch_size']}") + print(f" Num Heads: {results['num_heads']}") + print(f" Sequence Length: {results['seq_len']}") + print(f" Head Dimension: {results['head_dim']}") + print(f" Causal: {results['is_causal']}") + print(f" Data Type: {results['dtype']}") + print(f" Per Block Mean: {results['per_block_mean']}") + print(f" Single Level P Quant: {results['single_level_p_quant']}") + print(f"\nPerformance:") + print(f" Average Time: {results['avg_time_ms']:.3f} ms") + print(f" Total FLOPs: {results['total_flops']/1e12:.4f} TFLOPs (theoretical)") + print(f" Throughput: {results['tflops']:.4f} TFLOPs/s") + print(f" Tokens/sec: {results['tokens_per_second']:,.0f}") + print("="*100 + "\n") + sys.stdout.flush() + + +def run_benchmark_suite(): + """Run a comprehensive benchmark suite with various configurations.""" + print("Starting SageAttention3 Benchmark Suite...") + print(f"CUDA Device: {torch.cuda.get_device_name(0)}") + print(f"CUDA Version: {torch.version.cuda}") + print(f"PyTorch Version: {torch.__version__}\n") + sys.stdout.flush() + + # Default configurations to test + # (batch_size, num_heads, seq_len, head_dim, is_causal, dtype) + configs = [ + (1, 16, 1024, 64, False, torch.bfloat16), + (1, 16, 2048, 64, False, torch.bfloat16), + (1, 16, 4096, 64, False, torch.bfloat16), + (1, 16, 8192, 64, False, torch.bfloat16), + (1, 16, 16384, 64, False, torch.bfloat16), + + (1, 16, 1024, 128, False, torch.bfloat16), + (1, 16, 2048, 128, False, torch.bfloat16), + (1, 16, 4096, 128, False, torch.bfloat16), + (1, 16, 8192, 128, False, torch.bfloat16), + (1, 16, 16384, 128, False, torch.bfloat16), + + (1, 32, 1024, 64, False, torch.bfloat16), + (1, 32, 2048, 64, False, torch.bfloat16), + (1, 32, 4096, 64, False, torch.bfloat16), + (1, 32, 8192, 64, False, torch.bfloat16), + (1, 32, 16384, 64, False, torch.bfloat16), + + (1, 32, 1024, 128, False, torch.bfloat16), + (1, 32, 2048, 128, False, torch.bfloat16), + (1, 32, 4096, 128, False, torch.bfloat16), + (1, 32, 8192, 128, False, torch.bfloat16), + (1, 32, 16384, 128, False, torch.bfloat16), + ] + + all_results = [] + + for config in configs: + batch_size, num_heads, seq_len, head_dim, is_causal, dtype = config + + print(f"\nBenchmarking: B={batch_size}, H={num_heads}, L={seq_len}, D={head_dim}, " + f"Causal={is_causal}, dtype={dtype}...") + sys.stdout.flush() + + try: + results = benchmark_sageattn3( + batch_size=batch_size, + num_heads=num_heads, + seq_len=seq_len, + head_dim=head_dim, + is_causal=is_causal, + dtype=dtype, + num_warmups=10, + num_tests=50 + ) + + print_results(results) + all_results.append(results) + + except Exception as e: + print(f"Error benchmarking configuration {config}:") + print(f" Exception: {e}") + traceback.print_exc() + sys.stdout.flush() + continue + + # Print summary table + print("\n" + "="*120) + print("Summary Table") + print("="*120) + print(f"{'B':<4} {'H':<4} {'L':<6} {'D':<4} {'Causal':<7} {'Time (ms)':<12} {'TFLOPs/s':<12} {'Tokens/s':<15}") + print("-"*120) + + for r in all_results: + print(f"{r['batch_size']:<4} {r['num_heads']:<4} {r['seq_len']:<6} {r['head_dim']:<4} " + f"{str(r['is_causal']):<7} {r['avg_time_ms']:<12.3f} {r['tflops']:<12.4f} " + f"{r['tokens_per_second']:<15,.0f}") + + print("="*120 + "\n") + sys.stdout.flush() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Benchmark SageAttention3 in TFLOPs') + parser.add_argument('--batch-size', type=int, default=None, help='Batch size') + parser.add_argument('--num-heads', type=int, default=None, help='Number of attention heads') + parser.add_argument('--seq-len', type=int, default=None, help='Sequence length') + parser.add_argument('--head-dim', type=int, default=None, help='Head dimension') + parser.add_argument('--causal', action='store_true', help='Use causal attention') + parser.add_argument('--dtype', type=str, default='bfloat16', choices=['bfloat16', 'float16'], + help='Data type') + parser.add_argument('--per-block-mean', action='store_true', default=True, + help='Use per-block mean for Q smoothing (default: True)') + parser.add_argument('--no-per-block-mean', action='store_false', dest='per_block_mean', + help='Disable per-block mean for Q smoothing') + parser.add_argument('--single-level-p-quant', action='store_true', default=False, + help='Use single-level P quantization (default: True)') + parser.add_argument('--two-level-p-quant', action='store_false', dest='single_level_p_quant', + help='Use two-level P quantization') + parser.add_argument('--num-warmups', type=int, default=10, help='Number of warmup iterations') + parser.add_argument('--num-tests', type=int, default=50, help='Number of test iterations') + parser.add_argument('--suite', action='store_true', help='Run full benchmark suite') + + args = parser.parse_args() + + dtype_map = { + 'bfloat16': torch.bfloat16, + 'float16': torch.float16 + } + + if args.suite: + run_benchmark_suite() + elif args.batch_size and args.num_heads and args.seq_len and args.head_dim: + results = benchmark_sageattn3( + batch_size=args.batch_size, + num_heads=args.num_heads, + seq_len=args.seq_len, + head_dim=args.head_dim, + is_causal=args.causal, + dtype=dtype_map[args.dtype], + per_block_mean=args.per_block_mean, + single_level_p_quant=args.single_level_p_quant, + num_warmups=args.num_warmups, + num_tests=args.num_tests + ) + print_results(results) + else: + print("Running default benchmark suite. Use --suite for full suite or provide all parameters.") + print( + "Example: python benchmarks/benchmark_sageattn3.py " + "--batch-size 1 --num-heads 16 --seq-len 4096 --head-dim 128" + ) + sys.stdout.flush() + run_benchmark_suite() diff --git a/fastvideo-kernel/pyproject.toml b/fastvideo-kernel/pyproject.toml index 155f9811df..923d1b41f1 100644 --- a/fastvideo-kernel/pyproject.toml +++ b/fastvideo-kernel/pyproject.toml @@ -32,4 +32,4 @@ dependencies = [ [tool.scikit-build] cmake.build-type = "Release" minimum-version = "build-system.requires" -wheel.packages = ["python/fastvideo_kernel"] +wheel.packages = ["python/fastvideo_kernel", "attn_qat_infer"] diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/__init__.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/__init__.py new file mode 100644 index 0000000000..01008c151e --- /dev/null +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/__init__.py @@ -0,0 +1,5 @@ +"""Triton kernel entrypoints exposed by ``fastvideo_kernel``.""" + +from .fused_attention import attention as fused_attention + +__all__ = ["fused_attention"] diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/attn_qat_train.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/attn_qat_train.py new file mode 100644 index 0000000000..32be67e512 --- /dev/null +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/attn_qat_train.py @@ -0,0 +1,1119 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py + +import os + +import torch +import triton +import triton.language as tl +from triton.tools.tensor_descriptor import TensorDescriptor + +from .quant_utils import fake_quantize_q, fake_quantize_kv, fake_quantize +# from attn_qat_infer.api import triton_group_mean + + +def is_cuda(): + return triton.runtime.driver.active.get_current_target().backend == "cuda" + + +def supports_host_descriptor(): + return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 + + +def is_blackwell(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 10 + + +def is_hopper(): + return is_cuda() and torch.cuda.get_device_capability()[0] == 9 + + +@triton.jit +def _mul_alpha(acc, alpha, BM: tl.constexpr, BN: tl.constexpr): + acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() + acc0 = acc0 * alpha[:, None] + acc1 = acc1 * alpha[:, None] + acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) + return acc + + +@triton.jit +def _attn_fwd_inner(acc, high_prec_acc, l_i, m_i, q, q_valid, + desc_k, desc_v, + offset_y, dtype: tl.constexpr, start_m, qk_scale, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + N_CTX: tl.constexpr, warp_specialize: tl.constexpr, IS_HOPPER: tl.constexpr, + IS_QAT: tl.constexpr, + fake_quant_P: tl.constexpr = True, + two_level_quant_P: tl.constexpr = False, + use_global_sf_P: tl.constexpr = True): + # range of values handled by this stage (kv blocks) + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + # causal = False + else: + lo, hi = 0, N_CTX + offsetk_y = offset_y + lo # offset from the start of the current batch-head combination + if dtype == tl.float8e5: + offsetv_y = offset_y * HEAD_DIM + lo + else: + offsetv_y = offset_y + lo + # loop over k, v and update accumulator + for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): + start_n = tl.multiple_of(start_n, BLOCK_N) + kv_valid = (start_n + offs_n) < N_CTX + # -- compute qk ---- + k = desc_k.load([offsetk_y, 0]) + k = tl.where(kv_valid[:, None], k, 0.0) + qk = tl.dot(q, tl.trans(k)) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask & kv_valid[None, :], 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + qk = tl.where(kv_valid[None, :], qk, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + if IS_QAT: + p, high_prec_p = fake_quantize( + src_tensor=p, + valid_src_mask=tl.full(shape=p.shape, value=1.0, dtype=p.dtype) == 1.0, + BLOCK_SIZE_OUT_DIM=BLOCK_M, + BLOCK_SIZE_QUANT_DIM=BLOCK_N, + dst_dtype=dtype, + two_level_quant_P=two_level_quant_P, + use_global_sf=use_global_sf_P + ) + l_ij = tl.sum(high_prec_p, 1) + else: + l_ij = tl.sum(p, 1) + # -- compute correction factor + alpha = tl.math.exp2(m_i - m_ij) + # -- update output accumulator -- + if not IS_HOPPER and warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: + BM: tl.constexpr = acc.shape[0] + BN: tl.constexpr = acc.shape[1] + acc = _mul_alpha(acc, alpha, BM, BN) + if IS_QAT: + high_prec_acc = _mul_alpha(high_prec_acc, alpha, BM, BN) + else: + acc = acc * alpha[:, None] + if IS_QAT: + high_prec_acc = high_prec_acc * alpha[:, None] + # prepare p and v for the dot + if dtype == tl.float8e5: + v = desc_v.load([0, offsetv_y]).T + else: + v = desc_v.load([offsetv_y, 0]) + v = tl.where(kv_valid[:, None], v, 0.0) + p = p.to(dtype) + # note that this non transposed v for FP8 is only supported on Blackwell + acc = tl.dot(p, v.to(dtype), acc) + if IS_QAT: + high_prec_acc = tl.dot(high_prec_p, v, high_prec_acc) + # update m_i and l_i + # place this at the end of the loop to reduce register pressure + l_i = l_i * alpha + l_ij + m_i = m_ij + offsetk_y += BLOCK_N + offsetv_y += BLOCK_N + if IS_QAT: + return acc, high_prec_acc, l_i, m_i + else: + return acc, acc, l_i, m_i + + +def _host_descriptor_pre_hook(nargs): + BLOCK_M = nargs["BLOCK_M"] + BLOCK_N = nargs["BLOCK_N"] + HEAD_DIM = nargs["HEAD_DIM"] + if not isinstance(nargs["desc_q"], TensorDescriptor): + return + nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] + if nargs["FP8_OUTPUT"]: + nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] + else: + nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] + nargs["desc_o"].block_shape = [1, BLOCK_M, HEAD_DIM] + if "desc_high_prec_o" in nargs and isinstance(nargs["desc_high_prec_o"], TensorDescriptor): + nargs["desc_high_prec_o"].block_shape = [1, BLOCK_M, HEAD_DIM] + + +NUM_STAGES_OPTIONS = [2, 3, 4] + +configs = [ + triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) + for BM in [64, 128] + for BN in [32, 64, 128] + for s in NUM_STAGES_OPTIONS + for w in [4, 8] +] +if "PYTEST_VERSION" in os.environ: + # Use a single config in testing for reproducibility + configs = [ + triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook), + ] + + +def keep(conf): + BLOCK_M = conf.kwargs["BLOCK_M"] + BLOCK_N = conf.kwargs["BLOCK_N"] + return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 + and conf.num_warps == 8) + + +def prune_invalid_configs(configs, named_args, **kwargs): + N_CTX_Q = kwargs["N_CTX_Q"] + N_CTX_KV = kwargs["N_CTX_KV"] + return [ + conf for conf in configs + if conf.kwargs.get("BLOCK_M", 0) <= N_CTX_Q + and conf.kwargs.get("BLOCK_N", 0) <= N_CTX_KV + and conf.kwargs.get("BLOCK_N", 0) % conf.kwargs.get("BLOCK_M", 0) == 0 + ] + + +@triton.jit +def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): + if isinstance(desc_or_ptr, tl.tensor_descriptor): + return desc_or_ptr + else: + return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) + + +# @triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX_Q", "N_CTX_KV", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], +# prune_configs_by={'early_config_prune': prune_invalid_configs}, cache_results=True) +@triton.jit +def _attn_fwd(sm_scale, M, + Z, H, desc_q, desc_k, desc_v, desc_o, desc_high_prec_o, N_CTX_Q, N_CTX_KV, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + FP8_OUTPUT: tl.constexpr, + STAGE: tl.constexpr, + warp_specialize: tl.constexpr, + IS_HOPPER: tl.constexpr, + IS_QAT: tl.constexpr, + fake_quant_P: tl.constexpr = True, + two_level_quant_P: tl.constexpr = False, + use_global_sf_P: tl.constexpr = True, + ): + dtype = tl.float8e5 if FP8_OUTPUT else tl.bfloat16 + tl.static_assert(BLOCK_N <= HEAD_DIM) + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H # since it's Z then H in the shape + off_h = off_hz % H + + y_dim_q = Z * H * N_CTX_Q + y_dim_kv = Z * H * N_CTX_KV + desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim_q, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + if FP8_OUTPUT: + desc_v = _maybe_make_tensor_desc(desc_v, shape=[HEAD_DIM, y_dim_kv], strides=[N_CTX_KV, 1], + block_shape=[HEAD_DIM, BLOCK_N]) + else: + desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim_kv, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim_kv, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + desc_o = _maybe_make_tensor_desc( + desc_o, shape=[Z * H, N_CTX_Q, HEAD_DIM], + strides=[N_CTX_Q * HEAD_DIM, HEAD_DIM, 1], + block_shape=[1, BLOCK_M, HEAD_DIM] + ) + if IS_QAT: + desc_high_prec_o = _maybe_make_tensor_desc( + desc_high_prec_o, shape=[Z * H, N_CTX_Q, HEAD_DIM], + strides=[N_CTX_Q * HEAD_DIM, HEAD_DIM, 1], + block_shape=[1, BLOCK_M, HEAD_DIM] + ) + + offset_y_q = off_z * (N_CTX_Q * H) + off_h * N_CTX_Q # offset for query tensor + offset_y_kv = off_z * (N_CTX_KV * H) + off_h * N_CTX_KV # offset for key/value tensors + qo_offset_y = offset_y_q + start_m * BLOCK_M + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) # for kv blocks + q_valid = offs_m < N_CTX_Q + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + if IS_QAT: + high_prec_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + else: + # dummy value + high_prec_acc = acc + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = desc_q.load([qo_offset_y, 0]) # load from start of q block and start of the entire head dimension + q = tl.where(q_valid[:, None], q, 0.0) + # stage 1: off-band + # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE + # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE + if STAGE & 1: + acc, high_prec_acc, l_i, m_i = _attn_fwd_inner( + acc, high_prec_acc, l_i, m_i, q, q_valid, + desc_k, desc_v, + offset_y_kv, dtype, start_m, qk_scale, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n, N_CTX_KV, + warp_specialize, IS_HOPPER, IS_QAT, fake_quant_P, two_level_quant_P, use_global_sf_P + ) + # stage 2: on-band + if STAGE & 2: + acc, high_prec_acc, l_i, m_i = _attn_fwd_inner( + acc, high_prec_acc, l_i, m_i, q, q_valid, + desc_k, desc_v, + offset_y_kv, dtype, start_m, qk_scale, + BLOCK_M, HEAD_DIM, BLOCK_N, + 2, offs_m, offs_n, N_CTX_KV, + warp_specialize, IS_HOPPER, IS_QAT, fake_quant_P, two_level_quant_P, use_global_sf_P + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + if IS_QAT: + high_prec_acc = high_prec_acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX_Q + offs_m + tl.store(m_ptrs, m_i, mask=q_valid) + desc_o.store([off_hz, start_m * BLOCK_M, 0], acc[None, :, :]) + if IS_QAT: + desc_high_prec_o.store([off_hz, start_m * BLOCK_M, 0], high_prec_acc[None, :, :]) + + +@triton.jit +def _attn_bwd_preprocess(O, DO, + Delta, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr + ): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) + off_n = tl.arange(0, HEAD_DIM) + valid = off_m < N_CTX + # load + o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :], mask=valid[:, None]) + do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :], mask=valid[:, None]).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hz * N_CTX + off_m, delta, mask=valid) + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv(dk, dv, + Q, k, v, qk_scale, + DO, + M, D, Q_MEAN, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, + MASK: tl.constexpr, + IS_QAT: tl.constexpr, + two_level_quant_P: tl.constexpr = False, + fake_quant_P: tl.constexpr = True, + SMOOTH_Q: tl.constexpr = False, + use_global_sf_P: tl.constexpr = True, + warp_specialize: tl.constexpr = False): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + q_ptrs = Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + if SMOOTH_Q: + q_m_ptrs = Q_MEAN + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + offs_m = curr_m + tl.arange(0, BLOCK_M1) + q_valid = offs_m < N_CTX + q = tl.load(q_ptrs, mask=q_valid[:, None]) + if SMOOTH_Q: + q_m = tl.load(q_m_ptrs, mask=q_valid[:, None]) + # Load m before computing qk to reduce pipeline stall. + m = tl.load(M + offs_m, mask=q_valid) + qk = tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + # Autoregressive masking - apply BEFORE exp2 to match forward pass behavior + if MASK: + mask = (offs_m[:, None] >= offs_n[None, :]) + qk = tl.where(mask, qk, -1.0e6) + p = tl.math.exp2(qk - m[:, None]) + do = tl.load(do_ptrs, mask=q_valid[:, None]) + # Compute dV. + p_quant = p + if IS_QAT and fake_quant_P: + p_quant, _ = fake_quantize( + src_tensor=p, + valid_src_mask=tl.full(shape=p.shape, value=1.0, dtype=p.dtype) == 1.0, + BLOCK_SIZE_OUT_DIM=BLOCK_M1, + BLOCK_SIZE_QUANT_DIM=BLOCK_N1, + dst_dtype=p.dtype, + two_level_quant_P=two_level_quant_P, + use_global_sf=use_global_sf_P + ) + dv += tl.dot(tl.trans(p_quant.to(tl.bfloat16)), do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m, mask=q_valid) + # Compute dP and dS. + dp = tl.dot(do, tl.trans(v)) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.bfloat16) + dk += tl.dot(tl.trans(ds), q) + if SMOOTH_Q: + dk += tl.sum(ds, axis=1, keep_dims=True) * q_m + # Increment pointers. + curr_m += step_m + q_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + if SMOOTH_Q: + q_m_ptrs += step_m * stride_tok + return dk, dv + + +# The main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq(dq, q, K, V, + do, m, D, qk_scale, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, + K_MEAN, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + HEAD_DIM: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, + MASK: tl.constexpr, + SMOOTH_K: tl.constexpr, + warp_specialize: tl.constexpr = False): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, HEAD_DIM) + k_ptrs = K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + v_ptrs = V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + q_valid = offs_m < N_CTX + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m, mask=q_valid) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + + if SMOOTH_K: + k_m = tl.load(K_MEAN + offs_k) + + for blk_idx in range(num_steps): + # bounds checking for kv block (dynamic) + offs_n = curr_n + tl.arange(0, BLOCK_N2) + kv_valid = offs_n < N_CTX + k = tl.load(k_ptrs, mask=kv_valid[:, None]) + v = tl.load(v_ptrs, mask=kv_valid[:, None]) + + qk = tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + # Autoregressive masking - apply BEFORE exp2 to match forward pass behavior + if MASK: + mask = (offs_m[:, None] >= offs_n[None, :]) + qk = tl.where(mask, qk, -1.0e6) + p = tl.math.exp2(qk - m) + # Compute dP and dS. + dp = tl.dot(do, tl.trans(v)) + ds = p * (dp - Di[:, None]) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because k was pre-scaled. + dq += tl.dot(ds.to(tl.bfloat16), k) + + if SMOOTH_K: + dq += tl.sum(ds, axis=1, keep_dims=True) * k_m[None, :] + # Increment pointers. + curr_n += step_n + k_ptrs += step_n * stride_tok + v_ptrs += step_n * stride_tok + return dq + + +@triton.jit +def _compute_cross_attn_pointer_offsets(bhid, H, N_CTX_Q, stride_z_q, stride_z_kv, stride_h_q, stride_h_kv): + """Helper function to compute pointer offsets for cross-attention backward kernels.""" + off_chz = (bhid * N_CTX_Q).to(tl.int64) + adj_q = (stride_h_q * (bhid % H) + stride_z_q * (bhid // H)) + adj_kv = (stride_h_kv * (bhid % H) + stride_z_kv * (bhid // H)) + return off_chz, adj_q, adj_kv + + +@triton.jit +def _attn_bwd_dq_cross(Q, K, V, sm_scale, + DO, DQ, + M, D, + stride_z_q, stride_z_kv, stride_h_q, + stride_h_kv, stride_tok_q, stride_tok_kv, stride_d_q, stride_d_kv, + H, N_CTX_Q, N_CTX_KV, + K_MEAN, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + HEAD_DIM: tl.constexpr, + SMOOTH_K: tl.constexpr, + warp_specialize: tl.constexpr = False): + # Apply scale AFTER dot product for better precision + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + qk_scale = sm_scale * RCP_LN2 + + bhid = tl.program_id(2) + off_chz, adj_q, adj_kv = _compute_cross_attn_pointer_offsets(bhid, H, N_CTX_Q, stride_z_q, stride_z_kv, stride_h_q, stride_h_kv) + + Q += adj_q + K += adj_kv + V += adj_kv + DO += adj_q + DQ += adj_q + M += off_chz + D += off_chz + + pid = tl.program_id(0) + start_m = pid * BLOCK_M2 + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_k = tl.arange(0, HEAD_DIM) + + q_valid = offs_m < N_CTX_Q + q = tl.load(Q + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None]) + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None]) + m = tl.load(M + offs_m, mask=q_valid)[:, None] + + Di = tl.load(D + offs_m, mask=q_valid) + num_steps = (N_CTX_KV + BLOCK_N2 - 1) // BLOCK_N2 + if SMOOTH_K: + k_m = tl.load(K_MEAN + offs_k) + for step in range(num_steps): + start_n = step * BLOCK_N2 + offs_n = start_n + tl.arange(0, BLOCK_N2) + kv_valid = offs_n < N_CTX_KV + + k = tl.load(K + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None]) + v = tl.load(V + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None]) + + qk = tl.dot(q, tl.trans(k)) + qk = qk * qk_scale + p = tl.math.exp2(qk - m) + + dp = tl.dot(do, tl.trans(v)) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.bfloat16) + dq += tl.dot(ds, k) + + if SMOOTH_K: + dq += tl.sum(ds, axis=1, keep_dims=True) * k_m[None, :] + + # NOTE: dq is scaled by sm_scale since K is not pre-scaled + dq *= sm_scale + dq_ptrs = DQ + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q + tl.store(dq_ptrs, dq, mask=q_valid[:, None]) + + +@triton.jit +def _attn_bwd_dkdv_cross(Q, K, V, sm_scale, + DO, DK, DV, + M, D, Q_MEAN, + stride_z_q, stride_z_kv, stride_h_q, stride_h_kv, + stride_tok_q, stride_tok_kv, stride_d_q, stride_d_kv, + H, N_CTX_Q, N_CTX_KV, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + HEAD_DIM: tl.constexpr, + IS_QAT: tl.constexpr, + two_level_quant_P: tl.constexpr = False, + fake_quant_P: tl.constexpr = True, + SMOOTH_Q: tl.constexpr = False, + use_global_sf_P: tl.constexpr = True, + warp_specialize: tl.constexpr = False + ): + # Apply scale AFTER dot product for better precision + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + qk_scale = sm_scale * RCP_LN2 + + bhid = tl.program_id(2) + off_chz, adj_q, adj_kv = _compute_cross_attn_pointer_offsets(bhid, H, N_CTX_Q, stride_z_q, stride_z_kv, stride_h_q, stride_h_kv) + + Q += adj_q + K += adj_kv + V += adj_kv + DO += adj_q + DK += adj_kv + DV += adj_kv + M += off_chz + D += off_chz + + if SMOOTH_Q: + Q_MEAN += adj_q + + pid = tl.program_id(0) + start_n = pid * BLOCK_N1 + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, HEAD_DIM) + + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + kv_valid = offs_n < N_CTX_KV + k_block = tl.load(K + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None]) + v_block = tl.load(V + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None]) + + num_q_steps = (N_CTX_Q + BLOCK_M1 - 1) // BLOCK_M1 + for step in range(num_q_steps): + start_m = step * BLOCK_M1 + offs_m = start_m + tl.arange(0, BLOCK_M1) + q_valid = offs_m < N_CTX_Q + + q = tl.load(Q + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None]) + do = tl.load(DO + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None]) + m = tl.load(M + offs_m, mask=q_valid) + + qk = tl.dot(q, tl.trans(k_block)) + # Apply scale AFTER dot product (matches forward pass, better precision) + qk = qk * qk_scale + p = tl.math.exp2(qk - m[:, None]) + p_quant = p + if IS_QAT and fake_quant_P: + p_quant, _ = fake_quantize( + src_tensor=p, + valid_src_mask=tl.full(shape=p.shape, value=1.0, dtype=p.dtype) == 1.0, + BLOCK_SIZE_OUT_DIM=BLOCK_M1, + BLOCK_SIZE_QUANT_DIM=BLOCK_N1, + dst_dtype=p.dtype, + two_level_quant_P=two_level_quant_P, + use_global_sf=use_global_sf_P + ) + dv += tl.dot(tl.trans(p_quant.to(tl.bfloat16)), do) + + dp = tl.dot(do, tl.trans(v_block)) + Di = tl.load(D + offs_m, mask=q_valid) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.bfloat16) + dk += tl.dot(tl.trans(ds), q) + + if SMOOTH_Q: + q_m = tl.load(Q_MEAN + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None]) + dk += tl.sum(ds, axis=1, keep_dims=True) * q_m + + dv_ptrs = DV + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv + tl.store(dv_ptrs, dv, mask=kv_valid[:, None]) + + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv + tl.store(dk_ptrs, dk, mask=kv_valid[:, None]) + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, + DO, + DQ, DK, DV, + M, D, Q_MEAN, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + H, N_CTX, + K_MEAN, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + HEAD_DIM: tl.constexpr, + CAUSAL: tl.constexpr, + IS_QAT: tl.constexpr, + SMOOTH_K: tl.constexpr, + two_level_quant_P: tl.constexpr = False, + fake_quant_P: tl.constexpr = True, + SMOOTH_Q: tl.constexpr = False, + use_global_sf_P: tl.constexpr = True, + warp_specialize: tl.constexpr = False): + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + if SMOOTH_Q: + Q_MEAN += adj + + # load scales + offs_k = tl.arange(0, HEAD_DIM) + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + kv_valid = offs_n < N_CTX + + dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, mask=kv_valid[:, None]) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d, mask=kv_valid[:, None]) + + # qk_scale is expressed in base-2 domain because we use exp2/log2. + # Two equivalent ways to form qk_scale: + # - **Post-scale (default)**: qk = dot(q, k^T) * (sm_scale / ln2) + # - **Pre-scale (Flash-aligned)**: k is pre-multiplied by (sm_scale / ln2) and qk_scale = 1 + RCP_LN2: tl.constexpr = 1.4426950408889634 # = 1.0 / ln(2) + qk_scale = sm_scale * RCP_LN2 + + # For causal attention, process diagonal block with masking, then rest without + # For non-causal attention, process all blocks without masking + if CAUSAL: + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv( + dk, dv, + Q, k, v, qk_scale, + DO, + M, D, Q_MEAN, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, + start_n, start_m, num_steps, + MASK=True, + IS_QAT=IS_QAT, + two_level_quant_P=two_level_quant_P, + fake_quant_P=fake_quant_P, + SMOOTH_Q=SMOOTH_Q, + use_global_sf_P=use_global_sf_P, + warp_specialize=warp_specialize + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m + BLOCK_M1 - 1) // BLOCK_M1 + else: + # For non-causal, start from 0 and process all Q blocks + start_m = 0 + num_steps = (N_CTX + BLOCK_M1 - 1) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv( + dk, dv, + Q, k, v, qk_scale, + DO, + M, D, Q_MEAN, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, HEAD_DIM, + start_n, start_m, num_steps, + MASK=False, + IS_QAT=IS_QAT, + two_level_quant_P=two_level_quant_P, + fake_quant_P=fake_quant_P, + SMOOTH_Q=SMOOTH_Q, + use_global_sf_P=use_global_sf_P, + warp_specialize=warp_specialize + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv, mask=kv_valid[:, None]) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk, mask=kv_valid[:, None]) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + q_valid = offs_m < N_CTX + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, mask=q_valid[:, None]) + + dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d, mask=q_valid[:, None]) + + m = tl.load(M + offs_m, mask=q_valid)[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + if CAUSAL: + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, qk_scale, + stride_tok, stride_d, + H, N_CTX, + K_MEAN, + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True, + SMOOTH_K=SMOOTH_K, + warp_specialize=warp_specialize, + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2: process KV blocks from 0 to end_n (before diagonal), ceiling division to include remainder + num_steps = (end_n + BLOCK_N2 - 1) // BLOCK_N2 + start_n = 0 + else: + # For non-causal, process all KV blocks from 0 to N_CTX (ceiling division to include remainder) + num_steps = (N_CTX + BLOCK_N2 - 1) // BLOCK_N2 + start_n = 0 + # stage 2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, qk_scale, + stride_tok, stride_d, + H, N_CTX, + K_MEAN, + BLOCK_M2, BLOCK_N2, HEAD_DIM, + start_m, start_n, num_steps, + MASK=False, + SMOOTH_K=SMOOTH_K, + warp_specialize=warp_specialize, + ) + # Write back dQ. + # NOTE: dq is scaled by sm_scale since K is not pre-scaled (unlike original which used LN2) + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= sm_scale + tl.store(dq_ptrs, dq, mask=q_valid[:, None]) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + causal, + sm_scale, + use_qat_qkv_backward=True, + smooth_k=True, + warp_specialize=True, + IS_QAT=True, + two_level_quant_P=True, + fake_quant_P=True, + use_high_prec_o=False, + smooth_q=False, + use_global_sf_P=True, + use_global_sf_QKV=True, + ): + # shape constraints + HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] + # when v is in float8_e5m2 it is transposed. + HEAD_DIM_V = v.shape[-1] + assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V + assert HEAD_DIM_K in {16, 32, 64, 128, 256} + + # Support different sequence lengths for q and k/v (needed for cross attention) + N_CTX_Q = q.shape[2] # Query sequence length + N_CTX_KV = k.shape[2] # Key/Value sequence length (may differ from query) + assert k.shape[2] == v.shape[2], "k and v must have the same sequence length" + + # smoothing k from SageAttn + ctx.k_mean = None + if smooth_k: + k_mean = k.mean(dim=(0, 1, 2), keepdim=True) + k = k - k_mean + ctx.k_mean = k_mean.view(-1) + + q_orig, k_orig, v_orig = None, None, None + if not use_qat_qkv_backward: + q_orig, k_orig, v_orig = q, k, v + ctx.q_orig, ctx.k_orig, ctx.v_orig = q_orig, k_orig, v_orig + + o = torch.empty_like(q) + if IS_QAT: + high_prec_o = torch.empty_like(q) + else: + # Initialize to a dummy value + high_prec_o = o + stage = 3 if causal else 1 + extra_kern_args = {} + + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) # (Z, H, N_CTX_Q) + # Use device_descriptor for Hopper + warpspec. + if supports_host_descriptor() and not (is_hopper() and warp_specialize): + # Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor + y_dim_q = q.shape[0] * q.shape[1] * q.shape[2] + y_dim_kv = k.shape[0] * k.shape[1] * k.shape[2] + + dummy_block = [1, 1] + # (Z, H, N_CTX_Q, HEAD_DIM_K) -> (Z*H*N_CTX_Q, HEAD_DIM_K) + desc_q = TensorDescriptor(q, shape=[y_dim_q, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + if q.dtype == torch.float8_e5m2: + desc_v = TensorDescriptor( + v, shape=[HEAD_DIM_K, y_dim_kv], strides=[k.shape[2], 1], + block_shape=dummy_block + ) + else: + desc_v = TensorDescriptor( + v, shape=[y_dim_kv, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], + block_shape=dummy_block + ) + desc_k = TensorDescriptor(k, shape=[y_dim_kv, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + # Use 3D descriptor for output to handle N_CTX_Q boundaries correctly + dummy_block3d = [1, 1, 1] + ZH = q.shape[0] * q.shape[1] + desc_o = TensorDescriptor( + o, shape=[ZH, N_CTX_Q, HEAD_DIM_K], + strides=[N_CTX_Q * HEAD_DIM_K, HEAD_DIM_K, 1], + block_shape=dummy_block3d + ) + if IS_QAT: + desc_high_prec_o = TensorDescriptor( + high_prec_o, shape=[ZH, N_CTX_Q, HEAD_DIM_K], + strides=[N_CTX_Q * HEAD_DIM_K, HEAD_DIM_K, 1], + block_shape=dummy_block3d + ) + else: + desc_high_prec_o = desc_o # Use regular output descriptor when not in QAT mode + else: + desc_q = q + desc_v = v + desc_k = k + desc_o = o + if IS_QAT: + desc_high_prec_o = high_prec_o + else: + desc_high_prec_o = o # Use regular output when not in QAT mode + + def alloc_fn(size: int, align: int, _): + return torch.empty(size, dtype=torch.int8, device="cuda") + + triton.set_allocator(alloc_fn) + + def grid(META): + # (ceil(N_CTX / BLOCK_M), Z * H, 1) + return (triton.cdiv(q.shape[2], META["BLOCK_M"]), q.shape[0] * q.shape[1], 1) + + ctx.grid = grid + if is_blackwell() and warp_specialize: + if HEAD_DIM_K == 128 and q.dtype == torch.float16: + extra_kern_args["maxnreg"] = 168 + else: + extra_kern_args["maxnreg"] = 80 + + BLOCK_M, BLOCK_N = 32, 32 + if IS_QAT: + fake_q = torch.empty_like(q) + fake_k = torch.empty_like(k) + fake_v = torch.empty_like(v) + + # override desc_q, desc_k, desc_v with fake_q, fake_k, fake_v + if supports_host_descriptor() and not (is_hopper() and warp_specialize): + # (Z, H, N_CTX_Q, HEAD_DIM_K) -> (Z*H*N_CTX_Q, HEAD_DIM_K) + desc_q = TensorDescriptor(fake_q, shape=[y_dim_q, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_k = TensorDescriptor(fake_k, shape=[y_dim_kv, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + desc_v = TensorDescriptor(fake_v, shape=[y_dim_kv, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) + else: + desc_q = fake_q + desc_k = fake_k + desc_v = fake_v + + H = q.shape[1] + grid_1 = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + grid_2 = (triton.cdiv(k.shape[2], BLOCK_N), q.shape[0] * q.shape[1], 1) + + fake_quantize_q[grid_1]( + q, fake_q, + q.stride(0), q.stride(1), + q.stride(2), q.stride(3), + fake_q.stride(0), fake_q.stride(1), + fake_q.stride(2), fake_q.stride(3), + H, N_CTX_Q, + BLOCK_M=BLOCK_M, HEAD_DIM=HEAD_DIM_K, + use_global_sf=use_global_sf_QKV, + ) + fake_quantize_kv[grid_2]( + k, v, fake_k, fake_v, + k.stride(0), k.stride(1), + k.stride(2), k.stride(3), + fake_k.stride(0), fake_k.stride(1), + fake_k.stride(2), fake_k.stride(3), + H, N_CTX_KV, + BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + use_global_sf=use_global_sf_QKV, + ) + + # Apply pre-hook to set block shapes on tensor descriptors + _host_descriptor_pre_hook({ + "BLOCK_M": BLOCK_M, + "BLOCK_N": BLOCK_N, + "HEAD_DIM": HEAD_DIM_K, + "desc_q": desc_q, + "desc_k": desc_k, + "desc_v": desc_v, + "desc_o": desc_o, + "desc_high_prec_o": desc_high_prec_o, + "FP8_OUTPUT": q.dtype == torch.float8_e5m2, + }) + + _attn_fwd[grid]( + sm_scale, M, + q.shape[0], q.shape[1], + desc_q, desc_k, desc_v, desc_o, desc_high_prec_o, + N_CTX_Q=N_CTX_Q, + N_CTX_KV=N_CTX_KV, + HEAD_DIM=HEAD_DIM_K, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + FP8_OUTPUT=q.dtype == torch.float8_e5m2, + STAGE=stage, + warp_specialize=warp_specialize, + IS_HOPPER=is_hopper(), + IS_QAT=IS_QAT, + fake_quant_P=fake_quant_P, + two_level_quant_P=two_level_quant_P, + use_global_sf_P=use_global_sf_P, + num_warps=4, + num_stages=2, + **extra_kern_args + ) + o_for_bwd = high_prec_o if IS_QAT and use_high_prec_o else o + + if IS_QAT: + q = fake_q + k = fake_k + v = fake_v + + ctx.save_for_backward(q, k, v, o_for_bwd, M) + ctx.sm_scale = sm_scale + ctx.HEAD_DIM = HEAD_DIM_K + ctx.causal = causal + ctx.IS_QAT = IS_QAT + ctx.use_qat_qkv_backward = use_qat_qkv_backward + ctx.smooth_k = smooth_k + ctx.two_level_quant_P = two_level_quant_P + ctx.fake_quant_P = fake_quant_P + ctx.smooth_q = smooth_q + ctx.use_global_sf_P = use_global_sf_P + ctx.warp_specialize = warp_specialize + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o_for_bwd, M = ctx.saved_tensors + do = do.contiguous() + assert do.is_contiguous() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX_Q = q.shape[:3] + N_CTX_KV = k.shape[2] + assert k.shape[2] == v.shape[2], "k and v must have the same sequence length" + PRE_BLOCK = 128 + NUM_STAGES = 3 + NUM_WARPS = 4 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 32, 32, 32 + if not ctx.use_qat_qkv_backward: + q = ctx.q_orig + k = ctx.k_orig + v = ctx.v_orig + BLK_SLICE_FACTOR = 1 if ctx.IS_QAT else 2 # must be 1 for QAT + # NOTE: K is NOT pre-scaled here - scaling is applied AFTER qk dot product in kernels + # This improves precision by avoiding rounding errors in K before the dot product + arg_k = k + pre_grid = ((N_CTX_Q + PRE_BLOCK - 1) // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o_for_bwd, do, + delta, + BATCH, N_HEAD, N_CTX_Q, + BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM + ) + + q_m = None + if ctx.smooth_q: + # _, q_m = triton_group_mean(q) + q_m = q_m.repeat_interleave(q.shape[2] // q_m.shape[2], dim=2) # B,H,L,D + + if N_CTX_Q == N_CTX_KV: + # Use existing kernel for self-attention (same sequence lengths) + grid = ((N_CTX_KV + BLOCK_N1 - 1) // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, + M, delta, q_m, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + N_HEAD, N_CTX_KV, + ctx.k_mean, + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + HEAD_DIM=ctx.HEAD_DIM, + CAUSAL=ctx.causal, + IS_QAT=ctx.IS_QAT, + SMOOTH_K=ctx.smooth_k, + two_level_quant_P=ctx.two_level_quant_P, + fake_quant_P=ctx.fake_quant_P, + SMOOTH_Q=ctx.smooth_q, + use_global_sf_P=ctx.use_global_sf_P, + warp_specialize=ctx.warp_specialize, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES + ) + else: + # Use separate kernels for cross-attention (different sequence lengths) + grid_dq = ((N_CTX_Q + BLOCK_M2 - 1) // BLOCK_M2, 1, BATCH * N_HEAD) + _attn_bwd_dq_cross[grid_dq]( + q, arg_k, v, ctx.sm_scale, do, dq, M, delta, + q.stride(0), k.stride(0), q.stride(1), k.stride(1), q.stride(2), k.stride(2), q.stride(3), k.stride(3), + N_HEAD, N_CTX_Q, N_CTX_KV, + ctx.k_mean, + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, + HEAD_DIM=ctx.HEAD_DIM, + SMOOTH_K=ctx.smooth_k, + warp_specialize=ctx.warp_specialize, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + ) + grid_dkdv = ((N_CTX_KV + BLOCK_N1 - 1) // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd_dkdv_cross[grid_dkdv]( + q, arg_k, v, ctx.sm_scale, do, dk, dv, M, delta, q_m, + q.stride(0), k.stride(0), q.stride(1), k.stride(1), q.stride(2), k.stride(2), q.stride(3), k.stride(3), + N_HEAD, N_CTX_Q, N_CTX_KV, + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, + HEAD_DIM=ctx.HEAD_DIM, + IS_QAT=ctx.IS_QAT, + two_level_quant_P=ctx.two_level_quant_P, + fake_quant_P=ctx.fake_quant_P, + SMOOTH_Q=ctx.smooth_q, + use_global_sf_P=ctx.use_global_sf_P, + warp_specialize=ctx.warp_specialize, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + ) + + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + + +attention = _attention.apply diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_attention.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_attention.py new file mode 100644 index 0000000000..148a40d294 --- /dev/null +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/fused_attention.py @@ -0,0 +1,55 @@ +"""Compatibility shim for the legacy non-QAT Triton attention import path. + +Historically callers imported +``fastvideo_kernel.triton_kernels.fused_attention`` directly. The shared +implementation now lives in ``attn_qat_train.py`` and is parameterized by the +``IS_QAT`` flag. This module preserves the original public API for tests and +downstream users while always dispatching to the non-QAT configuration. +""" + +from __future__ import annotations + +import torch + +from .attn_qat_train import attention as _attention + + +def attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + causal: bool, + sm_scale: float, + warp_specialize: bool = True, +) -> torch.Tensor: + """Run the shared Triton attention kernel in non-QAT mode.""" + use_qat_qkv_backward = True + smooth_k = False + is_qat = False + two_level_quant_p = False + fake_quant_p = False + use_high_prec_o = False + smooth_q = False + use_global_sf_p = False + use_global_sf_qkv = False + + return _attention( + q, + k, + v, + causal, + sm_scale, + use_qat_qkv_backward, + smooth_k, + warp_specialize, + is_qat, + two_level_quant_p, + fake_quant_p, + use_high_prec_o, + smooth_q, + use_global_sf_p, + use_global_sf_qkv, + ) + + +__all__ = ["attention"] diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py new file mode 100644 index 0000000000..6fdbb36e9e --- /dev/null +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/nvfp4_utils.py @@ -0,0 +1,237 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py +# and https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py + +import triton +import triton.language as tl +from triton.language.target_info import cuda_capability_geq + +MXFP_BLOCK_SIZE = tl.constexpr(16) + +@triton.jit +def _compute_quant_and_scale( + src_tensor, + valid_src_mask, + mx_tensor_dtype: tl.constexpr = tl.uint8, + use_global_sf=True, + two_level_quant_P=False, +): + BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + + tl.static_assert( + is_fp4 + or mx_tensor_dtype == tl.float8e4nv + or mx_tensor_dtype == tl.float8e5, + "mx_tensor_dtype must be uint8, float8e4nv, or float8e5", + ) + + # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16 + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation + + if two_level_quant_P: + # row max from SageAttn3 paper + global_max_val = tl.max(f32_tensor, axis=1, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, 1) + global_max_val = tl.maximum(global_max_val, 1e-8) + s_enc = ((6 * 448) / global_max_val).reshape([BLOCK_SIZE_OUT_DIM, 1, 1]) + s_dec = (1 / s_enc) + + abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + + if use_global_sf and not two_level_quant_P: + global_max_val = tl.max(abs_tensor) + # Avoid division by zero: if all values are padding (max is 0), use a default scale + global_max_val = tl.maximum(global_max_val, 1e-8) + s_enc = (6 * 448) / global_max_val + s_dec = (1 / s_enc) + elif not two_level_quant_P and not use_global_sf: + s_dec = 1.0 + s_enc = 1.0 + + max_val = tl.max(abs_tensor, axis=2, keep_dims=True) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) # per block maxima + s_dec_b = max_val / 6 # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) + s_dec_b_e4m3 = (s_dec_b * s_enc).to(tl.float8e4nv) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) + s_enc_b = 1 / (s_dec_b_e4m3.to(tl.float32) * s_dec) # (BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1) + + f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + quant_tensor = f32_tensor * s_enc_b + + # Reshape the tensors after scaling + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format. + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0.0) + dequant_scale = s_dec_b_e4m3.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE]) + + if is_fp4 and cuda_capability_geq(10, 0): + # Convert scaled values to two f32 lanes and use PTX cvt to e2m1x2 with two f32 operands. + pairs = tl.reshape(quant_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2]) + lo_f, hi_f = tl.split(pairs) + lo_f32 = lo_f.to(tl.float32) + hi_f32 = hi_f.to(tl.float32) + + # Inline PTX: cvt.rn.satfinite.e2m1x2.f32 takes two f32 sources and produces one .b8 packed e2m1x2. + out_tensor = tl.inline_asm_elementwise( + """ + { + .reg .b8 r; + cvt.rn.satfinite.e2m1x2.f32 r, $1, $2; + mov.b32 $0, {r, r, r, r}; + } + """, + constraints="=r,f,f", + args=[hi_f32, lo_f32], + dtype=tl.uint8, + is_pure=True, + pack=1, + ) + elif is_fp4: + quant_tensor = quant_tensor.to(tl.uint32, bitcast=True) + signs = quant_tensor & 0x80000000 + exponents = (quant_tensor >> 23) & 0xFF + mantissas_orig = (quant_tensor & 0x7FFFFF) + + # For RTNE: 0.25 < x < 0.75 maps to 0.5 (denormal); exactly 0.25 maps to 0.0 + E8_BIAS = 127 + E2_BIAS = 1 + # Move implicit bit 1 at the beginning to mantissa for denormals + is_subnormal = exponents < E8_BIAS + adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False) + mantissas_pre = (0x400000 | (mantissas_orig >> 1)) + mantissas = tl.where(is_subnormal, mantissas_pre >> adjusted_exponents, mantissas_orig) + + # For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. + exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS) + + # Combine sign, exponent, and mantissa, while saturating + # Round to nearest, ties to even (RTNE): use guard/sticky and LSB to decide increment + m2bits = mantissas >> 21 + lsb_keep = (m2bits >> 1) & 0x1 + guard = m2bits & 0x1 + IS_SRC_FP32: tl.constexpr = src_tensor.dtype == tl.float32 + if IS_SRC_FP32: + bit0_dropped = (mantissas_orig & 0x1) != 0 + mask = (1 << tl.minimum(adjusted_exponents, 31)) - 1 + dropped_post = (mantissas_pre & mask) != 0 + sticky = is_subnormal & (bit0_dropped | dropped_post) + sticky |= ((mantissas & 0x1FFFFF) != 0).to(tl.uint32) + else: + sticky = ((mantissas & 0x1FFFFF) != 0).to(tl.uint32) + round_inc = guard & (sticky | lsb_keep) + e2m1_tmp = tl.minimum((((exponents << 2) | m2bits) + round_inc) >> 1, 0x7) + e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8) + + e2m1_value = tl.reshape(e2m1_value, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2]) + evens, odds = tl.split(e2m1_value) + out_tensor = evens | (odds << 4) + else: + out_tensor = quant_tensor.to(mx_tensor_dtype) + + return out_tensor, dequant_scale, s_dec + +@triton.jit +def _compute_dequant( + mx_tensor, + scale, + s_dec, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, + dst_dtype: tl.constexpr, +): + tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"Block size along quantization block must be a multiple of {MXFP_BLOCK_SIZE=}") + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor.dtype + tl.static_assert(dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 or dst_dtype == tl.float32) + tl.static_assert( + mx_tensor_dtype == tl.uint8 + or ((mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) or mx_tensor_dtype == dst_dtype), + "mx_tensor_ptr must be uint8 or float8 or dst_dtype") + tl.static_assert(scale.dtype == tl.float8e4nv, "scale must be float8e4nv") + + # Determine if we are dealing with fp8 types. + is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8 + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE + + # Upcast the scale to the destination type. + if dst_dtype == tl.bfloat16: + dst_scale = scale.to(tl.bfloat16) + else: + dst_scale = scale.to(tl.float32) + if dst_dtype == tl.float16: + dst_scale = dst_scale.to(tl.float16) + + # Now upcast the tensor. + intermediate_dtype: tl.constexpr = tl.bfloat16 if dst_dtype == tl.float32 else dst_dtype + if cuda_capability_geq(10, 0): + assert is_fp4 + packed_u32 = tl.inline_asm_elementwise( + asm=""" + { + .reg .b8 in_8; + .reg .f16x2 out; + cvt.u8.u32 in_8, $1; + cvt.rn.f16x2.e2m1x2 out, in_8; + mov.b32 $0, out; + } + """, + constraints="=r,r", + args=[mx_tensor], # tl.uint8 passed in as a 32-bit reg with value in low 8 bits + dtype=tl.uint32, + is_pure=True, + pack=1, + ) + lo_u16 = (packed_u32 & 0xFFFF).to(tl.uint16) + hi_u16 = (packed_u32 >> 16).to(tl.uint16) + lo_f16 = lo_u16.to(tl.float16, bitcast=True) + hi_f16 = hi_u16.to(tl.float16, bitcast=True) + + if intermediate_dtype == tl.float16: + x0, x1 = lo_f16, hi_f16 + else: + x0 = lo_f16.to(intermediate_dtype) + x1 = hi_f16.to(intermediate_dtype) + + dst_tensor = tl.interleave(x0, x1) + + else: + assert is_fp4 + dst_bias: tl.constexpr = 127 if intermediate_dtype == tl.bfloat16 else 15 # exponent bias + dst_0p5: tl.constexpr = 16128 if intermediate_dtype == tl.bfloat16 else 0x3800 + dst_m_bits: tl.constexpr = 7 if intermediate_dtype == tl.bfloat16 else 10 # mantissa bits + # e2m1 + em0 = mx_tensor & 0x07 + em1 = mx_tensor & 0x70 + x0 = (em0.to(tl.uint16) << (dst_m_bits - 1)) | ((mx_tensor & 0x08).to(tl.uint16) << 12) + x1 = (em1.to(tl.uint16) << (dst_m_bits - 5)) | ((mx_tensor & 0x80).to(tl.uint16) << 8) + # Three cases: + # 1) x is normal and non-zero: Correct bias + x0 = tl.where((em0 & 0x06) != 0, x0 + ((dst_bias - 1) << dst_m_bits), x0) + x1 = tl.where((em1 & 0x60) != 0, x1 + ((dst_bias - 1) << dst_m_bits), x1) + # 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + x0 = tl.where(em0 == 0x01, dst_0p5 | (x0 & 0x8000), x0) + x1 = tl.where(em1 == 0x10, dst_0p5 | (x1 & 0x8000), x1) + # 3) x is zero, do nothing + dst_tensor = tl.interleave(x0, x1).to(intermediate_dtype, bitcast=True) + + dst_tensor = dst_tensor.to(dst_dtype) + + # Reshape for proper broadcasting: the scale was stored with a 16‐sized “inner” grouping. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) + scale = scale.reshape(dst_scale.shape) + + out_tensor = dst_tensor * dst_scale * s_dec # NVFP4 has the additional global scale factor + if dst_dtype == tl.float32: + max_fin = 3.4028234663852886e+38 + elif dst_dtype == tl.bfloat16: + max_fin = 3.3895313892515355e+38 + else: + tl.static_assert(dst_dtype == tl.float16) + max_fin = 65504 + out_tensor = tl.clamp(out_tensor, min=-max_fin, max=max_fin) + out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + out_tensor = out_tensor.to(dst_dtype) + return out_tensor diff --git a/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py new file mode 100644 index 0000000000..31d4acc796 --- /dev/null +++ b/fastvideo-kernel/python/fastvideo_kernel/triton_kernels/quant_utils.py @@ -0,0 +1,80 @@ +import triton +import triton.language as tl + +from .nvfp4_utils import _compute_quant_and_scale, _compute_dequant + +@triton.jit +def fake_quantize(src_tensor, valid_src_mask, BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, + dst_dtype: tl.constexpr, + mx_tensor_dtype: tl.constexpr = tl.uint8, + use_global_sf: tl.constexpr = True, + two_level_quant_P: tl.constexpr = False): + high_prec_src_tensor = src_tensor + src_tensor, src_scale, src_s_dec = _compute_quant_and_scale(src_tensor=src_tensor, + valid_src_mask=valid_src_mask, + mx_tensor_dtype=mx_tensor_dtype, + use_global_sf=use_global_sf, + two_level_quant_P=two_level_quant_P) + src_tensor = _compute_dequant(mx_tensor=src_tensor, + scale=src_scale, + s_dec=src_s_dec, + BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM, + BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM, + dst_dtype=dst_dtype) + return src_tensor, high_prec_src_tensor.to(src_tensor.dtype) + +@triton.jit +def fake_quantize_q(Q, fake_Q, stride_z_q, stride_h_q, + stride_tok_q, stride_d_q, + fake_stride_z_q, fake_stride_h_q, + fake_stride_tok_q, fake_stride_d_q, + H, N_CTX_Q, + BLOCK_M: tl.constexpr, + HEAD_DIM: tl.constexpr, + use_global_sf: tl.constexpr = True): + bhid = tl.program_id(1) + adj_q = (stride_h_q * (bhid % H) + stride_z_q * (bhid // H)) + fake_adj_q = (fake_stride_h_q * (bhid % H) + fake_stride_z_q * (bhid // H)) + Q += adj_q + fake_Q += fake_adj_q + + pid = tl.program_id(0) + start_m = pid * BLOCK_M + offs_m = start_m + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, HEAD_DIM) + + q_valid = offs_m < N_CTX_Q + q = tl.load(Q + offs_m[:, None] * stride_tok_q + offs_k[None, :] * stride_d_q, mask=q_valid[:, None], other=0.0) + q, _ = fake_quantize(src_tensor=q, valid_src_mask=q_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_M, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=q.dtype, use_global_sf=use_global_sf) + tl.store(fake_Q + offs_m[:, None] * fake_stride_tok_q + offs_k[None, :] * fake_stride_d_q, q, mask=q_valid[:, None]) + +@triton.jit +def fake_quantize_kv(K, V, fake_K, fake_V, stride_z_kv, stride_h_kv, + stride_tok_kv, stride_d_kv, + fake_stride_z_kv, fake_stride_h_kv, + fake_stride_tok_kv, fake_stride_d_kv, + H, N_CTX_KV, + BLOCK_N: tl.constexpr, + HEAD_DIM: tl.constexpr, + use_global_sf: tl.constexpr = True): + bhid = tl.program_id(1) + adj_kv = (stride_h_kv * (bhid % H) + stride_z_kv * (bhid // H)) + fake_adj_kv = (fake_stride_h_kv * (bhid % H) + fake_stride_z_kv * (bhid // H)) + K += adj_kv + V += adj_kv + fake_K += fake_adj_kv + fake_V += fake_adj_kv + + pid = tl.program_id(0) + start_n = pid * BLOCK_N + offs_n = start_n + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + + kv_valid = offs_n < N_CTX_KV + k_block = tl.load(K + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0) + v_block = tl.load(V + offs_n[:, None] * stride_tok_kv + offs_k[None, :] * stride_d_kv, mask=kv_valid[:, None], other=0.0) + k, _ = fake_quantize(src_tensor=k_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=k_block.dtype, use_global_sf=use_global_sf) + v, _ = fake_quantize(src_tensor=v_block, valid_src_mask=kv_valid[:, None], BLOCK_SIZE_OUT_DIM=BLOCK_N, BLOCK_SIZE_QUANT_DIM=HEAD_DIM, dst_dtype=v_block.dtype, use_global_sf=use_global_sf) + tl.store(fake_K + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, k, mask=kv_valid[:, None]) + tl.store(fake_V + offs_n[:, None] * fake_stride_tok_kv + offs_k[None, :] * fake_stride_d_kv, v, mask=kv_valid[:, None]) diff --git a/fastvideo-kernel/tests/__init__.py b/fastvideo-kernel/tests/__init__.py index e69de29bb2..e8ef18d510 100644 --- a/fastvideo-kernel/tests/__init__.py +++ b/fastvideo-kernel/tests/__init__.py @@ -0,0 +1,4 @@ +from ._bootstrap import ensure_local_kernel_sources_first + + +ensure_local_kernel_sources_first() diff --git a/fastvideo-kernel/tests/_bootstrap.py b/fastvideo-kernel/tests/_bootstrap.py new file mode 100644 index 0000000000..c37bfcede6 --- /dev/null +++ b/fastvideo-kernel/tests/_bootstrap.py @@ -0,0 +1,56 @@ +"""Test import helpers for preferring the in-tree fastvideo-kernel sources.""" + +from __future__ import annotations + +import importlib +import sys +from pathlib import Path + + +def _prepend_import_path(path: Path) -> None: + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _purge_package(name: str) -> None: + prefix = f"{name}." + for module_name in tuple(sys.modules): + if module_name == name or module_name.startswith(prefix): + sys.modules.pop(module_name, None) + + +def _module_is_from_checkout(module_file: str | None, checkout_root: Path) -> bool: + if module_file is None: + return False + + try: + module_path = Path(module_file).resolve() + except OSError: + return False + + return module_path.is_relative_to(checkout_root.resolve()) + + +def ensure_local_kernel_sources_first() -> None: + tests_root = Path(__file__).resolve().parent + kernel_root = tests_root.parent + repo_root = kernel_root.parent + kernel_python_root = kernel_root / "python" + + # Keep the in-tree kernel sources ahead of any preinstalled wheel so tests + # exercise the checkout under review. + for path in (repo_root, kernel_root, kernel_python_root): + _prepend_import_path(path) + + importlib.invalidate_caches() + + loaded_kernel = sys.modules.get("fastvideo_kernel") + if loaded_kernel is None: + return + + if not _module_is_from_checkout( + getattr(loaded_kernel, "__file__", None), + kernel_python_root, + ): + _purge_package("fastvideo_kernel") diff --git a/fastvideo-kernel/tests/conftest.py b/fastvideo-kernel/tests/conftest.py new file mode 100644 index 0000000000..e8ef18d510 --- /dev/null +++ b/fastvideo-kernel/tests/conftest.py @@ -0,0 +1,4 @@ +from ._bootstrap import ensure_local_kernel_sources_first + + +ensure_local_kernel_sources_first() diff --git a/fastvideo-kernel/tests/test_attn_qat_train.py b/fastvideo-kernel/tests/test_attn_qat_train.py new file mode 100644 index 0000000000..20783f306c --- /dev/null +++ b/fastvideo-kernel/tests/test_attn_qat_train.py @@ -0,0 +1,1503 @@ +#!/usr/bin/env python3 +""" +Simple test for attn_qat_train. +Tests forward and backward passes with and without QAT enabled. +""" + +import torch +from fastvideo_kernel.triton_kernels.attn_qat_train import _attention +from fastvideo_kernel.triton_kernels.fused_attention import attention as fused_attention +from math import sqrt + +attention = _attention.apply +DEVICE = torch.device("cuda") + + +def attn_qat_train_wrapper(q_BLHD, k_BLHD, v_BLHD, is_causal=False, sm_scale=None): + """ + Wrapper function that mimics attn_qat_train from the backend wrapper. + Converts from BLHD format to BHLD format, calls attention, then converts back. + + Args: + q_BLHD: Query tensor in (B, L, H, D) format + k_BLHD: Key tensor in (B, L, H, D) format + v_BLHD: Value tensor in (B, L, H, D) format + is_causal: Whether to apply causal masking + sm_scale: Scale factor for attention scores (if None, uses 1.0 / sqrt(D)) + + Returns: + Output tensor in (B, L, H, D) format + """ + if sm_scale is None: + sm_scale = 1.0 / sqrt(q_BLHD.shape[-1]) + + # Convert from BLHD to BHLD format + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + + # Call attention with BHLD format + o_BHLD = attention(q_BHLD, k_BHLD, v_BHLD, is_causal, sm_scale) + + # Convert back from BHLD to BLHD format + return o_BHLD.permute(0, 2, 1, 3).contiguous() + + +def fused_attn_wrapper(q_BLHD, k_BLHD, v_BLHD, is_causal=False, sm_scale=None, warp_specialize=True): + """ + Wrapper function for fused_attention from fused_attention.py. + Converts from BLHD format to BHLD format, calls attention, then converts back. + + Args: + q_BLHD: Query tensor in (B, L, H, D) format + k_BLHD: Key tensor in (B, L, H, D) format + v_BLHD: Value tensor in (B, L, H, D) format + is_causal: Whether to apply causal masking + sm_scale: Scale factor for attention scores (if None, uses 1.0 / sqrt(D)) + warp_specialize: Whether to use warp specialization (default: True) + + Returns: + Output tensor in (B, L, H, D) format + """ + if sm_scale is None: + sm_scale = 1.0 / sqrt(q_BLHD.shape[-1]) + + # Convert from BLHD to BHLD format + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + + # Call fused attention with BHLD format + # Note: fused_attention expects inputs in BHLD format and only supports same sequence lengths + o_BHLD = fused_attention(q_BHLD, k_BHLD, v_BHLD, is_causal, sm_scale, warp_specialize) + + # Convert back from BHLD to BLHD format + return o_BHLD.permute(0, 2, 1, 3).contiguous() + + +def cosine_similarity(tensor1, tensor2): + """ + Compute cosine similarity between two tensors. + + Args: + tensor1: First tensor + tensor2: Second tensor (same shape as tensor1) + + Returns: + Cosine similarity value (scalar) + """ + # Flatten tensors for computation + t1_flat = tensor1.flatten().float() + t2_flat = tensor2.flatten().float() + + # Compute cosine similarity: (A · B) / (||A|| * ||B||) + dot_product = torch.dot(t1_flat, t2_flat) + norm1 = torch.norm(t1_flat) + norm2 = torch.norm(t2_flat) + + # Avoid division by zero + if norm1 == 0 or norm2 == 0: + return 0.0 + + cos_sim = dot_product / (norm1 * norm2) + return cos_sim.item() + + +def naive_attention(q, k, v, causal, sm_scale): + """ + Naive PyTorch implementation of attention for comparison. + + Args: + q: Query tensor of shape (Z, H, N_CTX_Q, HEAD_DIM) + k: Key tensor of shape (Z, H, N_CTX_KV, HEAD_DIM) + v: Value tensor of shape (Z, H, N_CTX_KV, HEAD_DIM) + causal: Whether to apply causal masking (only meaningful when N_CTX_Q == N_CTX_KV) + sm_scale: Scale factor for attention scores + + Returns: + Output tensor of shape (Z, H, N_CTX_Q, HEAD_DIM) + """ + # Compute attention scores: QK^T + scores = torch.matmul(q, k.transpose(-2, -1)) * sm_scale + + # Apply causal mask if needed + # Note: Causal masking is only meaningful when Q and KV have the same sequence length + if causal: + N_CTX_Q = q.shape[-2] + N_CTX_KV = k.shape[-2] + if N_CTX_Q == N_CTX_KV: + # Standard causal masking: query i can only attend to keys 0 to i + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_KV, device=scores.device, dtype=scores.dtype)) + scores = scores.masked_fill(mask == 0, float("-inf")) + else: + # For different sequence lengths, causal masking doesn't apply in the same way + # In cross-attention, typically no causal masking is applied + pass + + # Apply softmax + attn_weights = torch.softmax(scores, dim=-1) + + # Apply attention weights to values + out = torch.matmul(attn_weights, v) + + return out + +def test_qat_attention_forward(): + """Test forward pass with QAT enabled vs disabled.""" + torch.manual_seed(42) + + # Test parameters + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False + + # Create input tensors in BLHD format (B, L, H, D) = (Z, N_CTX, H, HEAD_DIM) + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # Test with QAT (using wrapper that does permute/contiguous) + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention comparison + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Check that outputs have correct shape + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + + # Check that outputs are finite + assert torch.isfinite(out_qat_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare QAT and naive outputs + max_diff = (out_qat_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_qat_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_qat_BLHD, out_naive_BLHD) + print(f" QAT vs Naive Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM} - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # QAT output should be reasonably close to naive (within tolerance due to quantization) + # Using a reasonable tolerance for float16 and quantization effects + # assert max_diff < 1.0, f"QAT and naive outputs should be reasonably close, got max_diff={max_diff.item():.6f}" + + print("✓ Forward pass test passed.") + + +def test_qat_attention_backward(): + """Test backward pass with QAT enabled.""" + torch.manual_seed(42) + + # Test parameters + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format for QAT + q_qat_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_qat_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_qat_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # Create input tensors for naive (same values, convert to BHLD) + q_naive_BHLD = q_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + k_naive_BHLD = k_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + v_naive_BHLD = v_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + + # Forward pass with QAT (using wrapper) + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + + # Forward pass with naive + out_naive_BHLD = naive_attention(q_naive_BHLD, k_naive_BHLD, v_naive_BHLD, causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Create dummy gradient (same for both, in BLHD format) + # Ensure gradient is contiguous as required by the backward function + dout = torch.randn_like(out_qat_BLHD).contiguous() + + # Backward pass for QAT + out_qat_BLHD.backward(dout) + + # Backward pass for naive + out_naive_BLHD.backward(dout) + + # Check that gradients are computed + assert q_qat_BLHD.grad is not None + assert k_qat_BLHD.grad is not None + assert v_qat_BLHD.grad is not None + assert q_naive_BHLD.grad is not None + assert k_naive_BHLD.grad is not None + assert v_naive_BHLD.grad is not None + + # Check that gradients have correct shape + assert q_qat_BLHD.grad.shape == q_qat_BLHD.shape + assert k_qat_BLHD.grad.shape == k_qat_BLHD.shape + assert v_qat_BLHD.grad.shape == v_qat_BLHD.shape + + # Check that gradients are finite + assert torch.isfinite(q_qat_BLHD.grad).all() + assert torch.isfinite(k_qat_BLHD.grad).all() + assert torch.isfinite(v_qat_BLHD.grad).all() + + # Convert naive gradients to BLHD format for comparison + q_naive_grad_BLHD = q_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + k_naive_grad_BLHD = k_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + v_naive_grad_BLHD = v_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + + # Compare gradients with naive + dq_diff = (q_qat_BLHD.grad - q_naive_grad_BLHD).abs() + dk_diff = (k_qat_BLHD.grad - k_naive_grad_BLHD).abs() + dv_diff = (v_qat_BLHD.grad - v_naive_grad_BLHD).abs() + + dq_cos_sim = cosine_similarity(q_qat_BLHD.grad, q_naive_grad_BLHD) + dk_cos_sim = cosine_similarity(k_qat_BLHD.grad, k_naive_grad_BLHD) + dv_cos_sim = cosine_similarity(v_qat_BLHD.grad, v_naive_grad_BLHD) + + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cosine sim: {dq_cos_sim:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cosine sim: {dk_cos_sim:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cosine sim: {dv_cos_sim:.6f}") + + # Gradients should be reasonably close (within tolerance due to quantization) + # assert dq_diff.max() < 2.0, f"dQ gradients should be reasonably close, got max_diff={dq_diff.max().item():.6f}" + # assert dk_diff.max() < 2.0, f"dK gradients should be reasonably close, got max_diff={dk_diff.max().item():.6f}" + # assert dv_diff.max() < 2.0, f"dV gradients should be reasonably close, got max_diff={dv_diff.max().item():.6f}" + + print("✓ Backward pass test passed.") + + +def test_qat_attention_different_shapes(): + """Test QAT attention with different input shapes.""" + torch.manual_seed(42) + + test_configs = [ + (2, 4, 128, 64), # Medium + (1, 8, 256, 128), # Large head dim + ] + + dtype = torch.bfloat16 + causal = True + + for Z, H, N_CTX, HEAD_DIM in test_configs: + # Create input tensors in BLHD format + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + sm_scale = 1.0 / sqrt(HEAD_DIM) + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_qat_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare outputs + max_diff = (out_qat_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_qat_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_qat_BLHD, out_naive_BLHD) + print(f" (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs should be reasonably close + # assert max_diff < 1.0, f"QAT and naive outputs should be reasonably close for shape (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}), got max_diff={max_diff.item():.6f}" + + print(f"✓ Shape test passed for (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM})") + + +def test_qat_attention_non_causal(): + """Test QAT attention with non-causal masking.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False + + # Create input tensors in BLHD format + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_qat_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare outputs + max_diff = (out_qat_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_qat_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_qat_BLHD, out_naive_BLHD) + print(f" QAT vs Naive (non-causal) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs should be reasonably close + # assert max_diff < 1.0, f"QAT and naive outputs should be reasonably close for non-causal, got max_diff={max_diff.item():.6f}" + + print("✓ Non-causal test passed.") + + +def test_qat_attention_causal(): + """Test QAT attention with causal masking.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_qat_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare outputs + max_diff = (out_qat_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_qat_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_qat_BLHD, out_naive_BLHD) + print(f" QAT vs Naive (causal) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs should be reasonably close + # assert max_diff < 1.0, f"QAT and naive outputs should be reasonably close for causal, got max_diff={max_diff.item():.6f}" + + print("✓ Causal test passed.") + + +def test_qat_attention_causal_backward(): + """Test backward pass of QAT attention with causal masking vs naive.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format for QAT + q_qat_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_qat_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_qat_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # Create input tensors for naive (same values, convert to BHLD) + q_naive_BHLD = q_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + k_naive_BHLD = k_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + v_naive_BHLD = v_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + + # Forward pass with QAT + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + + # Forward pass with naive + out_naive_BHLD = naive_attention(q_naive_BHLD, k_naive_BHLD, v_naive_BHLD, causal, sm_scale) + out_naive_BHLD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Create dummy gradient + dout = torch.randn_like(out_qat_BLHD).contiguous() + + # Backward pass for QAT + out_qat_BLHD.backward(dout) + + # Backward pass for naive + out_naive_BHLD.backward(dout) + + # Check that gradients are computed + assert q_qat_BLHD.grad is not None + assert k_qat_BLHD.grad is not None + assert v_qat_BLHD.grad is not None + + # Convert naive gradients to BLHD format for comparison + q_naive_grad_BLHD = q_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + k_naive_grad_BLHD = k_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + v_naive_grad_BLHD = v_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + + # Compare gradients + dq_diff = (q_qat_BLHD.grad - q_naive_grad_BLHD).abs() + dk_diff = (k_qat_BLHD.grad - k_naive_grad_BLHD).abs() + dv_diff = (v_qat_BLHD.grad - v_naive_grad_BLHD).abs() + + dq_cos_sim = cosine_similarity(q_qat_BLHD.grad, q_naive_grad_BLHD) + dk_cos_sim = cosine_similarity(k_qat_BLHD.grad, k_naive_grad_BLHD) + dv_cos_sim = cosine_similarity(v_qat_BLHD.grad, v_naive_grad_BLHD) + + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cosine sim: {dq_cos_sim:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cosine sim: {dk_cos_sim:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cosine sim: {dv_cos_sim:.6f}") + + print("✓ Causal backward test passed.") + + +def test_qat_attention_different_seq_lengths(): + """Test QAT attention with different sequence lengths for Q and KV (cross-attention).""" + torch.manual_seed(42) + + test_configs = [ + (2, 4, 64, 128, 64, False), # Q shorter than KV, non-causal + (2, 4, 128, 64, 64, False), # Q longer than KV, non-causal + (1, 8, 256, 128, 64, False), # Q longer than KV, larger dimensions + ] + + dtype = torch.bfloat16 + + for Z, H, N_CTX_Q, N_CTX_KV, HEAD_DIM, causal in test_configs: + sm_scale = 1.0 / sqrt(HEAD_DIM) + + # Create input tensors in BLHD format with different sequence lengths + q_BLHD = torch.randn((Z, N_CTX_Q, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX_KV, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX_KV, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + # Test with QAT (using wrapper that does permute/contiguous) + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention comparison + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Check that outputs have correct shape (should match Q sequence length) + assert out_qat_BLHD.shape == (Z, N_CTX_Q, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX_Q, H, HEAD_DIM) + + # Check that outputs are finite + assert torch.isfinite(out_qat_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare QAT and naive outputs + max_diff = (out_qat_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_qat_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_qat_BLHD, out_naive_BLHD) + print(f" (N_CTX_Q={N_CTX_Q}, N_CTX_KV={N_CTX_KV}, causal={causal}) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs should be reasonably close + # assert max_diff < 1.0, f"QAT and naive outputs should be reasonably close for (N_CTX_Q={N_CTX_Q}, N_CTX_KV={N_CTX_KV}), got max_diff={max_diff.item():.6f}" + + print(f"✓ Different sequence lengths test passed for (N_CTX_Q={N_CTX_Q}, N_CTX_KV={N_CTX_KV})") + + +def test_qat_attention_different_seq_lengths_backward(): + """Test backward pass of QAT attention with different sequence lengths for Q and KV (cross-attention).""" + torch.manual_seed(42) + + test_configs = [ + (2, 4, 64, 128, 64, False), # Q shorter than KV, non-causal + (2, 4, 128, 64, 64, False), # Q longer than KV, non-causal + ] + + dtype = torch.bfloat16 + + for Z, H, N_CTX_Q, N_CTX_KV, HEAD_DIM, causal in test_configs: + sm_scale = 1.0 / sqrt(HEAD_DIM) + + # Create input tensors in BLHD format with different sequence lengths + q_qat_BLHD = torch.randn((Z, N_CTX_Q, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_qat_BLHD = torch.randn((Z, N_CTX_KV, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_qat_BLHD = torch.randn((Z, N_CTX_KV, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # Create input tensors for naive (same values, convert to BHLD) + q_naive_BHLD = q_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + k_naive_BHLD = k_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + v_naive_BHLD = v_qat_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + + # Forward pass with QAT + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + + # Forward pass with naive + out_naive_BHLD = naive_attention(q_naive_BHLD, k_naive_BHLD, v_naive_BHLD, causal, sm_scale) + out_naive_BHLD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Create dummy gradient + dout = torch.randn_like(out_qat_BLHD).contiguous() + + # Backward pass for QAT + out_qat_BLHD.backward(dout) + + # Backward pass for naive + out_naive_BHLD.backward(dout) + + # Check that gradients are computed + assert q_qat_BLHD.grad is not None + assert k_qat_BLHD.grad is not None + assert v_qat_BLHD.grad is not None + assert q_naive_BHLD.grad is not None + assert k_naive_BHLD.grad is not None + assert v_naive_BHLD.grad is not None + + # Check that gradients have correct shape + assert q_qat_BLHD.grad.shape == q_qat_BLHD.shape + assert k_qat_BLHD.grad.shape == k_qat_BLHD.shape + assert v_qat_BLHD.grad.shape == v_qat_BLHD.shape + + # Check that gradients are finite + assert torch.isfinite(q_qat_BLHD.grad).all() + assert torch.isfinite(k_qat_BLHD.grad).all() + assert torch.isfinite(v_qat_BLHD.grad).all() + + # Convert naive gradients to BLHD format for comparison + q_naive_grad_BLHD = q_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + k_naive_grad_BLHD = k_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + v_naive_grad_BLHD = v_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + + # Compare gradients + dq_diff = (q_qat_BLHD.grad - q_naive_grad_BLHD).abs() + dk_diff = (k_qat_BLHD.grad - k_naive_grad_BLHD).abs() + dv_diff = (v_qat_BLHD.grad - v_naive_grad_BLHD).abs() + + dq_cos_sim = cosine_similarity(q_qat_BLHD.grad, q_naive_grad_BLHD) + dk_cos_sim = cosine_similarity(k_qat_BLHD.grad, k_naive_grad_BLHD) + dv_cos_sim = cosine_similarity(v_qat_BLHD.grad, v_naive_grad_BLHD) + + print(f" (N_CTX_Q={N_CTX_Q}, N_CTX_KV={N_CTX_KV}) - dQ: Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cosine sim: {dq_cos_sim:.6f}") + print(f" (N_CTX_Q={N_CTX_Q}, N_CTX_KV={N_CTX_KV}) - dK: Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cosine sim: {dk_cos_sim:.6f}") + print(f" (N_CTX_Q={N_CTX_Q}, N_CTX_KV={N_CTX_KV}) - dV: Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cosine sim: {dv_cos_sim:.6f}") + + # Gradients should be reasonably close (within tolerance due to quantization) + # assert dq_diff.max() < 2.0, f"dQ gradients should be reasonably close, got max_diff={dq_diff.max().item():.6f}" + # assert dk_diff.max() < 2.0, f"dK gradients should be reasonably close, got max_diff={dk_diff.max().item():.6f}" + # assert dv_diff.max() < 2.0, f"dV gradients should be reasonably close, got max_diff={dv_diff.max().item():.6f}" + + print(f"✓ Cross-attention backward test passed for (N_CTX_Q={N_CTX_Q}, N_CTX_KV={N_CTX_KV})") + + +def test_fused_attention_forward(): + """Test forward pass of fused attention vs naive PyTorch implementation.""" + torch.manual_seed(42) + + # Test parameters + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 # fused_attention uses float16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False + + # Create input tensors in BLHD format (B, L, H, D) = (Z, N_CTX, H, HEAD_DIM) + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + # Test with fused attention (using wrapper that does permute/contiguous) + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention comparison + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Check that outputs have correct shape + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + + # Check that outputs are finite + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare fused and naive outputs + max_diff = (out_fused_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_naive_BLHD) + print(f" Fused vs Naive - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Fused attention should be reasonably close to naive (within tolerance for float16) + # assert max_diff < 1e-1, f"Fused and naive outputs should be reasonably close, got max_diff={max_diff.item():.6f}" + + print("✓ Fused attention forward pass test passed.") + + +def test_fused_attention_backward(): + """Test backward pass of fused attention vs naive PyTorch implementation.""" + torch.manual_seed(42) + + # Test parameters + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format for fused attention + q_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # Create input tensors for naive (same values, convert to BHLD) + q_naive_BHLD = q_fused_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + k_naive_BHLD = k_fused_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + v_naive_BHLD = v_fused_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + + # Forward pass with fused attention (using wrapper) + out_fused_BLHD = fused_attn_wrapper(q_fused_BLHD, k_fused_BLHD, v_fused_BLHD, causal, sm_scale) + + # Forward pass with naive + out_naive_BHLD = naive_attention(q_naive_BHLD, k_naive_BHLD, v_naive_BHLD, causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Create dummy gradient (same for both, in BLHD format) + dout = torch.randn_like(out_fused_BLHD).contiguous() + + # Backward pass for fused attention + out_fused_BLHD.backward(dout) + + # Backward pass for naive + out_naive_BLHD.backward(dout) + + # Check that gradients are computed + assert q_fused_BLHD.grad is not None + assert k_fused_BLHD.grad is not None + assert v_fused_BLHD.grad is not None + assert q_naive_BHLD.grad is not None + assert k_naive_BHLD.grad is not None + assert v_naive_BHLD.grad is not None + + # Check that gradients have correct shape + assert q_fused_BLHD.grad.shape == q_fused_BLHD.shape + assert k_fused_BLHD.grad.shape == k_fused_BLHD.shape + assert v_fused_BLHD.grad.shape == v_fused_BLHD.shape + + # Check that gradients are finite + assert torch.isfinite(q_fused_BLHD.grad).all() + assert torch.isfinite(k_fused_BLHD.grad).all() + assert torch.isfinite(v_fused_BLHD.grad).all() + + # Convert naive gradients to BLHD format for comparison + q_naive_grad_BLHD = q_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + k_naive_grad_BLHD = k_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + v_naive_grad_BLHD = v_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + + # Compare gradients with naive + dq_diff = (q_fused_BLHD.grad - q_naive_grad_BLHD).abs() + dk_diff = (k_fused_BLHD.grad - k_naive_grad_BLHD).abs() + dv_diff = (v_fused_BLHD.grad - v_naive_grad_BLHD).abs() + + dq_cos_sim = cosine_similarity(q_fused_BLHD.grad, q_naive_grad_BLHD) + dk_cos_sim = cosine_similarity(k_fused_BLHD.grad, k_naive_grad_BLHD) + dv_cos_sim = cosine_similarity(v_fused_BLHD.grad, v_naive_grad_BLHD) + + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cosine sim: {dq_cos_sim:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cosine sim: {dk_cos_sim:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cosine sim: {dv_cos_sim:.6f}") + + # Gradients should be reasonably close (within tolerance for float16) + # assert dq_diff.max() < 1e-1, f"dQ gradients should be reasonably close, got max_diff={dq_diff.max().item():.6f}" + # assert dk_diff.max() < 1e-1, f"dK gradients should be reasonably close, got max_diff={dk_diff.max().item():.6f}" + # assert dv_diff.max() < 1e-1, f"dV gradients should be reasonably close, got max_diff={dv_diff.max().item():.6f}" + + print("✓ Fused attention backward pass test passed.") + + +def test_fused_attention_different_shapes(): + """Test fused attention with different input shapes.""" + torch.manual_seed(42) + + test_configs = [ + (1, 2, 64, 32), # Small + (2, 4, 128, 64), # Medium + (1, 8, 256, 128), # Large head dim + ] + + dtype = torch.bfloat16 + causal = True + + for Z, H, N_CTX, HEAD_DIM in test_configs: + # Create input tensors in BLHD format + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + sm_scale = 1.0 / sqrt(HEAD_DIM) + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare outputs + max_diff = (out_fused_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_naive_BLHD) + print(f" (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs should be reasonably close + # assert max_diff < 1e-1, f"Fused and naive outputs should be reasonably close for shape (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}), got max_diff={max_diff.item():.6f}" + + print(f"✓ Fused attention shape test passed for (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM})") + + +def test_fused_attention_non_causal(): + """Test fused attention with non-causal masking.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False + + # Create input tensors in BLHD format + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare outputs + max_diff = (out_fused_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_naive_BLHD) + print(f" Fused vs Naive (non-causal) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs should be reasonably close + # assert max_diff < 1e-1, f"Fused and naive outputs should be reasonably close for non-causal, got max_diff={max_diff.item():.6f}" + + print("✓ Fused attention non-causal test passed.") + + +def test_fused_attention_causal(): + """Test fused attention with causal masking.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare outputs + max_diff = (out_fused_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_naive_BLHD) + print(f" Fused vs Naive (causal) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs should be reasonably close + # assert max_diff < 1e-1, f"Fused and naive outputs should be reasonably close for causal, got max_diff={max_diff.item():.6f}" + + print("✓ Fused attention causal test passed.") + + +def test_fused_attention_causal_backward(): + """Test backward pass of fused attention with causal masking vs naive.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format for fused attention + q_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # Create input tensors for naive (same values, convert to BHLD) + q_naive_BHLD = q_fused_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + k_naive_BHLD = k_fused_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + v_naive_BHLD = v_fused_BLHD.clone().permute(0, 2, 1, 3).contiguous().detach().requires_grad_(True) + + # Forward pass with fused attention + out_fused_BLHD = fused_attn_wrapper(q_fused_BLHD, k_fused_BLHD, v_fused_BLHD, causal, sm_scale) + + # Forward pass with naive + out_naive_BHLD = naive_attention(q_naive_BHLD, k_naive_BHLD, v_naive_BHLD, causal, sm_scale) + out_naive_BHLD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + + # Create dummy gradient + dout = torch.randn_like(out_fused_BLHD).contiguous() + + # Backward pass for fused attention + out_fused_BLHD.backward(dout) + + # Backward pass for naive + out_naive_BHLD.backward(dout) + + # Check that gradients are computed + assert q_fused_BLHD.grad is not None + assert k_fused_BLHD.grad is not None + assert v_fused_BLHD.grad is not None + + # Convert naive gradients to BLHD format for comparison + q_naive_grad_BLHD = q_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + k_naive_grad_BLHD = k_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + v_naive_grad_BLHD = v_naive_BHLD.grad.permute(0, 2, 1, 3).contiguous() + + # Compare gradients + dq_diff = (q_fused_BLHD.grad - q_naive_grad_BLHD).abs() + dk_diff = (k_fused_BLHD.grad - k_naive_grad_BLHD).abs() + dv_diff = (v_fused_BLHD.grad - v_naive_grad_BLHD).abs() + + dq_cos_sim = cosine_similarity(q_fused_BLHD.grad, q_naive_grad_BLHD) + dk_cos_sim = cosine_similarity(k_fused_BLHD.grad, k_naive_grad_BLHD) + dv_cos_sim = cosine_similarity(v_fused_BLHD.grad, v_naive_grad_BLHD) + + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cosine sim: {dq_cos_sim:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cosine sim: {dk_cos_sim:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cosine sim: {dv_cos_sim:.6f}") + + print("✓ Fused attention causal backward test passed.") + + +def test_fused_vs_qat_attention_forward(): + """Test forward pass comparing fused attention vs QAT attention.""" + torch.manual_seed(42) + + # Test parameters + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False + + # Create input tensors in BLHD format + # Use float16 for both (QAT can handle float16, though it typically uses float16) + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + + # Test with fused attention + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Test with QAT attention (convert to float16 for QAT) + q_qat_BLHD = q_BLHD.clone().to(torch.bfloat16) + k_qat_BLHD = k_BLHD.clone().to(torch.bfloat16) + v_qat_BLHD = v_BLHD.clone().to(torch.bfloat16) + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + + # Convert QAT output back to float16 for comparison + out_qat_BLHD = out_qat_BLHD.to(torch.bfloat16) + + # Check that outputs have correct shape + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + + # Check that outputs are finite + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_qat_BLHD).all() + + # Compare fused and QAT outputs + max_diff = (out_fused_BLHD - out_qat_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_qat_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_qat_BLHD) + print(f" Fused vs QAT - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # They may differ due to quantization in QAT and different implementations + # assert max_diff < 1.0, f"Fused and QAT outputs should be reasonably close, got max_diff={max_diff.item():.6f}" + + print("✓ Fused vs QAT forward pass test passed.") + + +def test_fused_vs_qat_attention_backward(): + """Test backward pass comparing fused attention vs QAT attention.""" + torch.manual_seed(42) + + # Test parameters + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format for fused attention (float16) + q_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE, requires_grad=True) + k_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE, requires_grad=True) + v_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE, requires_grad=True) + + # Create input tensors for QAT (float16, same values) + q_qat_BLHD = q_fused_BLHD.clone().to(torch.bfloat16).detach().requires_grad_(True) + k_qat_BLHD = k_fused_BLHD.clone().to(torch.bfloat16).detach().requires_grad_(True) + v_qat_BLHD = v_fused_BLHD.clone().to(torch.bfloat16).detach().requires_grad_(True) + + # Forward pass with fused attention + out_fused_BLHD = fused_attn_wrapper(q_fused_BLHD, k_fused_BLHD, v_fused_BLHD, causal, sm_scale) + + # Forward pass with QAT + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + + # Create dummy gradient (same for both, in BLHD format) + dout = torch.randn_like(out_fused_BLHD).contiguous() + dout_qat = dout.to(torch.bfloat16) + + # Backward pass for fused attention + out_fused_BLHD.backward(dout) + + # Backward pass for QAT + out_qat_BLHD.backward(dout_qat) + + # Check that gradients are computed + assert q_fused_BLHD.grad is not None + assert k_fused_BLHD.grad is not None + assert v_fused_BLHD.grad is not None + assert q_qat_BLHD.grad is not None + assert k_qat_BLHD.grad is not None + assert v_qat_BLHD.grad is not None + + # Check that gradients have correct shape + assert q_fused_BLHD.grad.shape == q_fused_BLHD.shape + assert k_fused_BLHD.grad.shape == k_fused_BLHD.shape + assert v_fused_BLHD.grad.shape == v_fused_BLHD.shape + + # Check that gradients are finite + assert torch.isfinite(q_fused_BLHD.grad).all() + assert torch.isfinite(k_fused_BLHD.grad).all() + assert torch.isfinite(v_fused_BLHD.grad).all() + + # Convert QAT gradients to float16 for comparison + q_qat_grad_f16 = q_qat_BLHD.grad.to(torch.bfloat16) + k_qat_grad_f16 = k_qat_BLHD.grad.to(torch.bfloat16) + v_qat_grad_f16 = v_qat_BLHD.grad.to(torch.bfloat16) + + # Compare gradients + dq_diff = (q_fused_BLHD.grad - q_qat_grad_f16).abs() + dk_diff = (k_fused_BLHD.grad - k_qat_grad_f16).abs() + dv_diff = (v_fused_BLHD.grad - v_qat_grad_f16).abs() + + dq_cos_sim = cosine_similarity(q_fused_BLHD.grad, q_qat_grad_f16) + dk_cos_sim = cosine_similarity(k_fused_BLHD.grad, k_qat_grad_f16) + dv_cos_sim = cosine_similarity(v_fused_BLHD.grad, v_qat_grad_f16) + + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cosine sim: {dq_cos_sim:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cosine sim: {dk_cos_sim:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cosine sim: {dv_cos_sim:.6f}") + + # Gradients may differ due to quantization in QAT and different implementations + # assert dq_diff.max() < 2.0, f"dQ gradients should be reasonably close, got max_diff={dq_diff.max().item():.6f}" + # assert dk_diff.max() < 2.0, f"dK gradients should be reasonably close, got max_diff={dk_diff.max().item():.6f}" + # assert dv_diff.max() < 2.0, f"dV gradients should be reasonably close, got max_diff={dv_diff.max().item():.6f}" + + print("✓ Fused vs QAT backward pass test passed.") + + +def test_fused_vs_qat_attention_different_shapes(): + """Test fused vs QAT attention with different input shapes.""" + torch.manual_seed(42) + + test_configs = [ + (2, 4, 128, 64), # Medium + (1, 8, 256, 128), # Large head dim + ] + + causal = True + + for Z, H, N_CTX, HEAD_DIM in test_configs: + sm_scale = 1.0 / sqrt(HEAD_DIM) + + # Create input tensors in BLHD format (float16) + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + + # Test with fused attention + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Test with QAT attention (convert to float16) + q_qat_BLHD = q_BLHD.clone().to(torch.bfloat16) + k_qat_BLHD = k_BLHD.clone().to(torch.bfloat16) + v_qat_BLHD = v_BLHD.clone().to(torch.bfloat16) + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + out_qat_BLHD = out_qat_BLHD.to(torch.bfloat16) + + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_qat_BLHD).all() + + # Compare outputs + max_diff = (out_fused_BLHD - out_qat_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_qat_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_qat_BLHD) + print(f" (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs may differ due to quantization in QAT + # assert max_diff < 1.0, f"Fused and QAT outputs should be reasonably close for shape (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}), got max_diff={max_diff.item():.6f}" + + print(f"✓ Fused vs QAT shape test passed for (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM})") + + +def test_fused_vs_qat_attention_non_causal(): + """Test fused vs QAT attention with non-causal masking.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False + + # Create input tensors in BLHD format (float16) + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + + # Test with fused attention + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Test with QAT attention (convert to float16) + q_qat_BLHD = q_BLHD.clone().to(torch.bfloat16) + k_qat_BLHD = k_BLHD.clone().to(torch.bfloat16) + v_qat_BLHD = v_BLHD.clone().to(torch.bfloat16) + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + out_qat_BLHD = out_qat_BLHD.to(torch.bfloat16) + + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_qat_BLHD).all() + + # Compare outputs + max_diff = (out_fused_BLHD - out_qat_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_qat_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_qat_BLHD) + print(f" Fused vs QAT (non-causal) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs may differ due to quantization in QAT + # assert max_diff < 1.0, f"Fused and QAT outputs should be reasonably close for non-causal, got max_diff={max_diff.item():.6f}" + + print("✓ Fused vs QAT non-causal test passed.") + + +def test_fused_vs_qat_attention_causal(): + """Test fused vs QAT attention with causal masking.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format (float16) + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE) + + # Test with fused attention + out_fused_BLHD = fused_attn_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Test with QAT attention (convert to float16) + q_qat_BLHD = q_BLHD.clone().to(torch.bfloat16) + k_qat_BLHD = k_BLHD.clone().to(torch.bfloat16) + v_qat_BLHD = v_BLHD.clone().to(torch.bfloat16) + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + out_qat_BLHD = out_qat_BLHD.to(torch.bfloat16) + + assert out_fused_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert torch.isfinite(out_fused_BLHD).all() + assert torch.isfinite(out_qat_BLHD).all() + + # Compare outputs + max_diff = (out_fused_BLHD - out_qat_BLHD).abs().max() + mean_diff = (out_fused_BLHD - out_qat_BLHD).abs().mean() + cos_sim = cosine_similarity(out_fused_BLHD, out_qat_BLHD) + print(f" Fused vs QAT (causal) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Outputs may differ due to quantization in QAT + # assert max_diff < 1.0, f"Fused and QAT outputs should be reasonably close for causal, got max_diff={max_diff.item():.6f}" + + print("✓ Fused vs QAT causal test passed.") + + +def test_fused_vs_qat_attention_causal_backward(): + """Test backward pass of fused vs QAT attention with causal masking.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 2, 4, 128, 64 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = True + + # Create input tensors in BLHD format for fused attention (float16) + q_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE, requires_grad=True) + k_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE, requires_grad=True) + v_fused_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=torch.bfloat16, device=DEVICE, requires_grad=True) + + # Create input tensors for QAT (float16, same values) + q_qat_BLHD = q_fused_BLHD.clone().to(torch.bfloat16).detach().requires_grad_(True) + k_qat_BLHD = k_fused_BLHD.clone().to(torch.bfloat16).detach().requires_grad_(True) + v_qat_BLHD = v_fused_BLHD.clone().to(torch.bfloat16).detach().requires_grad_(True) + + # Forward pass with fused attention + out_fused_BLHD = fused_attn_wrapper(q_fused_BLHD, k_fused_BLHD, v_fused_BLHD, causal, sm_scale) + + # Forward pass with QAT + out_qat_BLHD = attn_qat_train_wrapper(q_qat_BLHD, k_qat_BLHD, v_qat_BLHD, causal, sm_scale) + + # Create dummy gradient + dout = torch.randn_like(out_fused_BLHD).contiguous() + dout_qat = dout.to(torch.bfloat16) + + # Backward pass for fused attention + out_fused_BLHD.backward(dout) + + # Backward pass for QAT + out_qat_BLHD.backward(dout_qat) + + # Check that gradients are computed + assert q_fused_BLHD.grad is not None + assert k_fused_BLHD.grad is not None + assert v_fused_BLHD.grad is not None + assert q_qat_BLHD.grad is not None + assert k_qat_BLHD.grad is not None + assert v_qat_BLHD.grad is not None + + # Convert QAT gradients to float16 for comparison + q_qat_grad_f16 = q_qat_BLHD.grad.to(torch.bfloat16) + k_qat_grad_f16 = k_qat_BLHD.grad.to(torch.bfloat16) + v_qat_grad_f16 = v_qat_BLHD.grad.to(torch.bfloat16) + + # Compare gradients + dq_diff = (q_fused_BLHD.grad - q_qat_grad_f16).abs() + dk_diff = (k_fused_BLHD.grad - k_qat_grad_f16).abs() + dv_diff = (v_fused_BLHD.grad - v_qat_grad_f16).abs() + + dq_cos_sim = cosine_similarity(q_fused_BLHD.grad, q_qat_grad_f16) + dk_cos_sim = cosine_similarity(k_fused_BLHD.grad, k_qat_grad_f16) + dv_cos_sim = cosine_similarity(v_fused_BLHD.grad, v_qat_grad_f16) + + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cosine sim: {dq_cos_sim:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cosine sim: {dk_cos_sim:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cosine sim: {dv_cos_sim:.6f}") + + print("✓ Fused vs QAT causal backward test passed.") + + +def test_qat_attention_wan_shape_forward(): + """Test QAT attention forward pass with WAN shape [1, 40, 9360, 128].""" + torch.manual_seed(42) + + # WAN shape: [1, 40, 9360, 128] = (Z, H, N_CTX, HEAD_DIM) + Z, H, N_CTX, HEAD_DIM = 1, 40, 9360, 128 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False # WAN uses non-causal attention + + print(f" Testing WAN shape: Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}") + + # Create input tensors in BLHD format (B, L, H, D) = (Z, N_CTX, H, HEAD_DIM) + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + # Test with QAT (using wrapper that does permute/contiguous) + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD.clone(), k_BLHD.clone(), v_BLHD.clone(), causal, sm_scale) + + # Convert to BHLD for naive attention comparison + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + out_naive_BHLD = naive_attention(q_BHLD.clone(), k_BHLD.clone(), v_BHLD.clone(), causal, sm_scale) + print(f" out_naive_BHLD shape (before permute): {out_naive_BHLD.shape}") + out_naive_BLHD = out_naive_BHLD.permute(0, 2, 1, 3).contiguous() + print(f" out_naive_BLHD shape (after permute): {out_naive_BLHD.shape}") + print(f" Expected shape: {(Z, N_CTX, H, HEAD_DIM)}") + print(f" out_qat_BLHD shape: {out_qat_BLHD.shape}") + + # Check that outputs have correct shape + assert out_qat_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + assert out_naive_BLHD.shape == (Z, N_CTX, H, HEAD_DIM) + + # Check that outputs are finite + assert torch.isfinite(out_qat_BLHD).all() + assert torch.isfinite(out_naive_BLHD).all() + + # Compare QAT and naive outputs + max_diff = (out_qat_BLHD - out_naive_BLHD).abs().max() + mean_diff = (out_qat_BLHD - out_naive_BLHD).abs().mean() + cos_sim = cosine_similarity(out_qat_BLHD, out_naive_BLHD) + print(f" QAT vs Naive (WAN shape) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # QAT output should be reasonably close to naive (within tolerance due to quantization) + print("✓ WAN shape forward pass test passed.") + + +def test_qat_attention_wan_shape_backward(): + """Test QAT attention backward pass with WAN shape [1, 40, 9360, 128].""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 1, 40, 9360, 128 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False # WAN is non-causal + + print(f" Testing WAN shape: Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}") + + # ----------------------------- + # CREATE BLHD INPUTS FOR QAT + # ----------------------------- + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # ----------------------------- + # CREATE BHLD INPUTS FOR NAIVE + # ----------------------------- + q_BHLD = q_BLHD.detach().permute(0,2,1,3).contiguous().requires_grad_(True) + k_BHLD = k_BLHD.detach().permute(0,2,1,3).contiguous().requires_grad_(True) + v_BHLD = v_BLHD.detach().permute(0,2,1,3).contiguous().requires_grad_(True) + + # ----------------------------- + # FORWARD — QAT + # ----------------------------- + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD, k_BLHD, v_BLHD, causal, sm_scale) + + # ----------------------------- + # FORWARD — NAIVE + # ----------------------------- + out_naive_BHLD = naive_attention(q_BHLD, k_BHLD, v_BHLD, causal, sm_scale) + + # Convert naive output back to BLHD for comparing outputs + out_naive_BLHD = out_naive_BHLD.permute(0,2,1,3).contiguous() + + # ----------------------------- + # BACKWARD + # ----------------------------- + dout_BLHD = torch.randn_like(out_qat_BLHD).contiguous() + + # QAT backward (BLHD dout) + out_qat_BLHD.backward(dout_BLHD) + + # ❗ FIXED: convert BLHD dout → BHLD for naive backward + dout_BHLD = dout_BLHD.permute(0,2,1,3).contiguous() + + out_naive_BHLD.backward(dout_BHLD) + + # ----------------------------- + # GRADIENT FORMAT FIX + # Convert naive BHLD grads → BLHD + # ----------------------------- + q_naive_grad_BLHD = q_BHLD.grad.permute(0,2,1,3).contiguous() + k_naive_grad_BLHD = k_BHLD.grad.permute(0,2,1,3).contiguous() + v_naive_grad_BLHD = v_BHLD.grad.permute(0,2,1,3).contiguous() + + # ----------------------------- + # PRINT GRADIENTS + # ----------------------------- + print("\n === QAT Gradients ===") + print(f" dQ_QAT - Shape: {q_BLHD.grad.shape}, Min: {q_BLHD.grad.min().item():.6f}, Max: {q_BLHD.grad.max().item():.6f}, Mean: {q_BLHD.grad.mean().item():.6f}, Std: {q_BLHD.grad.std().item():.6f}") + print(f" dK_QAT - Shape: {k_BLHD.grad.shape}, Min: {k_BLHD.grad.min().item():.6f}, Max: {k_BLHD.grad.max().item():.6f}, Mean: {k_BLHD.grad.mean().item():.6f}, Std: {k_BLHD.grad.std().item():.6f}") + print(f" dV_QAT - Shape: {v_BLHD.grad.shape}, Min: {v_BLHD.grad.min().item():.6f}, Max: {v_BLHD.grad.max().item():.6f}, Mean: {v_BLHD.grad.mean().item():.6f}, Std: {v_BLHD.grad.std().item():.6f}") + + print("\n === Naive Gradients ===") + print(f" dQ_Naive - Shape: {q_naive_grad_BLHD.shape}, Min: {q_naive_grad_BLHD.min().item():.6f}, Max: {q_naive_grad_BLHD.max().item():.6f}, Mean: {q_naive_grad_BLHD.mean().item():.6f}, Std: {q_naive_grad_BLHD.std().item():.6f}") + print(f" dK_Naive - Shape: {k_naive_grad_BLHD.shape}, Min: {k_naive_grad_BLHD.min().item():.6f}, Max: {k_naive_grad_BLHD.max().item():.6f}, Mean: {k_naive_grad_BLHD.mean().item():.6f}, Std: {k_naive_grad_BLHD.std().item():.6f}") + print(f" dV_Naive - Shape: {v_naive_grad_BLHD.shape}, Min: {v_naive_grad_BLHD.min().item():.6f}, Max: {v_naive_grad_BLHD.max().item():.6f}, Mean: {v_naive_grad_BLHD.mean().item():.6f}, Std: {v_naive_grad_BLHD.std().item():.6f}") + + # Print sample values (first few elements) + # print("\n === Sample Gradient Values (first 10 elements) ===") + # print(f" dQ_QAT[0,0,0,:10]: {q_BLHD.grad[0,0,0,:10].cpu().tolist()}") + # print(f" dQ_Naive[0,0,0,:10]: {q_naive_grad_BLHD[0,0,0,:10].cpu().tolist()}") + # print(f" dK_QAT[0,0,0,:10]: {k_BLHD.grad[0,0,0,:10].cpu().tolist()}") + # print(f" dK_Naive[0,0,0,:10]: {k_naive_grad_BLHD[0,0,0,:10].cpu().tolist()}") + # print(f" dV_QAT[0,0,0,:10]: {v_BLHD.grad[0,0,0,:10].cpu().tolist()}") + # print(f" dV_Naive[0,0,0,:10]: {v_naive_grad_BLHD[0,0,0,:10].cpu().tolist()}") + + # ----------------------------- + # COMPARE + # ----------------------------- + dq_diff = (q_BLHD.grad - q_naive_grad_BLHD).abs() + dk_diff = (k_BLHD.grad - k_naive_grad_BLHD).abs() + dv_diff = (v_BLHD.grad - v_naive_grad_BLHD).abs() + + dq_cos = cosine_similarity(q_BLHD.grad, q_naive_grad_BLHD) + dk_cos = cosine_similarity(k_BLHD.grad, k_naive_grad_BLHD) + dv_cos = cosine_similarity(v_BLHD.grad, v_naive_grad_BLHD) + + print("\n === Gradient Comparison ===") + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean diff: {dq_diff.mean().item():.6f}, Cos: {dq_cos:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean diff: {dk_diff.mean().item():.6f}, Cos: {dk_cos:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean diff: {dv_diff.mean().item():.6f}, Cos: {dv_cos:.6f}") + + print("\n✓ WAN shape backward pass test passed.") + + + +def test_qat_attention_wan_shape_non_divisible_64(): + """Test QAT attention where N_CTX is not divisible by 64.""" + torch.manual_seed(42) + + Z, H, N_CTX, HEAD_DIM = 1, 40, 9367, 128 + dtype = torch.bfloat16 + sm_scale = 1.0 / sqrt(HEAD_DIM) + causal = False + + print(f" Testing WAN non-divisible: Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}") + print(f" N_CTX % 64 = {N_CTX % 64}") + + # ----------------------------- + # CREATE BLHD INPUTS + # ----------------------------- + q_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + k_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + v_BLHD = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True) + + # ----------------------------- + # BHLD for naive + # ----------------------------- + q_BHLD = q_BLHD.detach().permute(0,2,1,3).contiguous().requires_grad_(True) + k_BHLD = k_BLHD.detach().permute(0,2,1,3).contiguous().requires_grad_(True) + v_BHLD = v_BLHD.detach().permute(0,2,1,3).contiguous().requires_grad_(True) + + # ----------------------------- + # FORWARD + # ----------------------------- + out_qat_BLHD = attn_qat_train_wrapper(q_BLHD, k_BLHD, v_BLHD, causal, sm_scale) + out_naive_BHLD = naive_attention(q_BHLD, k_BHLD, v_BHLD, causal, sm_scale) + out_naive_BLHD = out_naive_BHLD.permute(0,2,1,3).contiguous() + + # ----------------------------- + # BACKWARD + # ----------------------------- + dout_BLHD = torch.randn_like(out_qat_BLHD).contiguous() + + # QAT backward + out_qat_BLHD.backward(dout_BLHD) + + # FIXED: Convert dout to BHLD for naive + dout_BHLD = dout_BLHD.permute(0,2,1,3).contiguous() + + out_naive_BHLD.backward(dout_BHLD) + + # ----------------------------- + # GRADIENT FORMAT FIX + # ----------------------------- + q_naive_grad_BLHD = q_BHLD.grad.permute(0,2,1,3).contiguous() + k_naive_grad_BLHD = k_BHLD.grad.permute(0,2,1,3).contiguous() + v_naive_grad_BLHD = v_BHLD.grad.permute(0,2,1,3).contiguous() + + # ----------------------------- + # PRINT GRADIENTS + # ----------------------------- + print("\n === QAT Gradients ===") + print(f" dQ_QAT - Shape: {q_BLHD.grad.shape}, Min: {q_BLHD.grad.min().item():.6f}, Max: {q_BLHD.grad.max().item():.6f}, Mean: {q_BLHD.grad.mean().item():.6f}, Std: {q_BLHD.grad.std().item():.6f}") + print(f" dK_QAT - Shape: {k_BLHD.grad.shape}, Min: {k_BLHD.grad.min().item():.6f}, Max: {k_BLHD.grad.max().item():.6f}, Mean: {k_BLHD.grad.mean().item():.6f}, Std: {k_BLHD.grad.std().item():.6f}") + print(f" dV_QAT - Shape: {v_BLHD.grad.shape}, Min: {v_BLHD.grad.min().item():.6f}, Max: {v_BLHD.grad.max().item():.6f}, Mean: {v_BLHD.grad.mean().item():.6f}, Std: {v_BLHD.grad.std().item():.6f}") + + print("\n === Naive Gradients ===") + print(f" dQ_Naive - Shape: {q_naive_grad_BLHD.shape}, Min: {q_naive_grad_BLHD.min().item():.6f}, Max: {q_naive_grad_BLHD.max().item():.6f}, Mean: {q_naive_grad_BLHD.mean().item():.6f}, Std: {q_naive_grad_BLHD.std().item():.6f}") + print(f" dK_Naive - Shape: {k_naive_grad_BLHD.shape}, Min: {k_naive_grad_BLHD.min().item():.6f}, Max: {k_naive_grad_BLHD.max().item():.6f}, Mean: {k_naive_grad_BLHD.mean().item():.6f}, Std: {k_naive_grad_BLHD.std().item():.6f}") + print(f" dV_Naive - Shape: {v_naive_grad_BLHD.shape}, Min: {v_naive_grad_BLHD.min().item():.6f}, Max: {v_naive_grad_BLHD.max().item():.6f}, Mean: {v_naive_grad_BLHD.mean().item():.6f}, Std: {v_naive_grad_BLHD.std().item():.6f}") + + # Print sample values (first few elements) + # print("\n === Sample Gradient Values (first 10 elements) ===") + # print(f" dQ_QAT[0,0,0,:10]: {q_BLHD.grad[0,0,0,:10].cpu().tolist()}") + # print(f" dQ_Naive[0,0,0,:10]: {q_naive_grad_BLHD[0,0,0,:10].cpu().tolist()}") + # print(f" dK_QAT[0,0,0,:10]: {k_BLHD.grad[0,0,0,:10].cpu().tolist()}") + # print(f" dK_Naive[0,0,0,:10]: {k_naive_grad_BLHD[0,0,0,:10].cpu().tolist()}") + # print(f" dV_QAT[0,0,0,:10]: {v_BLHD.grad[0,0,0,:10].cpu().tolist()}") + # print(f" dV_Naive[0,0,0,:10]: {v_naive_grad_BLHD[0,0,0,:10].cpu().tolist()}") + + # ----------------------------- + # COMPARE + # ----------------------------- + dq_diff = (q_BLHD.grad - q_naive_grad_BLHD).abs() + dk_diff = (k_BLHD.grad - k_naive_grad_BLHD).abs() + dv_diff = (v_BLHD.grad - v_naive_grad_BLHD).abs() + + dq_cos = cosine_similarity(q_BLHD.grad, q_naive_grad_BLHD) + dk_cos = cosine_similarity(k_BLHD.grad, k_naive_grad_BLHD) + dv_cos = cosine_similarity(v_BLHD.grad, v_naive_grad_BLHD) + + print("\n === Gradient Comparison ===") + print(f" dQ - Max diff: {dq_diff.max().item():.6f}, Mean: {dq_diff.mean().item():.6f}, Cos: {dq_cos:.6f}") + print(f" dK - Max diff: {dk_diff.max().item():.6f}, Mean: {dk_diff.mean().item():.6f}, Cos: {dk_cos:.6f}") + print(f" dV - Max diff: {dv_diff.max().item():.6f}, Mean: {dv_diff.mean().item():.6f}, Cos: {dv_cos:.6f}") + + print("\n✓ WAN shape non-divisible-by-64 forward/backward test passed.") + + +if __name__ == "__main__": + print("Running QAT attention tests...") + print(f"Device: {DEVICE}") + print() + + test_qat_attention_forward() + test_qat_attention_backward() + test_qat_attention_different_shapes() + test_qat_attention_non_causal() + test_qat_attention_causal() + test_qat_attention_causal_backward() + test_qat_attention_different_seq_lengths() + test_qat_attention_different_seq_lengths_backward() + test_qat_attention_wan_shape_forward() + test_qat_attention_wan_shape_backward() + test_qat_attention_wan_shape_non_divisible_64() + + print() + print("Running Fused attention tests...") + print() + + # test_fused_attention_forward() + # test_fused_attention_backward() + # test_fused_attention_different_shapes() + # test_fused_attention_non_causal() + # test_fused_attention_causal() + # test_fused_attention_causal_backward() + + # print() + # print("Running Fused vs QAT attention comparison tests...") + # print() + + # test_fused_vs_qat_attention_forward() + # test_fused_vs_qat_attention_backward() + # test_fused_vs_qat_attention_different_shapes() + # test_fused_vs_qat_attention_non_causal() + # test_fused_vs_qat_attention_causal() + # test_fused_vs_qat_attention_causal_backward() + + print() + print("All tests passed! ✓") diff --git a/fastvideo-kernel/tests/test_bootstrap.py b/fastvideo-kernel/tests/test_bootstrap.py new file mode 100644 index 0000000000..a1f6c09427 --- /dev/null +++ b/fastvideo-kernel/tests/test_bootstrap.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import sys +import types +from pathlib import Path + +from tests._bootstrap import ensure_local_kernel_sources_first + + +def test_bootstrap_prefers_local_kernel_checkout(monkeypatch) -> None: + tests_root = Path(__file__).resolve().parent + kernel_root = tests_root.parent + repo_root = kernel_root.parent + kernel_python_root = kernel_root / "python" + + stale_module = types.ModuleType("fastvideo_kernel") + stale_module.__file__ = ( + "/tmp/site-packages/fastvideo_kernel/__init__.py" + ) + monkeypatch.setitem(sys.modules, "fastvideo_kernel", stale_module) + monkeypatch.setattr(sys, "path", ["/tmp/site-packages"]) + + ensure_local_kernel_sources_first() + + assert "fastvideo_kernel" not in sys.modules + assert sys.path[:3] == [ + str(kernel_python_root), + str(kernel_root), + str(repo_root), + ] diff --git a/fastvideo-kernel/tests/test_fake_quant.py b/fastvideo-kernel/tests/test_fake_quant.py new file mode 100644 index 0000000000..02e2870ad0 --- /dev/null +++ b/fastvideo-kernel/tests/test_fake_quant.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +""" +Precision test to compare numerical differences for fake_quantize triton function: +1. fake_quantize (triton implementation) vs reference implementations +2. Tests various shapes, dtypes, and value ranges +3. Evaluates cosine similarity, max diff, and mean diff between input and output +""" + +import torch +import triton +import triton.language as tl +from flashinfer import SfLayout, nvfp4_quantize, e2m1_and_ufp8sf_scale_to_float +from fastvideo_kernel.triton_kernels.nvfp4_utils import ( + _compute_quant_and_scale, + _compute_dequant, +) +from typing import Optional + +# MXFP_BLOCK_SIZE is 16 - use Python int for runtime checks +MXFP_BLOCK_SIZE = 16 + +DEVICE = torch.device("cuda") + + +def cosine_similarity(tensor1, tensor2): + """ + Compute cosine similarity between two tensors. + + Args: + tensor1: First tensor + tensor2: Second tensor (same shape as tensor1) + + Returns: + Cosine similarity value (scalar) + - Returns 1.0 if both tensors are zero (identical zero vectors) + - Returns 0.0 if only one tensor is zero (orthogonal to non-zero vector) + """ + # Flatten tensors for computation + t1_flat = tensor1.flatten().float() + t2_flat = tensor2.flatten().float() + + # Compute cosine similarity: (A · B) / (||A|| * ||B||) + dot_product = torch.dot(t1_flat, t2_flat) + norm1 = torch.norm(t1_flat) + norm2 = torch.norm(t2_flat) + + # Handle zero vectors + if norm1 == 0 and norm2 == 0: + # Both are zero vectors - they are identical, so similarity is 1.0 + return 1.0 + elif norm1 == 0 or norm2 == 0: + # One is zero, one is not - they are orthogonal, so similarity is 0.0 + return 0.0 + + cos_sim = dot_product / (norm1 * norm2) + return cos_sim.item() + + +@triton.jit +def fake_quantize(src_tensor, valid_src_mask, BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, + dst_dtype: tl.constexpr, + mx_tensor_dtype: tl.constexpr = tl.uint8): + """ + Fake quantize function - matches API from attn_qat_train.py. + """ + high_prec_src_tensor = src_tensor + src_tensor, src_scale, src_s_dec = _compute_quant_and_scale( + src_tensor=src_tensor, + valid_src_mask=valid_src_mask, + mx_tensor_dtype=mx_tensor_dtype + ) + src_tensor = _compute_dequant( + mx_tensor=src_tensor, + scale=src_scale, + s_dec=src_s_dec, + BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM, + BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM, + dst_dtype=dst_dtype + ) + return src_tensor, high_prec_src_tensor + + +def get_fake_quant_reference(x: torch.Tensor): + """ + Reference implementation using FlashInfer for comparison. + """ + orig_shape = x.shape + orig_dtype = x.dtype + device = x.device + x = x.view(-1, x.shape[-1]) + x_global_sf = (448 * 6) / x.float().abs().nan_to_num().max() + x_fp4, x_scale = nvfp4_quantize(x, x_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False) + x_dequant = e2m1_and_ufp8sf_scale_to_float(x_fp4, x_scale, 1 / x_global_sf) + return x_dequant.view(orig_shape).to(orig_dtype).to(device) + + +@triton.jit +def fake_quantize_kernel( + src_ptr, + dst_ptr, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, + dst_dtype: tl.constexpr, + mx_tensor_dtype: tl.constexpr, + stride_src_outer, + stride_src_quant, + stride_dst_outer, + stride_dst_quant, + outer_dim, + quant_dim, +): + """ + Kernel wrapper to call fake_quantize on a block of data. + """ + outer_idx = tl.program_id(0) + quant_idx = tl.program_id(1) + + # Compute offsets + start_outer = outer_idx * BLOCK_SIZE_OUT_DIM + start_quant = quant_idx * BLOCK_SIZE_QUANT_DIM + + # Create offset arrays + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None] + offs_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :] + + # Create masks for valid elements + mask_outer = (start_outer + offs_outer) < outer_dim + mask_quant = (start_quant + offs_quant) < quant_dim + full_mask = mask_outer & mask_quant + + # Load source tensor + src_offsets = (start_outer + offs_outer) * stride_src_outer + (start_quant + offs_quant) * stride_src_quant + src_tensor = tl.load(src_ptr + src_offsets, mask=full_mask, other=0.0) + + # Call fake_quantize with valid_src_mask parameter + quantized_tensor, high_prec_tensor = fake_quantize( + src_tensor=src_tensor, + valid_src_mask=full_mask, + BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM, + BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM, + dst_dtype=dst_dtype, + mx_tensor_dtype=mx_tensor_dtype + ) + + # Store result + dst_offsets = (start_outer + offs_outer) * stride_dst_outer + (start_quant + offs_quant) * stride_dst_quant + tl.store(dst_ptr + dst_offsets, quantized_tensor, mask=full_mask) + + +def triton_fake_quantize( + x: torch.Tensor, + BLOCK_SIZE_OUT_DIM: int = 128, + BLOCK_SIZE_QUANT_DIM: int = 128, + use_fp4: bool = True, # True for fp4 (uint8), False for fp8 (float8e4nv) + dst_dtype: Optional[torch.dtype] = None +) -> torch.Tensor: + """ + Call fake_quantize triton function on a tensor. + + Args: + x: Input tensor (2D or can be reshaped to 2D) + BLOCK_SIZE_OUT_DIM: Block size for outer dimension + BLOCK_SIZE_QUANT_DIM: Block size for quantization dimension (must be multiple of 16) + use_fp4: If True, use fp4 (uint8), else use fp8 (float8e4nv) + dst_dtype: Output dtype (defaults to input dtype) + + Returns: + Fake quantized tensor + """ + assert x.is_cuda, "Input must be on CUDA" + assert BLOCK_SIZE_QUANT_DIM % 16 == 0, f"BLOCK_SIZE_QUANT_DIM must be multiple of 16" + + orig_shape = x.shape + orig_dtype = x.dtype + + # Reshape to 2D + x_2d = x.view(-1, x.shape[-1]) + outer_dim, quant_dim = x_2d.shape + + if dst_dtype is None: + dst_dtype = orig_dtype + + # Map torch dtype to triton dtype + dtype_map = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + } + triton_dst_dtype = dtype_map.get(dst_dtype, tl.float16) + + # Allocate output + output = torch.empty_like(x_2d, dtype=dst_dtype) + + # Launch kernel with appropriate quantization dtype + grid = ( + triton.cdiv(outer_dim, BLOCK_SIZE_OUT_DIM), + triton.cdiv(quant_dim, BLOCK_SIZE_QUANT_DIM), + ) + + if use_fp4: + # Use fp4 (uint8) + fake_quantize_kernel[grid]( + src_ptr=x_2d, + dst_ptr=output, + BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM, + BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM, + dst_dtype=triton_dst_dtype, + mx_tensor_dtype=tl.uint8, # fp4 uses uint8 + stride_src_outer=x_2d.stride(0), + stride_src_quant=x_2d.stride(1), + stride_dst_outer=output.stride(0), + stride_dst_quant=output.stride(1), + outer_dim=outer_dim, + quant_dim=quant_dim, + ) + else: + # Use fp8 (float8e4nv) + fake_quantize_kernel[grid]( + src_ptr=x_2d, + dst_ptr=output, + BLOCK_SIZE_OUT_DIM=BLOCK_SIZE_OUT_DIM, + BLOCK_SIZE_QUANT_DIM=BLOCK_SIZE_QUANT_DIM, + dst_dtype=triton_dst_dtype, + mx_tensor_dtype=tl.float8e4nv, # fp8 uses float8e4nv + stride_src_outer=x_2d.stride(0), + stride_src_quant=x_2d.stride(1), + stride_dst_outer=output.stride(0), + stride_dst_quant=output.stride(1), + outer_dim=outer_dim, + quant_dim=quant_dim, + ) + + return output.view(orig_shape).to(dst_dtype) + + +def test_fake_quantize_basic(): + """Test basic functionality of fake_quantize.""" + torch.manual_seed(42) + + # Test parameters + shape = (128, 128) + dtype = torch.bfloat16 + + # Create input tensor + x = torch.randn(shape, dtype=dtype, device=DEVICE) + + # Test triton fake_quantize + x_fq = triton_fake_quantize(x, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + # Check that outputs have correct shape + assert x_fq.shape == x.shape + assert x_fq.dtype == x.dtype + + # Check that outputs are finite + assert torch.isfinite(x_fq).all() + + # Compare input and output + max_diff = (x_fq - x).abs().max() + mean_diff = (x_fq - x).abs().mean() + cos_sim = cosine_similarity(x_fq, x) + print(f" Input vs Output - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + print("✓ Basic test passed.") + + +def test_fake_quantize_different_shapes(): + """Test fake_quantize with different input shapes.""" + torch.manual_seed(42) + + test_configs = [ + (128, 64), # Small + (256, 128), # Medium + (512, 256), # Large + (128, 256), # Rectangular + (256, 128), # Rectangular (reversed) + ] + + dtype = torch.bfloat16 + + for outer_dim, quant_dim in test_configs: + # Create input tensor + x = torch.randn((outer_dim, quant_dim), dtype=dtype, device=DEVICE) + + # Test triton fake_quantize + x_fq = triton_fake_quantize(x, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + assert x_fq.shape == x.shape + assert torch.isfinite(x_fq).all() + + # Compare input and output + max_diff = (x_fq - x).abs().max() + mean_diff = (x_fq - x).abs().mean() + cos_sim = cosine_similarity(x_fq, x) + print(f" Shape ({outer_dim}, {quant_dim}) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + print(f"✓ Shape test passed for ({outer_dim}, {quant_dim})") + + +def test_fake_quantize_different_dtypes(): + """Test fake_quantize with different input dtypes.""" + torch.manual_seed(42) + + shape = (128, 128) + dtypes = [torch.float32, torch.float16, torch.bfloat16] + + for dtype in dtypes: + # Create input tensor + x = torch.randn(shape, dtype=dtype, device=DEVICE) + + # Test triton fake_quantize + x_fq = triton_fake_quantize(x, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + assert x_fq.shape == x.shape + assert x_fq.dtype == x.dtype + assert torch.isfinite(x_fq).all() + + # Compare input and output + max_diff = (x_fq - x).abs().max() + mean_diff = (x_fq - x).abs().mean() + cos_sim = cosine_similarity(x_fq, x) + print(f" Dtype {dtype} - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + print(f"✓ Dtype test passed for {dtype}") + + +def test_fake_quantize_3d_4d_tensors(): + """Test fake_quantize with 3D and 4D tensors (reshaped to 2D internally).""" + torch.manual_seed(42) + + dtype = torch.bfloat16 + + # Test 3D tensor (B, H, D) + print("\nTesting 3D tensor (B, H, D)") + x_3d = torch.randn((2, 8, 128), dtype=dtype, device=DEVICE) + x_fq_3d = triton_fake_quantize(x_3d, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + assert x_fq_3d.shape == x_3d.shape + assert torch.isfinite(x_fq_3d).all() + + max_diff = (x_fq_3d - x_3d).abs().max() + mean_diff = (x_fq_3d - x_3d).abs().mean() + cos_sim = cosine_similarity(x_fq_3d, x_3d) + print(f" 3D (2, 8, 128) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Test 4D tensor (B, H, L, D) + print("\nTesting 4D tensor (B, H, L, D)") + x_4d = torch.randn((1, 8, 256, 128), dtype=dtype, device=DEVICE) + x_fq_4d = triton_fake_quantize(x_4d, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + assert x_fq_4d.shape == x_4d.shape + assert torch.isfinite(x_fq_4d).all() + + max_diff = (x_fq_4d - x_4d).abs().max() + mean_diff = (x_fq_4d - x_4d).abs().mean() + cos_sim = cosine_similarity(x_fq_4d, x_4d) + print(f" 4D (1, 8, 256, 128) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + print("✓ 3D/4D tensor test passed.") + + +def test_fake_quantize_edge_cases(): + """Test edge cases like zeros, ones, extreme values.""" + torch.manual_seed(42) + + dtype = torch.bfloat16 + shape = (128, 128) + + test_cases = [ + ("zeros", torch.zeros(shape, dtype=dtype, device=DEVICE)), + ("ones", torch.ones(shape, dtype=dtype, device=DEVICE)), + ("negative_ones", -torch.ones(shape, dtype=dtype, device=DEVICE)), + ("very_small", torch.randn(shape, dtype=dtype, device=DEVICE) * 1e-6), + ("very_large", torch.randn(shape, dtype=dtype, device=DEVICE) * 1e6), + ] + + for name, x in test_cases: + x_fq = triton_fake_quantize(x, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + assert x_fq.shape == x.shape + assert torch.isfinite(x_fq).all() + + max_diff = (x_fq - x).abs().max() + mean_diff = (x_fq - x).abs().mean() + cos_sim = cosine_similarity(x_fq, x) + print(f" {name} - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + print("✓ Edge cases test passed.") + + +def test_fake_quantize_vs_reference(): + """Test triton fake_quantize vs reference implementation.""" + torch.manual_seed(42) + + dtype = torch.bfloat16 + shape = (128, 128) + + # Create input tensor + x = torch.randn(shape, dtype=dtype, device=DEVICE) + + # Test triton fake_quantize + x_fq_triton = triton_fake_quantize(x, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + # Test reference implementation + x_fq_ref = get_fake_quant_reference(x) + + # Compare triton vs reference + max_diff = (x_fq_triton - x_fq_ref).abs().max() + mean_diff = (x_fq_triton - x_fq_ref).abs().mean() + cos_sim = cosine_similarity(x_fq_triton, x_fq_ref) + print(f" Triton vs Reference - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + # Also compare each vs input + max_diff_triton = (x_fq_triton - x).abs().max() + mean_diff_triton = (x_fq_triton - x).abs().mean() + cos_sim_triton = cosine_similarity(x_fq_triton, x) + print(f" Triton vs Input - Max diff: {max_diff_triton.item():.6f}, Mean diff: {mean_diff_triton.item():.6f}, Cosine sim: {cos_sim_triton:.6f}") + + max_diff_ref = (x_fq_ref - x).abs().max() + mean_diff_ref = (x_fq_ref - x).abs().mean() + cos_sim_ref = cosine_similarity(x_fq_ref, x) + print(f" Reference vs Input - Max diff: {max_diff_ref.item():.6f}, Mean diff: {mean_diff_ref.item():.6f}, Cosine sim: {cos_sim_ref:.6f}") + + print("✓ Reference comparison test passed.") + + +def test_fake_quantize_attention_shapes(): + """Test fake_quantize with attention-like shapes (similar to test_attn_qat_train.py).""" + torch.manual_seed(42) + + test_configs = [ + (2, 4, 128, 64), # Z, H, N_CTX, HEAD_DIM - Medium + (1, 8, 256, 128), # Large head dim + (1, 40, 9360, 128), # WAN shape + ] + + dtype = torch.bfloat16 + + for Z, H, N_CTX, HEAD_DIM in test_configs: + # Create input tensor in BLHD format (B, L, H, D) = (Z, N_CTX, H, HEAD_DIM) + x = torch.randn((Z, N_CTX, H, HEAD_DIM), dtype=dtype, device=DEVICE) + + # Test triton fake_quantize + x_fq = triton_fake_quantize(x, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + assert x_fq.shape == x.shape + assert torch.isfinite(x_fq).all() + + # Compare input and output + max_diff = (x_fq - x).abs().max() + mean_diff = (x_fq - x).abs().mean() + cos_sim = cosine_similarity(x_fq, x) + print(f" (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM}) - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + print(f"✓ Attention shape test passed for (Z={Z}, H={H}, N_CTX={N_CTX}, HEAD_DIM={HEAD_DIM})") + + +def test_fake_quantize_non_divisible_blocks(): + """Test fake_quantize with shapes that are not divisible by block sizes.""" + torch.manual_seed(42) + + dtype = torch.bfloat16 + + # Test with non-divisible dimensions + test_shapes = [ + (100, 100), # Not divisible by 128 + (150, 200), # Not divisible by 128 + (256, 100), # One dimension divisible, one not + ] + + for shape in test_shapes: + x = torch.randn(shape, dtype=dtype, device=DEVICE) + + # Test triton fake_quantize + x_fq = triton_fake_quantize(x, BLOCK_SIZE_OUT_DIM=128, BLOCK_SIZE_QUANT_DIM=128, use_fp4=True, dst_dtype=dtype) + + assert x_fq.shape == x.shape + assert torch.isfinite(x_fq).all() + + max_diff = (x_fq - x).abs().max() + mean_diff = (x_fq - x).abs().mean() + cos_sim = cosine_similarity(x_fq, x) + print(f" Shape {shape} - Max diff: {max_diff.item():.6f}, Mean diff: {mean_diff.item():.6f}, Cosine sim: {cos_sim:.6f}") + + print(f"✓ Non-divisible block test passed for {shape}") + + +if __name__ == "__main__": + print("Running fake_quantize tests...") + print(f"Device: {DEVICE}") + print() + + test_fake_quantize_basic() + test_fake_quantize_different_shapes() + test_fake_quantize_different_dtypes() + test_fake_quantize_3d_4d_tensors() + test_fake_quantize_edge_cases() + test_fake_quantize_vs_reference() + test_fake_quantize_attention_shapes() + test_fake_quantize_non_divisible_blocks() + + print() + print("All tests passed! ✓") diff --git a/fastvideo/api/compat.py b/fastvideo/api/compat.py index fefb749768..de35a53594 100644 --- a/fastvideo/api/compat.py +++ b/fastvideo/api/compat.py @@ -107,6 +107,8 @@ def legacy_from_pretrained_to_config( engine[key] = value elif key == "override_text_encoder_quant": quantization["text_encoder_quant"] = value + elif key == "transformer_quant": + quantization["transformer_quant"] = value elif key == "workload_type": pipeline["workload_type"] = value elif key == "lora_path": diff --git a/fastvideo/attention/backends/abstract.py b/fastvideo/attention/backends/abstract.py index a244f2a82d..ae8d2f96cf 100644 --- a/fastvideo/attention/backends/abstract.py +++ b/fastvideo/attention/backends/abstract.py @@ -2,7 +2,7 @@ # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py from abc import ABC, abstractmethod -from dataclasses import dataclass, fields +from dataclasses import dataclass, field, fields from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar if TYPE_CHECKING: @@ -53,6 +53,10 @@ class AttentionMetadata: """Attention metadata for prefill and decode batched together.""" # Current step of diffusion process current_timestep: int + VSA_sparsity: float = field(default=0.0, kw_only=True) + + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def asdict_zerocopy(self, skip_fields: set[str] | None = None) -> dict[str, Any]: """Similar to dataclasses.asdict, but avoids deepcopying.""" @@ -82,7 +86,7 @@ def prepare(self) -> None: @abstractmethod def build( self, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> AttentionMetadata: """Build attention metadata with on-device tensors.""" raise NotImplementedError diff --git a/fastvideo/attention/backends/attn_qat_infer.py b/fastvideo/attention/backends/attn_qat_infer.py new file mode 100644 index 0000000000..c2f8472b6c --- /dev/null +++ b/fastvideo/attention/backends/attn_qat_infer.py @@ -0,0 +1,121 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import sys +from collections.abc import Callable +from pathlib import Path + +import torch + +from fastvideo.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + +_project_root = Path(__file__).resolve().parent.parent.parent.parent +_kernel_root = _project_root / "fastvideo-kernel" +_kernel_python_root = _kernel_root / "python" +_attn_qat_infer: Callable[..., torch.Tensor] | None = None +_attn_qat_infer_import_attempted = False + + +def _ensure_kernel_paths() -> None: + for path in (_project_root, _kernel_root, _kernel_python_root): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _get_attn_qat_infer() -> Callable[..., torch.Tensor] | None: + global _attn_qat_infer + global _attn_qat_infer_import_attempted + + if _attn_qat_infer_import_attempted: + return _attn_qat_infer + + _attn_qat_infer_import_attempted = True + _ensure_kernel_paths() + + try: + # Prefer the in-repo kernel implementation during local development. + _attn_qat_infer = importlib.import_module("attn_qat_infer").sageattn_blackwell + except ImportError: + _attn_qat_infer = None + + return _attn_qat_infer + + +def is_attn_qat_infer_available() -> bool: + return _get_attn_qat_infer() is not None + + +class AttnQatInferBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 128] + + @staticmethod + def get_name() -> str: + return "ATTN_QAT_INFER" + + @staticmethod + def get_impl_cls() -> type["AttnQatInferImpl"]: + return AttnQatInferImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + raise NotImplementedError + + +class AttnQatInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + attn_qat_infer = _get_attn_qat_infer() + if attn_qat_infer is None: + raise ImportError("attn_qat_infer is not available. Please ensure the " + "attn_qat_infer kernel package is installed.") + + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + output = attn_qat_infer( + query, + key, + value, + attn_mask=None, + is_causal=self.causal, + ) + return output.transpose(1, 2).contiguous() diff --git a/fastvideo/attention/backends/attn_qat_train.py b/fastvideo/attention/backends/attn_qat_train.py new file mode 100644 index 0000000000..271b78dbec --- /dev/null +++ b/fastvideo/attention/backends/attn_qat_train.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 + +import importlib +import sys +from collections.abc import Callable +from pathlib import Path + +import torch + +from fastvideo.attention.backends.abstract import ( + AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, +) +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + +_project_root = Path(__file__).resolve().parent.parent.parent.parent +_kernel_root = _project_root / "fastvideo-kernel" +_kernel_python_root = _kernel_root / "python" +_attn_qat_train_attention: Callable[..., torch.Tensor] | None = None +_attn_qat_train_import_attempted = False + + +def _ensure_kernel_paths() -> None: + for path in (_project_root, _kernel_root, _kernel_python_root): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def _get_attn_qat_train_attention() -> Callable[..., torch.Tensor] | None: + global _attn_qat_train_attention + global _attn_qat_train_import_attempted + + if _attn_qat_train_import_attempted: + return _attn_qat_train_attention + + _attn_qat_train_import_attempted = True + _ensure_kernel_paths() + + try: + _attn_qat_train_attention = importlib.import_module("fastvideo_kernel.triton_kernels.attn_qat_train").attention + except ImportError: + _attn_qat_train_attention = None + + return _attn_qat_train_attention + + +def attn_qat_train(q_BLHD: torch.Tensor, + k_BLHD: torch.Tensor, + v_BLHD: torch.Tensor, + is_causal: bool = False) -> torch.Tensor: + attention = _get_attn_qat_train_attention() + if attention is None: + raise ImportError("fastvideo_kernel.triton_kernels.attn_qat_train is not available. " + "Please ensure the FastVideo kernel package is installed.") + + q_BHLD = q_BLHD.permute(0, 2, 1, 3).contiguous() + k_BHLD = k_BLHD.permute(0, 2, 1, 3).contiguous() + v_BHLD = v_BLHD.permute(0, 2, 1, 3).contiguous() + + use_qat_qkv_backward = True + smooth_k = False + warp_specialize = True + is_qat = True + two_level_quant_p_sage3 = False + fake_quant_p_bwd = True + use_high_prec_o = True + smooth_q = False + sm_scale = 1.0 / (q_BHLD.shape[-1]**0.5) + use_global_sf_qkv = False + use_global_sf_p = False + + o_BHLD = attention( + q_BHLD, + k_BHLD, + v_BHLD, + is_causal, + sm_scale, + use_qat_qkv_backward, + smooth_k, + warp_specialize, + is_qat, + two_level_quant_p_sage3, + fake_quant_p_bwd, + use_high_prec_o, + smooth_q, + use_global_sf_p, + use_global_sf_qkv, + ) + return o_BHLD.permute(0, 2, 1, 3).contiguous() + + +class AttnQatTrainBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "ATTN_QAT_TRAIN" + + @staticmethod + def get_impl_cls() -> type["AttnQatTrainImpl"]: + return AttnQatTrainImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + def get_builder_cls() -> type["AttentionMetadataBuilder"]: + raise NotImplementedError + + +class AttnQatTrainImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + causal: bool, + softmax_scale: float, + num_kv_heads: int | None = None, + prefix: str = "", + **extra_impl_args, + ) -> None: + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = extra_impl_args.get("dropout_p", 0.0) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + return attn_qat_train(query, key, value, is_causal=self.causal) diff --git a/fastvideo/attention/backends/bsa_attn.py b/fastvideo/attention/backends/bsa_attn.py index be804f3e52..8ab656e2eb 100644 --- a/fastvideo/attention/backends/bsa_attn.py +++ b/fastvideo/attention/backends/bsa_attn.py @@ -33,13 +33,11 @@ try: from fastvideo.attention.utils.flash_attn_no_pad import ( flash_attn_varlen_func_impl, ) + FLASH_ATTN_AVAILABLE = True except ImportError: - try: - from flash_attn import flash_attn_varlen_func as flash_attn_varlen_func_impl - FLASH_ATTN_AVAILABLE = True - except ImportError: - FLASH_ATTN_AVAILABLE = False + flash_attn_varlen_func_impl = None + FLASH_ATTN_AVAILABLE = False logger = init_logger(__name__) diff --git a/fastvideo/attention/backends/sage_attn3.py b/fastvideo/attention/backends/sage_attn3.py index b2d3379bb7..7780845b7a 100644 --- a/fastvideo/attention/backends/sage_attn3.py +++ b/fastvideo/attention/backends/sage_attn3.py @@ -16,7 +16,7 @@ class SageAttention3Backend(AttentionBackend): @staticmethod def get_supported_head_sizes() -> list[int]: - return [64, 128, 256] + return [64, 128] @staticmethod def get_name() -> str: diff --git a/fastvideo/attention/backends/video_sparse_attn.py b/fastvideo/attention/backends/video_sparse_attn.py index 43960a9401..a9a6952022 100644 --- a/fastvideo/attention/backends/video_sparse_attn.py +++ b/fastvideo/attention/backends/video_sparse_attn.py @@ -133,7 +133,6 @@ def get_builder_cls() -> type["VideoSparseAttentionMetadataBuilder"]: class VideoSparseAttentionMetadata(AttentionMetadata): current_timestep: int dit_seq_shape: list[int] - VSA_sparsity: float num_tiles: list[int] total_seq_length: int tile_partition_indices: torch.LongTensor @@ -144,10 +143,10 @@ class VideoSparseAttentionMetadata(AttentionMetadata): class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder): - def __init__(self): + def __init__(self) -> None: pass - def prepare(self): + def prepare(self) -> None: pass def build( # type: ignore diff --git a/fastvideo/attention/backends/vmoba.py b/fastvideo/attention/backends/vmoba.py index eaa618968e..2694bb31b8 100644 --- a/fastvideo/attention/backends/vmoba.py +++ b/fastvideo/attention/backends/vmoba.py @@ -61,10 +61,10 @@ class VideoMobaAttentionMetadata(AttentionMetadata): class VideoMobaAttentionMetadataBuilder(AttentionMetadataBuilder): - def __init__(self): + def __init__(self) -> None: pass - def prepare(self): + def prepare(self) -> None: pass def build( # type: ignore @@ -81,7 +81,7 @@ def build( # type: ignore moba_select_mode: str = 'threshold', moba_threshold: float = 0.25, moba_threshold_type: str = 'query_head', - device: torch.device = None, + device: torch.device | None = None, first_full_layer: int = 0, first_full_step: int = 12, temporal_layer: int = 1, @@ -142,7 +142,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AttentionMetadata, + attn_metadata: VideoMobaAttentionMetadata, ) -> torch.Tensor: """ query: [B, L, H, D] @@ -154,7 +154,9 @@ def forward( # select chunk type according to layer idx: loop_layer_num = attn_metadata.temporal_layer + attn_metadata.spatial_layer + attn_metadata.st_layer + assert self.layer_idx is not None, "VMoBA attention requires layer_idx to be set" moba_layer = self.layer_idx - attn_metadata.first_full_layer + moba_chunk_size: int | tuple[int, int] | tuple[int, int, int] if moba_layer % loop_layer_num < attn_metadata.temporal_layer: moba_chunk_size = attn_metadata.temporal_chunk_size moba_topk = attn_metadata.temporal_topk @@ -164,6 +166,8 @@ def forward( elif moba_layer % loop_layer_num < attn_metadata.temporal_layer + attn_metadata.spatial_layer + attn_metadata.st_layer: moba_chunk_size = attn_metadata.st_chunk_size moba_topk = attn_metadata.st_topk + else: + raise ValueError(f"Invalid MoBA layer selection for layer {moba_layer}") query, chunk_size = process_moba_input(query, attn_metadata.patch_resolution, moba_chunk_size) key, chunk_size = process_moba_input(key, attn_metadata.patch_resolution, moba_chunk_size) diff --git a/fastvideo/attention/layer.py b/fastvideo/attention/layer.py index 84fa3d3bb4..56a0911dba 100644 --- a/fastvideo/attention/layer.py +++ b/fastvideo/attention/layer.py @@ -158,6 +158,7 @@ def forward( replicated_v: torch.Tensor | None = None, gate_compress: torch.Tensor | None = None, freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None, + attention_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Forward pass for distributed attention. @@ -170,6 +171,7 @@ def forward( replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens replicated_k (Optional[torch.Tensor]): Replicated key tensor replicated_v (Optional[torch.Tensor]): Replicated value tensor + attention_mask (Optional[torch.Tensor]): Attention mask [batch_size, seq_len] Returns: Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing: diff --git a/fastvideo/attention/selector.py b/fastvideo/attention/selector.py index 2ce231ee10..ad6e2ffc9f 100644 --- a/fastvideo/attention/selector.py +++ b/fastvideo/attention/selector.py @@ -85,7 +85,26 @@ def get_attn_backend( supported_attention_backends: tuple[AttentionBackendEnum, ...] | None = None, ) -> type[AttentionBackend]: - return _cached_get_attn_backend(head_size, dtype, supported_attention_backends) + selected_backend, is_forced = _resolve_backend_override() + return _cached_get_attn_backend( + head_size, + dtype, + supported_attention_backends, + selected_backend, + is_forced, + ) + + +def _resolve_backend_override() -> tuple[AttentionBackendEnum | None, bool]: + backend_by_global_setting = get_global_forced_attn_backend() + if backend_by_global_setting is not None: + return backend_by_global_setting, True + + backend_by_env_var: str | None = envs.FASTVIDEO_ATTENTION_BACKEND + if backend_by_env_var is not None: + return backend_name_to_enum(backend_by_env_var), False + + return None, False @cache @@ -94,28 +113,16 @@ def _cached_get_attn_backend( dtype: torch.dtype, supported_attention_backends: tuple[AttentionBackendEnum, ...] | None = None, + selected_backend: AttentionBackendEnum | None = None, + is_forced_backend: bool = False, ) -> type[AttentionBackend]: - # Check whether a particular choice of backend was - # previously forced. - # - # THIS SELECTION OVERRIDES THE FASTVIDEO_ATTENTION_BACKEND - # ENVIRONMENT VARIABLE. if not supported_attention_backends: raise ValueError("supported_attention_backends is empty") - selected_backend = None - backend_by_global_setting: AttentionBackendEnum | None = (get_global_forced_attn_backend()) - if backend_by_global_setting is not None: - selected_backend = backend_by_global_setting - else: - # Check the environment variable and override if specified - backend_by_env_var: str | None = envs.FASTVIDEO_ATTENTION_BACKEND - if backend_by_env_var is not None: - selected_backend = backend_name_to_enum(backend_by_env_var) # get device-specific attn_backend from fastvideo.platforms import current_platform - if selected_backend not in supported_attention_backends: + if not is_forced_backend and selected_backend not in supported_attention_backends: selected_backend = None attention_cls = current_platform.get_attn_backend_cls(selected_backend, head_size, dtype) if not attention_cls: diff --git a/fastvideo/attention/utils/flash_attn_no_pad.py b/fastvideo/attention/utils/flash_attn_no_pad.py index 4698b17619..89e71a1183 100644 --- a/fastvideo/attention/utils/flash_attn_no_pad.py +++ b/fastvideo/attention/utils/flash_attn_no_pad.py @@ -14,30 +14,44 @@ # of rights and permissions under this agreement. # See the License for the specific language governing permissions and limitations under the License. +from typing import Any + +import torch from einops import rearrange -from flash_attn.bert_padding import pad_input, unpad_input from flash_attn import flash_attn_varlen_qkvpacked_func +from flash_attn.bert_padding import pad_input, unpad_input -try: - from fastvideo.attention.utils.flash_attn_cute import ( - flash_attn_varlen_func as flash_attn_varlen_func_impl, ) -except ImportError: + +def _resolve_flash_attn_varlen_func() -> Any: try: - from flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_impl, ) + from fastvideo.attention.utils.flash_attn_cute import ( + flash_attn_varlen_func as flash_attn_varlen_func_cute, ) + + return flash_attn_varlen_func_cute except ImportError: - from flash_attn import ( - flash_attn_varlen_func as flash_attn_varlen_func_impl, ) + try: + from flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_interface, ) + + return flash_attn_varlen_func_interface + except ImportError: + from flash_attn import ( + flash_attn_varlen_func as flash_attn_varlen_func_flash, ) + + return flash_attn_varlen_func_flash + + +flash_attn_varlen_func_impl = _resolve_flash_attn_varlen_func() def flash_attn_no_pad( - qkv, - key_padding_mask, - causal=False, - dropout_p=0.0, - softmax_scale=None, - deterministic=False, -): + qkv: torch.Tensor, + key_padding_mask: torch.Tensor, + causal: bool = False, + dropout_p: float = 0.0, + softmax_scale: float | None = None, + deterministic: bool = False, +) -> torch.Tensor: batch_size = qkv.shape[0] seqlen = qkv.shape[1] nheads = qkv.shape[-2] @@ -68,13 +82,13 @@ def flash_attn_no_pad( def flash_attn_no_pad_v3( - qkv, - key_padding_mask, - causal=False, - dropout_p=0.0, - softmax_scale=None, - deterministic=False, -): + qkv: torch.Tensor, + key_padding_mask: torch.Tensor, + causal: bool = False, + dropout_p: float = 0.0, + softmax_scale: float | None = None, + deterministic: bool = False, +) -> torch.Tensor: from flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3, ) @@ -120,16 +134,16 @@ def flash_attn_no_pad_v3( def flash_attn_varlen_qk_no_pad( - query, - key, - value, - query_padding_mask, - key_padding_mask, - causal=False, - dropout_p=0.0, - softmax_scale=None, - deterministic=False, -): + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + query_padding_mask: torch.Tensor, + key_padding_mask: torch.Tensor, + causal: bool = False, + dropout_p: float = 0.0, + softmax_scale: float | None = None, + deterministic: bool = False, +) -> torch.Tensor: batch_size, q_seqlen, nheads, _ = query.shape query_unpad, q_indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), diff --git a/fastvideo/configs/models/base.py b/fastvideo/configs/models/base.py index 12510d86eb..6112e0bc7a 100644 --- a/fastvideo/configs/models/base.py +++ b/fastvideo/configs/models/base.py @@ -12,8 +12,15 @@ # 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users @dataclass class ArchConfig: - stacked_params_mapping: list[tuple[str, str, str]] = field( + stacked_params_mapping: list[tuple[str, str, str | int]] = field( default_factory=list) # mapping from huggingface weight names to custom names + output_hidden_states: bool = False + + def __post_init__(self) -> None: + pass + + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") @dataclass @@ -24,19 +31,22 @@ class ModelConfig: # FastVideo-specific parameters here - def __getattr__(self, name): + def __post_init__(self) -> None: + pass + + def __getattr__(self, name: str) -> Any: # Only called if 'name' is not found in ModelConfig directly if hasattr(self.arch_config, name): return getattr(self.arch_config, name) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") - def __getstate__(self): + def __getstate__(self) -> dict[str, Any]: # Return a dictionary of attributes to pickle # Convert to dict and exclude any problematic attributes state = self.__dict__.copy() return state - def __setstate__(self, state): + def __setstate__(self, state: dict[str, Any]) -> None: # Restore instance attributes from the unpickled state self.__dict__.update(state) diff --git a/fastvideo/configs/models/dits/base.py b/fastvideo/configs/models/dits/base.py index 3b70c33aa2..7d8bbda2dd 100644 --- a/fastvideo/configs/models/dits/base.py +++ b/fastvideo/configs/models/dits/base.py @@ -19,13 +19,19 @@ class DiTArchConfig(ArchConfig): AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.SAGE_ATTN_THREE, - AttentionBackendEnum.SLA_ATTN, AttentionBackendEnum.SAGE_SLA_ATTN) + AttentionBackendEnum.ATTN_QAT_INFER, + AttentionBackendEnum.ATTN_QAT_TRAIN, AttentionBackendEnum.SLA_ATTN, + AttentionBackendEnum.SAGE_SLA_ATTN) hidden_size: int = 0 num_attention_heads: int = 0 num_channels_latents: int = 0 - in_channels: int = 0 - out_channels: int = 0 + in_channels: int | None = 0 + out_channels: int | None = 0 + patch_size: int | tuple[int, int, int] | None = None + expand_timesteps: bool = False + num_layers: int = 0 + ffn_dim: int = 0 exclude_lora_layers: list[str] = field(default_factory=list) boundary_ratio: float | None = None @@ -41,6 +47,13 @@ class DiTConfig(ModelConfig): # FastVideoDiT-specific parameters prefix: str = "" quant_config: QuantizationConfig | None = None + expand_timesteps: bool = False + boundary_ratio: float | None = None + + def __post_init__(self) -> None: + super().__post_init__() + self.arch_config.expand_timesteps = self.expand_timesteps + self.arch_config.boundary_ratio = self.boundary_ratio @staticmethod def add_cli_args(parser: Any, prefix: str = "dit-config") -> Any: diff --git a/fastvideo/configs/models/dits/hyworld.py b/fastvideo/configs/models/dits/hyworld.py index 66d33e055f..a1be3b9032 100644 --- a/fastvideo/configs/models/dits/hyworld.py +++ b/fastvideo/configs/models/dits/hyworld.py @@ -85,7 +85,7 @@ class HYWorldArchConfig(DiTArchConfig): reverse_param_names_mapping: dict = field(default_factory=lambda: {}) # Parameters from HY-WorldPlay config.json (loaded from checkpoint) - patch_size: list | tuple | int = field(default_factory=lambda: [1, 1, 1]) + patch_size: tuple[int, int, int] = (1, 1, 1) # Base latent channels - will be expanded in __post_init__ if concat_condition=True in_channels: int = 32 concat_condition: bool = True @@ -120,7 +120,7 @@ class HYWorldArchConfig(DiTArchConfig): task_type: str = "i2v" exclude_lora_layers: list[str] = field(default_factory=lambda: ["img_in", "txt_in", "time_in", "vector_in"]) - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() # Convert HY-WorldPlay naming to FastVideo naming conventions self.num_attention_heads: int = self.heads_num diff --git a/fastvideo/configs/models/encoders/base.py b/fastvideo/configs/models/encoders/base.py index cd98267af9..76b8fde5aa 100644 --- a/fastvideo/configs/models/encoders/base.py +++ b/fastvideo/configs/models/encoders/base.py @@ -24,15 +24,16 @@ class TextEncoderArchConfig(EncoderArchConfig): hidden_size: int = 0 num_hidden_layers: int = 0 num_attention_heads: int = 0 - pad_token_id: int = 0 - eos_token_id: int = 0 + pad_token_id: int | None = 0 + eos_token_id: int | None = 0 text_len: int = 0 hidden_state_skip_layer: int = 0 decoder_start_token_id: int = 0 output_past: bool = True scalable_attention: bool = True tie_word_embeddings: bool = False - stacked_params_mapping: list[tuple[str, str, str]] = field( + padding_side: str = "right" + stacked_params_mapping: list[tuple[str, str, str | int]] = field( default_factory=list) # mapping from huggingface weight names to custom names tokenizer_kwargs: dict[str, Any] = field(default_factory=dict) _fsdp_shard_conditions: list = field(default_factory=lambda: []) @@ -70,10 +71,10 @@ class EncoderConfig(ModelConfig): @dataclass class TextEncoderConfig(EncoderConfig): - arch_config: ArchConfig = field(default_factory=TextEncoderArchConfig) + arch_config: TextEncoderArchConfig = field(default_factory=TextEncoderArchConfig) is_chat_model: bool = False @dataclass class ImageEncoderConfig(EncoderConfig): - arch_config: ArchConfig = field(default_factory=ImageEncoderArchConfig) + arch_config: ImageEncoderArchConfig = field(default_factory=ImageEncoderArchConfig) diff --git a/fastvideo/configs/models/encoders/clip.py b/fastvideo/configs/models/encoders/clip.py index 997bcbef97..71e771dffc 100644 --- a/fastvideo/configs/models/encoders/clip.py +++ b/fastvideo/configs/models/encoders/clip.py @@ -32,7 +32,7 @@ class CLIPTextArchConfig(TextEncoderArchConfig): bos_token_id: int = 49406 eos_token_id: int = 49407 text_len: int = 77 - stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=lambda: [ + stacked_params_mapping: list[tuple[str, str, str | int]] = field(default_factory=lambda: [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -57,7 +57,7 @@ class CLIPVisionArchConfig(ImageEncoderArchConfig): attention_dropout: float = 0.0 initializer_range: float = 0.02 initializer_factor: float = 1.0 - stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=lambda: [ + stacked_params_mapping: list[tuple[str, str, str | int]] = field(default_factory=lambda: [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), diff --git a/fastvideo/configs/models/encoders/t5.py b/fastvideo/configs/models/encoders/t5.py index 555009fa9d..1ca44bcb2c 100644 --- a/fastvideo/configs/models/encoders/t5.py +++ b/fastvideo/configs/models/encoders/t5.py @@ -41,7 +41,7 @@ class T5ArchConfig(TextEncoderArchConfig): text_len: int = 512 dtype: str | None = None gradient_checkpointing: bool = False - stacked_params_mapping: list[tuple[str, str, str]] = field(default_factory=lambda: [ + stacked_params_mapping: list[tuple[str, str, str | int]] = field(default_factory=lambda: [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q", "q"), (".qkv_proj", ".k", "k"), @@ -51,7 +51,7 @@ class T5ArchConfig(TextEncoderArchConfig): default_factory=lambda: [_is_transformer_layer, _is_embeddings, _is_final_layernorm]) # Referenced from https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() act_info = self.feed_forward_proj.split("-") self.dense_act_fn: str = act_info[-1] diff --git a/fastvideo/configs/models/vaes/base.py b/fastvideo/configs/models/vaes/base.py index 23192b569c..094e7c8838 100644 --- a/fastvideo/configs/models/vaes/base.py +++ b/fastvideo/configs/models/vaes/base.py @@ -13,6 +13,8 @@ @dataclass class VAEArchConfig(ArchConfig): scaling_factor: float | torch.Tensor = 0 + scale_factor_temporal: int = 1 + scale_factor_spatial: int = 1 temporal_compression_ratio: int = 4 spatial_compression_ratio: int = 8 @@ -37,8 +39,12 @@ class VAEConfig(ModelConfig): use_tiling: bool = True use_temporal_tiling: bool = True use_parallel_tiling: bool = True + ltx2_spatial_tile_size_in_pixels: int | None = None + ltx2_spatial_tile_overlap_in_pixels: int | None = None + ltx2_temporal_tile_size_in_frames: int | None = None + ltx2_temporal_tile_overlap_in_frames: int | None = None - def __post_init__(self): + def __post_init__(self) -> None: self.blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames @staticmethod diff --git a/fastvideo/configs/pipelines/base.py b/fastvideo/configs/pipelines/base.py index 80ddc04b39..c0cbe5d7c9 100644 --- a/fastvideo/configs/pipelines/base.py +++ b/fastvideo/configs/pipelines/base.py @@ -31,7 +31,7 @@ class PipelineConfig: pipeline_config_path: str | None = None # Video generation parameters - embedded_cfg_scale: float = 6.0 + embedded_cfg_scale: float | None = 6.0 flow_shift: float | None = None flow_shift_sr: float | None = None disable_autocast: bool = False @@ -40,7 +40,7 @@ class PipelineConfig: # Model configuration dit_config: DiTConfig = field(default_factory=DiTConfig) dit_precision: str = "bf16" - upsampler_config: UpsamplerConfig = field(default_factory=UpsamplerConfig) + upsampler_config: UpsamplerConfig | tuple[UpsamplerConfig, ...] = field(default_factory=UpsamplerConfig) upsampler_precision: str = "fp32" # VAE configuration @@ -58,8 +58,7 @@ class PipelineConfig: text_encoder_configs: tuple[EncoderConfig, ...] = field(default_factory=lambda: (EncoderConfig(), )) text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp32", )) preprocess_text_funcs: tuple[Callable[[str], str], ...] = field(default_factory=lambda: (preprocess_text, )) - postprocess_text_funcs: tuple[Callable[[BaseEncoderOutput], torch.tensor], - ...] = field(default_factory=lambda: (postprocess_text, )) + postprocess_text_funcs: tuple[Callable[..., Any], ...] = field(default_factory=lambda: (postprocess_text, )) # DMD parameters dmd_denoising_steps: list[int] | None = field(default=None) @@ -71,6 +70,12 @@ class PipelineConfig: # Compilation # enable_torch_compile: bool = False + def __post_init__(self) -> None: + pass + + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + @staticmethod def add_cli_args(parser: FlexibleArgumentParser, prefix: str = "") -> FlexibleArgumentParser: prefix_with_dot = f"{prefix}." if (prefix.strip() != "") else "" diff --git a/fastvideo/configs/pipelines/cosmos2_5.py b/fastvideo/configs/pipelines/cosmos2_5.py index 447048e3a1..cf96a2fd3c 100644 --- a/fastvideo/configs/pipelines/cosmos2_5.py +++ b/fastvideo/configs/pipelines/cosmos2_5.py @@ -41,9 +41,9 @@ class Cosmos25Config(PipelineConfig): in_channels=16, out_channels=16, num_layers=28, - patch_size=[1, 2, 2], - max_size=[128, 240, 240], - rope_scale=[1.0, 3.0, 3.0], + patch_size=(1, 2, 2), + max_size=(128, 240, 240), + rope_scale=(1.0, 3.0, 3.0), text_embed_dim=1024, mlp_ratio=4.0, adaln_lora_dim=256, @@ -75,7 +75,7 @@ class Cosmos25Config(PipelineConfig): vae_tiling: bool = False vae_sp: bool = False - def __post_init__(self): + def __post_init__(self) -> None: self.vae_config.load_encoder = True self.vae_config.load_decoder = True self._vae_latent_dim = 16 diff --git a/fastvideo/configs/pipelines/gen3c.py b/fastvideo/configs/pipelines/gen3c.py index a98c3d016f..e2dafe4db6 100644 --- a/fastvideo/configs/pipelines/gen3c.py +++ b/fastvideo/configs/pipelines/gen3c.py @@ -22,7 +22,7 @@ class _Gen3CT5LargeArchConfig(T5LargeArchConfig): refactor [PR#1142](https://github.com/hao-ai-lab/FastVideo/pull/1142). """ - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() self.tokenizer_kwargs["padding"] = "max_length" diff --git a/fastvideo/configs/pipelines/hunyuan15.py b/fastvideo/configs/pipelines/hunyuan15.py index b8a06bae35..6899e2f283 100644 --- a/fastvideo/configs/pipelines/hunyuan15.py +++ b/fastvideo/configs/pipelines/hunyuan15.py @@ -112,7 +112,7 @@ class Hunyuan15T2V480PConfig(PipelineConfig): vae_tiling: bool = True - def __post_init__(self): + def __post_init__(self) -> None: self.vae_config.load_encoder = False self.vae_config.load_decoder = True if self.text_encoder_configs: diff --git a/fastvideo/configs/pipelines/hunyuangamecraft.py b/fastvideo/configs/pipelines/hunyuangamecraft.py index d67846fe3c..ba777d1913 100644 --- a/fastvideo/configs/pipelines/hunyuangamecraft.py +++ b/fastvideo/configs/pipelines/hunyuangamecraft.py @@ -44,7 +44,7 @@ class HunyuanGameCraftPipelineConfig(PipelineConfig): # Denoising parameters # Official GameCraft does NOT use embedded guidance (passes guidance=None) # It uses standard CFG with guidance_scale=6.0 instead - embedded_cfg_scale = None + embedded_cfg_scale: float | None = None flow_shift: int = 5 # Official GameCraft uses flow_shift=5.0 # Text encoding stage - same as HunyuanVideo @@ -60,7 +60,7 @@ class HunyuanGameCraftPipelineConfig(PipelineConfig): vae_precision: str = "fp16" text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp16", "fp16")) - def __post_init__(self): + def __post_init__(self) -> None: # VAE only needs decoder for inference self.vae_config.load_encoder = False self.vae_config.load_decoder = True diff --git a/fastvideo/configs/pipelines/hyworld.py b/fastvideo/configs/pipelines/hyworld.py index 64de5faf7e..d4753a6a75 100644 --- a/fastvideo/configs/pipelines/hyworld.py +++ b/fastvideo/configs/pipelines/hyworld.py @@ -22,7 +22,7 @@ class HYWorldConfig(Hunyuan15T2V480PConfig): # Text encoding text_encoder_precisions: tuple[str, ...] = field(default_factory=lambda: ("fp16", "fp32")) - def __post_init__(self): + def __post_init__(self) -> None: super().__post_init__() self.vae_config.load_encoder = True self.vae_config.load_decoder = True diff --git a/fastvideo/configs/pipelines/longcat.py b/fastvideo/configs/pipelines/longcat.py index 4140d1513e..9f45049024 100644 --- a/fastvideo/configs/pipelines/longcat.py +++ b/fastvideo/configs/pipelines/longcat.py @@ -32,7 +32,7 @@ class LongCatDiTArchConfig(DiTArchConfig): num_heads: int = 32 out_channels: int = 16 text_tokens_zero_pad: bool = True - patch_size: list[int] = field(default_factory=lambda: [1, 2, 2]) + patch_size: tuple[int, int, int] = (1, 2, 2) cp_split_hw: list[int] | None = None bsa_params: dict | None = None @@ -314,7 +314,7 @@ class LongCatT2V704PConfig(LongCatT2V480PConfig): } -def get_bucket_config(resolution, scale_factor_spatial): +def get_bucket_config(resolution: str, scale_factor_spatial: int) -> dict[str, tuple[list[int], int]]: if resolution == '480p': if scale_factor_spatial == 16 or scale_factor_spatial == 32: return ASPECT_RATIO_627 diff --git a/fastvideo/configs/sample/base.py b/fastvideo/configs/sample/base.py index e27eb5c978..82e4189b81 100644 --- a/fastvideo/configs/sample/base.py +++ b/fastvideo/configs/sample/base.py @@ -47,7 +47,7 @@ class SamplingParam: # Text inputs prompt: str | list[str] | None = None - negative_prompt: str = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" + negative_prompt: str | None = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" prompt_path: str | None = None output_path: str = "outputs/" output_video_name: str | None = None @@ -68,6 +68,7 @@ class SamplingParam: num_inference_steps: int = 50 num_inference_steps_sr: int = 50 guidance_scale: float = 1.0 + guidance_scale_2: float | None = None guidance_rescale: float = 0.0 boundary_ratio: float | None = None sigmas: list[float] | None = None @@ -89,7 +90,10 @@ class SamplingParam: def __post_init__(self) -> None: self.data_type = "video" if self.num_frames > 1 else "image" - def check_sampling_param(self): + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def check_sampling_param(self) -> None: if self.prompt_path and not self.prompt_path.endswith(".txt"): raise ValueError("prompt_path must be a txt file") diff --git a/fastvideo/distributed/device_communicators/pyhccl.py b/fastvideo/distributed/device_communicators/pyhccl.py index ed5ecda565..062e974074 100644 --- a/fastvideo/distributed/device_communicators/pyhccl.py +++ b/fastvideo/distributed/device_communicators/pyhccl.py @@ -134,3 +134,18 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): buffer = buffer_type(tensor.data_ptr()) self.hccl.hcclBroadcast(buffer, tensor.numel(), hcclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, aclrtStream_t(stream.npu_stream)) + + def send(self, tensor: torch.Tensor, dst: int) -> None: + if isinstance(self.group, StatelessProcessGroup): + self.group.send_obj(tensor.cpu(), dst) + return + dist.send(tensor, dst=dist.get_process_group_ranks(self.group)[dst], group=self.group) + + def recv(self, tensor: torch.Tensor, src: int) -> None: + if isinstance(self.group, StatelessProcessGroup): + received = self.group.recv_obj(src) + if not isinstance(received, torch.Tensor): + raise TypeError(f"Expected a tensor from rank {src}, got {type(received)!r}") + tensor.copy_(received.to(device=tensor.device, dtype=tensor.dtype)) + return + dist.recv(tensor, src=dist.get_process_group_ranks(self.group)[src], group=self.group) diff --git a/fastvideo/entrypoints/streaming_generator.py b/fastvideo/entrypoints/streaming_generator.py index b0f8aec1cf..a6b3fcee1f 100644 --- a/fastvideo/entrypoints/streaming_generator.py +++ b/fastvideo/entrypoints/streaming_generator.py @@ -256,7 +256,7 @@ def _process_output_batch(self, output_batch: ForwardBatch) -> list[np.ndarray]: return frames - def shutdown(self): + def shutdown(self) -> None: if self.writer: self.writer.close() self.writer = None diff --git a/fastvideo/entrypoints/video_generator.py b/fastvideo/entrypoints/video_generator.py index 7047e10bf2..2503d7b972 100644 --- a/fastvideo/entrypoints/video_generator.py +++ b/fastvideo/entrypoints/video_generator.py @@ -65,6 +65,7 @@ "pin_cpu_memory", "enable_torch_compile", "torch_compile_kwargs", + "transformer_quant", }) @@ -324,7 +325,7 @@ def _generate_request_impl( results.extend(wrapped) continue wrapped.prompt_index = index - if wrapped.prompt is None: + if wrapped.prompt is None and isinstance(prompt, str): wrapped.prompt = prompt results.append(wrapped) return results @@ -357,7 +358,7 @@ def _generate_single_request( def _generate_video_impl( self, - prompt: str | None = None, + prompt: str | list[str] | None = None, sampling_param: SamplingParam | None = None, mouse_cond: torch.Tensor | None = None, keyboard_cond: torch.Tensor | None = None, @@ -429,6 +430,8 @@ def _generate_video_impl( # Single prompt generation (original behavior) if prompt is None: raise ValueError("Either prompt or prompt_txt must be provided") + if not isinstance(prompt, str): + raise ValueError("Single-prompt generation expects a string prompt") output_path = self._prepare_output_path(sampling_param.output_path, prompt) kwargs["output_path"] = output_path return self._generate_single_video( @@ -786,7 +789,7 @@ def unmerge_lora_weights(self) -> None: def merge_lora_weights(self) -> None: self.executor.merge_lora_weights() - def shutdown(self): + def shutdown(self) -> None: """ Shutdown the video generator. """ diff --git a/fastvideo/envs.py b/fastvideo/envs.py index b6de4d588f..0f3efa0cb0 100644 --- a/fastvideo/envs.py +++ b/fastvideo/envs.py @@ -201,6 +201,8 @@ def maybe_convert_int(value: str | None) -> int | None: # - "VIDEO_SPARSE_ATTN": use Video Sparse Attention # - "SAGE_ATTN": use Sage Attention # - "SAGE_ATTN_THREE": use Sage Attention 3 + # - "ATTN_QAT_INFER": use the in-repo attn_qat_infer inference backend + # - "ATTN_QAT_TRAIN": use the FastVideoKernel Triton attn_qat_train backend "FASTVIDEO_ATTENTION_BACKEND": lambda: os.getenv("FASTVIDEO_ATTENTION_BACKEND", None), diff --git a/fastvideo/fastvideo_args.py b/fastvideo/fastvideo_args.py index 10a5fe0b3e..8611b7f32a 100644 --- a/fastvideo/fastvideo_args.py +++ b/fastvideo/fastvideo_args.py @@ -172,6 +172,7 @@ class FastVideoArgs: override_text_encoder_safetensors: str | None = None # path to safetensors file for text encoder override override_text_encoder_quant: QuantizationMethods = None + transformer_quant: QuantizationMethods = None override_transformer_cls_name: str | None = None init_weights_from_safetensors: str = "" # path to safetensors file for initial weight loading @@ -183,7 +184,7 @@ class FastVideoArgs: # dmd_denoising_steps: List[int] | None = field(default=None) # MoE parameters used by Wan2.2 - boundary_ratio: float | None = 0.875 + boundary_ratio: float = 0.875 @property def training_mode(self) -> bool: @@ -201,6 +202,9 @@ def __post_init__(self): self._apply_ltx2_vae_overrides() self.check_fastvideo_args() + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def _apply_ltx2_vae_overrides(self) -> None: if self.pipeline_config is None: return @@ -461,8 +465,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--use-fsdp-inference", action=StoreBoolean, help= - "Use FSDP for inference by sharding the model weights. FSDP helps reduce GPU memory usage but may introduce" - + " weight transfer overhead depending on the specific setup. Enable if run out of memory.", + "Use FSDP for inference by sharding the model weights. Latency is very low due to prefetch--enable if run out of memory.", ) parser.add_argument( "--text-encoder-cpu-offload", @@ -528,6 +531,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=FastVideoArgs.override_text_encoder_quant, help="Quantization method for text encoder override", ) + parser.add_argument( + "--transformer-quant", + type=str, + choices=QUANTIZATION_METHODS, + default=FastVideoArgs.transformer_quant, + help="Quantization method for transformer loading", + ) parser.add_argument( "--override-transformer-cls-name", type=str, @@ -782,7 +792,8 @@ class TrainingArgs(FastVideoArgs): trackers: list[str] = dataclasses.field(default_factory=list) tracker_project_name: str = "" wandb_run_name: str = "" - seed: int | None = None + seed: int = 0 + _loading_teacher_critic_model: bool = False # output output_dir: str = "" @@ -855,6 +866,8 @@ class TrainingArgs(FastVideoArgs): # simulate generator forward to match inference simulate_generator_forward: bool = False warp_denoising_step: bool = False + generator_4bit_attn: bool = False + generator_4bit_linear: bool = False # Self-forcing specific arguments num_frame_per_block: int = 3 diff --git a/fastvideo/layers/fp4linear.py b/fastvideo/layers/fp4linear.py new file mode 100644 index 0000000000..7c6733e8b8 --- /dev/null +++ b/fastvideo/layers/fp4linear.py @@ -0,0 +1,115 @@ +from typing import Any + +import torch + +try: + import flashinfer +except ImportError: + flashinfer = None + + +def _require_flashinfer() -> Any: + if flashinfer is None: + raise ImportError("flashinfer is required for FP4 linear layers. " + "Please install flashinfer to use this path.") + return flashinfer + + +class _LinearFWD4BWD16Fn(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, weight, bias, backend="cutlass", block_size=16, use_128x4_sf_layout=True): + flashinfer_mod = _require_flashinfer() + + # assert activation dtype + if x.dtype not in (torch.float16, torch.bfloat16): + x = x.to(dtype=torch.bfloat16) + + # cast params (can be fp32) to activation dtype for quantization + weight_cast = weight.to(dtype=x.dtype) + bias_cast = bias.to(dtype=x.dtype) if bias is not None else None + + # shapes + orig_shape = x.shape + k = weight_cast.shape[1] + n = weight_cast.shape[0] + x2d = x.reshape(-1, k).contiguous() + M = x2d.shape[0] + + out2d = torch.empty((M, n), device=x.device, dtype=x.dtype) + + @torch.compile + def _global_sf(t: torch.Tensor) -> torch.Tensor: + maxabs = t.float().abs().nan_to_num().max() + maxabs = torch.maximum(maxabs, torch.tensor(1e-12, device=t.device, dtype=maxabs.dtype)) + return (448.0 * 6.0) / maxabs + + a_sf_layout = (flashinfer_mod.SfLayout.layout_128x4 + if use_128x4_sf_layout else flashinfer_mod.SfLayout.layout_8x4) + global_sf_a = _global_sf(x2d) + global_sf_b = _global_sf(weight_cast) + + a_fp4, a_inv_s = flashinfer_mod.nvfp4_quantize( + x2d, + global_sf_a, + sfLayout=a_sf_layout, + do_shuffle=False, + ) + b_fp4, b_inv_s = flashinfer_mod.nvfp4_quantize( + weight_cast, + global_sf_b, + sfLayout=flashinfer_mod.SfLayout.layout_128x4, + do_shuffle=False, + ) + + alpha = 1.0 / (global_sf_a * global_sf_b) + + flashinfer_mod.mm_fp4( + a_fp4, + b_fp4.T, + a_inv_s, + b_inv_s.T, + alpha, + x.dtype, + out2d, + block_size=block_size, + use_8x4_sf_layout=(not use_128x4_sf_layout), + backend=backend, + ) + + if bias_cast is not None: + out2d.add_(bias_cast) + + # save tensors for backward (keep original dtypes) + ctx.save_for_backward(x2d, weight, bias) + ctx.k = k + ctx.n = n + ctx.orig_shape = orig_shape + return out2d.reshape(*orig_shape[:-1], n) + + @staticmethod + def backward(ctx, grad_out): + x2d, weight, bias = ctx.saved_tensors + M = x2d.shape[0] + n = ctx.n + + grad_out_2d = grad_out.reshape(M, n).contiguous() + + # cast to grad dtype for matmuls + weight_cast = weight.to(dtype=grad_out.dtype) + x_cast = x2d.to(dtype=grad_out.dtype) + + grad_x = grad_out_2d.matmul(weight_cast).reshape(*ctx.orig_shape) + grad_w = grad_out_2d.t().matmul(x_cast) + grad_b = grad_out_2d.sum(dim=0) if bias is not None else None + + # None for the three extra forward args + return grad_x, grad_w, grad_b, None, None, None + + +def fp4_linear_forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor | None]: + # pass config **positionally**; autograd.Function.apply ignores kwargs + bias = self.bias if not self.skip_bias_add else None + output = _LinearFWD4BWD16Fn.apply(x, self.weight, bias, "cutlass", 16, True) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias diff --git a/fastvideo/layers/linear.py b/fastvideo/layers/linear.py index e33a738ec1..524dc19672 100644 --- a/fastvideo/layers/linear.py +++ b/fastvideo/layers/linear.py @@ -211,36 +211,30 @@ class ReplicatedLinear(LinearBase): (e.g. model.layers.0.qkv_proj) """ - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: torch.dtype | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - super().__init__( - input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix=prefix, - ) + enable_shape_tracking = False + _unique_shapes: set[tuple[torch.Size, torch.Size]] = set() + _shape_to_layer_types: dict[tuple[torch.Size, torch.Size], list[str]] = {} + + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: torch.dtype | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = ""): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix=prefix) # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights( - self, - self.input_size, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader, - ) + if self.quant_method is None: + self.quant_method = UnquantizedLinearMethod() + + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter(torch.empty( @@ -269,8 +263,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor) -> None: def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, Parameter | None]: bias = self.bias if not self.skip_bias_add else None - assert self.quant_method is not None + if self.quant_method is None: + self.quant_method = UnquantizedLinearMethod() + output = self.quant_method.apply(self, x, bias) + if self.enable_shape_tracking: + self._track_shape(x.shape, output.shape) + output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -280,6 +279,48 @@ def extra_repr(self) -> str: s += f", bias={self.bias is not None}" return s + @classmethod + def get_shape_mapping(cls) -> dict: + """Get the mapping from (input_shape, output_shape) to layer types.""" + return cls._shape_to_layer_types.copy() + + @classmethod + def reset_shape_tracking(cls) -> None: + """Clear tracked shapes and layer type mappings.""" + cls._unique_shapes.clear() + cls._shape_to_layer_types.clear() + + def _track_shape(self, input_shape: torch.Size, output_shape: torch.Size) -> None: + shape_key = (input_shape, output_shape) + + if shape_key not in self._unique_shapes: + self._unique_shapes.add(shape_key) + self._shape_to_layer_types[shape_key] = [] + print(f"Layer: {self.prefix} | input shape: {input_shape} --> " + f"output shape: {output_shape}, Quant Method: " + f"{self.quant_method.__class__.__name__}") + + layer_type = self.__class__.__name__ + if layer_type not in self._shape_to_layer_types[shape_key]: + self._shape_to_layer_types[shape_key].append(layer_type) + + @classmethod + def print_shape_summary(cls): + """Print a summary of all unique shapes and their layer types.""" + if not cls._shape_to_layer_types: + print("No shapes have been processed yet.") + return + + print("\n=== Matrix Multiplication Shape Summary ===") + print(f"Total unique shapes: {len(cls._shape_to_layer_types)}") + print() + + for i, (shape_key, layer_types) in enumerate(cls._shape_to_layer_types.items(), 1): + input_shape, output_shape = shape_key + print(f"{i}. Input: {input_shape} → Output: {output_shape}") + print(f" Layer types: {', '.join(layer_types)}") + print() + class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. @@ -340,7 +381,9 @@ def __init__( if output_sizes is None: output_sizes = [output_size] - assert self.quant_method is not None + if self.quant_method is None: + self.quant_method = UnquantizedLinearMethod() + self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, @@ -399,7 +442,8 @@ def forward(self, input_: torch.Tensor) -> tuple[torch.Tensor, Parameter | None] bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - assert self.quant_method is not None + if self.quant_method is None: + self.quant_method = UnquantizedLinearMethod() output_parallel = self.quant_method.apply(self, input_, bias) # All-gather across the partitions if needed. output = tensor_model_parallel_all_gather(output_parallel) if self.gather_output else output_parallel @@ -916,7 +960,9 @@ def __init__( self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - assert self.quant_method is not None + if self.quant_method is None: + self.quant_method = UnquantizedLinearMethod() + self.quant_method.create_weights( layer=self, input_size_per_partition=self.input_size_per_partition, @@ -983,7 +1029,8 @@ def forward(self, input_) -> tuple[torch.Tensor, Parameter | None]: input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - assert self.quant_method is not None + if self.quant_method is None: + self.quant_method = UnquantizedLinearMethod() # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in TP>1 case) bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias diff --git a/fastvideo/layers/mlp.py b/fastvideo/layers/mlp.py index 7610a775d3..caa0f3fcf4 100644 --- a/fastvideo/layers/mlp.py +++ b/fastvideo/layers/mlp.py @@ -5,6 +5,7 @@ from fastvideo.layers.activation import get_act_fn from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.layers.quantization import QuantizationConfig class MLP(nn.Module): @@ -21,18 +22,35 @@ def __init__( act_type: str = "gelu_pytorch_tanh", dtype: torch.dtype | None = None, prefix: str = "", + quant_config: QuantizationConfig | None = None, ): super().__init__() self.fc_in = ReplicatedLinear( input_dim, mlp_hidden_dim, # For activation func like SiLU that need 2x width bias=bias, - params_dtype=dtype) + params_dtype=dtype, + quant_config=quant_config, + prefix=f"{prefix}.fc_in", + ) + if quant_config is not None: + quant_method = self.fc_in.quant_config.get_quant_method(self.fc_in, f"{prefix}.fc_in") + if quant_method is not None: + quant_method.process_weights_after_loading(self.fc_in) self.act = get_act_fn(act_type) if output_dim is None: output_dim = input_dim - self.fc_out = ReplicatedLinear(mlp_hidden_dim, output_dim, bias=bias, params_dtype=dtype) + self.fc_out = ReplicatedLinear(mlp_hidden_dim, + output_dim, + bias=bias, + params_dtype=dtype, + quant_config=quant_config, + prefix=f"{prefix}.fc_out") + if quant_config is not None: + quant_method = self.fc_out.quant_config.get_quant_method(self.fc_out, f"{prefix}.fc_out") + if quant_method is not None: + quant_method.process_weights_after_loading(self.fc_out) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.fc_in(x) diff --git a/fastvideo/layers/quantization/__init__.py b/fastvideo/layers/quantization/__init__.py index 366fd463b7..3067e49763 100644 --- a/fastvideo/layers/quantization/__init__.py +++ b/fastvideo/layers/quantization/__init__.py @@ -2,7 +2,7 @@ from fastvideo.layers.quantization.base_config import QuantizationConfig -QuantizationMethods = Literal[None, "AbsMaxFP8"] +QuantizationMethods = Literal[None, "AbsMaxFP8", "fp4"] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -51,9 +51,11 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: # lazy import to avoid triggering `torch.compile` too early from .absmax_fp8 import AbsMaxFP8Config + from .fp4_config import FP4Config method_to_config: dict[str, type[QuantizationConfig]] = { "AbsMaxFP8": AbsMaxFP8Config, + "fp4": FP4Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/fastvideo/layers/quantization/base_config.py b/fastvideo/layers/quantization/base_config.py index 20f2ea0b54..c7dc2754a9 100644 --- a/fastvideo/layers/quantization/base_config.py +++ b/fastvideo/layers/quantization/base_config.py @@ -61,7 +61,7 @@ def method_has_implemented_embedding(method_class: type[QuantizeMethodBase]) -> class QuantizationConfig(ABC): """Base class for quantization configs.""" - def __init__(self): + def __init__(self) -> None: super().__init__() # mapping is updated by models as they initialize self.packed_modules_mapping: dict[str, list[str]] = dict() diff --git a/fastvideo/layers/quantization/fp4_config.py b/fastvideo/layers/quantization/fp4_config.py new file mode 100644 index 0000000000..539ecfd1be --- /dev/null +++ b/fastvideo/layers/quantization/fp4_config.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 +import logging +from typing import Any + +import torch +from torch.nn.parameter import Parameter + +from fastvideo.layers.quantization.base_config import QuantizationConfig, QuantizeMethodBase +from fastvideo.models.utils import set_weight_attrs + +try: + import flashinfer +except ImportError: + flashinfer = None + +logger = logging.getLogger(__name__) + + +def _require_flashinfer() -> Any: + if flashinfer is None: + raise ImportError("flashinfer is required for FP4 quantization. " + "Please install flashinfer to use the fp4 quantization backend.") + return flashinfer + + +class FP4QuantizeMethod(QuantizeMethodBase): + + def __init__(self) -> None: + super().__init__() + self.weight_fp4 = None + self.weight_scale = None + + def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], + input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): + """Create weights for a linear layer. Note the corrected signature to match LinearMethodBase.""" + weight = Parameter(torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + @torch.compile + def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + """Apply FP4 quantized computation.""" + flashinfer_mod = _require_flashinfer() + out_dim = layer.weight.shape[0] + original_shape = x.shape + assert x.dtype == torch.bfloat16 or x.dtype == torch.float16, f"only allow bf16/fp16 inputs to fp4 linear, got {x.dtype}" + + x = x.view(-1, x.shape[-1]) + + x_global_sf = (448 * 6) / x.float().abs().nan_to_num().max() + x_fp4, x_scale = flashinfer_mod.nvfp4_quantize( + x, + x_global_sf, + sfLayout=flashinfer_mod.SfLayout.layout_128x4, + do_shuffle=False, + ) + weight_fp4 = layer._fp4_weight + weight_scale = layer._fp4_weight_scale + weight_global_sf = layer._weight_global_sf + + out = flashinfer_mod.mm_fp4( + x_fp4, + weight_fp4.T, + x_scale, + weight_scale.T, + 1.0 / (x_global_sf * weight_global_sf), + torch.bfloat16, + None, + backend="cutlass", + ) + + if bias is not None: + if bias.device != out.device or bias.dtype != out.dtype: + bias = bias.to(device=out.device, dtype=out.dtype) + out = out + bias + + if len(original_shape) == 3: + out = out.view(original_shape[0], original_shape[1], out_dim) + + return out + + +class FP4Config(QuantizationConfig): + + def __init__(self) -> None: + super().__init__() + + def get_name(self): + return "fp4" + + def get_supported_act_dtypes(self): + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls): + return 100 + + @staticmethod + def get_config_filenames(): + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "FP4Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, prefix: str): + from fastvideo.layers.linear import LinearBase + fp4_layers = ["ffn.fc_in", "ffn.fc_out", "to_q", "to_k", "to_v", "to_out"] + if isinstance(layer, LinearBase) and any(layer_name in prefix for layer_name in fp4_layers): + return FP4QuantizeMethod() + return None + + +@torch.compile +def convert_model_to_fp4(model: torch.nn.Module): + flashinfer_mod = _require_flashinfer() + from torch.distributed.tensor import DTensor # type: ignore + for mod in model.modules(): + qm = getattr(mod, "quant_method", None) + if isinstance(qm, FP4QuantizeMethod): + weight = getattr(mod, "weight", None) + if weight is None: + continue + weight_local = weight.to_local() if isinstance(weight, DTensor) else weight # type: ignore[arg-type] + weight_global_sf = (448 * 6) / weight_local.float().abs().nan_to_num().max() + fp4_w, fp4_s = flashinfer_mod.nvfp4_quantize( + weight_local, + weight_global_sf, + sfLayout=flashinfer_mod.SfLayout.layout_128x4, + do_shuffle=False, + ) + mod.register_buffer("_fp4_weight", fp4_w, persistent=False) + mod.register_buffer("_fp4_weight_scale", fp4_s, persistent=False) + mod.register_buffer("_weight_global_sf", + torch.tensor(weight_global_sf, dtype=torch.bfloat16), + persistent=False) diff --git a/fastvideo/models/dits/wanvideo.py b/fastvideo/models/dits/wanvideo.py index 359db5742b..480a36cc86 100644 --- a/fastvideo/models/dits/wanvideo.py +++ b/fastvideo/models/dits/wanvideo.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import math +from contextlib import nullcontext from typing import Any import torch @@ -26,6 +27,7 @@ from fastvideo.logger import init_logger from fastvideo.models.dits.base import BaseDiT from fastvideo.platforms import AttentionBackendEnum, current_platform +from fastvideo.layers.quantization import QuantizationConfig from fastvideo.distributed.parallel_state import get_sp_world_size @@ -106,7 +108,9 @@ def __init__(self, window_size=(-1, -1), qk_norm=True, eps=1e-6, - parallel_attention=False) -> None: + parallel_attention=False, + quant_config: QuantizationConfig | None = None, + prefix: str = "") -> None: assert dim % num_heads == 0 super().__init__() self.dim = dim @@ -118,10 +122,10 @@ def __init__(self, self.parallel_attention = parallel_attention # layers - self.to_q = ReplicatedLinear(dim, dim) - self.to_k = ReplicatedLinear(dim, dim) - self.to_v = ReplicatedLinear(dim, dim) - self.to_out = ReplicatedLinear(dim, dim) + self.to_q = ReplicatedLinear(dim, dim, quant_config=quant_config, prefix=f"{prefix}.to_q") + self.to_k = ReplicatedLinear(dim, dim, quant_config=quant_config, prefix=f"{prefix}.to_k") + self.to_v = ReplicatedLinear(dim, dim, quant_config=quant_config, prefix=f"{prefix}.to_v") + self.to_out = ReplicatedLinear(dim, dim, quant_config=quant_config, prefix=f"{prefix}.to_out") self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() @@ -194,13 +198,15 @@ def __init__( qk_norm=True, eps=1e-6, supported_attention_backends: tuple[AttentionBackendEnum, ...] - | None = None + | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: super().__init__(dim, num_heads, window_size, qk_norm, eps, - supported_attention_backends) + supported_attention_backends, quant_config=quant_config, prefix=prefix) - self.add_k_proj = ReplicatedLinear(dim, dim) - self.add_v_proj = ReplicatedLinear(dim, dim) + self.add_k_proj = ReplicatedLinear(dim, dim, quant_config=quant_config, prefix=f"{prefix}.add_k_proj") + self.add_v_proj = ReplicatedLinear(dim, dim, quant_config=quant_config, prefix=f"{prefix}.add_v_proj") self.norm_added_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() self.norm_added_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() @@ -246,16 +252,17 @@ def __init__(self, added_kv_proj_dim: int | None = None, supported_attention_backends: tuple[AttentionBackendEnum, ...] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = ""): super().__init__() # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.to_q = ReplicatedLinear(dim, dim, bias=True) - self.to_k = ReplicatedLinear(dim, dim, bias=True) - self.to_v = ReplicatedLinear(dim, dim, bias=True) + self.to_q = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_q") + self.to_k = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_k") + self.to_v = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_v") - self.to_out = ReplicatedLinear(dim, dim, bias=True) + self.to_out = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_out") self.attn1 = DistributedAttention( num_heads=num_heads, head_size=dim // num_heads, @@ -290,13 +297,17 @@ def __init__(self, self.attn2 = WanI2VCrossAttention(dim, num_heads, qk_norm=qk_norm, - eps=eps) + eps=eps, + quant_config=quant_config, + prefix=f"{prefix}.attn2") else: # T2V self.attn2 = WanT2VCrossAttention(dim, num_heads, qk_norm=qk_norm, - eps=eps) + eps=eps, + quant_config=quant_config, + prefix=f"{prefix}.attn2") self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( dim, norm_type="layer", @@ -306,7 +317,7 @@ def __init__(self, compute_dtype=torch.float32) # 3. Feed-forward - self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh", quant_config=quant_config, prefix=f"{prefix}.ffn") self.mlp_residual = ScaleResidual() self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) @@ -406,17 +417,17 @@ def __init__(self, added_kv_proj_dim: int | None = None, supported_attention_backends: tuple[AttentionBackendEnum, ...] | None = None, + quant_config: QuantizationConfig | None = None, prefix: str = ""): super().__init__() # 1. Self-attention self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) - self.to_q = ReplicatedLinear(dim, dim, bias=True) - self.to_k = ReplicatedLinear(dim, dim, bias=True) - self.to_v = ReplicatedLinear(dim, dim, bias=True) - self.to_gate_compress = ReplicatedLinear(dim, dim, bias=True) - - self.to_out = ReplicatedLinear(dim, dim, bias=True) + self.to_q = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_q") + self.to_k = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_k") + self.to_v = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_v") + self.to_gate_compress = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_gate_compress") + self.to_out = ReplicatedLinear(dim, dim, bias=True, quant_config=quant_config, prefix=f"{prefix}.to_out") self.attn1 = DistributedAttention_VSA( num_heads=num_heads, head_size=dim // num_heads, @@ -451,13 +462,17 @@ def __init__(self, self.attn2 = WanI2VCrossAttention(dim, num_heads, qk_norm=qk_norm, - eps=eps) + eps=eps, + quant_config=quant_config, + prefix=f"{prefix}.attn2") else: # T2V self.attn2 = WanT2VCrossAttention(dim, num_heads, qk_norm=qk_norm, - eps=eps) + eps=eps, + quant_config=quant_config, + prefix=f"{prefix}.attn2") self.cross_attn_residual_norm = ScaleResidualLayerNormScaleShift( dim, norm_type="layer", @@ -467,7 +482,7 @@ def __init__(self, compute_dtype=torch.float32) # 3. Feed-forward - self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh") + self.ffn = MLP(dim, ffn_dim, act_type="gelu_pytorch_tanh", quant_config=quant_config, prefix=f"{prefix}.ffn") self.mlp_residual = ScaleResidual() self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) @@ -556,6 +571,7 @@ class WanTransformer3DModel(BaseDiT): def __init__(self, config: WanVideoConfig, hf_config: dict[str, Any]) -> None: super().__init__(config=config, hf_config=hf_config) + self.quant_config = config.quant_config inner_dim = config.num_attention_heads * config.attention_head_dim self.hidden_size = config.hidden_size @@ -594,6 +610,7 @@ def __init__(self, config: WanVideoConfig, hf_config: dict[str, config.eps, config.added_kv_proj_dim, self._supported_attention_backends, + quant_config=config.quant_config, prefix=f"{config.prefix}.blocks.{i}") for i in range(config.num_layers) ]) diff --git a/fastvideo/models/loader/component_loader.py b/fastvideo/models/loader/component_loader.py index 9dfde7cb24..a2b0418fe5 100644 --- a/fastvideo/models/loader/component_loader.py +++ b/fastvideo/models/loader/component_loader.py @@ -561,6 +561,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): # If parsing fails, fall through to AutoTokenizer below. pass + tokenizer = AutoTokenizer.from_pretrained( resolved_model_path, # "/tokenizer" # in v0, this was same string as encoder_name "ClipTextModel" @@ -872,6 +873,9 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): # Config from Diffusers supersedes fastvideo's model config dit_config = deepcopy(fastvideo_args.pipeline_config.dit_config) dit_config.update_model_arch(config) + if fastvideo_args.transformer_quant is not None: + quant_cls = get_quantization_config(fastvideo_args.transformer_quant) + dit_config.quant_config = quant_cls() model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) @@ -951,6 +955,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): training_mode=fastvideo_args.training_mode, enable_torch_compile=fastvideo_args.enable_torch_compile, torch_compile_kwargs=fastvideo_args.torch_compile_kwargs, + transformer_quant=fastvideo_args.transformer_quant, ) total_params = sum(p.numel() for p in model.parameters()) diff --git a/fastvideo/models/loader/fsdp_load.py b/fastvideo/models/loader/fsdp_load.py index f0e4494141..40e738c9c9 100644 --- a/fastvideo/models/loader/fsdp_load.py +++ b/fastvideo/models/loader/fsdp_load.py @@ -20,6 +20,7 @@ from torch.nn.modules.module import _IncompatibleKeys from fastvideo.logger import init_logger +from fastvideo.layers.quantization import QuantizationMethods from fastvideo.models.loader.utils import (get_param_names_mapping, hf_to_custom_state_dict) from fastvideo.models.loader.weight_utils import safetensors_weights_iterator @@ -75,6 +76,7 @@ def maybe_load_fsdp_model( pin_cpu_memory: bool = True, enable_torch_compile: bool = False, torch_compile_kwargs: dict[str, Any] | None = None, + transformer_quant: QuantizationMethods = None, ) -> torch.nn.Module: """ Load the model with FSDP if is training, else load the model without FSDP. @@ -157,6 +159,11 @@ def maybe_load_fsdp_model( # Avoid unintended computation graph accumulation during inference if isinstance(p, torch.nn.Parameter): p.requires_grad = False + if transformer_quant == "fp4": + from fastvideo.layers.quantization.fp4_config import convert_model_to_fp4 + + logger.info("Converting model to FP4 quantized runtime weights") + convert_model_to_fp4(model) compile_in_loader = enable_torch_compile and training_mode if compile_in_loader: diff --git a/fastvideo/pipelines/basic/matrixgame/matrixgame_causal_dmd_pipeline.py b/fastvideo/pipelines/basic/matrixgame/matrixgame_causal_dmd_pipeline.py index c1a3c9caa5..bb8950bc51 100644 --- a/fastvideo/pipelines/basic/matrixgame/matrixgame_causal_dmd_pipeline.py +++ b/fastvideo/pipelines/basic/matrixgame/matrixgame_causal_dmd_pipeline.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Matrix-Game causal DMD pipeline implementation.""" +from typing import Any, cast + from fastvideo.fastvideo_args import FastVideoArgs import torch from fastvideo.logger import init_logger @@ -9,6 +11,7 @@ from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, InputValidationStage, LatentPreparationStage, TextEncodingStage, MatrixGameImageEncodingStage, MatrixGameCausalDenoisingStage) from fastvideo.pipelines.stages.image_encoding import (MatrixGameImageVAEEncodingStage) +from fastvideo.pipelines.stages.matrixgame_denoising import BlockProcessingContext logger = init_logger(__name__) @@ -78,7 +81,7 @@ def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs): def streaming_step(self, keyboard_action, mouse_action) -> ForwardBatch: denoiser = self._stage_name_mapping["denoising_stage"] - ctx = denoiser._streaming_ctx + ctx = cast(BlockProcessingContext | None, denoiser._streaming_ctx) assert ctx is not None, "streaming_ctx must be set" start_idx = ctx.start_index @@ -87,13 +90,15 @@ def streaming_step(self, keyboard_action, mouse_action) -> ForwardBatch: # Decode only the new generated block if end_idx > start_idx: + assert batch.latents is not None, "latents must be set after streaming_step" current_latents = batch.latents[:, :, start_idx:end_idx, :, :] args = ctx.fastvideo_args decoder = self._stage_name_mapping["decoding_stage"] - decoded_frames, self._vae_cache = decoder.streaming_decode(current_latents, - args, - cache=self._vae_cache, - is_first_chunk=(start_idx == 0)) + decoded_frames_and_cache = cast( + tuple[Any, Any], + decoder.streaming_decode(current_latents, args, cache=self._vae_cache, is_first_chunk=(start_idx == 0)), + ) + decoded_frames, self._vae_cache = decoded_frames_and_cache batch.output = decoded_frames else: batch.output = None diff --git a/fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_pipeline.py b/fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_pipeline.py index f49f0c91c6..2eaf15a1a3 100644 --- a/fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_pipeline.py +++ b/fastvideo/pipelines/basic/turbodiffusion/turbodiffusion_pipeline.py @@ -33,7 +33,6 @@ def initialize_pipeline(self, fastvideo_args: FastVideoArgs): def create_pipeline_stages(self, fastvideo_args: FastVideoArgs) -> None: """Set up pipeline stages with proper dependency injection.""" - self.add_stage(stage_name="input_validation_stage", stage=InputValidationStage()) self.add_stage(stage_name="prompt_encoding_stage", diff --git a/fastvideo/pipelines/composed_pipeline_base.py b/fastvideo/pipelines/composed_pipeline_base.py index 079ea169a8..76c11957b6 100644 --- a/fastvideo/pipelines/composed_pipeline_base.py +++ b/fastvideo/pipelines/composed_pipeline_base.py @@ -39,8 +39,8 @@ class ComposedPipelineBase(ABC): is_video_pipeline: bool = False # To be overridden by video pipelines _required_config_modules: list[str] = [] _extra_config_module_map: dict[str, str] = {} - training_args: TrainingArgs | None = None - fastvideo_args: FastVideoArgs | TrainingArgs | None = None + training_args: Any = None + fastvideo_args: Any = None modules: dict[str, Any] = {} # do not need to include moe related transformers trainable_transformer_names: list[str] = ["transformer"] @@ -203,7 +203,7 @@ def from_pretrained(cls, device: str | None = None, torch_dtype: torch.dtype | None = None, pipeline_config: str | PipelineConfig | None = None, - args: argparse.Namespace | None = None, + args: argparse.Namespace | FastVideoArgs | TrainingArgs | None = None, required_config_modules: list[str] | None = None, loaded_modules: dict[str, torch.nn.Module] | None = None, @@ -213,13 +213,16 @@ def from_pretrained(cls, loaded_modules: Optional[Dict[str, torch.nn.Module]] = None, If provided, loaded_modules will be used instead of loading from config/pretrained weights. """ - if args is None or args.inference_mode: + if args is None or (isinstance(args, FastVideoArgs) and args.inference_mode): kwargs['model_path'] = model_path fastvideo_args = FastVideoArgs.from_kwargs(**kwargs) else: - assert args is not None, "args must be provided for training mode" - fastvideo_args = TrainingArgs.from_cli_args(args) + if isinstance(args, TrainingArgs): + fastvideo_args = args + else: + assert isinstance(args, argparse.Namespace), "training mode expects argparse.Namespace args" + fastvideo_args = TrainingArgs.from_cli_args(args) # TODO(will): fix this so that its not so ugly fastvideo_args.model_path = model_path for key, value in kwargs.items(): @@ -249,6 +252,11 @@ def get_module(self, module_name: str, default_value: Any = None) -> Any: def add_module(self, module_name: str, module: Any): self.modules[module_name] = module + def __getattr__(self, name: str) -> Any: + if "_stage_name_mapping" in self.__dict__ and name in self._stage_name_mapping: + return self._stage_name_mapping[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + def _load_config(self, model_path: str) -> dict[str, Any]: model_path = maybe_download_model(self.model_path) self.model_path = model_path @@ -455,3 +463,12 @@ def forward( def train(self) -> None: raise NotImplementedError("if training_mode is True, the pipeline must implement this method") + + def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch: + raise NotImplementedError(f"{type(self).__name__} does not support streaming_reset") + + def streaming_step(self, *args: Any, **kwargs: Any) -> ForwardBatch: + raise NotImplementedError(f"{type(self).__name__} does not support streaming_step") + + def streaming_clear(self) -> None: + raise NotImplementedError(f"{type(self).__name__} does not support streaming_clear") diff --git a/fastvideo/pipelines/pipeline_batch_info.py b/fastvideo/pipelines/pipeline_batch_info.py index 8331d663c5..d457fd5ea8 100644 --- a/fastvideo/pipelines/pipeline_batch_info.py +++ b/fastvideo/pipelines/pipeline_batch_info.py @@ -26,18 +26,18 @@ class PipelineLoggingInfo: """Simple approach using OrderedDict to track stage metrics.""" - def __init__(self): + def __init__(self) -> None: # OrderedDict preserves insertion order and allows easy access self.stages: OrderedDict[str, dict[str, Any]] = OrderedDict() - def add_stage_execution_time(self, stage_name: str, execution_time: float): + def add_stage_execution_time(self, stage_name: str, execution_time: float) -> None: """Add execution time for a stage.""" if stage_name not in self.stages: self.stages[stage_name] = {} self.stages[stage_name]['execution_time'] = execution_time self.stages[stage_name]['timestamp'] = time.time() - def add_stage_metric(self, stage_name: str, metric_name: str, value: Any): + def add_stage_metric(self, stage_name: str, metric_name: str, value: Any) -> None: """Add any metric for a stage.""" if stage_name not in self.stages: self.stages[stage_name] = {} @@ -121,9 +121,12 @@ class ForwardBatch: # Latent tensors latents: torch.Tensor | None = None lq_latents: torch.Tensor | None = None - raw_latent_shape: tuple[int, ...] | None = None + raw_latent_shape: Any | None = None noise_pred: torch.Tensor | None = None image_latent: torch.Tensor | None = None + conditioning_latents: torch.Tensor | None = None + cond_latents: torch.Tensor | None = None + padding_mask: torch.Tensor | None = None # Action control inputs (Matrix-Game) mouse_cond: torch.Tensor | None = None # Shape: (B, T, 2) @@ -147,18 +150,28 @@ class ForwardBatch: trajectory_type: str | None = None movement_distance: float | None = None camera_rotation: str | None = None + condition_video_pose: torch.Tensor | None = None + condition_video_input_mask: torch.Tensor | None = None + condition_video_augment_sigma: torch.Tensor | None = None + rendered_warp_images: torch.Tensor | None = None + rendered_warp_masks: torch.Tensor | None = None + input_image_conditioning: torch.Tensor | None = None + cache_3d: Any = None # Latent dimensions - height_latents: list[int] | int | None = None - width_latents: list[int] | int | None = None - num_frames: list[int] | int = 1 # Default for image models + height_latents: Any = None + width_latents: Any = None + num_frames: Any = 1 # Default for image models + latent_height: int | None = None + latent_width: int | None = None + latent_frames: int | None = None # Original dimensions (before VAE scaling) - height: list[int] | int | None = None - width: list[int] | int | None = None - height_sr: list[int] | int | None = None - width_sr: list[int] | int | None = None - fps: list[int] | int | None = None + height: Any = None + width: Any = None + height_sr: Any = None + width_sr: Any = None + fps: Any = None # Timesteps timesteps: torch.Tensor | None = None @@ -170,7 +183,7 @@ class ForwardBatch: num_inference_steps: int = 50 num_inference_steps_sr: int = 50 guidance_scale: float = 1.0 - guidance_scale_2: float | None = None + guidance_scale_2: Any = None guidance_rescale: float = 0.0 eta: float = 0.0 sigmas: list[float] | None = None @@ -196,18 +209,18 @@ class ForwardBatch: extra_step_kwargs: dict[str, Any] = field(default_factory=dict) # Component modules (populated by the pipeline) - modules: dict[str, Any] = field(default_factory=dict) + modules: Any = field(default_factory=dict) # Final output (after pipeline completion) output: torch.Tensor | None = None return_trajectory_latents: bool = False return_trajectory_decoded: bool = False - trajectory_timesteps: list[torch.Tensor] | None = None - trajectory_latents: torch.Tensor | None = None - trajectory_decoded: list[torch.Tensor] | None = None + trajectory_timesteps: Any = None + trajectory_latents: Any = None + trajectory_decoded: Any = None # Extra parameters that might be needed by specific pipeline implementations - extra: dict[str, Any] = field(default_factory=dict) + extra: Any = field(default_factory=dict) # Misc save_video: bool = True @@ -217,11 +230,21 @@ class ForwardBatch: # VSA parameters VSA_sparsity: float = 0.0 + kv_cache_dict: dict[int, Any] = field(default_factory=dict) + use_kv_cache: bool = False + num_cond_latents: int = 0 + num_cond_frames_added: int = 0 + num_noise_frames_added: int = 0 + new_frame_size_before_padding: int | None = None + cond_indicator: torch.Tensor | None = None + uncond_indicator: torch.Tensor | None = None + cond_mask: torch.Tensor | None = None + uncond_mask: torch.Tensor | None = None # Logging info - logging_info: PipelineLoggingInfo = field(default_factory=PipelineLoggingInfo) + logging_info: PipelineLoggingInfo | None = field(default_factory=PipelineLoggingInfo) - def __post_init__(self): + def __post_init__(self) -> None: """Initialize dependent fields after dataclass initialization.""" # Enable CFG for standard guidance_scale and LTX-2 text CFG scales. @@ -233,7 +256,10 @@ def __post_init__(self): if self.guidance_scale_2 is None: self.guidance_scale_2 = self.guidance_scale - def __str__(self): + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __str__(self) -> str: return pprint.pformat(asdict(self), indent=2, width=120) @@ -244,7 +270,7 @@ class TrainingBatch: # Dataloader batch outputs latents: torch.Tensor | None = None - raw_latent_shape: tuple[int, ...] | None = None + raw_latent_shape: Any | None = None noise_latents: torch.Tensor | None = None encoder_hidden_states: torch.Tensor | None = None encoder_attention_mask: torch.Tensor | None = None @@ -256,6 +282,12 @@ class TrainingBatch: audio_encoder_hidden_states: torch.Tensor | None = None audio_encoder_attention_mask: torch.Tensor | None = None conditioning_mask: torch.Tensor | None = None + mouse_cond: torch.Tensor | None = None + keyboard_cond: torch.Tensor | None = None + + def __getattr__(self, name: str) -> Any: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + # i2v preprocessed_image: torch.Tensor | None = None image_embeds: torch.Tensor | None = None @@ -277,20 +309,20 @@ class TrainingBatch: attn_metadata: AttentionMetadata | None = None # input kwargs - input_kwargs: dict[str, Any] | None = None + input_kwargs: Any = None # Training loss loss: torch.Tensor | None = None # Training outputs - total_loss: float | None = None - grad_norm: float | None = None + total_loss: Any = None + grad_norm: Any = None # Distillation-specific attributes encoder_hidden_states_neg: torch.Tensor | None = None encoder_attention_mask_neg: torch.Tensor | None = None - conditional_dict: dict[str, Any] | None = None - unconditional_dict: dict[str, Any] | None = None + conditional_dict: Any = None + unconditional_dict: Any = None # Distillation losses generator_loss: float = 0.0 diff --git a/fastvideo/pipelines/preprocess/v1_preprocess.py b/fastvideo/pipelines/preprocess/v1_preprocess.py index d71354f5b0..a01cdcb6e3 100644 --- a/fastvideo/pipelines/preprocess/v1_preprocess.py +++ b/fastvideo/pipelines/preprocess/v1_preprocess.py @@ -7,6 +7,7 @@ from fastvideo.distributed import (get_world_size, maybe_init_distributed_environment_and_model_parallel) from fastvideo.fastvideo_args import FastVideoArgs from fastvideo.logger import init_logger +from fastvideo.pipelines.preprocess.preprocess_pipeline_base import BasePreprocessPipeline from fastvideo.pipelines.preprocess.preprocess_pipeline_i2v import (PreprocessPipeline_I2V) from fastvideo.pipelines.preprocess.preprocess_pipeline_ode_trajectory import (PreprocessPipeline_ODE_Trajectory) from fastvideo.pipelines.preprocess.preprocess_pipeline_t2v import (PreprocessPipeline_T2V) @@ -46,6 +47,7 @@ def main(args) -> None: text_encoder_cpu_offload=False, pipeline_config=pipeline_config, ) + PreprocessPipeline: type[BasePreprocessPipeline] if args.preprocess_task == "t2v": PreprocessPipeline = PreprocessPipeline_T2V elif args.preprocess_task == "i2v": diff --git a/fastvideo/pipelines/preprocess/v1_preprocessing_new.py b/fastvideo/pipelines/preprocess/v1_preprocessing_new.py index 017af0c74e..2f3a22889f 100644 --- a/fastvideo/pipelines/preprocess/v1_preprocessing_new.py +++ b/fastvideo/pipelines/preprocess/v1_preprocessing_new.py @@ -10,6 +10,8 @@ def main(fastvideo_args: FastVideoArgs) -> None: maybe_init_distributed_environment_and_model_parallel(1, 1) preprocess_workflow_cls = WorkflowBase.get_workflow_cls(fastvideo_args) + if preprocess_workflow_cls is None: + raise ValueError(f"No workflow found for mode {fastvideo_args.mode}") preprocess_workflow = preprocess_workflow_cls(fastvideo_args) preprocess_workflow.run() diff --git a/fastvideo/pipelines/stages/base.py b/fastvideo/pipelines/stages/base.py index 06b6bee273..9043ce024b 100644 --- a/fastvideo/pipelines/stages/base.py +++ b/fastvideo/pipelines/stages/base.py @@ -35,6 +35,8 @@ class PipelineStage(ABC): for a specific part of the process, such as prompt encoding, latent preparation, etc. """ + _streaming_ctx: object | None = None + def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: """ Verify the input for the stage. @@ -110,6 +112,18 @@ def set_logging(self, enable: bool): """ self._enable_logging = enable + def streaming_reset(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> ForwardBatch: + raise NotImplementedError(f"{type(self).__name__} does not support streaming_reset") + + def streaming_step(self, *args: object, **kwargs: object) -> ForwardBatch: + raise NotImplementedError(f"{type(self).__name__} does not support streaming_step") + + def streaming_clear(self) -> None: + raise NotImplementedError(f"{type(self).__name__} does not support streaming_clear") + + def streaming_decode(self, *args: object, **kwargs: object) -> object: + raise NotImplementedError(f"{type(self).__name__} does not support streaming_decode") + def __call__( self, batch: ForwardBatch, diff --git a/fastvideo/pipelines/stages/decoding.py b/fastvideo/pipelines/stages/decoding.py index cc88fbd7e2..83c1c87066 100644 --- a/fastvideo/pipelines/stages/decoding.py +++ b/fastvideo/pipelines/stages/decoding.py @@ -222,7 +222,7 @@ def forward( if hasattr(batch, 'num_cond_frames_added') and hasattr(batch, 'new_frame_size_before_padding'): num_cond_frames_added = batch.num_cond_frames_added new_frame_size = batch.new_frame_size_before_padding - if num_cond_frames_added > 0 or frames.shape[2] != new_frame_size: + if new_frame_size is not None and (num_cond_frames_added > 0 or frames.shape[2] != new_frame_size): # frames is [B, C, T, H, W], crop temporal dimension frames = frames[:, :, num_cond_frames_added:num_cond_frames_added + new_frame_size, :, :] logger.info("Cropped LongCat refinement padding: %s:%s, final shape: %s", num_cond_frames_added, diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 1efeec7348..30eaaa808c 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -61,9 +61,10 @@ def __init__(self, transformer, scheduler, pipeline=None, transformer_2=None, va self.attn_backend = get_attn_backend( head_size=attn_head_size, dtype=torch.float16, # TODO(will): hack - supported_attention_backends=(AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.BSA_ATTN, - AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.FLASH_ATTN, - AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN_THREE) # hack + supported_attention_backends=( + AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.BSA_ATTN, AttentionBackendEnum.VMOBA_ATTN, + AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, AttentionBackendEnum.SAGE_ATTN_THREE, + AttentionBackendEnum.ATTN_QAT_INFER, AttentionBackendEnum.ATTN_QAT_TRAIN) # hack ) def forward( @@ -205,6 +206,8 @@ def forward( temporal_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_temporal spatial_scale = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size + if not isinstance(patch_size, tuple): + raise ValueError(f"Expected 3D patch_size tuple for denoising, got {patch_size!r}") seq_len = ((F - 1) // temporal_scale + 1) * (batch.height // spatial_scale) * ( batch.width // spatial_scale) // (patch_size[1] * patch_size[2]) @@ -317,6 +320,7 @@ def forward( self.attn_metadata_builder = self.attn_metadata_builder_cls() # Prepare V-MoBA parameters from config moba_params = fastvideo_args.moba_config.copy() + assert batch.raw_latent_shape is not None, "raw_latent_shape must be set for V-MoBA" moba_params.update({ "current_timestep": i, "raw_latent_shape": batch.raw_latent_shape[2:5], @@ -1139,9 +1143,9 @@ def forward( if i < len(timesteps) - 1: next_timestep = timesteps[i + 1] * torch.ones([1], dtype=torch.long, device=pred_video.device) - noise = torch.randn(video_raw_latent_shape, - dtype=pred_video.dtype, - generator=batch.generator[0]).to(self.device) + noise_generator = batch.generator[0] if isinstance(batch.generator, list) else batch.generator + noise = torch.randn(video_raw_latent_shape, dtype=pred_video.dtype, + generator=noise_generator).to(self.device) latents = self.scheduler.add_noise(pred_video.flatten(0, 1), noise.flatten(0, 1), next_timestep).unflatten(0, pred_video.shape[:2]) else: diff --git a/fastvideo/pipelines/stages/gamecraft_denoising.py b/fastvideo/pipelines/stages/gamecraft_denoising.py index 3bac8d95a6..e3ce442f71 100644 --- a/fastvideo/pipelines/stages/gamecraft_denoising.py +++ b/fastvideo/pipelines/stages/gamecraft_denoising.py @@ -133,6 +133,7 @@ def forward( # Get latents and embeddings latents = batch.latents + assert latents is not None, "latents must be initialized before GameCraft denoising" prompt_embeds = batch.prompt_embeds assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan" diff --git a/fastvideo/pipelines/stages/gen3c_stages.py b/fastvideo/pipelines/stages/gen3c_stages.py index e120a5cc09..f374a4c3f0 100644 --- a/fastvideo/pipelines/stages/gen3c_stages.py +++ b/fastvideo/pipelines/stages/gen3c_stages.py @@ -4,7 +4,7 @@ separate from pipeline orchestration. """ -from typing import Any +from typing import Any, cast import torch from diffusers.utils.torch_utils import randn_tensor @@ -109,14 +109,18 @@ def forward( height = getattr(batch, 'height', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[0] width = getattr(batch, 'width', None) or getattr(pipeline_config, 'video_resolution', (720, 1280))[1] - num_frames = getattr(batch, 'num_frames', None) or getattr(pipeline_config, 'num_frames', 121) - - trajectory_type = (getattr(batch, 'trajectory_type', None) or batch_extra.get("trajectory_type") - or getattr(pipeline_config, 'default_trajectory_type', 'left')) - movement_distance = (getattr(batch, 'movement_distance', None) or batch_extra.get("movement_distance") - or getattr(pipeline_config, 'default_movement_distance', 0.3)) - camera_rotation = (getattr(batch, 'camera_rotation', None) or batch_extra.get("camera_rotation") - or getattr(pipeline_config, 'default_camera_rotation', 'center_facing')) + num_frames_value = getattr(batch, 'num_frames', None) or getattr(pipeline_config, 'num_frames', 121) + num_frames = int(cast(int, num_frames_value)) + + trajectory_type = str( + getattr(batch, 'trajectory_type', None) or batch_extra.get("trajectory_type") + or getattr(pipeline_config, 'default_trajectory_type', 'left')) + movement_distance_value = (getattr(batch, 'movement_distance', None) or batch_extra.get("movement_distance") + or getattr(pipeline_config, 'default_movement_distance', 0.3)) + movement_distance = float(cast(float, movement_distance_value)) + camera_rotation = str( + getattr(batch, 'camera_rotation', None) or batch_extra.get("camera_rotation") + or getattr(pipeline_config, 'default_camera_rotation', 'center_facing')) frame_buffer_max = getattr(pipeline_config, 'frame_buffer_max', 2) noise_aug_strength = getattr(pipeline_config, 'noise_aug_strength', 0.0) @@ -488,6 +492,7 @@ class Gen3CDenoisingStage(DenoisingStage): def __init__(self, transformer, scheduler, pipeline=None) -> None: super().__init__(transformer, scheduler, pipeline) + self.callback_on_step_end_tensor_inputs: tuple[str, ...] = () def _has_edm_preconditioning(self) -> bool: return hasattr(self.scheduler, "precondition_inputs") @@ -582,6 +587,7 @@ def forward( autocast_enabled = (target_dtype != torch.float32) and not fastvideo_args.disable_autocast latents = batch.latents + assert latents is not None, "latents must be initialized before GEN3C denoising" num_inference_steps = batch.num_inference_steps guidance_scale = batch.guidance_scale fps = getattr(fastvideo_args.pipeline_config, 'fps', 24) @@ -622,6 +628,7 @@ def forward( )) timestep = t.flatten().expand(latents.size(0)) + assert batch.batch_size is not None and batch.height is not None and batch.width is not None padding_mask = torch.zeros( batch.batch_size, 1, diff --git a/fastvideo/pipelines/stages/hyworld_denoising.py b/fastvideo/pipelines/stages/hyworld_denoising.py index a0be6b93a3..c62afc3758 100644 --- a/fastvideo/pipelines/stages/hyworld_denoising.py +++ b/fastvideo/pipelines/stages/hyworld_denoising.py @@ -143,6 +143,7 @@ def forward( # Get latents and embeddings latents = batch.latents + assert latents is not None, "latents must be initialized before HYWorld denoising" prompt_embeds = batch.prompt_embeds assert not torch.isnan(prompt_embeds[0]).any(), "prompt_embeds contains nan" if batch.do_classifier_free_guidance: @@ -163,6 +164,8 @@ def forward( # Use conditional latents directly (prepared by HYWorldImageEncodingStage) # batch.image_latent is already [1, 33, T, H, W] with first frame encoded, rest zeros cond_latents = batch.image_latent + assert cond_latents is not None, "image_latent must be initialized before HYWorld denoising" + assert viewmats is not None and Ks is not None and action is not None # Calculate chunk configuration latent_frames = latents.shape[2] diff --git a/fastvideo/pipelines/stages/image_encoding.py b/fastvideo/pipelines/stages/image_encoding.py index 0f7e94e0b3..c7e0fa09f4 100644 --- a/fastvideo/pipelines/stages/image_encoding.py +++ b/fastvideo/pipelines/stages/image_encoding.py @@ -9,6 +9,8 @@ - VideoVAEEncodingStage: Encodes videos to latent space using VAE for V2V and control tasks """ +from typing import Any + import PIL import torch @@ -110,6 +112,7 @@ def forward(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> Forward if batch.pil_image is None: batch.image_embeds = [torch.zeros(1, 729, 1152, device=get_local_torch_device())] + assert batch.raw_latent_shape is not None, "raw_latent_shape must be initialized before image encoding" raw_latent_shape = list(batch.raw_latent_shape) raw_latent_shape[1] = 1 batch.video_latent = torch.zeros(tuple(raw_latent_shape), device=get_local_torch_device()) @@ -124,7 +127,7 @@ class HYWorldImageEncodingStage(ImageEncodingStage): Also encodes reference image with VAE for conditional latent. """ - def __init__(self, image_encoder=None, image_processor=None, vae=None): + def __init__(self, image_encoder: Any = None, image_processor: Any = None, vae: Any = None) -> None: super().__init__(image_encoder=image_encoder, image_processor=image_processor) self.vae = vae @@ -155,6 +158,7 @@ def forward(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> Forward vision_dim = 1152 # SigLIP hidden size # Get temporal dimension from raw_latent_shape (set by LatentPreparationStage) + assert batch.raw_latent_shape is not None, "raw_latent_shape must be initialized before HYWorld image encoding" raw_latent_shape = list(batch.raw_latent_shape) latent_channels = raw_latent_shape[1] latent_temporal = raw_latent_shape[2] # T dimension diff --git a/fastvideo/pipelines/stages/input_validation.py b/fastvideo/pipelines/stages/input_validation.py index 0490954e3a..6705e241c9 100644 --- a/fastvideo/pipelines/stages/input_validation.py +++ b/fastvideo/pipelines/stages/input_validation.py @@ -105,6 +105,8 @@ def forward( else: # Standard Wan logic patch_size = fastvideo_args.pipeline_config.dit_config.arch_config.patch_size + if not isinstance(patch_size, tuple): + raise ValueError(f"Expected 3D patch_size tuple for ti2v preprocessing, got {patch_size!r}") vae_stride = fastvideo_args.pipeline_config.vae_config.arch_config.scale_factor_spatial dh, dw = patch_size[1] * vae_stride, patch_size[2] * vae_stride max_area = 480 * 832 diff --git a/fastvideo/pipelines/stages/latent_preparation.py b/fastvideo/pipelines/stages/latent_preparation.py index 5f60409ded..d31ce55da9 100644 --- a/fastvideo/pipelines/stages/latent_preparation.py +++ b/fastvideo/pipelines/stages/latent_preparation.py @@ -4,7 +4,7 @@ """ import os -from typing import Any +from typing import Any, cast import numpy as np import torch @@ -652,10 +652,12 @@ def forward( return batch def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: - return Cosmos25LatentPreparationStage.verify_input(self, batch, fastvideo_args) # type: ignore[misc] + typed_self = cast(Cosmos25LatentPreparationStage, self) + return Cosmos25LatentPreparationStage.verify_input(typed_self, batch, fastvideo_args) def verify_output(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult: - return Cosmos25LatentPreparationStage.verify_output(self, batch, fastvideo_args) # type: ignore[misc] + typed_self = cast(Cosmos25LatentPreparationStage, self) + return Cosmos25LatentPreparationStage.verify_output(typed_self, batch, fastvideo_args) class Cosmos25V2WLatentPreparationStage(Cosmos25LatentPreparationStage): diff --git a/fastvideo/pipelines/stages/longcat_denoising.py b/fastvideo/pipelines/stages/longcat_denoising.py index 758e938a80..1f78a524ff 100644 --- a/fastvideo/pipelines/stages/longcat_denoising.py +++ b/fastvideo/pipelines/stages/longcat_denoising.py @@ -82,6 +82,8 @@ def forward( # Extract batch parameters latents = batch.latents timesteps = batch.timesteps + assert latents is not None, "latents must be initialized before LongCat denoising" + assert timesteps is not None, "timesteps must be initialized before LongCat denoising" prompt_embeds = batch.prompt_embeds[0] # LongCat uses single encoder prompt_attention_mask = batch.prompt_attention_mask[0] if batch.prompt_attention_mask else None guidance_scale = batch.guidance_scale @@ -89,6 +91,7 @@ def forward( # Get negative prompts if doing CFG if do_classifier_free_guidance: + assert batch.negative_prompt_embeds is not None, "negative_prompt_embeds required for CFG" negative_prompt_embeds = batch.negative_prompt_embeds[0] negative_prompt_attention_mask = (batch.negative_attention_mask[0] if batch.negative_attention_mask else None) diff --git a/fastvideo/pipelines/stages/longcat_i2v_denoising.py b/fastvideo/pipelines/stages/longcat_i2v_denoising.py index 27292be180..fbf4f694ff 100644 --- a/fastvideo/pipelines/stages/longcat_i2v_denoising.py +++ b/fastvideo/pipelines/stages/longcat_i2v_denoising.py @@ -51,6 +51,8 @@ def forward( latents = batch.latents timesteps = batch.timesteps + assert latents is not None, "latents must be initialized before I2V denoising" + assert timesteps is not None, "timesteps must be initialized before I2V denoising" prompt_embeds = batch.prompt_embeds[0] prompt_attention_mask = (batch.prompt_attention_mask[0] if batch.prompt_attention_mask else None) guidance_scale = batch.guidance_scale @@ -64,6 +66,7 @@ def forward( # Prepare negative prompts for CFG if do_classifier_free_guidance: + assert batch.negative_prompt_embeds is not None, "negative_prompt_embeds required for CFG" negative_prompt_embeds = batch.negative_prompt_embeds[0] negative_prompt_attention_mask = (batch.negative_attention_mask[0] if batch.negative_attention_mask else None) diff --git a/fastvideo/pipelines/stages/longcat_image_vae_encoding.py b/fastvideo/pipelines/stages/longcat_image_vae_encoding.py index 4642f6feae..67ccf36497 100644 --- a/fastvideo/pipelines/stages/longcat_image_vae_encoding.py +++ b/fastvideo/pipelines/stages/longcat_image_vae_encoding.py @@ -6,6 +6,8 @@ LongCat-specific normalization for I2V conditioning. """ +from typing import Any + import PIL import torch @@ -31,7 +33,7 @@ class LongCatImageVAEEncodingStage(PipelineStage): 4. Stores latent and calculates num_cond_latents """ - def __init__(self, vae): + def __init__(self, vae: Any) -> None: super().__init__() self.vae = vae @@ -115,7 +117,7 @@ def forward( return batch - def retrieve_latents(self, encoder_output: object, generator: torch.Generator | None) -> torch.Tensor: + def retrieve_latents(self, encoder_output: Any, generator: torch.Generator | None) -> torch.Tensor: """Sample from VAE posterior.""" # WAN VAE returns an object with .sample() method if hasattr(encoder_output, 'sample'): diff --git a/fastvideo/pipelines/stages/longcat_kv_cache_init.py b/fastvideo/pipelines/stages/longcat_kv_cache_init.py index c85d0ce2b8..dabcf75cda 100644 --- a/fastvideo/pipelines/stages/longcat_kv_cache_init.py +++ b/fastvideo/pipelines/stages/longcat_kv_cache_init.py @@ -6,6 +6,8 @@ """ +from typing import Any + import torch from fastvideo.fastvideo_args import FastVideoArgs @@ -27,7 +29,7 @@ class LongCatKVCacheInitStage(PipelineStage): - batch.latents contains ONLY noise latents """ - def __init__(self, transformer): + def __init__(self, transformer: Any) -> None: super().__init__() self.transformer = transformer @@ -57,6 +59,7 @@ def forward( return batch # Extract conditioning latents + assert batch.latents is not None, "latents must be initialized before KV-cache setup" cond_latents = batch.latents[:, :, :num_cond_latents].clone() logger.info("Initializing KV cache for %d conditioning latents, shape: %s", num_cond_latents, diff --git a/fastvideo/pipelines/stages/longcat_refine_init.py b/fastvideo/pipelines/stages/longcat_refine_init.py index 25cac2b05a..1d07bc12fb 100644 --- a/fastvideo/pipelines/stages/longcat_refine_init.py +++ b/fastvideo/pipelines/stages/longcat_refine_init.py @@ -83,6 +83,8 @@ def forward( else: # Path-based refine: load video from disk (original design) logger.info("Initializing LongCat refinement from file: %s", refine_from) + if refine_from is None: + raise ValueError("refine_from must be provided when stage1_video is not set") stage1_video_path = Path(refine_from) if not stage1_video_path.exists(): raise FileNotFoundError(f"Stage1 video not found: {refine_from}") diff --git a/fastvideo/pipelines/stages/longcat_vc_denoising.py b/fastvideo/pipelines/stages/longcat_vc_denoising.py index ee6b4c4a29..d1547208c8 100644 --- a/fastvideo/pipelines/stages/longcat_vc_denoising.py +++ b/fastvideo/pipelines/stages/longcat_vc_denoising.py @@ -60,6 +60,8 @@ def forward( latents = batch.latents timesteps = batch.timesteps + assert latents is not None, "latents must be initialized before VC denoising" + assert timesteps is not None, "timesteps must be initialized before VC denoising" prompt_embeds = batch.prompt_embeds[0] prompt_attention_mask = (batch.prompt_attention_mask[0] if batch.prompt_attention_mask else None) guidance_scale = batch.guidance_scale @@ -75,6 +77,7 @@ def forward( # Prepare negative prompts for CFG if do_classifier_free_guidance: + assert batch.negative_prompt_embeds is not None, "negative_prompt_embeds required for CFG" negative_prompt_embeds = batch.negative_prompt_embeds[0] negative_prompt_attention_mask = (batch.negative_attention_mask[0] if batch.negative_attention_mask else None) diff --git a/fastvideo/pipelines/stages/longcat_video_vae_encoding.py b/fastvideo/pipelines/stages/longcat_video_vae_encoding.py index 8905dc0821..14062b2d03 100644 --- a/fastvideo/pipelines/stages/longcat_video_vae_encoding.py +++ b/fastvideo/pipelines/stages/longcat_video_vae_encoding.py @@ -35,7 +35,7 @@ class LongCatVideoVAEEncodingStage(PipelineStage): 6. Calculates num_cond_latents """ - def __init__(self, vae): + def __init__(self, vae: Any) -> None: super().__init__() self.vae = vae diff --git a/fastvideo/pipelines/stages/ltx2_text_encoding.py b/fastvideo/pipelines/stages/ltx2_text_encoding.py index c409e29da3..20a6c4573a 100644 --- a/fastvideo/pipelines/stages/ltx2_text_encoding.py +++ b/fastvideo/pipelines/stages/ltx2_text_encoding.py @@ -59,10 +59,13 @@ def forward( else: logger.info("[LTX2TextEncodingStage] SP rank %d: receiving broadcast", sp_rank) # Other ranks: receive broadcast and populate batch - broadcast_dict = sp_group.broadcast_tensor_dict(None, src=0) + broadcast_dict_maybe = sp_group.broadcast_tensor_dict(None, src=0) + if broadcast_dict_maybe is None: + raise RuntimeError("Sequence-parallel text broadcast returned no data on non-zero rank") + received_broadcast_dict: dict[str, torch.Tensor] = broadcast_dict_maybe # Unpack into batch - self._unpack_broadcast_to_batch(batch, broadcast_dict) + self._unpack_broadcast_to_batch(batch, received_broadcast_dict) logger.info("[LTX2TextEncodingStage] SP rank %d: received %d prompt embeds", sp_rank, len(batch.prompt_embeds)) @@ -86,6 +89,7 @@ def _build_broadcast_dict(self, batch: ForwardBatch) -> dict[str, torch.Tensor]: has_prompt_masks = (batch.prompt_attention_mask is not None and len(batch.prompt_attention_mask) > 0) d["_has_prompt_masks"] = torch.tensor([1 if has_prompt_masks else 0], device=device) if has_prompt_masks: + assert batch.prompt_attention_mask is not None for i, pm in enumerate(batch.prompt_attention_mask): d[f"prompt_mask_{i}"] = pm @@ -93,6 +97,7 @@ def _build_broadcast_dict(self, batch: ForwardBatch) -> dict[str, torch.Tensor]: has_neg_embeds = (batch.negative_prompt_embeds is not None and len(batch.negative_prompt_embeds) > 0) d["_has_neg_embeds"] = torch.tensor([1 if has_neg_embeds else 0], device=device) if has_neg_embeds: + assert batch.negative_prompt_embeds is not None d["_num_neg_embeds"] = torch.tensor([len(batch.negative_prompt_embeds)], device=device) for i, ne in enumerate(batch.negative_prompt_embeds): d[f"neg_embed_{i}"] = ne @@ -101,6 +106,7 @@ def _build_broadcast_dict(self, batch: ForwardBatch) -> dict[str, torch.Tensor]: has_neg_masks = (batch.negative_attention_mask is not None and len(batch.negative_attention_mask) > 0) d["_has_neg_masks"] = torch.tensor([1 if has_neg_masks else 0], device=device) if has_neg_masks: + assert batch.negative_attention_mask is not None for i, nm in enumerate(batch.negative_attention_mask): d[f"neg_mask_{i}"] = nm diff --git a/fastvideo/pipelines/stages/matrixgame_denoising.py b/fastvideo/pipelines/stages/matrixgame_denoising.py index 7e1496857b..24790f7de2 100644 --- a/fastvideo/pipelines/stages/matrixgame_denoising.py +++ b/fastvideo/pipelines/stages/matrixgame_denoising.py @@ -669,9 +669,11 @@ def streaming_step(self, start_frame = 0 if start_index == 0 else 1 + vae_ratio * (start_index - 1) if keyboard_action is not None: + assert batch.keyboard_cond is not None, "keyboard_cond must be initialized for streaming updates" n = keyboard_action.shape[1] batch.keyboard_cond[:, start_frame:start_frame + n] = keyboard_action.to(batch.keyboard_cond.device) if mouse_action is not None: + assert batch.mouse_cond is not None, "mouse_cond must be initialized for streaming updates" n = mouse_action.shape[1] batch.mouse_cond[:, start_frame:start_frame + n] = mouse_action.to(batch.mouse_cond.device) diff --git a/fastvideo/pipelines/stages/sr_denoising.py b/fastvideo/pipelines/stages/sr_denoising.py index 0102770d7b..2301fa946c 100644 --- a/fastvideo/pipelines/stages/sr_denoising.py +++ b/fastvideo/pipelines/stages/sr_denoising.py @@ -63,7 +63,8 @@ def __init__(self, transformer, scheduler, upsampler, pipeline=None, vae=None) - dtype=torch.float16, # TODO(will): hack supported_attention_backends=(AttentionBackendEnum.VIDEO_SPARSE_ATTN, AttentionBackendEnum.VMOBA_ATTN, AttentionBackendEnum.FLASH_ATTN, AttentionBackendEnum.TORCH_SDPA, - AttentionBackendEnum.SAGE_ATTN_THREE) # hack + AttentionBackendEnum.SAGE_ATTN_THREE, AttentionBackendEnum.ATTN_QAT_INFER, + AttentionBackendEnum.ATTN_QAT_TRAIN) # hack ) def add_noise_to_lq(self, lq_latents: torch.Tensor, strength: float = 0.7) -> torch.Tensor: @@ -226,6 +227,7 @@ def forward( self.attn_metadata_builder = self.attn_metadata_builder_cls() # Prepare V-MoBA parameters from config moba_params = fastvideo_args.moba_config.copy() + assert batch.raw_latent_shape is not None, "raw_latent_shape must be set for V-MoBA SR" moba_params.update({ "current_timestep": i, "raw_latent_shape": batch.raw_latent_shape[2:5], diff --git a/fastvideo/pipelines/stages/text_encoding.py b/fastvideo/pipelines/stages/text_encoding.py index b62d98bbcd..37337af0cb 100644 --- a/fastvideo/pipelines/stages/text_encoding.py +++ b/fastvideo/pipelines/stages/text_encoding.py @@ -85,9 +85,6 @@ def forward( encoder_index=all_indices, return_attention_mask=True, ) - if self._last_audio_embeds is not None: - batch.extra["ltx2_audio_negative_embeds"] = self._last_audio_embeds - assert batch.negative_prompt_embeds is not None for ne in neg_embeds_list: batch.negative_prompt_embeds.append(ne) @@ -186,7 +183,6 @@ def encode_text( postprocess_funcs = fastvideo_args.pipeline_config.postprocess_text_funcs encoder_cfgs = fastvideo_args.pipeline_config.text_encoder_configs is_ltx2 = getattr(fastvideo_args.pipeline_config.dit_config, "prefix", "") == "ltx2" - if return_type not in ("list", "dict", "stack"): raise ValueError(f"Invalid return_type '{return_type}'. Expected one of: 'list', 'dict', 'stack'") diff --git a/fastvideo/pipelines/stages/validators.py b/fastvideo/pipelines/stages/validators.py index 89300a0b84..c44d6803cd 100644 --- a/fastvideo/pipelines/stages/validators.py +++ b/fastvideo/pipelines/stages/validators.py @@ -376,6 +376,13 @@ def add_check(self, field_name: str, value: Any, return self + def add_failure(self, field_name: str, failure: ValidationFailure | str) -> "VerificationResult": + self._checks[field_name] = False + if isinstance(failure, str): + failure = ValidationFailure(validator_name="manual_failure", actual_value=None, error_msg=failure) + self._failures.setdefault(field_name, []).append(failure) + return self + def _create_validation_failure(self, validator: Callable, value: Any) -> ValidationFailure: """Create a ValidationFailure with detailed information.""" validator_name = getattr(validator, '__name__', str(validator)) diff --git a/fastvideo/platforms/cuda.py b/fastvideo/platforms/cuda.py index aa6c10ec21..0232de36fb 100644 --- a/fastvideo/platforms/cuda.py +++ b/fastvideo/platforms/cuda.py @@ -102,7 +102,7 @@ def get_current_memory_usage(cls, device: torch.types.Device | None = None) -> f return float(torch.cuda.max_memory_allocated(device)) @classmethod - def get_torch_device(cls): + def get_torch_device(cls) -> object: """ Return torch.cuda """ @@ -140,6 +140,31 @@ def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, hea except ImportError as e: logger.info(e) logger.info("Sage Attention 3 backend is not installed. Fall back to Flash Attention.") + elif selected_backend == AttentionBackendEnum.ATTN_QAT_INFER: + try: + from fastvideo.attention.backends.attn_qat_infer import ( # noqa: F401 + AttnQatInferBackend, is_attn_qat_infer_available, + ) + if not is_attn_qat_infer_available(): + raise ImportError("attn_qat_infer could not be imported.") + logger.info("Using attn_qat_infer backend.") + + return "fastvideo.attention.backends.attn_qat_infer.AttnQatInferBackend" + except ImportError as e: + logger.info(e) + logger.info("attn_qat_infer backend is not installed. Fall back to Flash Attention.") + elif selected_backend == AttentionBackendEnum.ATTN_QAT_TRAIN: + try: + from fastvideo_kernel.triton_kernels.attn_qat_train import attention # noqa: F401 + + from fastvideo.attention.backends.attn_qat_train import ( # noqa: F401 + AttnQatTrainBackend) + logger.info("Using attn_qat_train backend.") + + return "fastvideo.attention.backends.attn_qat_train.AttnQatTrainBackend" + except ImportError as e: + logger.info(e) + logger.info("attn_qat_train backend is not installed. Fall back to Flash Attention.") elif selected_backend == AttentionBackendEnum.VIDEO_SPARSE_ATTN: try: from fastvideo_kernel import video_sparse_attn # noqa: F401 diff --git a/fastvideo/platforms/interface.py b/fastvideo/platforms/interface.py index 60e2441f62..b4b433b7da 100644 --- a/fastvideo/platforms/interface.py +++ b/fastvideo/platforms/interface.py @@ -1,6 +1,6 @@ import enum import random -from typing import NamedTuple +from typing import Any, NamedTuple import numpy as np import torch @@ -15,6 +15,8 @@ class AttentionBackendEnum(enum.Enum): TORCH_SDPA = enum.auto() SAGE_ATTN = enum.auto() SAGE_ATTN_THREE = enum.auto() + ATTN_QAT_INFER = enum.auto() + ATTN_QAT_TRAIN = enum.auto() VIDEO_SPARSE_ATTN = enum.auto() BSA_ATTN = enum.auto() VMOBA_ATTN = enum.auto() @@ -174,7 +176,7 @@ def is_async_output_supported(cls, enforce_eager: bool | None) -> bool: raise NotImplementedError @classmethod - def get_torch_device(cls): + def get_torch_device(cls) -> Any: """ Check if the current platform supports torch device. """ diff --git a/fastvideo/platforms/npu.py b/fastvideo/platforms/npu.py index 23e162dd17..b72b200f63 100644 --- a/fastvideo/platforms/npu.py +++ b/fastvideo/platforms/npu.py @@ -87,7 +87,7 @@ def is_pin_memory_available(cls): return True @classmethod - def get_torch_device(cls): + def get_torch_device(cls) -> object: """ Return torch.npu """ diff --git a/fastvideo/platforms/rocm.py b/fastvideo/platforms/rocm.py index 4611046219..a179885e41 100644 --- a/fastvideo/platforms/rocm.py +++ b/fastvideo/platforms/rocm.py @@ -54,7 +54,7 @@ def get_current_memory_usage(cls, device: torch.device | None = None) -> float: return float(torch.cuda.max_memory_allocated(device)) @classmethod - def get_torch_device(cls): + def get_torch_device(cls) -> object: """ Return torch.cuda """ @@ -72,7 +72,7 @@ def get_attn_backend_cls(cls, selected_backend: AttentionBackendEnum | None, hea elif selected_backend in (AttentionBackendEnum.FLASH_ATTN, None): pass - elif selected_backend in (AttentionBackendEnum.SAGE_ATTN): + elif selected_backend in (AttentionBackendEnum.SAGE_ATTN, ): raise ValueError(f"{selected_backend.name} is not supported on {cls.device_name}.") elif selected_backend: raise ValueError(f"Invalid attention backend for {cls.device_name}: {selected_backend}") diff --git a/fastvideo/registry.py b/fastvideo/registry.py index d37a4d639c..20ab445099 100644 --- a/fastvideo/registry.py +++ b/fastvideo/registry.py @@ -97,7 +97,7 @@ # --- Part 1: Pipeline Discovery --- -_PIPELINE_REGISTRY: dict[str, dict[str, type[ComposedPipelineBase]]] = {} +_PIPELINE_REGISTRY: dict[str, dict[str, type[ComposedPipelineBase] | None]] = {} # Registry for pipeline configuration classes (for single-file weights without # model_index.json). Maps pipeline_class_name -> (PipelineConfig, SamplingParam) @@ -250,9 +250,6 @@ def _register_configs() -> None: "FastVideo/LTX2-base", "FastVideo/LTX2-Diffusers", ], - model_detectors=[ - lambda path: ("ltx2" in path.lower() or "ltx-2" in path.lower()) and "distilled" not in path.lower(), - ], ) # LTX-2 (distilled) register_configs( @@ -262,9 +259,6 @@ def _register_configs() -> None: hf_model_paths=[ "FastVideo/LTX2-Distilled-Diffusers", ], - model_detectors=[ - lambda path: ("ltx2" in path.lower() or "ltx-2" in path.lower()) and "distilled" in path.lower(), - ], ) # Hunyuan 1.5 (specific) diff --git a/fastvideo/tests/api/test_schema_parity_inventory.py b/fastvideo/tests/api/test_schema_parity_inventory.py index a0c1610881..a30061df07 100644 --- a/fastvideo/tests/api/test_schema_parity_inventory.py +++ b/fastvideo/tests/api/test_schema_parity_inventory.py @@ -235,8 +235,8 @@ def test_cli_dest_inventory_matches_live_parsers() -> None: def test_review_gap_fields_are_explicitly_inventory_tracked() -> None: inventory = _load_inventory() - sampling_extensions = inventory["surfaces"]["sampling_param_extensions"] - assert "guidance_scale_2" in sampling_extensions["moved"] + sampling_base = inventory["surfaces"]["sampling_param_base"] + assert "guidance_scale_2" in sampling_base["moved"] image_request = inventory["surfaces"]["openai_image_request"] video_request = inventory["surfaces"]["openai_video_request"] diff --git a/fastvideo/tests/attention/test_flash_attn_cute_custom_op.py b/fastvideo/tests/attention/test_flash_attn_cute_custom_op.py index a3f7bb476a..7c266f2c13 100644 --- a/fastvideo/tests/attention/test_flash_attn_cute_custom_op.py +++ b/fastvideo/tests/attention/test_flash_attn_cute_custom_op.py @@ -220,4 +220,4 @@ def test_flash_attn_varlen_func_parity_forward_backward( _assert_close(dq_test, dq_ref, dtype=dtype, is_grad=True) _assert_close(dk_test, dk_ref, dtype=dtype, is_grad=True) - _assert_close(dv_test, dv_ref, dtype=dtype, is_grad=True) + _assert_close(dv_test, dv_ref, dtype=dtype, is_grad=True) \ No newline at end of file diff --git a/fastvideo/tests/entrypoints/test_video_generator.py b/fastvideo/tests/entrypoints/test_video_generator.py index c0f79580b1..f1f1de9d07 100644 --- a/fastvideo/tests/entrypoints/test_video_generator.py +++ b/fastvideo/tests/entrypoints/test_video_generator.py @@ -9,6 +9,7 @@ GenerationResult, GeneratorConfig, InputConfig, + QuantizationConfig, SamplingConfig, ) from fastvideo.configs.sample import SamplingParam @@ -152,9 +153,10 @@ def test_prepare_output_path_empty_prompt_fallback(tmp_path): def test_from_config_normalizes_and_translates(monkeypatch): captured = _patch_from_fastvideo_args(monkeypatch) - _patch_fastvideo_args_from_kwargs(monkeypatch) + fastvideo_args_capture = _patch_fastvideo_args_from_kwargs(monkeypatch) config = GeneratorConfig(model_path="test-model") config.engine.num_gpus = 2 + config.engine.quantization = QuantizationConfig(transformer_quant="fp4") config.pipeline.workload_type = "t2v" generator = VideoGenerator.from_config(config) @@ -162,6 +164,7 @@ def test_from_config_normalizes_and_translates(monkeypatch): assert captured["fastvideo_args"].model_path == "test-model" assert captured["fastvideo_args"].num_gpus == 2 assert captured["fastvideo_args"].workload_type.value == "t2v" + assert fastvideo_args_capture["kwargs"]["transformer_quant"] == "fp4" assert generator.config == config @@ -195,6 +198,7 @@ def test_from_pretrained_convenience_kwargs_do_not_warn(monkeypatch): "test-model", num_gpus=4, use_fsdp_inference=False, + transformer_quant="fp4", text_encoder_cpu_offload=True, pin_cpu_memory=True, dit_cpu_offload=False, @@ -205,6 +209,7 @@ def test_from_pretrained_convenience_kwargs_do_not_warn(monkeypatch): assert captured["fastvideo_args"].model_path == "test-model" assert captured["fastvideo_args"].num_gpus == 4 assert fastvideo_args_capture["kwargs"]["use_fsdp_inference"] is False + assert fastvideo_args_capture["kwargs"]["transformer_quant"] == "fp4" assert fastvideo_args_capture["kwargs"]["text_encoder_cpu_offload"] is True assert fastvideo_args_capture["kwargs"]["pin_cpu_memory"] is True assert fastvideo_args_capture["kwargs"]["dit_cpu_offload"] is False @@ -212,6 +217,8 @@ def test_from_pretrained_convenience_kwargs_do_not_warn(monkeypatch): assert generator.config is not None assert generator.config.model_path == "test-model" assert generator.config.engine.num_gpus == 4 + assert generator.config.engine.quantization is not None + assert generator.config.engine.quantization.transformer_quant == "fp4" def test_from_pretrained_legacy_only_kwargs_warn(monkeypatch): diff --git a/fastvideo/train/callbacks/validation.py b/fastvideo/train/callbacks/validation.py index 41697336d5..ae6780cfb9 100644 --- a/fastvideo/train/callbacks/validation.py +++ b/fastvideo/train/callbacks/validation.py @@ -11,7 +11,7 @@ import contextlib import os from dataclasses import dataclass -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, cast import imageio import numpy as np @@ -29,6 +29,7 @@ ) from fastvideo.logger import init_logger from fastvideo.pipelines import ForwardBatch +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase from fastvideo.train.callbacks.callback import Callback from fastvideo.train.utils.instantiate import resolve_target from fastvideo.train.utils.moduleloader import ( @@ -272,7 +273,7 @@ def _get_pipeline( return self._pipeline tc = self.training_config - PipelineCls = resolve_target(self.pipeline_target) + PipelineCls = cast(type[ComposedPipelineBase], resolve_target(self.pipeline_target)) flow_shift = getattr( tc.pipeline_config, "flow_shift", diff --git a/fastvideo/train/entrypoint/dcp_to_diffusers.py b/fastvideo/train/entrypoint/dcp_to_diffusers.py index f2ef982817..4db98d966d 100644 --- a/fastvideo/train/entrypoint/dcp_to_diffusers.py +++ b/fastvideo/train/entrypoint/dcp_to_diffusers.py @@ -228,6 +228,8 @@ def convert( checkpoint_dir, output_dir=checkpoint_dir, ) + if resolved is None: + raise FileNotFoundError(f"Could not resolve checkpoint directory from {checkpoint_dir!r}") dcp_dir = resolved / "dcp" if not dcp_dir.is_dir(): raise FileNotFoundError(f"Missing dcp/ under {resolved}") diff --git a/fastvideo/train/models/base.py b/fastvideo/train/models/base.py index c3635cd04d..ba974a2e63 100644 --- a/fastvideo/train/models/base.py +++ b/fastvideo/train/models/base.py @@ -29,6 +29,10 @@ class ModelBase(ABC): transformer: torch.nn.Module noise_scheduler: Any _trainable: bool + dataloader: Any + world_group: Any + sp_group: Any + vae: Any @property def device(self) -> torch.device: @@ -76,6 +80,7 @@ def prepare_batch( *, generator: torch.Generator, latents_source: Literal["data", "zeros"] = "data", + **kwargs: Any, ) -> TrainingBatch: """Convert a dataloader batch into forward primitives.""" diff --git a/fastvideo/train/models/wan/wan.py b/fastvideo/train/models/wan/wan.py index c485f30a1f..3d211c515a 100644 --- a/fastvideo/train/models/wan/wan.py +++ b/fastvideo/train/models/wan/wan.py @@ -46,14 +46,19 @@ from fastvideo.train.utils.training_config import ( TrainingConfig, ) +VideoSparseAttentionMetadataBuilder: type[Any] | None +VideoMobaAttentionMetadataBuilder: type[Any] | None + try: from fastvideo.attention.backends.video_sparse_attn import ( - VideoSparseAttentionMetadataBuilder, ) + VideoSparseAttentionMetadataBuilder as _VideoSparseAttentionMetadataBuilder, ) from fastvideo.attention.backends.vmoba import ( - VideoMobaAttentionMetadataBuilder, ) + VideoMobaAttentionMetadataBuilder as _VideoMobaAttentionMetadataBuilder, ) + VideoSparseAttentionMetadataBuilder = _VideoSparseAttentionMetadataBuilder + VideoMobaAttentionMetadataBuilder = _VideoMobaAttentionMetadataBuilder except Exception: - VideoSparseAttentionMetadataBuilder = None # type: ignore[assignment] - VideoMobaAttentionMetadataBuilder = None # type: ignore[assignment] + VideoSparseAttentionMetadataBuilder = None + VideoMobaAttentionMetadataBuilder = None class WanModel(ModelBase): @@ -327,8 +332,8 @@ def _get_training_dtype(self) -> torch.dtype: def _init_timestep_mechanics(self) -> None: assert self.training_config is not None tc = self.training_config - self.timestep_shift = float(tc.pipeline_config.flow_shift # type: ignore[union-attr] - ) + flow_shift = tc.pipeline_config.flow_shift + self.timestep_shift = float(0.0 if flow_shift is None else flow_shift) self.num_train_timestep = int(self.noise_scheduler.num_train_timesteps) # min/max timestep ratios now come from method_config; # default to full range. @@ -485,6 +490,7 @@ def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBa "(or flash_attn>=2.7.4) is not " "correctly installed.") moba_params = tc.model.moba_config.copy() + assert training_batch.raw_latent_shape is not None moba_params.update({ "current_timestep": (training_batch.timesteps), "raw_latent_shape": (training_batch.raw_latent_shape[2:5]), diff --git a/fastvideo/training/distillation_pipeline.py b/fastvideo/training/distillation_pipeline.py index 319f89e4fc..ffdbfabcf9 100644 --- a/fastvideo/training/distillation_pipeline.py +++ b/fastvideo/training/distillation_pipeline.py @@ -271,6 +271,12 @@ def load_module_from_path(self, model_path: str, module_type: str, training_args logger.info("Loading %s from custom path: %s", module_type, model_path) # Set flag to prevent custom weight loading for teacher/critic models training_args._loading_teacher_critic_model = True + suppressed_attn_backend = None + if self._should_force_generator_attn_qat_train(training_args) and \ + envs.FASTVIDEO_ATTENTION_BACKEND == "ATTN_QAT_TRAIN": + suppressed_attn_backend = os.environ.pop("FASTVIDEO_ATTENTION_BACKEND", None) + logger.info("Temporarily disabling FASTVIDEO_ATTENTION_BACKEND=ATTN_QAT_TRAIN while loading %s", + module_type) try: from fastvideo.models.loader.component_loader import (PipelineComponentLoader) @@ -307,6 +313,8 @@ def load_module_from_path(self, model_path: str, module_type: str, training_args logger.info("Successfully loaded %s from %s", module_type, component_path) return module finally: + if suppressed_attn_backend is not None: + os.environ["FASTVIDEO_ATTENTION_BACKEND"] = suppressed_attn_backend # Always clean up the flag if hasattr(training_args, '_loading_teacher_critic_model'): delattr(training_args, '_loading_teacher_critic_model') diff --git a/fastvideo/training/training_pipeline.py b/fastvideo/training/training_pipeline.py index cd8f7e0b09..a424534439 100644 --- a/fastvideo/training/training_pipeline.py +++ b/fastvideo/training/training_pipeline.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 from dataclasses import asdict +from contextlib import AbstractContextManager, nullcontext import math import os import shutil @@ -35,13 +36,15 @@ from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs from fastvideo.forward_context import set_forward_context from fastvideo.logger import init_logger +from fastvideo.attention.selector import global_force_attn_backend_context_manager from fastvideo.pipelines import (ComposedPipelineBase, ForwardBatch, LoRAPipeline, TrainingBatch) -from fastvideo.platforms import current_platform +from fastvideo.platforms import AttentionBackendEnum, current_platform from fastvideo.training.activation_checkpoint import (apply_activation_checkpointing) from fastvideo.training.trackers import (DummyTracker, TrackerType, initialize_trackers, Trackers) from fastvideo.training.training_utils import (clip_grad_norm_while_handling_failing_dtensor_cases, compute_density_for_timestep_sampling, count_trainable, get_scheduler, - get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint) + get_sigmas, load_checkpoint, normalize_dit_input, save_checkpoint, + swap_fp4_linear, traverse_swap_module) from fastvideo.utils import (is_vmoba_available, is_vsa_available, set_random_seed, shallow_asdict) try: @@ -84,6 +87,24 @@ def __init__(self, def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): raise RuntimeError("create_pipeline_stages should not be called for training pipeline") + @staticmethod + def _should_force_generator_attn_qat_train(fastvideo_args: FastVideoArgs) -> bool: + if not isinstance(fastvideo_args, TrainingArgs): + return False + return (fastvideo_args.generator_4bit_attn or envs.FASTVIDEO_ATTENTION_BACKEND == "ATTN_QAT_TRAIN") + + def load_modules(self, + fastvideo_args: FastVideoArgs, + loaded_modules: dict[str, torch.nn.Module] | None = None) -> dict[str, Any]: + force_generator_qat = self._should_force_generator_attn_qat_train(fastvideo_args) + load_context: AbstractContextManager[None] = nullcontext() + if force_generator_qat: + logger.info("Forcing generator attention backend to ATTN_QAT_TRAIN during module loading") + load_context = global_force_attn_backend_context_manager(AttentionBackendEnum.ATTN_QAT_TRAIN) + + with load_context: + return super().load_modules(fastvideo_args, loaded_modules) + def set_schemas(self) -> None: self.train_dataset_schema = pyarrow_schema_t2v @@ -114,6 +135,9 @@ def initialize_training_pipeline(self, training_args: TrainingArgs): self.transformer_2 = apply_activation_checkpointing( self.transformer_2, checkpointing_type=training_args.enable_gradient_checkpointing_type) + if training_args.generator_4bit_linear: + num_swaps = traverse_swap_module(self.transformer, swap_fn=swap_fp4_linear) + logger.info("Swapped %s linear layers to the FP4 forward path in self.transformer", num_swaps) noise_scheduler = self.modules["scheduler"] self.set_trainable() params_to_optimize = self.transformer.parameters() @@ -345,6 +369,7 @@ def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBa patch_size = self.training_args.pipeline_config.dit_config.patch_size current_vsa_sparsity = training_batch.current_vsa_sparsity assert latents_shape is not None + assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}" assert training_batch.timesteps is not None if envs.FASTVIDEO_ATTENTION_BACKEND == "VIDEO_SPARSE_ATTN": if not vsa_available: @@ -365,7 +390,7 @@ def _build_attention_metadata(self, training_batch: TrainingBatch) -> TrainingBa moba_params = self.training_args.moba_config.copy() moba_params.update({ "current_timestep": training_batch.timesteps, - "raw_latent_shape": training_batch.raw_latent_shape[2:5], + "raw_latent_shape": latents_shape[2:5], "patch_size": self.training_args.pipeline_config.dit_config.patch_size, "device": get_local_torch_device(), }) @@ -466,6 +491,9 @@ def train_one_step(self, training_batch: TrainingBatch) -> TrainingBatch: training_batch = self._normalize_dit_input(training_batch) # Create noisy model input training_batch = self._prepare_dit_inputs(training_batch) + assert training_batch.latents is not None + assert training_batch.noisy_model_input is not None + assert training_batch.noise is not None # old sharding code, need to shard latents and noise but not input # Shard latents across sp groups @@ -590,9 +618,12 @@ def train(self) -> None: "vsa_sparsity": current_vsa_sparsity, } try: + assert training_batch.raw_latent_shape is not None metrics["batch_size"] = int(training_batch.raw_latent_shape[0]) - patch_t, patch_h, patch_w = self.training_args.pipeline_config.dit_config.patch_size + patch_size = self.training_args.pipeline_config.dit_config.patch_size + assert isinstance(patch_size, tuple), f"Expected tuple patch_size, got {patch_size!r}" + patch_t, patch_h, patch_w = patch_size seq_len = (training_batch.raw_latent_shape[2] // patch_t) * ( training_batch.raw_latent_shape[3] // patch_h) * (training_batch.raw_latent_shape[4] // patch_w) if training_batch.encoder_hidden_states is not None: diff --git a/fastvideo/training/training_utils.py b/fastvideo/training/training_utils.py index c539f2608e..2610ee7473 100644 --- a/fastvideo/training/training_utils.py +++ b/fastvideo/training/training_utils.py @@ -5,11 +5,13 @@ import time from collections.abc import Callable, Iterator from enum import Enum +import types from typing import Any import torch import torch.distributed as dist import torch.distributed.checkpoint as dcp +import torch.nn as nn from safetensors.torch import save_file from torch.optim import Optimizer from torch.optim.lr_scheduler import LambdaLR @@ -23,6 +25,40 @@ _HAS_ERRORED_CLIP_GRAD_NORM_WHILE_HANDLING_FAILING_DTENSOR_CASES = False +def swap_fp4_linear(obj: Any, obj_path: str) -> int: + """ + Swap supported linear layers to use the FP4 forward path in-place. + + Returns the number of layers updated on ``obj``. + """ + from fastvideo.layers.fp4linear import fp4_linear_forward + from fastvideo.layers.linear import ReplicatedLinear + + del obj_path # Unused today, but kept for traverse_swap_module compatibility. + + swaps_performed = 0 + + for attr_name in dir(obj): + if attr_name.startswith("_"): + continue + + try: + attr_value = getattr(obj, attr_name) + except Exception: + continue + + should_replace = False + if isinstance(attr_value, ReplicatedLinear): + should_replace = True + + if should_replace: + layer = getattr(obj, attr_name) + layer.forward = types.MethodType(fp4_linear_forward, layer) + swaps_performed += 1 + + return swaps_performed + + def gather_state_dict_on_cpu_rank0( model, device: torch.device | None = None, @@ -1702,3 +1738,90 @@ def __exit__(self, exc_type, exc, tb): def apply_to_model(self, module: torch.nn.Module) -> _ApplyEMACtx: return EMA_FSDP._ApplyEMACtx(self, module) + + +def traverse_swap_module( + root: Any, + swap_fn: Callable[[Any, str], int], + *, + verbose: bool = False, + max_depth: int | None = None, +) -> int: + """ + Traverse `root` (e.g., your transformer) and apply `swap_fn` to every reachable object. + - Cycle-safe via `visited` set. + - Handles nn.Module submodules, regular attributes, and common containers (list/tuple/set/dict). + + Args: + root: object graph root (e.g., transformer) + swap_fn: function(obj, obj_path) -> int (number of swaps it performed on `obj`) + verbose: print what gets swapped + max_depth: optional limit to traversal depth (0 = only root) + + Returns: + Total number of swaps performed across the graph. + """ + visited = set() + + def is_traversable(x: Any) -> bool: + # Skip obvious non-containers / primitives / callables / modules + return not isinstance( + x, + str | bytes | bytearray | memoryview | int | float | bool | complex | types.BuiltinFunctionType + | types.FunctionType | types.MethodType | types.ModuleType, + ) + + def walk(obj: Any, obj_path: str, depth: int) -> int: + if max_depth is not None and depth > max_depth: + return 0 + + oid = id(obj) + if oid in visited: + return 0 + visited.add(oid) + + swaps = 0 + # First, try to swap on the current object + swaps += swap_fn(obj, obj_path) + + # Then, traverse children + # 1) PyTorch modules: prefer named_children() for proper module graph traversal + if isinstance(obj, nn.Module): + + for name, child in obj.named_children(): + if is_traversable(child): + swaps += walk(child, f"{obj_path}.{name}", depth + 1) + + # Also traverse non-module attributes living on the module + d = getattr(obj, "__dict__", None) + if isinstance(d, dict): + for k, v in d.items(): + # Avoid double-visiting known submodules (already covered by named_children) + if is_traversable(v) and not isinstance(v, nn.Module): + swaps += walk(v, f"{obj_path}.{k}", depth + 1) + + return swaps + + # 2) Generic Python objects: __dict__ attrs + d = getattr(obj, "__dict__", None) + if isinstance(d, dict): + for k, v in d.items(): + if is_traversable(v): + swaps += walk(v, f"{obj_path}.{k}", depth + 1) + + # 3) Common containers + if isinstance(obj, list | tuple | set | frozenset): + for idx, v in enumerate(obj): + if is_traversable(v): + swaps += walk(v, f"{obj_path}[{idx}]", depth + 1) + elif isinstance(obj, dict): + for k, v in obj.items(): + if is_traversable(v): + k_show = repr(k) + if len(k_show) > 40: + k_show = k_show[:37] + "..." + swaps += walk(v, f"{obj_path}[{k_show}]", depth + 1) + + return swaps + + return walk(root, "root", 0) diff --git a/fastvideo/utils.py b/fastvideo/utils.py index 75b9bf5208..d08ca6e93a 100644 --- a/fastvideo/utils.py +++ b/fastvideo/utils.py @@ -11,7 +11,6 @@ import json import math import multiprocessing -from multiprocessing.context import BaseContext import os import signal import socket @@ -796,8 +795,10 @@ def dict_to_3d_list( return result -def set_random_seed(seed: int) -> None: +def set_random_seed(seed: int | None) -> None: from fastvideo.platforms import current_platform + if seed is None: + return current_platform.seed_everything(seed) @@ -845,7 +846,7 @@ def masks_like(tensor, zero=False, generator=None, p=0.2) -> tuple[list[torch.Te # adapted from: https://github.com/Wan-Video/Wan2.2/blob/main/wan/utils/utils.py -def best_output_size(w, h, dw, dh, expected_area): +def best_output_size(w: int, h: int, dw: int, dh: int, expected_area: int) -> tuple[int, int]: # float output size ratio = w / h ow = (expected_area * ratio)**0.5 @@ -1070,7 +1071,7 @@ def force_spawn() -> None: os.environ["FASTVIDEO_WORKER_MULTIPROC_METHOD"] = "spawn" -def get_mp_context() -> BaseContext: +def get_mp_context() -> Any: """Get a multiprocessing context with a particular method (spawn or fork). By default we follow the value of the FASTVIDEO_WORKER_MULTIPROC_METHOD to determine the multiprocessing method (default is fork). However, under diff --git a/fastvideo/worker/executor.py b/fastvideo/worker/executor.py index 964c099e13..49625e0e1e 100644 --- a/fastvideo/worker/executor.py +++ b/fastvideo/worker/executor.py @@ -23,6 +23,7 @@ def __init__( ): self.fastvideo_args = fastvideo_args self._log_queue = log_queue + self._streaming_enabled = False self._init_executor() @@ -120,3 +121,30 @@ def shutdown(self) -> None: Shutdown the executor. """ raise NotImplementedError + + def execute_streaming_reset(self, forward_batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> dict[str, Any]: + raise NotImplementedError + + def execute_streaming_step(self, keyboard_action: Any = None, mouse_action: Any = None) -> ForwardBatch: + raise NotImplementedError + + async def execute_streaming_step_async(self, keyboard_action: Any = None, mouse_action: Any = None) -> ForwardBatch: + raise NotImplementedError + + def execute_streaming_clear(self) -> dict[str, Any] | None: + raise NotImplementedError + + def submit_reset(self, forward_batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> None: + raise NotImplementedError + + def submit_step(self, keyboard_action: Any = None, mouse_action: Any = None) -> None: + raise NotImplementedError + + def submit_clear(self) -> None: + raise NotImplementedError + + def wait_result(self) -> Any: + raise NotImplementedError + + def disable_streaming(self) -> None: + self._streaming_enabled = False diff --git a/fastvideo/worker/ray_distributed_executor.py b/fastvideo/worker/ray_distributed_executor.py index 3de101426c..a4117ec88b 100644 --- a/fastvideo/worker/ray_distributed_executor.py +++ b/fastvideo/worker/ray_distributed_executor.py @@ -43,7 +43,7 @@ class RayWorkerMetaData: and we need to reset the rank after creating all workers. """ - worker: ActorHandle + worker: Any created_rank: int adjusted_rank: int = -1 ip: str = "" @@ -84,7 +84,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwar num_gpus = envs.FASTVIDEO_RAY_PER_WORKER_GPUS # The remaining workers are the actual ray actors. - self.workers: list[RayWorkerWrapper] = [] + self.workers: list[Any] = [] # Create the workers. # use the first N bundles that have GPU resources. @@ -242,11 +242,11 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # This is the list of workers that are rank 0 of each TP group EXCEPT # global rank 0. These are the workers that will broadcast to the # rest of the workers. - self.tp_driver_workers: list[RayWorkerWrapper] = [] + self.tp_driver_workers: list[Any] = [] # This is the list of workers that are not drivers and not the first # worker in a TP group. These are the workers that will be # broadcasted to. - self.non_driver_workers: list[RayWorkerWrapper] = [] + self.non_driver_workers: list[Any] = [] # Enforce rank order for correct rank to return final output. for index, worker in enumerate(self.workers): diff --git a/fastvideo/workflow/preprocess/preprocess_workflow.py b/fastvideo/workflow/preprocess/preprocess_workflow.py index 90355371a8..918212fe99 100644 --- a/fastvideo/workflow/preprocess/preprocess_workflow.py +++ b/fastvideo/workflow/preprocess/preprocess_workflow.py @@ -1,5 +1,4 @@ import os -from typing import cast from torch.utils.data import DataLoader @@ -105,18 +104,18 @@ def prepare_system_environment(self) -> None: self.training_dataset_output_dir = training_dataset_output_dir @classmethod - def get_workflow_cls(cls, fastvideo_args: FastVideoArgs) -> "PreprocessWorkflow": + def get_workflow_cls(cls, fastvideo_args: FastVideoArgs) -> type["PreprocessWorkflow"]: is_ltx2_t2v = (fastvideo_args.workload_type == WorkloadType.T2V and fastvideo_args.pipeline_config.__class__.__name__ == "LTX2T2VConfig") if is_ltx2_t2v: from fastvideo.workflow.preprocess.preprocess_workflow_ltx2_t2v import (PreprocessWorkflowLTX2T2V) - return cast(PreprocessWorkflow, PreprocessWorkflowLTX2T2V) + return PreprocessWorkflowLTX2T2V if fastvideo_args.workload_type == WorkloadType.T2V: from fastvideo.workflow.preprocess.preprocess_workflow_t2v import (PreprocessWorkflowT2V) - return cast(PreprocessWorkflow, PreprocessWorkflowT2V) + return PreprocessWorkflowT2V elif fastvideo_args.workload_type == WorkloadType.I2V: from fastvideo.workflow.preprocess.preprocess_workflow_i2v import (PreprocessWorkflowI2V) - return cast(PreprocessWorkflow, PreprocessWorkflowI2V) + return PreprocessWorkflowI2V else: raise ValueError( f"Workload type: {fastvideo_args.workload_type} is not supported in preprocessing workflow.") diff --git a/fastvideo/workflow/workflow_base.py b/fastvideo/workflow/workflow_base.py index c89fd62092..001dd03557 100644 --- a/fastvideo/workflow/workflow_base.py +++ b/fastvideo/workflow/workflow_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Optional +from typing import Any from fastvideo.fastvideo_args import ExecutionMode, FastVideoArgs from fastvideo.logger import init_logger @@ -37,7 +37,7 @@ class WorkflowBase(ABC): the overall processing flow. """ - def __init__(self, fastvideo_args: FastVideoArgs): + def __init__(self, fastvideo_args: FastVideoArgs) -> None: """ Initialize the workflow with configuration arguments. @@ -145,7 +145,7 @@ def prepare_system_environment(self) -> None: pass @abstractmethod - def run(self): + def run(self) -> Any: """ Execute the main workflow logic. @@ -156,7 +156,7 @@ def run(self): pass @classmethod - def get_workflow_cls(cls, fastvideo_args: FastVideoArgs) -> Optional["WorkflowBase"]: + def get_workflow_cls(cls, fastvideo_args: FastVideoArgs) -> type["WorkflowBase"] | None: """ Factory method to get the appropriate workflow class based on execution mode. diff --git a/mkdocs.yml b/mkdocs.yml index 68264c7940..a9b4b1f1c1 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -152,6 +152,7 @@ nav: - DMD: distillation/dmd.md - Attention: - Overview: attention/index.md + - Attention QAT: attention/attn_qat/index.md - Video Sparse Attention: attention/vsa/index.md - Sliding Tile Attention (Archived): attention/sta/index.md - Backend Development: contributing/attention_backend.md diff --git a/pyproject.toml b/pyproject.toml index 5479aff4c9..63ab439b97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,13 +23,14 @@ dependencies = [ # Machine Learning & Transformers "transformers==4.57.3", "tokenizers>=0.20.1", - "sentencepiece==0.2.0", + "sentencepiece==0.2.1", "timm==1.0.11", "peft>=0.15.0", "diffusers>=0.33.1", "torch>=2.10.0", "torchvision", "torchaudio", + "flashinfer-python", # Acceleration & Optimization "accelerate==1.0.1", @@ -149,6 +150,7 @@ disallow_untyped_calls = true check_untyped_defs = true follow_imports = "silent" + [tool.codespell] skip = "./data,./wandb,ui/package-lock.json" diff --git a/tests/local_tests/transformers/test_gamecraft_parity.py b/tests/local_tests/transformers/test_gamecraft_parity.py index fd5a35b7af..22a28678c3 100644 --- a/tests/local_tests/transformers/test_gamecraft_parity.py +++ b/tests/local_tests/transformers/test_gamecraft_parity.py @@ -17,6 +17,7 @@ # Enable debug logging GAMECRAFT_DEBUG_LOGS=1 pytest tests/local_tests/transformers/test_gamecraft_parity.py -v """ +import contextlib import os import sys from pathlib import Path @@ -33,6 +34,21 @@ repo_root = Path(__file__).resolve().parents[3] +@pytest.fixture(autouse=True) +def _cleanup_parallel_state(): + """Keep local parity tests isolated when they initialize NCCL groups.""" + from fastvideo.distributed.parallel_state import cleanup_dist_env_and_memory + + with contextlib.suppress(Exception): + cleanup_dist_env_and_memory() + + try: + yield + finally: + with contextlib.suppress(Exception): + cleanup_dist_env_and_memory() + + def _add_official_to_path(): """Add official GameCraft implementation to Python path.""" official_path = Path( diff --git a/ui/job_runner.py b/ui/job_runner.py index 725730d938..ff20c9f67e 100644 --- a/ui/job_runner.py +++ b/ui/job_runner.py @@ -13,7 +13,6 @@ import enum import logging import logging.handlers -import multiprocessing as mp import os import re import subprocess @@ -645,7 +644,7 @@ def _get_or_create_generator( vsa_sparsity: float = 0.0, tp_size: int = -1, sp_size: int = -1, - log_queue: mp.Queue | None = None, + log_queue: Any | None = None, ) -> Any: cache_key = ( model_id,