diff --git a/docs/source/builder-cli.md b/docs/source/builder-cli.md index 6715ba91..893926bb 100644 --- a/docs/source/builder-cli.md +++ b/docs/source/builder-cli.md @@ -308,7 +308,7 @@ Install a kernels skill for an AI assistant Default value: `cuda-kernels` - Possible values: `cuda-kernels`, `rocm-kernels` + Possible values: `cuda-kernels`, `rocm-kernels`, `xpu-kernels` * `--claude` — Install for Claude * `--codex` — Install for Codex diff --git a/docs/source/builder/agents-guide.md b/docs/source/builder/agents-guide.md index 6a414aaa..a56f4692 100644 --- a/docs/source/builder/agents-guide.md +++ b/docs/source/builder/agents-guide.md @@ -2,7 +2,7 @@ Code agents are a good fit to build custom kernels because the hard part is not just writing in Domain Specific Language (DSLs) like CUDA. You also need the right project layout, PyTorch bindings, architecture-specific choices, model-specific integration, and trustworthy benchmarks. -Kernels on Hugging Face are compatible with agents via skills and the `hf` CLI. The `cuda-kernels` and `rocm-kernels` skills contain knowledge so an agent can generate and publish a complete kernel project, instead of isolated snippets. +Kernels on Hugging Face are compatible with agents via skills and the `hf` CLI. The `cuda-kernels`, `rocm-kernels`, and `xpu-kernels` skills contain knowledge so an agent can generate and publish a complete kernel project, instead of isolated snippets. This guide is for **authoring new kernels**. If you only want to **load an existing precompiled kernel**, use `get_kernel()` instead. diff --git a/docs/source/cli-skills.md b/docs/source/cli-skills.md index 36ddf8ab..fdfd2883 100644 --- a/docs/source/cli-skills.md +++ b/docs/source/cli-skills.md @@ -4,6 +4,7 @@ Use `kernel-builder skills add` to install the skills for AI coding assistants l Supported skills include: - `cuda-kernels` (default) - `rocm-kernels` +- `xpu-kernels` Skill files are downloaded from the `huggingface/kernels` directory in this [repository](https://github.com/huggingface/kernels/tree/main/kernel-builder/skills). diff --git a/kernel-builder/skills/xpu-kernels/README.md b/kernel-builder/skills/xpu-kernels/README.md new file mode 100644 index 00000000..88243714 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/README.md @@ -0,0 +1,32 @@ +# XPU Kernels Skill + +This skill was adapted from [Xe-Forge](https://github.com/IntelLabs/Xe-Forge) — an LLM-driven optimization framework that transforms PyTorch code into fast Triton kernels for Intel XPU GPUs. + +The skill includes Xe-Forge's CLI tools (`scripts/`), knowledge base (`references/`), and the optimization workflow, all integrated into the hf-kernels skill format. + +## Full Experience + +For the complete Xe-Forge setup — including the ai-bench harness, test kernels, GEMM/reduction templates, annotated examples, and VTune profiling — clone the full project: + +```bash +# Clone the repository +git clone https://github.com/IntelLabs/Xe-Forge +cd Xe-Forge + +# Install for Intel XPU +uv sync --extra intel +``` + +## Prerequisites + +- Python 3.10+ +- PyTorch with XPU support +- [Intel XPU Backend for Triton](https://github.com/intel/intel-xpu-backend-for-triton) +- Intel XPU hardware (tested on Battlemage G21 / Arc Pro B50) +- Intel VTune Profiler 2025+ *(optional — set `vtune_enabled: false` in `scripts/config.yaml` to skip)* + +## Install Dependencies + +```bash +pip install -r scripts/requirements.txt +``` diff --git a/kernel-builder/skills/xpu-kernels/SKILL.md b/kernel-builder/skills/xpu-kernels/SKILL.md new file mode 100644 index 00000000..e62aeaa5 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/SKILL.md @@ -0,0 +1,290 @@ +--- +name: xpu-kernels +description: "Provides guidance for writing, optimizing, and benchmarking Triton kernels for Intel XPU GPUs (Battlemage/Arc Pro B50) using the Xe-Forge optimization framework. Includes an LLM-driven trial-loop workflow (analyze, validate, benchmark, profile, finalize), XPU-specific patterns (tensor descriptors, GRF mode, tile swizzling), KernelBench fused kernels, and Flash Attention." +disable-model-invocation: false +user-invocable: true +allowed-tools: "Read, Grep, Glob, Bash" +argument-hint: "kernel type: gemm, reduction, flash-attention, optimize, benchmark, tensor-descriptors, xe-forge" +--- + +# XPU Triton Kernels for Intel GPUs + +This skill provides patterns and guidance for developing optimized Triton kernels targeting Intel XPU GPUs (Battlemage/Arc Pro B50). It integrates the [Xe-Forge](https://github.com/IntelLabs/Xe-Forge) optimization framework — an LLM-driven loop that transforms PyTorch code into fast Triton kernels. + +## Quick Start + +### Optimize a Kernel (Xe-Forge Workflow) + +The full optimization workflow analyzes a PyTorch baseline, generates Triton kernel variants in a branching trial tree, benchmarks each on XPU hardware, and finalizes the best result. + +```bash +# 1. Analyze the baseline +python scripts/analyze_kernel.py test_kernels/70_Gemm_Sigmoid_Scaling_ResidualAdd_pytorch.py + +# 2. Initialize trial tracking +python scripts/trial_manager.py init 70_Gemm_Sigmoid test_kernels/70_Gemm_Sigmoid_Scaling_ResidualAdd_pytorch.py + +# 3. Validate a generated kernel (no GPU needed) +python scripts/validate_triton.py my_kernel.py + +# 4. Benchmark correctness + performance +python scripts/benchmark.py test_kernels/70_Gemm_Sigmoid_Scaling_ResidualAdd_pytorch.py my_kernel.py + +# 5. Profile with VTune (optional) +python scripts/xpu_profiler.py my_kernel.py + +# 6. Finalize best trial +python scripts/trial_manager.py finalize 70_Gemm_Sigmoid optimized_triton.py +``` + +## Supported Hardware + +| GPU | Architecture | XVEs | Mem BW | Key Feature | Verified | +|-----|-------------|------|--------|-------------|:--------:| +| **Battlemage G21 / Arc Pro B50** | Xe2 | 128 | ~500 GB/s | Tensor descriptors, GRF 256 | Yes | + +> See the [Intel XPU Backend for Triton](https://github.com/intel/intel-xpu-backend-for-triton) for supported hardware. + +## When This Skill Applies + +Use this skill when: +- Optimizing PyTorch operations into Triton kernels for **Intel XPU** +- Writing GEMM, fused kernels, reductions, or Flash Attention for Intel GPUs +- Running the **Xe-Forge optimization loop** (analyze → validate → benchmark → profile → finalize) +- Benchmarking kernel performance against PyTorch baseline on XPU + +## Xe-Forge Optimization Workflow + +Transform PyTorch code into optimized Triton kernels for Intel XPU. Kernels must be numerically equivalent and faster than baseline. + +### Configuration — Read `config.yaml` first + +At the start of every session, read `scripts/config.yaml`. It controls: +- **`max_trials`** — hard cap on optimization trials; always run all of them (use this instead of hardcoded "10") +- **`vtune_enabled`** — if `false`, skip ALL VTune profiling steps (Step 3.6 and profiler-related decisions) +- **`vtune_bin`** — path to the VTune binary (also settable via `VTUNE_BIN` env var) + +### Rules — Never Violate + +1. **ONLY create** Triton kernel files (`test_kernels/*_triton.py` or trial files `t.py`). +2. **NEVER create** benchmark scripts, test scripts, helper utilities, or any other Python files. +3. **NEVER write custom scripts** to measure performance or test correctness — ONLY use `scripts/benchmark.py`. +4. If a tool fails, **STOP and report the error**. Do NOT work around it with custom scripts. +5. Generated kernels must be **self-contained** — all helper functions inline. +6. You **MUST run all `max_trials` trials** from `config.yaml`. Do NOT stop early due to plateau — LLM sampling can discover new ideas at any point. The only valid early stop is speedup > 5x. + +### Mandatory Tools + +**CRITICAL — Single-XPU serialization**: There is only ONE XPU on this machine. You MUST NOT run multiple GPU workloads in parallel. `benchmark.py` and `xpu_profiler.py` must execute strictly one at a time — concurrent GPU jobs produce wrong results. CPU-only tools (`analyze_kernel.py`, `validate_triton.py`, `trial_manager.py`) are safe to parallelize with each other and with anything else. + +| Tool | Command | Purpose | +|------|---------|---------| +| **Analyze** | `python scripts/analyze_kernel.py ` | Static analysis: operations, shapes, fusion opportunities | +| **Validate** | `python scripts/validate_triton.py ` | Syntax + constraint checks before GPU time | +| **Benchmark** | `python scripts/benchmark.py [--triton-baseline] [--baseline-us ]` | Correctness + performance via ai-bench | +| **Profile** | `python scripts/xpu_profiler.py ` | VTune GPU hardware counters + recommendations | +| **Init trials** | `python scripts/trial_manager.py init [--triton-baseline]` | Initialize trial tracking | +| **Save trial** | `python scripts/trial_manager.py save [--parent ] [--strategy "..."]` | Save trial to tree | +| **Record result** | `python scripts/trial_manager.py result --validation pass --correctness --speedup --baseline_us --triton_us ` | Record benchmark result | +| **Check status** | `python scripts/trial_manager.py status ` | View trial tree | +| **Best trial** | `python scripts/trial_manager.py best ` | Get best trial | +| **Baseline time** | `python scripts/trial_manager.py baseline-us ` | Cached baseline time for `--baseline-us` | +| **Finalize** | `python scripts/trial_manager.py finalize _triton.py` | Copy best trial to output | + +### Workflow Steps + +#### Step 1: Analyze +- Read the baseline source file. Identify shapes, dtypes, operations, fusion opportunities. +- If baseline is PyTorch: run `python scripts/analyze_kernel.py `. +- If baseline is Triton (`--triton-baseline`): skip `analyze_kernel.py` (it only supports PyTorch). Read the Triton file directly. +- Read relevant knowledge base files: start with `references/correctness.yaml` and `references/xpu_optimizations.yaml`. +- Read `references/implementation_reference.md` for templates and the Model class pattern. + +#### Step 2: Initialize +```bash +python scripts/trial_manager.py init [--triton-baseline] +``` + +#### Step 3: Trial Loop (always run all `max_trials` from config.yaml) +For each trial: +1. **Write kernel** — start from templates or modify previous trial. See `references/implementation_reference.md`. +2. **Validate** — `python scripts/validate_triton.py ` (fix until passing; doesn't count as a trial). +3. **Save** — `python scripts/trial_manager.py save --parent --strategy "description"`. Omit `--parent` for the first trial (t0). +4. **Benchmark** (MANDATORY every trial): + - **Trial t0:** `python scripts/benchmark.py [--triton-baseline]` (measures both baseline and triton). + - **Trials t1+:** Get cached baseline via `python scripts/trial_manager.py baseline-us `, then run `python scripts/benchmark.py [--triton-baseline] --baseline-us ` (skips baseline perf, saves time). + - **After `finalize`:** Re-run `benchmark.py` without `--baseline-us` for final accurate comparison. +5. **Record** — `python scripts/trial_manager.py result --validation pass --correctness --speedup --baseline_us --triton_us ` (runtimes from benchmark output). +6. **Profile (MANDATORY after t1, if `vtune_enabled` is true in config.yaml)** — Run `python scripts/xpu_profiler.py ` after your first benchmarked trial. Use its output to guide subsequent trial strategies. Run again if speedup plateaus after 2+ additional trials. **Skip this step entirely if `vtune_enabled` is false.** +7. **Decide next action** (use profiler output from step 6 to inform decisions): + - Speedup > 5x → stop (excellent), finalize + - Speedup improved → continue on this branch, try next optimization level + - Speedup regressed → branch back to best trial, try different strategy + - Correctness failed → fix on same branch + - Profiler says low occupancy (if vtune_enabled) → increase tile sizes, check `references/xpu_optimizations.yaml` + - Profiler says overhead kernels dominate (if vtune_enabled) → pre-pack to bf16, see `references/optimization_levels.yaml` + - Plateau → do NOT stop. Try a fundamentally different approach (different algorithm, tiling, fusion strategy). LLM sampling can discover new ideas. + - See `references/optimization_strategies.md` for the full "try harder" decision tree + +#### Step 4: Finalize +```bash +python scripts/trial_manager.py finalize _triton.py +``` + +### Reference Docs — Read During Step 1 + +| Doc | Contents | +|-----|----------| +| `references/implementation_reference.md` | Code templates, Model class pattern, GEMM example | +| `references/optimization_strategies.md` | Strategy reference, optimization levels, checklist | +| `references/workflow_details.md` | Detailed workflow, decision tree, benchmarking/validation details | +| `references/correctness.yaml` | Critical constraints to avoid bugs | +| `references/xpu_optimizations.yaml` | XPU-specific patterns (tensor descriptors, GRF, swizzling) | +| `references/fusion_patterns.yaml` | When to fuse vs split operations | +| `references/optimization_levels.yaml` | Progressive optimization with "try harder" decision tree | + +### Existing Baselines Are Naive + +The `test_kernels/*.py` Triton files (non-pytorch) are **unoptimized baselines**. They use manual pointer arithmetic, lack autotune, and miss XPU optimizations. Do NOT copy their patterns. Use `references/implementation_reference.md` instead. + +## Core XPU Kernel Patterns + +### Tensor Descriptors (Preferred on XPU) + +Tensor descriptors produce better address generation and memory access codegen than block pointers on Intel XPU. + +```python +desc = tl.make_tensor_descriptor( + base=ptr, shape=[M, N], + strides=[stride_m, stride_n], + block_shape=[BLOCK_M, BLOCK_N], +) +block = tl.load(desc, [pid_m, pid_n], boundary_check=(0, 1)) +``` + +### GRF Mode '256' + +Use the large register file for compute-heavy kernels: + +```python +@triton.autotune( + configs=[triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256}, num_warps=32)], + key=['M', 'N', 'K'], +) +@triton.jit(launch_metadata=lambda *args, **kwargs: {'grf_mode': '256'}) +def kernel(...): + ... +``` + +### Tile Swizzling + +Use 1D grid with GROUP_SIZE_M for L2 locality: + +```python +grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) +# Inside kernel: +pid = tl.program_id(0) +num_pid_n = tl.cdiv(N, BLOCK_N) +group_id = pid // (GROUP_SIZE_M * num_pid_n) +``` + +### bf16 Inputs with fp32 Accumulation + +```python +a = tl.load(a_desc, [pid_m, k], boundary_check=(0, 1)) +b = tl.load(b_desc, [k, pid_n], boundary_check=(0, 1)) +acc += tl.dot(a.to(tl.bfloat16), b.to(tl.bfloat16), acc=acc) # fp32 accumulator +``` + +## Critical XPU Constraints + +- **NO default values** for `@triton.autotune` meta-parameters in kernel signature +- **1D grid** when using tile swizzling (GROUP_SIZE_M) +- **`boundary_check`** uses dimension indices `(0, 1)`, not booleans +- **Cast batch indices** to `int64` before stride multiplication +- **Prefer tensor descriptors** over block pointers for all new XPU kernels +- **Do NOT mix** block pointer and tensor descriptor APIs on same operation +- **Pre-zero output buffers** when using atomic accumulation +- Model class must be compatible with ai-bench (`nn.Module` with `nn.Linear`) +- Match `get_inputs()`, `get_init_inputs()`, and module-level constants from `*_pytorch.py` + +> Full constraint list: [correctness.yaml](references/correctness.yaml) + +## Performance Results + +Measured on Intel Battlemage G21 / Arc Pro B50 (128 XVEs). All runtimes are median of benchmark trials. + +### KernelBench Level 2 — Fused Kernels (bf16) + +Speedup is vs. PyTorch eager baseline. Includes GEMM+Sigmoid+Scaling, GEMM+GELU+Softmax, Conv+BatchNorm+ReLU, and other fused patterns. + +### Flash Attention Forward (fp16) + +Baseline is the flash attention kernel from the Intel XPU Triton backend; speedup is vs. that kernel across multiple sequence lengths. + +> Full results: see the [Xe-Forge repository](https://github.com/IntelLabs/Xe-Forge). + +## Common Issues + +| Issue | Symptom | Fix | +|-------|---------|-----| +| **Autotune BLOCK_D** | Wrong results (max_abs 4-8+) | **Never autotune BLOCK_D.** Use `triton.next_power_of_2(D)` | +| Python min/max | Runtime error | `tl.minimum()`/`tl.maximum()` | + +## Project Structure + +``` +xpu-kernels/ +├── SKILL.md # This file (skill definition + workflow) +├── manifest.txt # Files included in this skill +│ +├── scripts/ # Standalone CLI tools +│ ├── analyze_kernel.py # PyTorch → operations, shapes, fusion opportunities +│ ├── validate_triton.py # Syntax + constraint checks +│ ├── benchmark.py # Correctness + performance via ai-bench +│ ├── trial_manager.py # Tree-structured trial management +│ ├── xpu_profiler.py # VTune GPU hardware counters +│ ├── config.py # Shared configuration loader +│ ├── config.yaml # Session config (max_trials, vtune) +│ └── requirements.txt # Python dependencies +│ +└── references/ # Knowledge base + integration guides + ├── correctness.yaml # Hard constraints for XPU Triton + ├── xpu_optimizations.yaml # Tensor descriptors, GRF, swizzling + ├── implementation_reference.md # Code templates, Model class pattern + ├── implementation_reference.md # Code templates, Model class pattern + ├── optimization_strategies.md # Strategy reference + "try harder" tree + ├── optimization_levels.yaml # Progressive L1-L5 optimization levels + ├── workflow_details.md # Detailed workflow and decision tree + ├── fusion_patterns.yaml # When to fuse vs split + ├── memory_patterns.yaml # Access patterns and coalescing + ├── dtype_optimizations.yaml # Mixed precision choices + ├── persistent_kernel_patterns.yaml # Stream K and persistent kernels + ├── kernel-templates.md # Triton kernel templates for XPU + └── kernelbench-classification.md # KernelBench operator taxonomy +``` + +## See Also + +### Xe-Forge Tools +- [analyze_kernel.py](scripts/analyze_kernel.py) — Static analysis of PyTorch reference +- [validate_triton.py](scripts/validate_triton.py) — Pre-benchmark constraint checks +- [benchmark.py](scripts/benchmark.py) — Correctness + performance measurement +- [xpu_profiler.py](scripts/xpu_profiler.py) — VTune GPU hardware counters +- [trial_manager.py](scripts/trial_manager.py) — Branching trial tree management + +### XPU Optimization References +- [correctness.yaml](references/correctness.yaml) — Critical constraints +- [xpu_optimizations.yaml](references/xpu_optimizations.yaml) — Tensor descriptors, GRF, swizzling +- [optimization_strategies.md](references/optimization_strategies.md) — Strategy reference +- [optimization_levels.yaml](references/optimization_levels.yaml) — Progressive L1-L5 levels +- [implementation_reference.md](references/implementation_reference.md) — Code templates + +### Other References +- [kernelbench-classification.md](references/kernelbench-classification.md) — KernelBench operator taxonomy + +### External Resources +- [Xe-Forge Repository](https://github.com/IntelLabs/Xe-Forge) +- [AI-Bench](https://github.com/libxsmm/AI-bench) — Benchmark harness for correctness + performance +- [Intel XPU Backend for Triton](https://github.com/intel/intel-xpu-backend-for-triton) +- [Triton Language Guide](https://triton-lang.org/) diff --git a/kernel-builder/skills/xpu-kernels/manifest.txt b/kernel-builder/skills/xpu-kernels/manifest.txt new file mode 100644 index 00000000..e94b972e --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/manifest.txt @@ -0,0 +1,26 @@ +# Files for xpu-kernels skill +SKILL.md +README.md +references/correctness.yaml +references/dtype_optimizations.yaml +references/fusion_patterns.yaml +references/huggingface-kernels-integration.md +references/implementation_reference.md +references/kernelbench-classification.md +references/memory_patterns.yaml +references/optimization_levels.yaml +references/optimization_strategies.md +references/persistent_kernel_patterns.yaml +references/workflow_details.md +references/xpu_optimizations.yaml +scripts/analyze_kernel.py +scripts/benchmark.py +scripts/benchmark_kernels.py +scripts/config.py +scripts/config.yaml +scripts/huggingface_kernels_example.py +scripts/requirements.txt +scripts/transformers_injection_example.py +scripts/trial_manager.py +scripts/validate_triton.py +scripts/xpu_profiler.py diff --git a/kernel-builder/skills/xpu-kernels/references/correctness.yaml b/kernel-builder/skills/xpu-kernels/references/correctness.yaml new file mode 100644 index 00000000..7bc2658c --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/correctness.yaml @@ -0,0 +1,145 @@ +constraints: + - id: outputs_must_match + name: "Outputs must match original" + severity: info + description: | + The verification tool will check that outputs match the original. + If it fails, try a different optimization approach. + + - id: streamk_output_must_be_prezeroed + name: "Pre-zero output buffer when using atomic accumulation (Stream K)" + severity: critical + description: | + When partial tiles use tl.atomic_add to accumulate results, the output + tensor MUST be initialized to zero (torch.zeros, NOT torch.empty). + Otherwise partial sums will include garbage values. + + WRONG: + ```python + c = torch.empty((M, N), device=a.device, dtype=torch.float32) + first_wave[grid](a, b, c, ...) # atomic_add onto garbage + ``` + + CORRECT: + ```python + c = torch.zeros((M, N), device=a.device, dtype=torch.float32) + first_wave[grid](a, b, c, ...) # atomic_add safely onto zeros + ``` + + - id: streamk_atomic_add_needs_mask + name: "Atomic adds on partial tiles must be masked for boundary safety" + severity: critical + description: | + When falling back to tl.atomic_add for partial tiles, you MUST apply + boundary masks (rm < M, rn < N) to avoid writing out-of-bounds. + + ```python + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.atomic_add(c_ptr_, acc, mask=mask, sem='relaxed') + ``` + + - id: int64_cast_for_large_batch_offsets + name: "Cast batch/stride products to int64 to prevent pointer overflow" + severity: critical + description: | + When computing pointer offsets for batched operations, the product of + a batch index and a stride can exceed int32 range for large tensors. + Triton program_id returns int32 by default. You MUST cast to int64 + before multiplying by strides. + + WRONG (silent int32 overflow → wrong memory addresses): + ```python + bid = tl.program_id(axis=1) + offset_a = bid * stride_az # int32 * int32 → overflow for large tensors + a_ptrs = a_ptr + offset_a + ... + ``` + + CORRECT: + ```python + bid = tl.program_id(axis=1) + offset_a = bid.to(tl.int64) * stride_az # safe for large tensors + a_ptrs = a_ptr + offset_a + ... + ``` + + This applies whenever a program_id or loop index is multiplied by a + stride that could produce values > 2^31 (≈2 billion elements). Common + in batched GEMM, multi-head attention, and any kernel with a batch + dimension over large tensors. + + - id: autotune_no_defaults + name: "Do not put default values on @triton.autotune meta-parameters" + severity: critical + description: | + When using @triton.autotune, the meta-parameters (BLOCK_M, BLOCK_N, etc.) + must NOT have default values in the kernel signature. Default values cause + a "Conflicting meta-parameters" error at runtime. + + WRONG: + ```python + @triton.autotune(configs=[...], key=['M', 'N', 'K']) + @triton.jit + def kernel(..., BLOCK_M: tl.constexpr = 128, ...): + ... + ``` + + CORRECT: + ```python + @triton.autotune(configs=[...], key=['M', 'N', 'K']) + @triton.jit + def kernel(..., BLOCK_M: tl.constexpr, ...): + ... + ``` + + - id: model_class_pattern + name: "Model class must be compatible with ai-bench loading" + severity: critical + description: | + ai-bench creates Model via direct `__init__()` and uses standard + `load_state_dict()` for weight synchronization between reference + and optimized models. + + The Model class should use standard nn.Module patterns: + + ```python + class Model(nn.Module): + def __init__(self, input_size, hidden_size, ...): + super().__init__() + self.gemm = nn.Linear(input_size, hidden_size) + self._packed = False + + def _pack_weights(self): + device = torch.device("xpu") + w = self.gemm.weight.data.detach() + b = self.gemm.bias.data.detach() + self.weight_t = w.to(device, torch.float16).t().contiguous() + self.bias_xpu = b.to(device, torch.float16).contiguous() + self._packed = True + + def forward(self, x): + if not self._packed: + self._pack_weights() + # ... launch triton kernel ... + ``` + + - id: descriptor_no_boundary_check_arg + name: "Tensor descriptor .load() does NOT accept boundary_check" + severity: critical + description: | + Tensor descriptors are the preferred memory access API on XPU. + Unlike block pointers which use tl.load(ptr, boundary_check=(0, 1)), + tensor descriptors handle boundaries internally. The .load() method + takes only a coordinate list. + + WRONG: + ```python + desc = tl.make_tensor_descriptor(base=ptr, shape=(M, K), ...) + data = desc.load([row, col], boundary_check=(0, 1)) + ``` + + CORRECT: + ```python + desc = tl.make_tensor_descriptor(base=ptr, shape=(M, K), ...) + data = desc.load([row, col]) + ``` diff --git a/kernel-builder/skills/xpu-kernels/references/dtype_optimizations.yaml b/kernel-builder/skills/xpu-kernels/references/dtype_optimizations.yaml new file mode 100644 index 00000000..51146a5c --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/dtype_optimizations.yaml @@ -0,0 +1,112 @@ +# Dtype Optimization Patterns for Intel XPU + +patterns: + - id: dtype_float64_to_float32 + name: "Float64 to Float32 Accumulator" + stage: dtype_fix + description: "Replace float64 accumulators with float32" + rationale: | + float64 throughput is 16-32x slower than float32 on GPUs/XPUs. + This is the single biggest performance killer in many kernels. + Using float64 alone can cap performance at around 2 TFLOPS on Intel XPU. + pattern_before: | + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) + a = a_fp32.to(tl.float64) + b = b_fp32.to(tl.float64) + pattern_after: | + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + # No need to convert inputs - keep as float32 + expected_speedup: "5-10x" + applies_to: + - gemm + - matmul + - reduction + examples: + - before: | + @triton.jit + def kernel(a_ptr, b_ptr, c_ptr, ...): + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float64) + for k in range(K): + a = tl.load(a_ptr + ...).to(tl.float64) + b = tl.load(b_ptr + ...).to(tl.float64) + acc = tl.dot(a, b, acc) + after: | + @triton.jit + def kernel(a_ptr, b_ptr, c_ptr, ...): + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(K): + a = tl.load(a_ptr + ...) + b = tl.load(b_ptr + ...) + acc = tl.dot(a, b, acc) + + - id: dtype_input_conversion + name: "Remove Unnecessary Type Conversions" + stage: dtype_fix + description: "Avoid converting inputs to higher precision unnecessarily" + rationale: | + Converting float16 inputs to float64 for computation wastes bandwidth + and compute. Use float32 accumulators with float16 inputs for best + performance on modern accelerators. + pattern_before: | + x = tl.load(x_ptr + offsets).to(tl.float64) + result = x * x # float64 computation + pattern_after: | + x = tl.load(x_ptr + offsets) # Keep as float16 + x_fp32 = x.to(tl.float32) # Upcast to float32 only if needed + result = x_fp32 * x_fp32 + expected_speedup: "2-4x" + applies_to: + - elementwise + - reduction + + - id: dtype_prepack_bf16 + name: "Pre-pack weights and inputs to bf16 before kernel launch" + stage: dtype_fix + description: | + Convert weights to bf16 at _pack_weights() time and inputs to bf16 + before kernel launch, instead of loading fp32 and converting in-kernel. + rationale: | + Loading fp32 data and converting to bf16 inside the kernel wastes + memory bandwidth: + - fp32 load: 4 bytes per element from global memory + - In-kernel .to(tl.bfloat16): discards half the loaded data + - Net: 2x wasted bandwidth in the K-loop (the hottest path) + + Pre-packing to bf16 means the kernel loads 2 bytes per element directly. + For a GEMM with K-loop iterations, this halves the memory traffic for + both A and B tiles — often the difference between 2x and 4x+ speedup. + pattern_before: | + # In _pack_weights(): + self.weight_t = w.to(device).t().contiguous() # stored as fp32 + + # In forward(): + x = x.to(device).contiguous() # fp32 input + + # In kernel K-loop: + a = tl.load(a_block_ptr, boundary_check=(0, 1)) # loads 4B per element + a = a.to(tl.bfloat16) # converts to 2B — 2x waste + b = tl.load(b_block_ptr, boundary_check=(0, 1)) + b = b.to(tl.bfloat16) + acc += tl.dot(a, b) + pattern_after: | + # In _pack_weights(): + self.weight_t = w.to(device).t().contiguous().to(torch.bfloat16) # bf16 + + # In forward(): + x = x.to(device, torch.bfloat16).contiguous() # bf16 input + + # In kernel K-loop (no conversion needed): + a = tl.load(a_block_ptr, boundary_check=(0, 1)) # loads 2B directly + b = tl.load(b_block_ptr, boundary_check=(0, 1)) + acc = tl.dot(a, b, acc=acc) # fused accumulate + expected_speedup: "1.5-2x (halves K-loop memory traffic)" + applies_to: + - gemm + - matmul + - attention + - inference + notes: | + - Keep bias and epilogue vectors in fp32 (small, precision-sensitive) + - Combine with grf_mode='256' and tl.dot(a, b, acc=acc) for best results + - Only for inference; training needs fp32 gradients + - Works with both block pointers and tensor descriptors diff --git a/kernel-builder/skills/xpu-kernels/references/fusion_patterns.yaml b/kernel-builder/skills/xpu-kernels/references/fusion_patterns.yaml new file mode 100644 index 00000000..36a6b252 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/fusion_patterns.yaml @@ -0,0 +1,363 @@ +# Kernel Fusion Patterns for Intel XPU +# Fusion can reduce memory traffic by eliminating intermediate writes/reads, +# but it can also hurt due to GRF/register pressure, reduced occupancy, and +# losing access to vendor-tuned primitives (e.g., GEMM). +# +# Guidance: +# - Fuse bandwidth-bound elementwise chains aggressively (usually a win). +# - Be cautious fusing into GEMM unless you already need a custom GEMM. +# - Do NOT fuse if the intermediate is not materialized, is dead/redundant, +# or if fusion would replace a faster vendor primitive for a tiny epilogue. + +constraints: + - id: fuse_only_if_intermediate_is_materialized + name: "Fuse only if it removes a materialized intermediate that is otherwise written/read" + severity: critical + description: | + Fusion is beneficial only if the unfused baseline materializes an intermediate tensor + (stores it to memory) and then reloads it in a separate kernel. + + If the intermediate is already: + - handled by a fused library epilogue, OR + - kept in registers within a single kernel, OR + - dead / provably redundant for the workload, + then fusion provides little/no benefit and may harm performance. + + - id: do_not_replace_vendor_gemm_for_tiny_epilogue + name: "Do not replace vendor GEMM solely to fuse a tiny epilogue" + severity: critical + description: | + If a vendor GEMM path exists and is faster, do not implement a custom Triton GEMM + only to fuse a small epilogue (e.g., ReLU / min-sub / clamp). + + Prefer: + - vendor GEMM + separate epilogue kernel, OR + - vendor GEMM with native epilogue (if supported), + unless the epilogue is substantial OR you must use a custom GEMM due to layout/math. + + - id: fusion_register_pressure_guard + name: "Avoid fusion that causes GRF/register pressure collapse" + severity: critical + description: | + Fusion increases live values and temporary tensors. If register pressure causes: + - occupancy collapse, OR + - GRF spills, OR + - much lower parallelism, + fusion can be slower than unfused. + + Red flags: + - long activation chains with exp/log/tanh + multiple clamps + - large tiles (e.g., 256x256) + high num_warps + many fused ops + - reductions + heavy elementwise in same kernel on large blocks + + If unsure, fuse only a light epilogue (bias + simple activation) and keep the rest separate. + + - id: fuse_when_bandwidth_bound + name: "Prefer fusion when the baseline is bandwidth-bound" + severity: critical + description: | + Fusion helps most when the unfused path is bandwidth-bound: + - elementwise chains (add/mul/clip/relu/etc.) + - normalization epilogues (affine, bias, scale, clamp) + - pointwise ops following reductions where intermediates are large + + Fusion helps least when: + - compute-bound kernel dominates (large GEMM), + - the intermediate is small, + - or the baseline already uses a fused primitive. + + - id: skip_fusion_when_noop_or_redundant + name: "Skip fusion when the operation is provably no-op or redundant" + severity: critical + description: | + Do not fuse operations that are provably redundant for the workload: + - multiply by 1, add 0 + - clamp/min/max where thresholds never trigger for the data range + - dead outputs not consumed + Extra fused instructions without reducing memory traffic can slow kernels down. + +patterns: + - id: elementwise_chain_fusion + name: "Elementwise Chain Fusion (safe default)" + stage: fusion + description: "Fuse multiple elementwise ops into one kernel to reduce memory traffic" + rationale: | + Elementwise chains are typically bandwidth-bound. Each unfused step reads+writes the full + tensor again. Fusion usually wins unless the chain is very large or involves heavy + transcendentals that blow up register pressure. + pattern_before: | + y = a + b + y = y * scale + y = clamp(y, lo, hi) + y = relu(y) + pattern_after: | + @triton.jit + def fused_elementwise(a_ptr, b_ptr, out_ptr, ..., lo, hi, scale): + x = tl.load(a_ptr + offsets, mask=mask, other=0.0) + y = tl.load(b_ptr + offsets, mask=mask, other=0.0) + x = x + y + x = x * scale + x = tl.maximum(x, lo) + x = tl.minimum(x, hi) + x = tl.maximum(x, 0.0) + tl.store(out_ptr + offsets, x, mask=mask) + expected_speedup: "1.2-3x (bandwidth dependent)" + applies_to: + - all_elementwise + - normalization_epilogues + + - id: gemm_activation_fusion + name: "GEMM + Activation Fusion (conditional)" + stage: fusion + description: "Fuse activation into GEMM ONLY when you already need a custom GEMM" + rationale: | + Fusion removes an intermediate write+read of the GEMM output. This can be beneficial + when you already use a custom GEMM kernel (special layout, custom math, no vendor path). + However, replacing a vendor GEMM with a custom fused GEMM for a tiny activation can be slower. + applies_when: + - uses_custom_gemm: true + - intermediate_materialized: true + - replaces_vendor_gemm: false + rejects_when: + - replaces_vendor_gemm: true + - epilogue_is_tiny: true + - num_pid_m_le_1: true + pattern_before: | + # Kernel 1: GEMM (custom) + @triton.jit + def gemm_kernel(a_ptr, b_ptr, c_ptr, ...): + acc = tl.dot(a, b) + tl.store(c_ptr + offsets, acc, mask=mask) + + # Kernel 2: Activation (separate launch) + @triton.jit + def activation_kernel(c_ptr, out_ptr, ...): + c = tl.load(c_ptr + offsets, mask=mask, other=0.0) + out = tl.maximum(c, 0.0) # ReLU + tl.store(out_ptr + offsets, out, mask=mask) + pattern_after: | + # Single fused kernel (custom GEMM + activation epilogue) + @triton.jit + def fused_gemm_relu(a_ptr, b_ptr, out_ptr, ...): + acc = tl.dot(a, b) + acc = tl.maximum(acc, 0.0) # ReLU + tl.store(out_ptr + offsets, acc, mask=mask) + expected_speedup: "workload-dependent (positive if avoiding a large materialized intermediate; negative if replacing vendor GEMM)" + applies_to: + - custom_gemm_only + - mlp + - transformer + + - id: gemm_bias_activation_fusion + name: "GEMM + Bias + Activation Chain (partial fusion recommended)" + stage: fusion + description: "Fuse bias + light activation into GEMM epilogue; split heavy chains if needed" + rationale: | + Bias add + simple activations are often good epilogues. Very heavy activation chains + (multiple exp/tanh/clamp) can increase register pressure and reduce occupancy. + Prefer partial fusion when needed: + - fuse bias + simple activation + - keep heavy post-processing in a separate kernel + applies_when: + - uses_custom_gemm: true + - intermediate_materialized: true + - replaces_vendor_gemm: false + rejects_when: + - replaces_vendor_gemm: true + - activation_chain_is_heavy: true + pattern_before: | + c = gemm(a, b) + c = c + bias + c = swish(c) + c = clamp(c, -1, 1) + c = tanh(c) + pattern_after: | + @triton.jit + def fused_gemm_bias_activation(..., bias_ptr, out_ptr, ...): + acc = tl.dot(a, b) + bias = tl.load(bias_ptr + offs_n, mask=mask_n, other=0.0) + acc = acc + bias[None, :] + + # Light activation epilogue (example: ReLU or SiLU) + # ReLU: + acc = tl.maximum(acc, 0.0) + + tl.store(out_ptr + offsets, acc, mask=mask) + # If you truly need heavy chains (swish+clamp+tanh), consider splitting: + # - fused GEMM+bias+swish + # - separate clamp+tanh kernel + expected_speedup: "workload-dependent" + applies_to: + - mlp + - gelu + - silu + + - id: reduction_elementwise_fusion + name: "Reduction + Elementwise Fusion (usually good; watch pressure)" + stage: fusion + description: "Fuse reductions with subsequent broadcast elementwise ops" + rationale: | + Reductions followed by broadcast (softmax / layernorm / rmsnorm) often benefit from + fusion because it avoids writing the reduction results (max/sum/mean/var) to memory. + However, very large reductions or overly large blocks can increase register/shared + usage. If performance regresses, consider a 2-stage approach. + pattern_before: | + # Compute max for numerical stability + max_val = tl.max(x, axis=1) + x_stable = x - max_val[:, None] + # Compute exp + exp_x = tl.exp(x_stable) + # Sum for softmax denominator + sum_exp = tl.sum(exp_x, axis=1) + # Normalize + softmax = exp_x / sum_exp[:, None] + pattern_after: | + max_val = tl.max(x, axis=1, keep_dims=True) + x_stable = x - max_val + exp_x = tl.exp(x_stable) + sum_exp = tl.sum(exp_x, axis=1, keep_dims=True) + softmax = exp_x / sum_exp + expected_speedup: "1.2-3x (shape dependent)" + applies_to: + - softmax + - layernorm + - attention + + - id: fusion_skip_when_dead_or_constant + name: "Skip fusion if the fused op is dead / constant / no-op" + stage: fusion + description: "Do not fuse operations that are provably redundant for the given workload" + rationale: | + If an op is effectively a no-op for the given constants/data distribution, + fusion adds extra instructions without reducing memory traffic meaningfully. + + Examples: + - clamp range is so wide it never triggers + - min/max with a constant that is outside observed values + - multiply by 1, add 0 + expected_speedup: "prevents regressions" + applies_to: + - all + + - id: gemm_matrix_add_epilogue + name: "GEMM + matrix add (residual/bias matrix) epilogue fusion" + stage: fusion + description: "Load a full matrix D after the GEMM K-loop and add to accumulator before storing" + rationale: | + A common pattern is C = A @ B + D where D is a full [M, N] matrix (residual + connection, bias broadcast, etc.). Fusing the matrix add into the GEMM epilogue + avoids a separate kernel launch and an extra read+write of the [M, N] output. + + The epilogue is lightweight (one descriptor load + one add), so register pressure + impact is minimal and this fusion is almost always a win when you already have a + custom GEMM kernel. + + Tensor descriptors preferred on XPU; block pointers also supported. + pattern_before: | + # Separate: GEMM then add + c = matmul(a, b) # writes [M, N] to memory + c = c + d # reads [M, N] + [M, N], writes [M, N] + pattern_after: | + # Fused: add D inside GEMM epilogue (tensor descriptor variant) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + off_k = 0 + for _ in range(0, K, BLOCK_SIZE_K): + a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N]) + accumulator += tl.dot(a, b) + off_k += BLOCK_SIZE_K + + # Epilogue: load D tile and add + d_desc = tl.make_tensor_descriptor(base=d_ptr, shape=(M, N), + strides=(stride_dm, stride_dn), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N)) + d = d_desc.load([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N]) + c = accumulator + d + + c_desc = tl.make_tensor_descriptor(base=c_ptr, shape=(M, N), + strides=(stride_cm, stride_cn), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N)) + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], c) + expected_speedup: "eliminates one full [M, N] read+write pass" + applies_to: + - gemm + - residual_add + - bias_matrix + - transformer + notes: | + - D can be the same dtype as the accumulator or a different dtype (Triton auto-casts) + - For batched variants, create d_desc with base=d_ptr + batch_offset + - This pattern extends to any lightweight elementwise epilogue (add, mul, clamp, relu) + applied after the K-loop and before the store + + - id: algebraic_weight_folding + name: "Algebraic Weight Folding (BN/scale/affine into GEMM weights)" + stage: fusion + description: | + Fold per-channel linear transforms into GEMM weights at pack time, + eliminating the epilogue entirely. This is the most impactful fusion + for kernels with BatchNorm inference, per-channel scaling, or affine transforms. + rationale: | + BatchNorm inference is a per-channel linear operation: + BN(y) = gamma * (y - mean) / sqrt(var + eps) + beta + + When GEMM output feeds through scaling then BN: + y = (x @ W^T + bias) * scale + output = BN(y) = gamma * (y - mean) * inv_std + beta + + This entire chain is linear in W and bias. We can pre-compute: + alpha = gamma * scale * inv_std [N] + W_fused[n, :] = alpha[n] * W[n, :] [N, K] + b_fused[n] = alpha[n] * bias[n] - gamma[n] * inv_std[n] * mean[n] + beta[n] [N] + + The kernel becomes a pure GEMM + bias — zero epilogue overhead, zero extra + vector loads, maximum compute efficiency. + applies_when: + - BatchNorm or LayerNorm in inference mode (fixed running stats) + - Per-channel scaling (multiply by [N] vector) + - Any per-channel affine transform: y = a*x + b + rejects_when: + - Training mode (statistics computed per-batch, not fixed) + - Non-linear epilogues (sigmoid, ReLU) — cannot be folded + - Operations depending on the M dimension (row-wise reductions) + pattern_before: | + # _pack_weights: just transpose + self.weight_t = w.to(device).t().contiguous() + self.bias_xpu = b.to(device) + self.scale_xpu = s.to(device) + self.gamma_xpu = gamma.to(device) + # ... 6 more buffers + + # kernel epilogue: 6 vector loads + arithmetic + bias = tl.load(bias_ptr + offs_n, mask=mask_n, other=0.0) + acc = acc + bias[None, :] + scale = tl.load(scale_ptr + offs_n, mask=mask_n, other=1.0) + acc = acc * scale[None, :] + mean = tl.load(mean_ptr + offs_n, mask=mask_n, other=0.0) + var = tl.load(var_ptr + offs_n, mask=mask_n, other=1.0) + gamma = tl.load(gamma_ptr + offs_n, mask=mask_n, other=1.0) + beta = tl.load(beta_ptr + offs_n, mask=mask_n, other=0.0) + inv_std = 1.0 / tl.sqrt(var + eps) + acc = (acc - mean[None, :]) * inv_std[None, :] + acc = acc * gamma[None, :] + beta[None, :] + pattern_after: | + # _pack_weights: fold BN + scale into weights (one-time cost) + inv_std = 1.0 / torch.sqrt(rv + self.eps) + alpha = gamma * s * inv_std + w_fused = alpha.unsqueeze(1) * w + b_fused = alpha * b - gamma * inv_std * rm + beta + self.weight_t = w_fused.to(device).t().contiguous().to(torch.bfloat16) + self.bias_fused = b_fused.to(device) + + # kernel epilogue: just bias add (everything else is folded) + bias = tl.load(bias_ptr + offs_n, mask=mask_n, other=0.0) + acc = acc + bias[None, :] + expected_speedup: "1.5-2x on top of Level 1 optimizations (eliminates epilogue entirely)" + applies_to: + - gemm + - batchnorm + - inference + - mlp + examples: + - file: test_kernels/39_Gemm_Scale_BatchNorm_triton.py + description: "Folds Linear + Scale + BatchNorm into pure GEMM. Level 1: 2.69x → Level 3: 5.28x" \ No newline at end of file diff --git a/kernel-builder/skills/xpu-kernels/references/huggingface-kernels-integration.md b/kernel-builder/skills/xpu-kernels/references/huggingface-kernels-integration.md new file mode 100644 index 00000000..0d39119c --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/huggingface-kernels-integration.md @@ -0,0 +1,350 @@ +# HuggingFace Kernels Integration Guide (XPU) + +Complete guide for using and publishing kernels with the HuggingFace Kernels library (`get_kernel`) on Intel XPU. + +> **Quick Start:** See [huggingface_kernels_example.py](../scripts/huggingface_kernels_example.py) for a minimal working example. + +## Overview + +The [HuggingFace Kernels](https://huggingface.co/docs/kernels/en/index) library enables dynamic loading of pre-compiled kernels from the Hugging Face Hub. This eliminates the need for local compilation and ensures compatibility across different Python, PyTorch, and backend versions. + +**Key Benefits:** +- **No local compilation** — download pre-built binaries +- **Version management** — load specific kernel versions +- **Multi-version support** — multiple versions coexist in one Python process +- **Automatic compatibility** — matches your PyTorch configuration + +**XPU Note:** Not all Hub kernels have XPU builds. Triton-based kernels (e.g., `triton-layer-norm`) are more likely to work on XPU than CUDA C kernels. Always check with `has_kernel()` first. + +## Installation + +```bash +pip install kernels torch numpy +``` + +Requirements: +- PyTorch >= 2.5 (XPU build) +- Intel XPU GPU +- Python 3.8+ + +## Core API + +### get_kernel + +Download and load a kernel from the Hub: + +```python +from kernels import get_kernel + +kernel = get_kernel("kernels-community/triton-layer-norm") + +# With specific version +kernel = get_kernel("kernels-community/triton-layer-norm", version=1) + +# With specific revision +kernel = get_kernel("kernels-community/flash-attn", revision="v2.0.0") +``` + +**Parameters:** + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `repo_id` | str | required | Hub repository (e.g., "kernels-community/activation") | +| `revision` | str | "main" | Branch, tag, or commit hash | +| `version` | int/str | None | Kernel version number (mutually exclusive with `revision`) | + +**Returns:** `ModuleType` — the imported kernel module + +### has_kernel + +Check if a kernel build exists for your environment: + +```python +from kernels import has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + kernel = get_kernel("kernels-community/triton-layer-norm") +else: + print("No compatible build for this XPU/PyTorch version") +``` + +### get_local_kernel + +Load a kernel from a local path (useful during development): + +```python +from kernels import get_local_kernel + +kernel = get_local_kernel("/path/to/my-kernel") +``` + +### load_kernel & get_locked_kernel + +For reproducible, offline-capable deployments using lockfiles: + +```python +from kernels import load_kernel, get_locked_kernel + +kernel = load_kernel("lockfile.json") +kernel = get_locked_kernel("kernels-community/activation", lockfile="kernel.lock") +``` + +## Usage Examples + +### 1. RMSNorm Kernel from Hub + +**Note:** The actual function name may vary by kernel version. Use `dir(kernel)` to inspect, and check for `rms_norm_fn`, `rms_norm`, or `rmsnorm`. + +```python +import torch +from kernels import get_kernel, has_kernel + +repo_id = "kernels-community/triton-layer-norm" + +if has_kernel(repo_id): + layer_norm = get_kernel(repo_id) + + # Inspect available functions + print([f for f in dir(layer_norm) if not f.startswith('_')]) + # e.g. ['layer_norm', 'layer_norm_fn', 'rms_norm_fn', ...] + + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="xpu") + weight = torch.ones(2048, dtype=torch.bfloat16, device="xpu") + + # Use the actual function name (rms_norm_fn in current version) + out = layer_norm.rms_norm_fn(x, weight, eps=1e-6) + print(f"Output shape: {out.shape}") +else: + print("No XPU-compatible build available") +``` + +### 2. Integration with Transformers Models + +```python +import torch +from kernels import get_kernel, has_kernel + +repo_id = "kernels-community/triton-layer-norm" + +if has_kernel(repo_id): + rmsnorm_kernel = get_kernel(repo_id) + + def patch_rmsnorm_with_hub_kernel(model): + """Patch model's RMSNorm to use Hub kernel.""" + patched = 0 + for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + eps = getattr(module, 'variance_epsilon', None) or getattr(module, 'eps', 1e-6) + + def make_forward(mod, epsilon): + def forward(hidden_states): + return rmsnorm_kernel.rms_norm(hidden_states, mod.weight, eps=epsilon) + return forward + + module.forward = make_forward(module, eps) + patched += 1 + return patched +``` + +### 3. Integration with Diffusers Pipelines + +```python +import torch +from diffusers import LTXPipeline +from kernels import get_kernel, has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + rmsnorm_kernel = get_kernel("kernels-community/triton-layer-norm") + + def patch_rmsnorm(model): + for name, module in model.named_modules(): + if type(module).__name__ == 'RMSNorm': + eps = getattr(module, 'eps', 1e-6) + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_forward(mod, epsilon): + def forward(x): + return rmsnorm_kernel.rms_norm(x, mod.weight, eps=epsilon) + return forward + module.forward = make_forward(module, eps) + + pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + pipe.to("xpu") + patch_rmsnorm(pipe.transformer) +``` + +### 4. Benchmark Hub Kernel vs PyTorch + +```python +import time +import torch +from kernels import get_kernel + +kernel = get_kernel("kernels-community/triton-layer-norm") + +sizes = [(2, 1024, 2048), (4, 4096, 4096)] +for shape in sizes: + x = torch.randn(shape, dtype=torch.bfloat16, device="xpu") + w = torch.ones(shape[-1], dtype=torch.bfloat16, device="xpu") + + for _ in range(10): + kernel.rms_norm(x, w, eps=1e-6) + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.xpu.synchronize() + + iters = 100 + start = time.perf_counter() + for _ in range(iters): + kernel.rms_norm(x, w, eps=1e-6) + torch.xpu.synchronize() + hub_ms = (time.perf_counter() - start) / iters * 1000 + + start = time.perf_counter() + for _ in range(iters): + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.xpu.synchronize() + pt_ms = (time.perf_counter() - start) / iters * 1000 + + print(f"Shape {shape}: Hub={hub_ms:.3f}ms, PyTorch={pt_ms:.3f}ms, Speedup={pt_ms/hub_ms:.2f}x") +``` + +## XPU-Specific Notes + +### Kernel Compatibility + +Not all Hub kernels have XPU builds: + +| Kernel Type | XPU Support | Notes | +|-------------|:----------:|-------| +| Triton-based (e.g., `triton-layer-norm`) | Likely | Triton compiles via Intel XPU backend | +| CUDA C-based (e.g., `flash-attn`) | Check | Needs explicit XPU build | +| Custom CUDA ops | Unlikely | CUDA-only unless ported | + +**Always check availability first:** +```python +from kernels import has_kernel + +if has_kernel("kernels-community/triton-layer-norm"): + print("XPU build available") +else: + print("No XPU build — use local Triton kernel instead") +``` + +### Fallback Strategy + +When a Hub kernel is not available for XPU, fall back to the local Triton implementation: + +```python +from kernels import has_kernel, get_kernel + +def get_rmsnorm_function(): + """Get best available RMSNorm implementation.""" + if has_kernel("kernels-community/triton-layer-norm"): + kernel = get_kernel("kernels-community/triton-layer-norm") + return lambda x, w, eps: kernel.rms_norm(x, w, eps=eps) + else: + from your_local_kernels import triton_rmsnorm + return triton_rmsnorm +``` + +### Environment Check + +```python +import torch +print(f"PyTorch: {torch.__version__}") +print(f"XPU available: {torch.xpu.is_available()}") +print(f"GPU: {torch.xpu.get_device_name()}") +``` + +## Publishing Kernels to Hub + +### Triton Kernel Project Structure + +For Triton-based kernels (best XPU compatibility): + +``` +my-triton-kernel/ +├── build.toml +├── kernel_src/ +│ └── rmsnorm.py # Triton kernel source +└── torch-ext/ + ├── torch_binding.cpp + └── my_kernels/ + └── __init__.py +``` + +### build.toml for Triton Kernels + +```toml +[general] +name = "my_triton_kernels" +backends = ["cuda", "xpu"] # Include XPU backend + +[torch] +src = ["torch-ext/torch_binding.cpp"] + +[kernel.rmsnorm] +backend = "triton" +src = ["kernel_src/rmsnorm.py"] +depends = ["torch"] +``` + +### Build and Publish + +```bash +pip install kernel-builder +kernel-builder build + +huggingface-cli repo create your-org/your-kernel --type model +huggingface-cli upload your-org/your-kernel ./dist +``` + +### Others Load It + +```python +from kernels import get_kernel + +rmsnorm = get_kernel("your-org/your-kernel") +``` + +## Available Community Kernels + +Popular kernels from `kernels-community`: + +| Kernel | Description | XPU? | +|--------|-------------|:----:| +| `triton-layer-norm` | LayerNorm, RMSNorm | Likely | +| `activation` | GELU, SiLU, etc. | Check | +| `flash-attn` | Flash Attention 2 | Check | +| `quantization` | INT8/INT4 ops | Check | + +Browse all kernels: https://huggingface.co/kernels-community + +## Caching and Offline Usage + +```python +import os +os.environ["HF_HUB_OFFLINE"] = "1" + +# Will only use cached kernels +kernel = get_kernel("kernels-community/triton-layer-norm") +``` + +## Best Practices + +1. **Always check availability** — `has_kernel()` before `get_kernel()` +2. **Pin versions** — `get_kernel(repo, version=1)` for reproducibility +3. **Have a fallback** — local Triton kernel when Hub build is unavailable +4. **Use lockfiles in production** — `load_kernel("kernel.lock")` +5. **Test on your GPU** — verify correctness after loading + +## See Also + +- [HuggingFace Kernels Documentation](https://huggingface.co/docs/kernels/en/index) +- [HuggingFace Kernels GitHub](https://github.com/huggingface/kernels) +- [Kernel Builder Documentation](https://github.com/huggingface/kernel-builder) +- [Community Kernels](https://huggingface.co/kernels-community) +- [Blog: Learn the Kernel Hub in 5 Minutes](https://huggingface.co/blog/hello-hf-kernels) diff --git a/kernel-builder/skills/xpu-kernels/references/implementation_reference.md b/kernel-builder/skills/xpu-kernels/references/implementation_reference.md new file mode 100644 index 00000000..8adba6be --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/implementation_reference.md @@ -0,0 +1,221 @@ +# Implementation Reference + +Code templates and patterns for Triton kernel development on Intel XPU. + +## Template Selection + +Start with a template that matches your kernel type. The core patterns are shown below: +- Basic GEMM (with tensor descriptors and tile swizzling) +- GEMM with fused epilogue +- Reduction operations + +## Core Implementation Pattern + +Generated kernels must be **self-contained and shareable**. Define all helper functions inline within the kernel file. + +```python +import math +import torch +import torch.nn as nn +import triton +import triton.language as tl + +# ============================================================================ +# Helper functions (inline definitions for self-contained kernel) +# ============================================================================ + +# Constants +kAlpha = tl.constexpr(math.sqrt(2.0 / math.pi)) # For GeLU +kInvLn2 = tl.constexpr(1.4426950408889634) # For exp2-based ops + +@triton.jit +def swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M): + """Tile swizzling for L2 cache locality""" + grid_m = tl.cdiv(M, BLOCK_SIZE_M) + grid_n = tl.cdiv(N, BLOCK_SIZE_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + +@triton.autotune( + configs=[ + # Large tiles for square GEMMs + triton.Config( + {'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, + num_warps=32, num_stages=2 + ), + triton.Config( + {'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, + num_warps=16, num_stages=3 + ), + # Medium tiles + triton.Config( + {'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, + num_warps=8, num_stages=4 + ), + triton.Config( + {'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 4, 'grf_mode': '256'}, + num_warps=16, num_stages=3 + ), + # Skinny-M configs (for M < 256) + triton.Config( + {'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 2, 'grf_mode': '256'}, + num_warps=8, num_stages=4 + ), + triton.Config( + {'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_SIZE_M': 2, 'grf_mode': '256'}, + num_warps=4, num_stages=5 + ), + ], + key=['M', 'N', 'K'], +) +@triton.jit +def kernel( + # Pointers + a_ptr, b_ptr, c_ptr, + # Shapes (as constexpr for better codegen) + M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + # Strides + stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bk: tl.constexpr, stride_bn: tl.constexpr, + stride_cm: tl.constexpr, stride_cn: tl.constexpr, + # Meta-parameters (NO defaults!) + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Optimized GEMM kernel for Intel XPU using tensor descriptors.""" + # Tile swizzling (1D grid) + pid = tl.program_id(0) + pid_m, pid_n = swizzle_tile(pid, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_SIZE_M) + + # Tensor descriptors (preferred on XPU — better codegen than block pointers) + a_desc = tl.make_tensor_descriptor( + base=a_ptr, shape=[M, K], strides=[stride_am, stride_ak], + block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, shape=[K, N], strides=[stride_bk, stride_bn], + block_shape=[BLOCK_K, BLOCK_N], + ) + + # Accumulator (fp32 for numerical stability) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + # K-loop + off_m = pid_m * BLOCK_M + off_n = pid_n * BLOCK_N + for off_k in range(0, K, BLOCK_K): + a = a_desc.load([off_m, off_k]) + b = b_desc.load([off_k, off_n]) + a = a.to(tl.bfloat16) + b = b.to(tl.bfloat16) + acc += tl.dot(a, b) + + # Store result + c_desc = tl.make_tensor_descriptor( + base=c_ptr, shape=[M, N], strides=[stride_cm, stride_cn], + block_shape=[BLOCK_M, BLOCK_N], + ) + c_desc.store([off_m, off_n], acc) +``` + +## Model Class Wrapper (ai-bench compatible) + +The Model class uses standard `nn.Module` patterns. ai-bench creates the model via `__init__()` and syncs weights using `copy_model_weights()`. + +```python +class Model(nn.Module): + def __init__(self, input_size, hidden_size, scaling_factor): + super().__init__() + self.input_size = input_size + self.hidden_size = hidden_size + self.scaling_factor = scaling_factor + self.gemm = nn.Linear(input_size, hidden_size) + self._packed = False + + def _pack_weights(self): + """Pack weight transpose once on XPU for fast tl.dot access.""" + device = torch.device("xpu") + w = self.gemm.weight.data.detach() + b = self.gemm.bias.data.detach() + self.weight_t = w.to(device, torch.float16).t().contiguous() + self.bias_xpu = b.to(device, torch.float16).contiguous() + self._packed = True + + def forward(self, x): + device = torch.device("xpu") + x = x.to(device, torch.float16).contiguous() + if not self._packed: + self._pack_weights() + + M, K = x.shape + N = self.weight_t.shape[1] + output = torch.empty((M, N), device=device, dtype=torch.float32) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), + ) + kernel[grid]( + x, self.weight_t, output, + M, N, K, + x.stride(0), x.stride(1), + self.weight_t.stride(0), self.weight_t.stride(1), + output.stride(0), output.stride(1), + ) + return output + +# ============================================================================ +# Benchmark harness interface (must match *_pytorch.py) +# ============================================================================ +batch_size = 1024 +input_size = 8192 +hidden_size = 8192 +scaling_factor = 2.0 + +def get_inputs(): + return [torch.rand(batch_size, input_size)] + +def get_init_inputs(): + return [input_size, hidden_size, scaling_factor] +``` + +## Example: GEMM Transformation + +**Input** (`test_kernels/14_Gemm_Divide_Sum_Scaling_pytorch.py`): +```python +x = torch.matmul(x, self.weight.T) # Gemm +x = x / 2 # Divide +x = torch.sum(x, dim=1, keepdim=True) # Sum +x = x * self.scaling_factor # Scaling +``` + +**Strategy**: +1. Use tensor descriptors for GEMM (preferred on XPU) +2. Fuse divide into GEMM epilogue (light) +3. Keep sum + scaling in separate reduction kernel (avoid serializing over N) + +**Output**: `gemm_kernel` (matmul + divide fused) + `row_sum_kernel` (sum + scaling). + +See `references/examples/gemm_activation_optimized.py` for a similar pattern. + +## File Naming Convention + +Spec YAML files live in `modules/ai-bench/problems/specs/KernelBench/level*/`. +Auto-detection strips suffixes (`_triton`, `_optimized`, `_opt`, `_pytorch`) from filename and searches `level1/`, `level2/`, `level3/`. Override with `--spec` if needed. + +## Activation Helpers + +```python +# exp2-based sigmoid (faster on XPU) +sigmoid(x) = 1 / (1 + exp2(-x * 1.44269504)) + +# tanh via sigmoid +tanh(x) = 2*sigmoid(2x) - 1 +``` + +Factor into reusable `@triton.jit` helpers defined inline in your kernel file. diff --git a/kernel-builder/skills/xpu-kernels/references/kernelbench-classification.md b/kernel-builder/skills/xpu-kernels/references/kernelbench-classification.md new file mode 100644 index 00000000..a0ec5c1c --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/kernelbench-classification.md @@ -0,0 +1,162 @@ +# KernelBench Operator Classification & Skill Mapping + +This document classifies KernelBench operators into categories and maps each to the appropriate kernel skill/pattern. + +## Classification Taxonomy + +### Level 1: Basic Operators (53 operators) + +#### Category A: GEMM / Matrix Multiplication (18 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 1 | Square matrix multiplication | Dense GEMM | Tile Swizzle + Autotune | +| 2 | Standard matrix multiplication | Dense GEMM (M!=N) | Tile Swizzle + Autotune | +| 3 | Batched matrix multiplication | BMM | Batch-indexed GEMM | +| 4 | Matrix-vector multiplication | MatVec | 1D reduction pattern | +| 5 | Matrix-scalar multiplication | Elementwise | Scale kernel | +| 6 | Matmul with large K | Large-K GEMM | K-dimension blocking | +| 7 | Matmul with small K | Small-K GEMM | Fewer K-iterations | +| 8 | Matmul with irregular shapes | Non-square GEMM | Mask handling | +| 9 | Tall-skinny matmul | Tall-skinny GEMM | Tile shape tuning | +| 10 | 3D tensor-matrix mul | Batched GEMM | Reshape + GEMM | +| 11 | 4D tensor-matrix mul | Batched GEMM | Einsum decomposition | +| 12 | Diagonal matrix mul | Special GEMM | Elementwise pattern | +| 13 | Symmetric matrices | Dense GEMM | Standard GEMM | +| 14 | Upper triangular mul | Masked GEMM | Triangle mask | +| 15 | Lower triangular mul | Masked GEMM | Triangle mask | +| 16 | Transposed A | Transposed GEMM | Stride adjustment | +| 17 | Transposed B | Transposed GEMM | Stride adjustment | +| 18 | Both transposed | Transposed GEMM | Stride adjustment | + +**Key Pattern**: Template 5 (GEMM with Tile Swizzle) +**Critical Optimization**: Tile swizzle + L2 cache grouping + tensor descriptors + +#### Category B: Elementwise / Activation Functions (14 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 19 | ReLU | Branching | `tl.where(x > 0, x, 0)` | +| 20 | LeakyReLU | Branching | `tl.where(x > 0, x, alpha*x)` | +| 21 | Sigmoid | Transcendental | `1/(1+exp(-x))` | +| 22 | Tanh | Transcendental | `(exp(2x)-1)/(exp(2x)+1)` | +| 23 | Softmax | Row reduction | Online softmax | +| 24 | LogSoftmax | Row reduction | Online softmax + log | +| 25 | Swish/SiLU | Transcendental | `x * sigmoid(x)` | +| 26 | GELU | Transcendental | `0.5*x*(1+erf(x/sqrt(2)))` | +| 27 | SELU | Branching + exp | `scale * where(x>0, x, alpha*(exp(x)-1))` | +| 28 | HardSigmoid | Clamp | `clamp((x+3)/6, 0, 1)` | +| 29 | Softplus | Transcendental | `log(1+exp(x))` | +| 30 | Softsign | Division | `x/(1+abs(x))` | +| 31 | ELU | Branching + exp | `where(x>0, x, alpha*(exp(x)-1))` | +| 32 | HardTanh | Clamp | `clamp(x, -1, 1)` | + +**Key Pattern**: Template 1 (Elementwise) +**Critical Optimization**: Large BLOCK_SIZE (4096-16384), FP32 compute + +#### Category C: Normalization (8 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 33 | BatchNorm | Multi-dim reduction | Welford algorithm | +| 34 | InstanceNorm | Per-instance reduction | Per-sample norm | +| 35 | GroupNorm | Group reduction | Grouped channels | +| 36 | RMSNorm | Row reduction | `x * rsqrt(mean(x^2) + eps)` | +| 37 | FrobeniusNorm | Full reduction | `sqrt(sum(x^2))` | +| 38 | L1 Norm | Full reduction | `sum(abs(x))` | +| 39 | L2 Norm | Full reduction | `sqrt(sum(x^2))` | +| 40 | LayerNorm | Row reduction | `(x-mean)/std * w + b` | + +**Key Pattern**: Template 3 (Row-wise Reduction) +**Critical Optimization**: FP32 accumulation, proper reduction + +#### Category D: Pooling (6 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 41 | Max Pooling 1D | Sliding window | Max reduction | +| 42 | Max Pooling 2D | 2D window | 2D index mapping | +| 43 | Max Pooling 3D | 3D window | Program_id flattening | +| 44 | Average Pooling 1D | Sliding window | Sum + divide | +| 45 | Average Pooling 2D | 2D window | 2D index mapping | +| 46 | Average Pooling 3D | 3D window | Program_id flattening | + +**Key Challenge**: 3D grid mapping with Triton's program_id limits + +#### Category E: Reduction (7 operators) + +| ID | Name | Sub-type | Key Skill | +|----|------|----------|-----------| +| 47 | Sum reduction | Sum | `tl.sum()` | +| 48 | Mean reduction | Mean | `tl.sum() / count` | +| 49 | Max reduction | Max | `tl.max()` | +| 50 | Min reduction | Min | `tl.min()` | +| 51 | Argmax | Index + max | Two-pass or manual | +| 52 | Argmin | Index + min | Two-pass or manual | +| 53 | Min (duplicate) | Min | `tl.min()` | + +**Key Pattern**: Template 5 (Dimension Reduction) +**Key Challenge**: Argmax/Argmin require manual implementation + +### Level 2: Fused Operators (20+ operators) + +Combine multiple operations into single kernels. + +| Category | Examples | Strategy | +|----------|---------|----------| +| GEMM + Activation | Gemm_ReLU, Gemm_GELU | Fuse activation into GEMM epilogue | +| GEMM + Norm | Gemm_BatchNorm, Gemm_GroupNorm | Two-phase kernel | +| GEMM + Scale | Gemm_Scale, Gemm_Divide | Fuse into GEMM store | +| Multi-op fusion | Matmul_Sum_Max_AvgPool | Sequential fusion | + +**Key Pattern**: Template 6 (Fused GEMM + Activation) + +### Level 3-4: Network Models / Transformers + +Full models requiring multiple kernel types. Decompose into Level 1 operators. + +### Level 6-7: Advanced / Expert + +| Operator | Type | Strategy | +|----------|------|----------| +| MinGPTNewGelu | Fused activation | GELU approximation kernel | +| ScaledDotProductAttention | Attention | Flash Attention pattern | +| GELU_And_Mul | Fused activation | `gelu(x) * y` | +| MoE_TopK_Softmax | MoE routing | Specialized kernel | +| Gemm_A8W8_Blockwise | Quantized GEMM | INT8 with block scaling | + +## Category → Skill Mapping + +| Category | Skill File | Priority | +|----------|-----------|----------| +| **GEMM** | `gemm-skill.md` (planned) | P0 - Most impactful | +| **Elementwise** | `elementwise-skill.md` (planned) | P0 - Most common | +| **Normalization** | `normalization-skill.md` (planned) | P1 - Frequently used | +| **Reduction** | `reduction-skill.md` (planned) | P1 - Common pattern | +| **Softmax** | `softmax-skill.md` (planned) | P1 - Critical for attention | +| **Pooling** | `pooling-skill.md` (planned) | P2 - Moderate complexity | +| **Attention** | `attention-skill.md` (planned) | P2 - High complexity | +| **Fused** | `fused-skill.md` (planned) | P2 - Combination patterns | + +## Performance Expectations by Category + +Based on kernel-agent test results: + +| Category | Achievable Speedup | Difficulty | Notes | +|----------|-------------------|------------|-------| +| Elementwise | 1.0-3.0x | Low | Large blocks, memory-bound | +| Reduction (sum/mean) | 1.5-5.0x | Medium | Good parallelism | +| Pooling | 1.5-5.0x | Medium | Grid mapping challenge | +| LayerNorm/RMSNorm | 1.5-2.0x | Medium | Row-wise reduction | +| Dense GEMM | 0.8-1.2x | High | Tile swizzle critical | +| Batched GEMM | 0.6-0.9x | High | Memory bandwidth limited | +| BatchNorm | <0.1x | Very High | HIP sync issues | +| Argmax/Argmin | FAIL | Very High | Triton API limitation | +| Fused operators | 0.3-1.0x | Very High | Correctness challenges | + +## Recommended Skill Development Order + +1. **Phase 1 (Quick wins)**: Elementwise activations, Sum/Mean reduction +2. **Phase 2 (Core)**: GEMM with tile swizzle, LayerNorm/RMSNorm +3. **Phase 3 (Advanced)**: Softmax, Pooling, Attention +4. **Phase 4 (Expert)**: Fused operators, BatchNorm, Quantized GEMM diff --git a/kernel-builder/skills/xpu-kernels/references/memory_patterns.yaml b/kernel-builder/skills/xpu-kernels/references/memory_patterns.yaml new file mode 100644 index 00000000..6228121c --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/memory_patterns.yaml @@ -0,0 +1,452 @@ +# Intel XPU Memory Patterns Optimizations +# These patterns focus on memory access, layout, and avoiding hidden sync/overhead. +# They are intended to be used alongside your XPU-specific compute patterns. + +constraints: + - id: no_device_to_host_scalar_sync + name: "Do NOT force device->host scalar sync in hot path" + severity: critical + description: | + Avoid .item(), float(tensor), int(tensor), printing device tensors, or any device->host scalar + extraction in forward()/kernel wrapper hot paths. This forces synchronization and kills perf. + + WRONG: + ```python + c = float(constant_tensor.item()) # syncs XPU -> host + ``` + + CORRECT: + - Keep constants as Python floats / CPU tensors + - Pass scalar to kernel as an argument + + - id: prefer_contiguous_inputs + name: "Prefer contiguous tensors for Triton kernels" + severity: critical + description: | + Triton kernels generally assume or strongly benefit from contiguous memory. + If inputs are strided/non-contiguous, do an explicit .contiguous() in wrapper + (outside the timed region if possible). + + WRONG: + ```python + # Non-contiguous input passed into kernel + kernel[grid](x_transposed, ...) + ``` + + CORRECT: + ```python + if not x.is_contiguous(): + x = x.contiguous() + kernel[grid](x, ...) + ``` + + - id: block_ptr_boundary_check_tuple + name: "Block pointer boundary_check must be tuple of ints" + severity: critical + description: | + tl.load(block_ptr) uses boundary_check=(dim0, dim1) where values are dimension indices. + Use (0,1) not booleans. + + CORRECT: + ```python + tl.load(ptr, boundary_check=(0, 1)) + ``` + + - id: no_tl_multiple_of_on_python_scalars + name: "Do NOT call tl.multiple_of / tl.max_contiguous on Python scalars or constexpr-like values" + severity: critical + description: | + tl.multiple_of / tl.max_contiguous are intended for Triton IR values (tensors/expressions), + not Python integers/constexpr-like scalars such as stride arguments. + + WRONG (can error like: 'constexpr' object has no attribute 'shape'): + ```python + tl.multiple_of(stride_xk, 1) + ``` + + CORRECT (apply to tensor expressions): + ```python + offs_n = col_start + tl.arange(0, BLOCK_N) + tl.multiple_of(offs_n, 8) # only if you KNOW it is aligned/multiple + ``` + +patterns: + - id: mem_block_pointers + name: "Block pointers for structured tiles (fallback — prefer tensor descriptors on XPU)" + stage: memory_access + description: "Use tl.make_block_ptr + tl.load(boundary_check=...) for tiled loads/stores" + rationale: | + NOTE: On Intel XPU, tensor descriptors (tl.make_tensor_descriptor) are preferred + over block pointers — they produce better address generation codegen. See + xpu_optimizations.yaml (xpu_tensor_descriptor_note, xpu_descriptor_gemm_pattern). + + Block pointers improve addressing and bounds handling over manual pointer math. + Use them as a fallback when tensor descriptors are not suitable (e.g., legacy code). + pattern_before: | + offs_m = pid_m * BM + tl.arange(0, BM) + offs_k = k0 + tl.arange(0, BK) + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk + x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & (offs_k[None, :] < K), other=0.0) + pattern_after: | + x_bp = tl.make_block_ptr( + base=x_ptr, + shape=(M, K), + strides=(stride_xm, stride_xk), + offsets=(pid_m * BM, k0), + block_shape=(BM, BK), + order=(1, 0), + ) + x = tl.load(x_bp, boundary_check=(0, 1)) + expected_speedup: "workload-dependent (often positive when indexing is heavy)" + applies_to: [gemm, conv, attention, all_memory_bound] + examples: + - before: | + k0 += BK + after: | + x_bp = tl.advance(x_bp, (0, BK)) + + - id: mem_skip_boundary_check_when_divisible + name: "Specialize away boundary checks when shapes are divisible" + stage: memory_access + description: "Avoid boundary_check/masks when M/N/K are known multiples of tile sizes" + rationale: | + Masks and boundary checks add overhead. If your benchmark/problem guarantees divisibility + (or you can specialize per shape), you can remove masks and use unconditional loads/stores. + pattern_before: | + x = tl.load(x_bp, boundary_check=(0, 1)) + tl.store(o_bp, acc, boundary_check=(0, 1)) + pattern_after: | + # If M % BM == 0 and K % BK == 0 (specialized case): + x = tl.load(x_bp) + tl.store(o_bp, acc) + expected_speedup: "1.05-1.20x" + applies_to: [gemm, large_fixed_shapes] + notes: | + - Only safe when you can guarantee divisibility or guard with a specialized kernel variant. + + - id: mem_use_static_range_for_k_loop + name: "Use tl.static_range for K loops (when K is known / bounded)" + stage: memory_access + description: "Enable better unrolling/pipelining by using tl.static_range" + rationale: | + tl.static_range can help Triton generate better scheduled loops and reduce loop overhead. + Particularly useful in GEMM-like reductions and fixed-shape workloads. + pattern_before: | + for k in range(0, K, BK): + ... + pattern_after: | + for k in tl.static_range(0, K, BK): + ... + expected_speedup: "1.05-1.15x" + applies_to: [gemm, attention] + notes: | + Use only when K is a compile-time constant or you launch specialized kernels by K. + + - id: mem_pointer_arithmetic_coalescing + name: "Prefer contiguous (coalesced) access in the innermost dimension" + stage: memory_access + description: "Make the fastest-changing index map to contiguous memory" + rationale: | + Coalesced loads/stores are critical. In most layouts, the last dimension is contiguous. + Make tl.arange(...) along the contiguous stride dimension when possible. + pattern_before: | + # Strided / scattered pattern (example) + offs = base + tl.arange(0, BN) * stride_n + x = tl.load(x_ptr + offs) + pattern_after: | + # Coalesced: make arange hit contiguous dimension + offs_n = col_start + tl.arange(0, BN) + x = tl.load(x_ptr + offs_n) + expected_speedup: "workload-dependent (can be large if fixing uncoalesced access)" + applies_to: [all_memory_bound] + + - id: mem_cache_modifiers_optional + name: "Optional cache modifiers (backend-dependent)" + stage: memory_access + description: "Use cache_modifier/eviction_policy carefully; keep it optional" + rationale: | + cache modifiers can help (or hurt) depending on backend and access pattern. + Treat them as an autotune dimension rather than a hard rule. + pattern_before: | + x = tl.load(ptr) + pattern_after: | + # Example: try these as tunables if supported by your Triton backend + x = tl.load(ptr, cache_modifier=".ca") + # or: + x = tl.load(ptr, cache_modifier=".cg") + expected_speedup: "workload-dependent" + applies_to: [streaming_reads, bandwidth_bound] + notes: | + - Do not assume Intel XPU backend interprets modifiers the same as CUDA. + - Use only if you confirm correctness and performance. + + - id: mem_alignment_hints_on_offsets + name: "Alignment hints on OFFSET tensors (not strides)" + stage: memory_access + description: "Use tl.multiple_of / tl.max_contiguous on tensor offset expressions" + rationale: | + Alignment hints can improve vectorization/coalescing when you truly have aligned access. + Apply hints to offset tensors (e.g., tl.arange expressions) and only when guaranteed. + pattern_before: | + offs_n = col_start + tl.arange(0, BN) + x = tl.load(x_ptr + offs_n) + pattern_after: | + offs_n = col_start + tl.arange(0, BN) + tl.max_contiguous(offs_n, BN) + # Only if you KNOW alignment/multiple is valid: + # tl.multiple_of(offs_n, 8) + x = tl.load(x_ptr + offs_n) + expected_speedup: "small but sometimes measurable" + applies_to: [gemm, bandwidth_bound] + notes: | + Do NOT apply tl.multiple_of blindly—incorrect assumptions can lead to wrong-code. + + - id: mem_avoid_redundant_reads + name: "Load once, reuse in registers" + stage: memory_access + description: "Avoid re-loading the same values inside inner loops" + rationale: | + Redundant tl.load calls inside reduction loops can dominate memory traffic. + If values are reused (e.g., bias, scale), load once per tile and broadcast. + pattern_before: | + for k in ...: + b = tl.load(b_ptr + offs_n) # re-loaded each iteration (waste) + acc += ... + pattern_after: | + b = tl.load(b_ptr + offs_n) # load once + for k in ...: + acc += ... + acc += b[None, :] + expected_speedup: "1.05-1.30x (depends on redundancy)" + applies_to: [gemm_epilogues, fused_pointwise] + + - id: mem_wrapper_one_time_device_moves + name: "One-time device moves in wrapper (avoid per-forward transfers)" + stage: memory_access + description: "Move parameters to XPU once, not every forward" + rationale: | + Per-forward .to('xpu') or .data = .to('xpu') costs time and can interfere with timing. + Use a one-time guard in the module. + pattern_before: | + def forward(self, x): + if self.weight.device.type != "xpu": + self.weight.data = self.weight.data.to("xpu") + ... + pattern_after: | + def __init__(...): + self._moved_to_xpu = False + + def forward(self, x): + if not self._moved_to_xpu: + self.weight.data = self.weight.data.to("xpu") + self._moved_to_xpu = True + expected_speedup: "huge for microbenchmarks; correctness/measurement improvement" + applies_to: [all] + notes: | + Ideally move model+inputs to XPU outside the timed region in the harness. + + - id: mem_layout_transform_prepack + name: "Pre-pack / pre-transform weights for better access" + stage: memory_access + description: "Change weight layout to match access pattern (when allowed)" + rationale: | + If your kernel accesses W as W^T (K,N), storing W in a layout that makes K-contiguous + for the inner loop can improve coalescing/cache behavior. This is a higher-level change. + pattern_before: | + # W stored as [N, K], kernel reads as [K, N] view + w_ptr shape=(K, N), strides=(stride_wk, stride_wn) + pattern_after: | + # Prepack once (outside timed loop), store as [K, N] or blocked layout if acceptable + # so inner K loads are contiguous in memory for each program. + # (Exact layout depends on your kernel and framework constraints.) + expected_speedup: "workload-dependent (can be large for bandwidth-limited kernels)" + applies_to: [gemm, attention] + notes: | + - Only if you control weight storage and can amortize the prepack cost. + - Ensure KernelBench correctness expectations still hold. + + - id: reduce_liveness_sink_load_and_prefetch + name: "Reduce variable liveness: prefetch early, load late (sink load closer to dot)" + stage: memory_access + description: "Replace long-lived live-in operand loads with prefetch + per-iteration load near dot/use" + rationale: | + On Intel XPU, keeping a large operand tile live across a long loop can reserve many registers, + increasing GRF pressure and causing spills. A better approach is: + 1) prefetch the data earlier to pull it into L1, then + 2) load the operand into registers only right before it is used (inside the loop), + reducing liveness and helping the register allocator. + + This is especially relevant when a tensor is loaded outside a loop and used repeatedly + inside the loop (classic FlashAttention Q tile scenario). + pattern_before: | + # Load once, keep live across loop (long liveness) + q = tl.load(q_ptrs) # q stays live for the whole loop + for k0 in range(0, K, BLOCK_K): + k = tl.load(k_ptrs) + acc += tl.dot(q, k) # q used repeatedly + ... + pattern_after: | + # Prefetch outside loop (warm cache) + # (pseudo; actual prefetch API may differ by backend/compiler lowering) + tl.prefetch(q_ptrs) # bring q closer (L1) without allocating regs long-term + + for k0 in range(0, K, BLOCK_K): + # Load right before use (short liveness) + q = tl.load(q_ptrs) # q live only for a short window + k = tl.load(k_ptrs) + acc += tl.dot(q, k) + ... + expected_speedup: "1.05-1.15x (selective; can regress if cache misses dominate)" + applies_to: + - attention + - flashattention + - dot_in_loop + - reduction_in_loop + + - id: reduce_liveness_duplicate_load_for_multi_use + name: "Reduce liveness when operand has multiple uses (duplicate loads near each use)" + stage: memory_access + description: "If a large tensor has multiple distant uses, reload near each use to shorten live ranges" + rationale: | + If a large tensor is used in multiple regions far apart (e.g., two loops or two stages), + keeping it live across both regions can reserve registers too long. On XPU, reloading + from L1 can be cheaper than spilling registers to memory. + pattern_before: | + q = tl.load(q_ptrs) # q stays live a long time + for ...: + acc += tl.dot(q, k) + ... + for ...: + acc2 += tl.dot(q, k2) # second use far away keeps q live across both loops + ... + pattern_after: | + tl.prefetch(q_ptrs) + for ...: + q = tl.load(q_ptrs) # short liveness + acc += tl.dot(q, k) + ... + for ...: + q = tl.load(q_ptrs) # reload for second stage + acc2 += tl.dot(q, k2) + ... + expected_speedup: "0.95-1.10x (only when it avoids spills; otherwise may regress)" + applies_to: + - attention + - multi_stage_kernels + - dot_in_loop + + - id: avoid_long_lived_large_tiles_across_control_flow + name: "Avoid long-lived tiles across control-flow splits (two-loop / masked attention patterns)" + stage: memory_access + description: "Prefer loading tiles inside each loop when control-flow creates long liveness (off-band/on-band)" + rationale: | + Control-flow splits (e.g., two loops for masked/unmasked regions) often extend the + lifetime of large tiles (Q) across both loops. On XPU, that increases GRF pressure. + Loading inside each loop reduces the live range and can prevent spills. + pattern_before: | + q = tl.load(Q_block_ptr) # q live across both loops + for start_n in range(lo, hi, BLOCK_N): + ... qk += tl.dot(q, k) ... + for start_n in range(lo, hi, BLOCK_N): + ... qk += tl.dot(q, k) ... # second loop extends q lifetime + pattern_after: | + tl.prefetch(Q_block_ptr) + for start_n in range(lo, hi, BLOCK_N): + q = tl.load(Q_block_ptr) # load per-iteration or per-loop (shorter liveness) + ... qk += tl.dot(q, k) ... + for start_n in range(lo, hi, BLOCK_N): + q = tl.load(Q_block_ptr) # reload for second loop + ... qk += tl.dot(q, k) ... + expected_speedup: "1.05-1.10x (most likely on masked/causal variants)" + applies_to: + - attention + - flashattention + - causal_mask + + - id: trade_bandwidth_for_regs_when_spilling + name: "Prefer extra loads over GRF spills (XPU-first rule)" + stage: memory_access + description: "When register pressure is high, allow re-loads (from cache) to reduce spills" + rationale: | + When GRF pressure is high, the real cost is often spilling to memory. + Reloading a tile (especially if it hits in L1) can be cheaper than spilling + and enables other compiler optimizations (unrolling/scheduling). + pattern_before: | + # Keep many intermediates live to avoid re-loads + a = tl.load(...) + b = tl.load(...) + c = tl.load(...) + # long epilogue chain keeps many values live + y = f(g(h(a,b,c,...))) + pattern_after: | + # Shorten live ranges: recompute/reload where cheap + tl.prefetch(...) + a = tl.load(...) # load near use + y = f(a) + # if needed later, reload instead of keeping live + a2 = tl.load(...) + z = g(a2) + expected_speedup: "workload-dependent; key benefit is avoiding catastrophic regressions" + applies_to: + - flashattention + - long_epilogues + - register_heavy_kernels + + - id: do_not_apply_if_cache_eviction_likely + name: "Do NOT sink loads if cache eviction is likely (small L1 / high conflict risk)" + stage: memory_access + description: "Guard against turning cheap L1 loads into expensive global loads" + rationale: | + The sink-load approach assumes the data remains in L1 after prefetch. + If cache conflicts/evictions are likely, sinking loads into the loop can increase + global memory traffic and regress performance. + pattern_before: | + # Always sink loads + tl.prefetch(ptrs) + for ...: + x = tl.load(ptrs) # becomes expensive if evicted + ... + pattern_after: | + # Keep original load placement OR stage to a closer explicitly-managed buffer (if available) + x = tl.load(ptrs) # load once if it avoids repeated global misses + for ...: + use(x) + expected_speedup: "prevents regressions" + applies_to: + - cache_sensitive + - small_cache_targets + - very_large_working_sets + + - id: mem_atomic_relaxed_for_accumulation + name: "Use sem='relaxed' for commutative atomic accumulation" + stage: memory_access + description: "Relaxed semantics for atomic_add when ordering doesn't matter" + rationale: | + When multiple programs accumulate partial sums into the same output + (e.g., Stream K partial tiles), the order of additions doesn't matter. + Using sem='relaxed' avoids unnecessary memory ordering overhead. + pattern_after: | + tl.atomic_add(c_ptr_, acc, mask=mask, sem='relaxed') + expected_speedup: "avoids synchronization overhead vs stricter semantics" + applies_to: [stream_k, split_accumulation, reduction] + + - id: mem_atomic_fallback_from_descriptors + name: "Fall back from descriptors/block pointers to manual pointers for atomic stores" + stage: memory_access + description: "Use manual pointer arithmetic when atomic operations are needed" + rationale: | + Tensor descriptors (desc.store) and block pointer stores do not support + atomic operations. When partial-tile accumulation requires tl.atomic_add, + you must compute pointers manually. This is the standard pattern for + Stream K partial tiles. + pattern_before: | + # Cannot do this: + c_desc.atomic_add(...) # NOT SUPPORTED + pattern_after: | + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + c_ptr_ = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.atomic_add(c_ptr_, acc, mask=mask, sem='relaxed') + expected_speedup: "N/A (required fallback when atomics are needed)" + applies_to: [stream_k, partial_tiles, split_accumulation] \ No newline at end of file diff --git a/kernel-builder/skills/xpu-kernels/references/optimization_levels.yaml b/kernel-builder/skills/xpu-kernels/references/optimization_levels.yaml new file mode 100644 index 00000000..72a0990c --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/optimization_levels.yaml @@ -0,0 +1,161 @@ +# Optimization Levels for Iterative Kernel Development +# +# Use this framework when optimizing a kernel. Don't stop at Level 1 — +# check the "try harder" decision tree to decide whether deeper optimization +# is warranted. Most production kernels should reach at least Level 2. +# +# Case study: kernel #39 (Gemm_Scale_BatchNorm) +# Level 1 alone: 2.69x speedup (tensor descriptors, tile swizzling, fused epilogue) +# Level 2 + 3: 5.28x speedup (bf16 pre-pack, BN folded into weights) +# See: test_kernels/39_Gemm_Scale_BatchNorm_triton.py + +levels: + - id: level_1_baseline_xpu + name: "Level 1: Baseline XPU Optimizations" + description: | + Get the kernel working with standard XPU patterns. This is the minimum + viable optimization — necessary but rarely sufficient for production. + checklist: + - Tensor descriptors (preferred on XPU) or block pointers (not manual pointer arithmetic) + - bf16/fp16 dot inputs, fp32 accumulator + - 1D grid with tile swizzling (GROUP_SIZE_M) for GEMM kernels + - "@triton.autotune with XPU-optimized configs" + - Pre-packed weight transpose cached in _pack_weights() + - Fused light epilogues (bias, simple activations) + - Correct boundary_check=(0, 1) for block pointers (tensor descriptors handle boundaries internally); no default block params + typical_speedup: "1.5-3x vs PyTorch" + when_done: "Move to Level 2 — there is almost always more to gain." + + - id: level_2_bandwidth_reduction + name: "Level 2: Bandwidth Reduction" + description: | + Reduce memory bandwidth consumption — the most common bottleneck after + Level 1. Loading fp32 data and converting in-kernel wastes 2x bandwidth. + checklist: + - "Pre-pack weights to bf16 at pack time (not in-kernel conversion)" + - "Pre-convert inputs to bf16 before kernel launch: x.to(device, torch.bfloat16)" + - "Use grf_mode='256' for large register file (XPU-specific)" + - "Use tl.dot(a, b, acc=acc) fused accumulate pattern" + - "Keep bias in fp32 (small, precision-sensitive)" + code_example: | + # BEFORE (Level 1 — wastes 2x bandwidth): + self.weight_t = w.to(device).t().contiguous() # fp32 on device + # kernel: a = tl.load(ptr).to(tl.bfloat16) # loads 4B, uses 2B + + # AFTER (Level 2 — halves K-loop memory traffic): + self.weight_t = w.to(device).t().contiguous().to(torch.bfloat16) # bf16 + # kernel: a = tl.load(ptr) # loads 2B directly + typical_speedup: "2-4x vs PyTorch" + when_done: | + Check Level 3 — if the epilogue contains linear transforms (BN, scale, + affine), they can often be eliminated entirely. + + - id: level_3_algebraic_fusion + name: "Level 3: Algebraic Fusion" + description: | + Eliminate epilogue work by algebraically folding linear transforms into + GEMM weights at pack time. This is the highest-impact optimization for + kernels with BatchNorm, per-channel scaling, or affine transforms. + + Key insight: BatchNorm inference is a per-channel linear operation: + BN(y) = gamma * (y - mean) / sqrt(var + eps) + beta + + When the GEMM output feeds into BN (possibly through scaling), the entire + chain can be reduced to a single GEMM + bias: + alpha_n = gamma_n * scale_n / sqrt(var_n + eps) + W_fused[n,:] = alpha_n * W[n,:] + b_fused[n] = alpha_n * bias_n - gamma_n * mean_n / sqrt(var_n + eps) + beta_n + checklist: + - Identify linear per-channel transforms in the epilogue + (BatchNorm inference, scaling, affine, LayerNorm with fixed stats) + - Derive fused weight and bias formulas algebraically (on paper first) + - Implement folding in _pack_weights() (one-time cost, not in hot path) + - Verify numerical equivalence (bf16 folding introduces small diffs) + - The kernel becomes a pure GEMM + bias — maximum compute efficiency + applies_when: + - "BatchNorm in inference mode (running_mean/var are fixed)" + - "Per-channel scaling (multiply by [N] vector)" + - "Any per-channel affine transform (y = a*x + b where a,b are [N] vectors)" + does_not_apply_when: + - "BatchNorm in training mode (statistics are computed per-batch)" + - "Non-linear epilogues (sigmoid, ReLU, etc.) — these cannot be folded" + - "Operations that depend on the M dimension (row-wise reductions)" + typical_speedup: "3-6x vs PyTorch" + when_done: | + For most kernels, this is sufficient. Move to Level 4 only for + critical-path kernels where profiling shows remaining bottlenecks. + example: + file: test_kernels/39_Gemm_Scale_BatchNorm_triton.py + description: | + Kernel #39 folds Linear + Scale + BatchNorm into a pure GEMM. + Before (Level 1): 6 vector loads + arithmetic in epilogue → 2.69x + After (Level 3): zero epilogue, pure GEMM + bias → 5.28x + + - id: level_4_expert + name: "Level 4: Expert Techniques" + description: | + Advanced patterns for squeezing the last 10-30% of performance. + Only pursue these for critical-path kernels after profiling confirms + the bottleneck. These techniques increase code complexity significantly. + checklist: + - Stream K decomposition (for non-square or non-tile-divisible GEMMs) + - Persistent kernels (fixed program count, iterate over tiles) + - Hardware capability queries (gpu_subslice_count for program count) + - Warp size sweeping (16 vs 32 on XPU) + - Shape-specific specialization (skip boundary checks when M % BLOCK_M == 0) + - Atomic partial tile accumulation + - Adaptive grid layouts (small vs large workloads) + typical_speedup: "5-10x+ vs PyTorch (shape-dependent)" + when_done: "Benchmark shows no further improvement — try a fundamentally different approach." + reference_files: + - kb/persistent_kernel_patterns.yaml + - kb/examples/stream_k_gemm_descriptors.py + + +# ============================================================================ +# "Try Harder" Decision Tree +# ============================================================================ +# Use this after completing Level 1 to decide whether to keep optimizing. + +try_harder: + description: | + After each optimization level, measure speedup and consult this tree. + The biggest gains typically come from Level 2 (bandwidth) and Level 3 + (algebraic fusion). Level 4 has diminishing returns for most workloads. + + decisions: + - condition: "Speedup < 2x after Level 1" + diagnosis: "Likely bandwidth-bound — loading fp32 and converting wastes 2x BW" + action: "Apply Level 2: pre-pack to bf16, add grf_mode='large'. Run 'python skills/xpu_profiler.py' to confirm bandwidth bottleneck (look for high XVE Stalled %)." + priority: high + + - condition: "Speedup 2-3x after Level 2" + diagnosis: "Epilogue may be adding unnecessary work" + action: | + Check: does the epilogue contain only linear per-channel transforms? + If yes → Apply Level 3: fold into weights algebraically. + If no (non-linear ops like sigmoid, tanh) → the epilogue is already minimal. + Consider Level 4 techniques or accept current speedup. + priority: high + + - condition: "Speedup 3-5x after Level 3" + diagnosis: "Good performance — likely near hardware limits for this shape" + action: | + Run 'python skills/xpu_profiler.py' to identify remaining bottleneck. + If compute-bound (XVE Active > Stalled) → try larger tiles or Stream K. + If memory-bound (XVE Stalled > Active) → check for unnecessary copies or layout transforms. + Keep going — try a different approach if current strategy has plateaued. + priority: medium + + - condition: "Speedup > 5x" + diagnosis: "Excellent — diminishing returns ahead" + action: "Stop unless this kernel is on the critical path. Focus effort elsewhere." + priority: low + + - condition: "Speedup < 1x (slower than PyTorch)" + diagnosis: "Something is fundamentally wrong" + action: | + Check for: fp64 usage, N-loop serialization, missing tile swizzling, + very small tiles, or replacing a fast vendor GEMM with a slow custom one. + Re-read kb/correctness.yaml and kb/fusion_patterns.yaml constraints. + priority: critical diff --git a/kernel-builder/skills/xpu-kernels/references/optimization_strategies.md b/kernel-builder/skills/xpu-kernels/references/optimization_strategies.md new file mode 100644 index 00000000..044dbab8 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/optimization_strategies.md @@ -0,0 +1,78 @@ +# Optimization Strategies Reference + +## Optimization Levels (Iterative Deepening) + +| Level | Focus | Typical Speedup | +|-------|-------|-----------------| +| **1. Baseline XPU** | Tensor descriptors, tile swizzling, `@triton.autotune`, fused epilogue | 1.5-3x | +| **2. Bandwidth** | Pre-pack to bf16, grf_mode='256', `tl.dot(a, b, acc=acc)` | 2-4x | +| **3. Algebraic** | Fold BN/scale/affine into weights (eliminate epilogue) | 3-6x | +| **4. Expert** | Stream K, persistent kernels, warp sweeping | 5-10x+ | + +**"Try harder" decision tree** (from `references/optimization_levels.yaml`): +- Speedup < 2x after Level 1 -> apply Level 2 (bandwidth is the bottleneck) +- Speedup 2-3x after Level 2 -> check Level 3 (can epilogue be algebraically eliminated?) +- Speedup 3-5x -> good for most workloads; Level 4 only for critical-path kernels +- Speedup > 5x -> diminishing returns, stop + +**Case study**: Kernel #39 (Gemm_Scale_BatchNorm) went from 2.69x (Level 1) to 5.28x (Level 2+3) by pre-packing to bf16 and folding BN into GEMM weights. + +## GEMM Kernels +1. Use tensor descriptors (preferred on XPU) or block pointers (not manual pointer arithmetic) +2. Apply tile swizzling with GROUP_SIZE_M (1D grid required) +3. `@triton.autotune` with varied configs - sweep block sizes, warps, GRF mode +4. Large tiles for square matrices: 256x256, 32 warps, grf_mode='256' +5. Smaller tiles for skinny-M: BLOCK_M in {32, 64}, fewer warps +6. Mixed precision: bf16/fp16 inputs, fp32 accumulator +7. Pre-pack weight transposes: `weight_t = weight.t().contiguous()` once in `_pack_weights()` +8. Pre-pack to bf16: Convert weights AND inputs to bf16 before kernel launch (not in-kernel) - see `references/dtype_optimizations.yaml` +9. Algebraic weight folding: Fold BN/scale/affine into GEMM weights at pack time - see `references/fusion_patterns.yaml` + +## Fusion +1. Fuse light epilogues: bias + simple activation (ReLU, SiLU) +2. Be cautious with heavy chains: multiple exp/tanh/clamp can hurt register pressure +3. Split GEMM + reduction: Use 2D GEMM -> separate reduction kernel (don't serialize over N) + +## Reductions (Softmax, LayerNorm) +1. Multi-row tiling: Process multiple rows per program (BLOCK_SIZE_Y) +2. Query hardware limits: Use `max_work_group_size` to compute BLOCK_SIZE_Y +3. Power-of-2 blocks: `BLOCK_SIZE_X = triton.next_power_of_2(n_cols)` +4. Sweep warp_size: Try both 16 and 32 with different num_warps + +## Critical "DO NOT" List +- Do NOT put default values on `@triton.autotune` meta-parameters in kernel signature +- Do NOT use 2D grid with tile swizzling (must be 1D) +- Do NOT repack weights inside forward() hot path +- Do NOT implement GEMM2 by looping all N tiles inside one program +- Do NOT mix block pointer and tensor descriptor APIs on same load/store +- Do NOT use fp64 unless absolutely required (5-10x slower) + +## KB Quick Index + +- **Starting a GEMM kernel?** -> `references/xpu_optimizations.yaml` +- **Fusing operations?** -> `references/fusion_patterns.yaml` +- **Memory access issues?** -> `references/memory_patterns.yaml` +- **Kernel crashes or wrong results?** -> `references/correctness.yaml` +- **Slow due to fp64?** -> `references/dtype_optimizations.yaml` +- **Advanced techniques?** -> `references/persistent_kernel_patterns.yaml` +- **Need more speedup?** -> `references/optimization_levels.yaml` +- **Looking for examples?** -> `references/examples/index.yaml` + `references/examples/*.py` + +## Common Patterns Checklist + +When transforming PyTorch -> Triton: + +- [ ] Identified operation type (GEMM, reduction, elementwise) +- [ ] Chosen memory access pattern (tensor descriptors preferred; block pointers as fallback) +- [ ] Applied tile swizzling (if GEMM) +- [ ] `@triton.autotune` with varied BLOCK_M/N/K, num_warps, grf_mode configs +- [ ] NO default values on autotune meta-parameters in kernel signature +- [ ] Used 1D grid if swizzling +- [ ] Mixed precision: bf16/fp16 -> fp32 accumulator +- [ ] Fused light epilogues only +- [ ] Pre-packed weight transposes (cached in `_pack_weights()`) +- [ ] Model class compatible with ai-bench (standard nn.Module with nn.Linear) +- [ ] Matched `get_inputs()`, `get_init_inputs()`, module-level constants from *_pytorch.py +- [ ] Triton file name matches base kernel name (for spec YAML auto-detection) +- [ ] Validated with `python scripts/validate_triton.py ` +- [ ] Benchmarked with `python scripts/benchmark.py ` diff --git a/kernel-builder/skills/xpu-kernels/references/persistent_kernel_patterns.yaml b/kernel-builder/skills/xpu-kernels/references/persistent_kernel_patterns.yaml new file mode 100644 index 00000000..e6141d02 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/persistent_kernel_patterns.yaml @@ -0,0 +1,204 @@ +# Intel XPU Persistent Kernel & Stream K Patterns +# Stage: persistent_kernel, stream_k +# +# Persistent kernels are ADVANCED and should NOT be auto-applied by default. +# They help mainly when: +# - there are many tiles (launch/tail overhead), +# - the workload is bandwidth/reduction bound, +# - cache reuse across tiles is valuable, +# and they do NOT replace a faster vendor primitive (e.g., vendor GEMM). +# +# Stream K is a related advanced scheduling strategy that splits tiles' +# K-loop iterations evenly across programs to fix quantization inefficiency. + +patterns: + - id: persistent_kernel_basic_tile_loop + name: "Persistent kernel: fixed program count loops over tiles" + stage: persistent_kernel + description: "Launch a fixed number of programs and iterate over all tiles with stride = NUM_PROGS" + rationale: | + Persistent kernels reduce kernel launch overhead and can improve cache reuse by having + fewer programs loop over many tiles. This can help Intel XPU when the baseline launches + huge grids (many tiles) and suffers from tail effects or overhead. + + WARNING: + - Persistent kernels can increase GRF/register pressure and reduce occupancy. + - Do NOT use by default. Use only when gated conditions indicate it. + pattern_before: | + # Standard per-tile launch + grid = (tl.cdiv(M, BLOCK_M), tl.cdiv(N, BLOCK_N)) + + @triton.jit + def kernel(...): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + # Process one tile + ... + pattern_after: | + # Persistent kernel launch (1D grid) + # NUM_PROGS is a tuned constant (do NOT assume it equals SM/EU count) + grid = (NUM_PROGS,) + + @triton.jit + def kernel(..., NUM_PROGS: tl.constexpr): + pid = tl.program_id(0) + tiles_m = tl.cdiv(M, BLOCK_M) + tiles_n = tl.cdiv(N, BLOCK_N) + num_tiles = tiles_m * tiles_n + + tile = pid + while tile < num_tiles: + pid_m = tile // tiles_n + pid_n = tile % tiles_n + + row_start = pid_m * BLOCK_M + col_start = pid_n * BLOCK_N + + # Process tile (GEMM / reduction / etc.) + ... + + tile += NUM_PROGS + expected_speedup: "0.8-2.0x (highly shape/workload dependent; can regress)" + applies_to: + - gemm + - attention + - softmax + - layernorm + - large_workloads + + - id: persistent_kernel_autotuned_num_progs + name: "Persistent kernel: autotune NUM_PROGS (required)" + stage: persistent_kernel + description: "Treat NUM_PROGS as a tunable meta-parameter instead of guessing hardware counts" + rationale: | + Intel XPU does not map cleanly to CUDA's SM model. Do not set NUM_PROGS = 'num_sm'. + Instead, autotune NUM_PROGS across a small set (e.g., 32/64/128/256) to find the best. + pattern_before: | + NUM_PROGS = gpu_sm_count # WRONG: not portable / not reliable on XPU + grid = (NUM_PROGS,) + pattern_after: | + @triton.autotune( + configs=[ + triton.Config({"NUM_PROGS": 32, "grf_mode": "256"}, num_warps=4, num_stages=2), + triton.Config({"NUM_PROGS": 64, "grf_mode": "256"}, num_warps=4, num_stages=2), + triton.Config({"NUM_PROGS": 128, "grf_mode": "256"}, num_warps=4, num_stages=2), + triton.Config({"NUM_PROGS": 256, "grf_mode": "256"}, num_warps=4, num_stages=2), + ], + key=["M", "N", "K"], + ) + @triton.jit + def kernel(..., NUM_PROGS: tl.constexpr): + ... + # Launch with 1D grid: + grid = lambda META: (META["NUM_PROGS"],) + expected_speedup: "prevents regressions; enables win when persistent is appropriate" + applies_to: + - gemm + - attention + - large_workloads + + - id: streamk_two_wave_decomposition + name: "Stream K: two-wave SK + data-parallel decomposition" + stage: stream_k + description: "Split tiles into stream-K (wave 1) and standard blocking (wave 2)" + rationale: | + Pure stream-K distributes ALL tiles via iteration splitting, maximizing + atomic contention. A hybrid two-wave approach is more efficient: + + Wave 1 (first_wave): Only "remainder" tiles that cause quantization + inefficiency. Programs share K-loop iterations, using atomic_add + for partial results. Program count = num_xe_core (subslice count). + + Wave 2 (full_tiles): Remaining tiles are 1:1 (standard GEMM tiling), + no atomics needed. Grid size = blocking_tiles. + + This minimizes atomic traffic while fixing utilization gaps. + pattern_after: | + num_xe_core = torch.xpu.get_device_capability(0)['gpu_subslice_count'] + streamk_programs = num_xe_core + + total_tiles = num_block_m * num_block_n + iters_per_tile = triton.cdiv(K, BLOCK_SIZE_K) + + # Two-tile SK + DP heuristic + streamk_tiles = total_tiles % streamk_programs + if total_tiles - streamk_tiles > streamk_programs: + streamk_tiles += streamk_programs + + blocking_tiles = total_tiles - streamk_tiles + streamk_iters = streamk_tiles * iters_per_tile + + streamk_full_tiles = streamk_iters // streamk_programs + streamk_partial_tiles = streamk_iters % streamk_programs + + # Wave 1: stream-K (fixed program count, iteration-split) + first_wave[(streamk_programs,)](a, b, c, ..., + streamk_full_tiles, streamk_partial_tiles, iters_per_tile) + + # Wave 2: standard 1:1 tiling (offset by streamk_tiles) + full_tiles[(blocking_tiles,)](a, b, c, ..., streamk_tiles) + expected_speedup: "significant for non-divisible tile counts; minimal for well-fitting shapes" + applies_to: [gemm, large_workloads, variable_shapes] + notes: | + - Requires output to be pre-zeroed (torch.zeros) — see correctness.yaml + - Wave 1 uses atomic_add for partial tiles only; full tiles use direct store + - Wave 2 tile_ids offset by streamk_tiles to avoid overlap with wave 1 + + - id: streamk_mac_loop_partial_tile + name: "Stream K: MAC loop with arbitrary iteration range" + stage: stream_k + description: "Inner multiply-accumulate loop that handles [start_iter, end_iter) within a tile" + rationale: | + In stream-K, a program may own only a subset of a tile's K iterations. + The MAC loop must: + 1. Derive tile_id from start_iter (tile_id = start_iter // iters_per_tile) + 2. Compute K offset (remain_iters * BLOCK_SIZE_K) + 3. Accumulate over the assigned range + 4. Store via descriptor (full tile) or atomic_add (partial tile) + pattern_after: | + tile_id = start_iter // iters_per_tile + remain_iters = start_iter % iters_per_tile + pid_m, pid_n = swizzle_tile(tile_id, ...) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + off_k = remain_iters * BLOCK_SIZE_K + + for _ in range(start_iter, end_iter): + a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N]) + acc += tl.dot(a, b) + off_k += BLOCK_SIZE_K + + # Full tile → direct store; partial tile → atomic add + if remain_iters == 0 and end_iter % iters_per_tile == 0: + c_desc.store([pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], acc) + else: + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptr_ = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + tl.atomic_add(c_ptr_, acc, mask=mask, sem='relaxed') + applies_to: [gemm, stream_k] + + - id: streamk_iteration_distribution + name: "Stream K: even iteration distribution (floor + remainder)" + stage: stream_k + description: "Distribute total stream-K iterations across programs with even split" + rationale: | + Each program gets floor(total_iters / num_programs) iterations, plus + one extra if pid < remainder. Programs iterate through their range, + snapping to tile boundaries so each mac_loop stays within one tile. + pattern_after: | + pid = tl.program_id(axis=0) + start_iter = pid * full_tiles + tl.minimum(pid, partial_tiles) + last_iter = (pid + 1) * full_tiles + tl.minimum(pid + 1, partial_tiles) + + while start_iter < last_iter: + end_iter = start_iter + (iters_per_tile - start_iter % iters_per_tile) + end_iter = tl.minimum(end_iter, last_iter) + mac_loop(..., start_iter, end_iter, ...) + start_iter = end_iter + applies_to: [stream_k] + notes: | + - The while loop handles programs that span multiple tiles + - end_iter snaps to tile boundaries within the assigned range \ No newline at end of file diff --git a/kernel-builder/skills/xpu-kernels/references/workflow_details.md b/kernel-builder/skills/xpu-kernels/references/workflow_details.md new file mode 100644 index 00000000..b4895bad --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/workflow_details.md @@ -0,0 +1,205 @@ +# Detailed Workflow Reference + +## Analysis Phase + +When given a PyTorch kernel (typically `*_pytorch.py`, but can be any user-specified path): + +> **Note**: The existing `test_kernels/*.py` Triton files (non-pytorch) are **naive, unoptimized baselines**. Do NOT treat them as examples of good Triton code. Use `references/implementation_reference.md` and `references/examples/` instead. + +1. **Parse the PyTorch code** to identify: + - Input/output shapes and dtypes + - Mathematical operations (matmul, activations, reductions) + - Operation fusion opportunities + - Memory access patterns + +2. **Consult the knowledge base** (`references/` directory): + - `xpu_optimizations.yaml`: XPU-specific patterns (tensor descriptors, GRF mode, warp count, tile swizzling) + - `fusion_patterns.yaml`: When to fuse operations + - `memory_patterns.yaml`: Memory access best practices + - `correctness.yaml`: Critical constraints to avoid bugs + - `dtype_optimizations.yaml`: Data type choices + +3. **Use the skills** to help: + - `python scripts/analyze_kernel.py ` - Extract operation structure + - Review `references/examples/index.yaml` for similar patterns + +## Design Phase + +1. **Identify the kernel type**: Pure GEMM, GEMM + epilogue, GEMM + reduction, complex fusion +2. **Select optimization strategies** from KB (memory, tiling, parallelism, fusion, dtypes) +3. **Apply critical constraints** (from `references/correctness.yaml` and `references/xpu_optimizations.yaml`): + - NO default values for `@triton.autotune` meta-parameters in kernel signature + - Use 1D grid when applying tile swizzling (GROUP_SIZE_M) + - boundary_check uses dimension indices (0, 1), not booleans + - Cast batch indices to int64 before stride multiplication + - Do NOT mix block pointer and tensor descriptor APIs on same operation + - Pre-zero output buffers when using atomic accumulation + - Model class must be compatible with ai-bench (standard `nn.Module` with `nn.Linear`) + +## Trial Loop Detail + +For each trial: + +### a. Implement / Modify Kernel +Start from a template (`references/implementation_reference.md`) or modify the previous trial's code. See `references/implementation_reference.md`. + +### b. Validate Syntax +```bash +python scripts/validate_triton.py +``` +If validation fails, fix and retry - doesn't count as a new trial. Note: `` should be `t.py`. + +### c. Save Trial +```bash +python scripts/trial_manager.py save --parent --strategy "description" +``` +For the first trial, omit `--parent`. + +### d. Benchmark +```bash +# Trial t0 — measures both baseline and triton: +python scripts/benchmark.py [--triton-baseline] + +# Trials t1+ — use cached baseline to save time: +python scripts/trial_manager.py baseline-us # get cached value +python scripts/benchmark.py [--triton-baseline] --baseline-us + +# After finalize — re-run without --baseline-us for final accurate comparison +``` + +### e. Record Results +```bash +python scripts/trial_manager.py result \ + --validation pass --correctness --speedup \ + --baseline_us --triton_us +``` + +### f. Decision Tree + +| Condition | Action | +|--------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------| +| **Speedup > 5x** | Stop - excellent result (the only valid early stop) | +| **Speedup improved** | Continue on this branch, try next optimization level | +| **Speedup regressed** | Branch back to best trial, try a different strategy | +| **Correctness failed** | Fix code on same branch | +| **After t1 (if `vtune_enabled`)** | Run `python scripts/xpu_profiler.py ` — mandatory first profile | +| **Speedup plateaued after 2+ more trials** | Run profiler again (if `vtune_enabled`); try a fundamentally different approach | +| **Plateau / diminishing returns** | Do NOT stop. Try a fundamentally different approach (different algorithm, tiling, fusion strategy). LLM sampling can discover new ideas at any point. | +| **Max trials reached** | Stop — must run all `max_trials` from `config.yaml` | + +### g. Check Status +```bash +python scripts/trial_manager.py status +python scripts/trial_manager.py best +``` + +## Trial Manager Commands Reference +```bash +python scripts/trial_manager.py init [--triton-baseline] +python scripts/trial_manager.py save [--parent ] [--strategy "..."] +python scripts/trial_manager.py result [--validation pass] [--correctness pass] [--speedup 3.2] [--baseline_us 150.0] [--triton_us 47.0] +python scripts/trial_manager.py status +python scripts/trial_manager.py best +python scripts/trial_manager.py baseline-us +python scripts/trial_manager.py finalize _triton.py +``` + +## Benchmarking Details + +`scripts/benchmark.py` uses ai-bench (`modules/ai-bench/`) for both correctness and performance: + +1. **Correctness** - Compares outputs between PyTorch and Triton implementations + - Uses `check_correctness()` with per-variant tolerances from YAML spec (defaults: rtol=1e-2, atol=1e-5) + - Syncs model weights via `copy_model_weights()` + - Falls back to direct module loading when no spec file is available + +2. **Performance** - Benchmarks both implementations on XPU hardware + - Reads the YAML spec file (auto-detected from `modules/ai-bench/problems/specs/KernelBench/level*/`) + - Reports speedup metrics (Triton vs PyTorch) per spec variant + +**Both checks must pass** for the kernel to be considered complete. + +**Setup**: External tools must be initialised: `git submodule update --init` + +## Profiling with VTune (`scripts/xpu_profiler.py`) + +```bash +python scripts/xpu_profiler.py [--warmup 5] [--iters 20] +``` + +Runs Intel VTune `gpu-offload` collection to capture both Level Zero API tasks and OA (Observation Architecture) hardware counters, then maps bottlenecks to KB optimization patterns. + +**Prerequisite**: OA counters require `observation_paranoid=0`: +```bash +echo 0 | sudo tee /proc/sys/dev/xe/observation_paranoid +``` + +### When to Profile +- **MANDATORY** after the first benchmarked trial (t1) — always run at least once per session +- Run again if speedup plateaus after 2+ additional trials +- You're unsure which optimization level to try next + +### What It Reports +1. **Platform info**: GPU name, XVE count, max frequency +2. **Host tasks**: CPU-side overhead (JIT compilation, data copies, synchronization) +3. **GPU computing tasks table** (per-kernel): Time, instance count, XVE Active/Stalled/Idle %, occupancy %, memory bandwidth read/write +4. **Primary kernel detail**: Full OA hardware counter breakdown including: + - XVE execution: Active/Stalled/Idle percentages + - Occupancy limiters: Work Size Limit, SLM Use Limit, Barriers Use Limit (tells WHY occupancy is low) + - Memory bandwidth: Read/Write GB/s + - Cache hierarchy: L3 Busy/Stalled %, L3 Miss Ratio, LSC Miss Ratio, LSC→L3 Miss Ratio + - Register spill size, SLM bank conflicts, TLB misses +5. **Optimization recommendations**: Each grounded in a specific KB pattern: + - XVE Stalled > Active → memory bound → `references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern)` + `references/optimization_levels.yaml (level_2)` + - Low occupancy + Work Size limiter → grid too small → `references/xpu_optimizations.yaml (xpu_tile_swizzling)` + - Low occupancy + SLM limiter → tile too large → `references/xpu_optimizations.yaml (xpu_grf_mode)` + - High L3 Miss → poor reuse → `references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern, xpu_tile_swizzling)` + - Register spill > 0 → reduce liveness → `references/memory_patterns.yaml (reduce_liveness_sink_load_and_prefetch)` + - Overhead kernels dominate → pre-pack to bf16 → `references/optimization_levels.yaml (level_2_bandwidth_reduction)` + - Host time >> GPU time → sync in hot path → `references/memory_patterns.yaml (no_device_to_host_scalar_sync)` + +### How to Use the Output +The profiler prints specific recommendations with references: +``` +>> XVE Stalled (72%) > Active (28%): memory/dependency bound. + Use tensor descriptors for better address codegen, pre-pack to bf16 to halve bandwidth. + Reference: references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern, xpu_tile_swizzling) + + references/optimization_levels.yaml (level_2_bandwidth_reduction) +``` +Read the referenced file and apply the suggested pattern in your next trial. + +## Validation Details + +`scripts/validate_triton.py` checks: +- Syntax correctness +- Autotune config issues (no default params) +- Grid/swizzling consistency +- boundary_check format +- Data type usage + +## Project Structure +``` +xpu-kernels/ +├── SKILL.md # Core rules and workflow (concise) +│ +├── references/ # Knowledge base +│ ├── implementation_reference.md # Templates, code patterns, Model class +│ ├── optimization_strategies.md # Strategy reference, checklist, KB index +│ ├── workflow_details.md # This file — detailed workflow +│ ├── correctness.yaml # Correctness constraints +│ ├── xpu_optimizations.yaml # XPU-specific patterns +│ ├── optimization_levels.yaml # Progressive optimization checklist +│ ├── fusion_patterns.yaml # Kernel fusion guidelines +│ ├── memory_patterns.yaml # Memory access optimizations +│ ├── dtype_optimizations.yaml # Data type optimizations +│ └── persistent_kernel_patterns.yaml # Stream K and persistent kernel patterns +│ +└── scripts/ # Standalone tools (DO NOT recreate) + ├── analyze_kernel.py # PyTorch → operations, shapes, fusion opportunities + ├── validate_triton.py # Syntax + constraint checks before benchmarking + ├── benchmark.py # Correctness + performance via ai-bench + ├── trial_manager.py # Tree-structured trial init/save/record/finalize + ├── xpu_profiler.py # VTune GPU hardware counters + recommendations + ├── config.yaml # max_trials, vtune_enabled, vtune_bin + └── config.py # Shared configuration loader for config.yaml +``` diff --git a/kernel-builder/skills/xpu-kernels/references/xpu_optimizations.yaml b/kernel-builder/skills/xpu-kernels/references/xpu_optimizations.yaml new file mode 100644 index 00000000..3fe22d2b --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/references/xpu_optimizations.yaml @@ -0,0 +1,1140 @@ +constraints: + - id: autotune_no_duplicate_params + name: "Do NOT re-define autotune parameters as kernel constexpr" + severity: critical + description: | + When using @triton.autotune, parameters defined in the Config dict + (like BLOCK_M, BLOCK_N, GROUP_SIZE_M, etc.) are automatically passed + to the kernel. Do NOT also define them with default values in the kernel signature. + + WRONG (conflicting meta-parameters): + ```python + @triton.autotune( + configs=[triton.Config({'BLOCK_M': 256, 'GROUP_SIZE_M': 4}, ...)], + key=['M','N','K'], + ) + @triton.jit + def kernel(..., + BLOCK_M: tl.constexpr = 256, + GROUP_SIZE_M: tl.constexpr = 4): + ... + ``` + + CORRECT (declare only; autotune provides values): + ```python + @triton.autotune( + configs=[triton.Config({'BLOCK_M': 256, 'GROUP_SIZE_M': 4}, ...)], + key=['M','N','K'], + ) + @triton.jit + def kernel(..., + BLOCK_M: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + ... + ``` + + - id: grid_must_match_swizzling + name: "Grid must be 1D when using tile swizzling" + severity: critical + description: | + If you use GROUP_SIZE_M swizzling (flattened pid), the launch grid must be 1D. + + WRONG: + ```python + grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + ``` + + CORRECT: + ```python + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) + ``` + - id: xpu_packed_weight_requires_stride_pass_through + name: "When using packed W^T, pass its strides (do not reuse original W strides)" + severity: critical + description: | + If you pass a packed transpose [K, N] into the kernel, you MUST use + weight_t.stride() (stride_wtk, stride_wtn) when creating block pointers. + Do NOT reuse the original [N, K] tensor strides. + + - id: boundary_check_tuple + name: "boundary_check must be tuple of ints, not bools" + severity: critical + description: | + tl.load(block_ptr) uses boundary_check=(dim0, dim1) where values are dimension indices (0,1), + not booleans. + + WRONG: + ```python + tl.load(ptr, boundary_check=(True, True)) + ``` + + CORRECT: + ```python + tl.load(ptr, boundary_check=(0, 1)) + ``` + + - id: block_ptr_vs_tma + name: "Tensor descriptors vs block pointers - DIFFERENT APIs (descriptors preferred)" + severity: critical + description: | + Tensor descriptors and block pointers are different APIs. Do NOT mix + them for the same load/store operation. **Prefer tensor descriptors + on XPU** — they produce better address generation and memory access + codegen (confirmed by Intel compiler team). + + OPTION 1: TENSOR DESCRIPTORS (preferred on XPU) + ```python + desc = tl.make_tensor_descriptor(...) + x = desc.load([row_offset, col_offset]) # NO boundary_check arg — handled internally + ``` + + OPTION 2: BLOCK POINTERS (fallback) + ```python + ptr = tl.make_block_ptr(...) + x = tl.load(ptr, boundary_check=(0, 1)) + ptr = tl.advance(ptr, (0, BLOCK_K)) + ``` + + CRITICAL ERROR TO AVOID: + ```python + # WRONG - mixing descriptor with block pointer syntax + desc = tl.make_tensor_descriptor(...) + data = desc.load(..., boundary_check=(0, 1)) # ERROR! descriptors don't have boundary_check + ``` + + EXCEPTION — atomic operations: It IS valid to use descriptors for + loads/stores AND manual pointer arithmetic for atomic operations in + the SAME kernel, because descriptors do not support atomics. This is + a different-capability path, not API mixing. + + ```python + # Full tile → descriptor store (fast path) + if is_full_tile: + c_desc.store([pid_m * BM, pid_n * BN], acc) + # Partial tile → manual pointers + atomic_add (required fallback) + else: + rm = pid_m * BM + tl.arange(0, BM) + rn = pid_n * BN + tl.arange(0, BN) + c_ptr_ = c_ptr + rm[:, None] * stride_cm + rn[None, :] * stride_cn + tl.atomic_add(c_ptr_, acc, mask=mask, sem='relaxed') + ``` + + - id: descriptor_no_atomic_support + name: "Tensor descriptors do NOT support atomic operations" + severity: critical + description: | + When you need tl.atomic_add (or other atomics), you MUST fall back to + manual pointer arithmetic. Descriptors only support .load() and .store(). + + This is a legitimate reason to have both descriptor loads AND manual + pointer stores in the same kernel — it is not a violation of the + "don't mix APIs" rule (see block_ptr_vs_tma exception). + + - id: no_device_to_host_scalar_sync + name: "Do NOT force device->host scalar sync in hot path" + severity: critical + description: | + Avoid .item(), float(tensor), int(tensor), printing device tensors, or any device->host scalar + extraction in forward()/kernel wrapper hot paths. This forces synchronization and kills perf. + + WRONG: + ```python + c = float(constant_tensor.item()) # syncs XPU -> host + ``` + + CORRECT: + - Keep constants as Python floats / CPU tensors + - Pass scalar to kernel as an argument + ```python + c = float(constant) # already host + kernel[...,](..., c) + ``` + + - id: dtype_names_are_bfloat16_float16 + name: "Use correct Triton dtype names" + severity: critical + description: | + Triton uses tl.bfloat16 and tl.float16 (not 'bf16'). + tl.dot supports float16 / bfloat16 / float32 inputs. + + - id: xpu_no_repack_transpose_in_forward + name: "Do NOT transpose+contiguous weights inside forward() hot path" + severity: critical + description: | + Creating packed transposes like W.t().contiguous() inside forward()/kernel_function + can dominate runtime and hide kernel improvements, especially for large matrices. + Pack once (on XPU) and reuse across iterations. + + WRONG (repack every call): + ```python + def kernel_function(x, W1, W2, ...): + W1_t = W1.t().contiguous() # expensive copy each call + W2_t = W2.t().contiguous() + kernel1(..., W1_t, ...) + kernel2(..., W2_t, ...) + ``` + + CORRECT (cache once, rebuild only if weights change): + ```python + class Model(nn.Module): + def _move_params_once(self): + self.W1_xpu = self.weight1.to("xpu", torch.float16).contiguous() + self.W2_xpu = self.weight2.to("xpu", torch.float16).contiguous() + self.W1_t = self.W1_xpu.t().contiguous() # [K, H] packed + self.W2_t = self.W2_xpu.t().contiguous() # [H, O] packed + + def forward(self, x): + x = x.to("xpu", torch.float16).contiguous() + inter = gemm_sigmoid(..., self.W1_t, ...) + out = gemm2(..., inter, self.W2_t, ...) + return out + ``` + + - id: xpu_gemm2_must_not_serialize_over_n_tiles + name: "Do NOT implement GEMM2 by looping all N tiles inside one program" + severity: critical + description: | + A fused GEMM2+reduction kernel that uses a 1D grid over M and then loops + over all N tiles inside each program collapses parallelism and often + underutilizes Intel XPU. + + WRONG (serializes N tiles inside each pid_m): + ```python + pid_m = tl.program_id(0) + for n_start in range(0, N, BLOCK_N): + acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32) + # K loop... + # update row-wise reduction... + ``` + + CORRECT (parallelize GEMM2 over (M, N) tiles): + ```python + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + # compute [BLOCK_M, BLOCK_N] tile of C2 in parallel + # then do reduction in a separate kernel (or a properly parallel reduction) + ``` + +patterns: + - id: xpu_block_pointers + name: "Block Pointers (fallback — prefer tensor descriptors)" + stage: block_pointers + description: "Replace manual pointer arithmetic with block pointers (use tensor descriptors when possible)" + rationale: | + Block pointers improve codegen over manual pointer arithmetic, but tensor + descriptors produce even better address generation on XPU. Use block pointers + only when tensor descriptors are not suitable (e.g., legacy code or when + tl.advance sequential patterns are strongly preferred). + pattern_before: | + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak + a = tl.load(a_ptrs, mask=(offs_m[:,None] < M) & (offs_k[None,:] < K), other=0.0) + pattern_after: | + a_bp = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(pid_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + a = tl.load(a_bp, boundary_check=(0, 1)) + expected_speedup: "workload-dependent (often good for indexing-heavy kernels)" + applies_to: [gemm, conv, all_memory_bound] + examples: + - before: | + a_ptrs += BLOCK_K * stride_ak + after: | + a_bp = tl.advance(a_bp, (0, BLOCK_K)) + + - id: xpu_grf_mode + name: "GRF Mode (tune; do not assume always wins)" + stage: xpu_specific + description: "Include grf_mode in autotune sweep" + rationale: | + Intel XPU supports multiple GRF modes. The best mode depends on register pressure and occupancy. + Tune grf_mode along with block sizes/warps/stages. + pattern_after: | + @triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '128'}, num_warps=4, num_stages=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '256'}, num_warps=4, num_stages=4), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '256'}, num_warps=4, num_stages=5), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'grf_mode': '256'}, num_warps=8, num_stages=4), + ], + key=['M','N','K'], + ) + @triton.jit + def kernel(..., + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): + ... + expected_speedup: "workload-dependent" + applies_to: [gemm, compute_bound] + + - id: xpu_warp_count + name: "Warp count (tune; avoid hard-coded '32 warps')" + stage: xpu_specific + description: "Sweep num_warps; choose based on shape and register pressure" + rationale: | + Higher num_warps can help, but can also increase GRF usage and reduce occupancy. + On skinny-M GEMMs, 4–8 warps often wins; 16+ can lose badly. + Treat 32 as 'only if measured best'. + pattern_before: | + triton.Config({...}, num_stages=2, num_warps=8) + pattern_after: | + configs = [ + triton.Config({...}, num_stages=4, num_warps=4), + triton.Config({...}, num_stages=4, num_warps=8), + triton.Config({...}, num_stages=3, num_warps=16), + # num_warps=32 only if autotune proves it faster + ] + expected_speedup: "workload-dependent" + applies_to: [all] + + - id: xpu_tile_swizzling + name: "Tile swizzling (only when num_pid_m > 1)" + stage: xpu_specific + description: "Use GROUP_SIZE_M swizzling only if it can actually reorder M-tiles" + rationale: | + Swizzling helps when there are multiple M-tiles and data reuse across nearby tiles. + If tl.cdiv(M, BLOCK_M) == 1, swizzling adds overhead with no benefit. + pattern_before: | + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pattern_after: | + # Enable ONLY if num_pid_m > 1 + pid = tl.program_id(0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = tl.minimum(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + expected_speedup: "shape-dependent" + applies_to: [gemm, large_matrices] + notes: | + Requires 1D grid: grid = (cdiv(M,BLOCK_M) * cdiv(N,BLOCK_N),) + + Swizzle logic can be factored into a reusable @triton.jit helper: + ```python + @triton.jit + def swizzle_tile(tile_id, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GROUP_SIZE_M): + grid_m = tl.cdiv(M, BLOCK_SIZE_M) + grid_n = tl.cdiv(N, BLOCK_SIZE_N) + width = GROUP_SIZE_M * grid_n + group_id = tile_id // width + group_size = tl.minimum(GROUP_SIZE_M, grid_m - group_id * GROUP_SIZE_M) + pid_m = group_id * GROUP_SIZE_M + (tile_id % group_size) + pid_n = (tile_id % width) // group_size + return pid_m, pid_n + ``` + + Use GROUP_SIZE_M > 0 as a toggle: when 0, fall back to linear tile mapping. + + - id: xpu_tile_shape_heuristics + name: "Tile shape heuristics (avoid always-256x256)" + stage: xpu_specific + description: "Choose tiles based on aspect ratio" + rationale: | + 256x256 can be good for large, square-ish GEMM, but for skinny M it reduces M-parallelism + and increases register pressure. Prefer smaller BLOCK_M when M is small. + pattern_after: | + # Heuristic guidance: + # - If M <= 256: BLOCK_M in {16, 32, 64} + # - If N is huge: BLOCK_N in {64, 128} + # - BLOCK_K in {32, 64} + # - For square problems: consider {128,256} tiles if it fits GRF/occupancy + expected_speedup: "prevents regressions; enables better autotune search" + applies_to: [gemm] + + - id: xpu_mixed_precision_dot + name: "Mixed precision dot (bfloat16/float16 inputs, fp32 accumulate)" + stage: xpu_specific + description: "Cast blocks to tl.bfloat16 or tl.float16 for tl.dot; accumulate fp32" + rationale: | + tl.dot supports bfloat16/float16 inputs; accumulating in fp32 often improves speed while + keeping reasonable numerics (if acceptable for the workload). + pattern_before: | + x = tl.load(x_bp).to(tl.float32) + w = tl.load(w_bp).to(tl.float32) + acc += tl.dot(x, w) + pattern_after: | + x = tl.load(x_bp).to(tl.bfloat16) # or tl.float16 + w = tl.load(w_bp).to(tl.bfloat16) # or tl.float16 + acc += tl.dot(x, w) # acc is tl.float32 + expected_speedup: "often large if fp32 path is slow" + applies_to: [gemm, attention, compute_bound] + notes: | + Use tl.bfloat16 / tl.float16 (not 'bf16'). + + - id: xpu_tensor_descriptor_note + name: "Tensor descriptors are preferred on XPU" + stage: memory_access + description: "Use tensor descriptors as the default memory access API on XPU" + rationale: | + Tensor descriptors (`tl.make_tensor_descriptor`) produce better address + generation and memory access codegen than block pointers on Intel XPU, + as confirmed by the Intel compiler team. + + Use tensor descriptors as the default for all new kernels: + - Standard GEMM K-loops: create descriptors once, increment offset variable + - Variable-offset access (Stream K, dynamic coordinates): same API, clean code + - Batched variants: bake batch offset into descriptor base pointer + + Fall back to block pointers only when: + - Atomic operations are needed (descriptors don't support atomics) + - Maintaining existing block-pointer code that already works well + expected_speedup: "better codegen vs block pointers for address generation" + applies_to: [gemm, stream_k, large_matrices, all_memory_bound] + + - id: xpu_descriptor_gemm_pattern + name: "Tensor descriptor GEMM pattern (preferred on XPU)" + stage: memory_access + description: "Use tensor descriptors as the default GEMM memory access pattern on XPU" + rationale: | + Tensor descriptors are the recommended memory access API on Intel XPU. + They produce better address generation codegen than block pointers. + Create descriptors once, load by coordinate — no tl.advance needed, + just update the offset variable. + pattern_after: | + # Create descriptors once (outside K-loop) + a_desc = tl.make_tensor_descriptor( + base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K)) + b_desc = tl.make_tensor_descriptor( + base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N)) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + off_k = 0 # or remain_iters * BLOCK_SIZE_K for Stream K + + for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k]) + b = b_desc.load([off_k, pid_n * BLOCK_SIZE_N]) + acc += tl.dot(a, b) + off_k += BLOCK_SIZE_K + expected_speedup: "preferred on XPU; better codegen than block pointers" + applies_to: [gemm, stream_k, variable_offset_access, all_memory_bound] + notes: | + For batched variants, add the batch offset to the descriptor's base pointer + rather than to each load coordinate: + + ```python + offset_a = bid.to(tl.int64) * stride_az + a_desc = tl.make_tensor_descriptor( + base=a_ptr + offset_a, shape=(M, K), # batch offset baked into base + strides=(stride_am, stride_ak), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K)) + # Loads use same coordinates as non-batched: + a = a_desc.load([pid_m * BLOCK_SIZE_M, off_k]) + ``` + + This is cleaner than adjusting every load coordinate and works for any + number of descriptors (A, B, C, D, ...) in the same kernel. + + - id: xpu_persistent_kernel + name: "Persistent kernel (advanced; measure carefully)" + stage: persistent_kernel + description: "Use persistent scheduling to reduce launch overhead / improve utilization" + rationale: | + Persistent matmul can help for some shapes by keeping a fixed number of programs that iterate tiles. + However, the best NUM_SMS / scheduling is hardware- and workload-dependent. + pattern_before: | + grid = (cdiv(M, BLOCK_M), cdiv(N, BLOCK_N)) + kernel[grid](...) + pattern_after: | + # Persistent pattern: + # grid = (NUM_WORKERS,) + # each program iterates tile_idx += NUM_WORKERS + # NOTE: NUM_WORKERS must be selected/tuned (not hard-coded). + expected_speedup: "workload-dependent; can be large for some cases" + applies_to: [gemm, large_workloads] + + - id: xpu_pack_weight_transpose_once + name: "Pack W^T once (contiguous [K, N]) for tl.dot fast path" + stage: xpu_specific + description: "Cache a packed transpose of weights on XPU and pass [K, N] contiguous into the kernel" + rationale: | + PyTorch/XPU GEMM paths frequently use packed/blocked layouts internally. + A Triton kernel that reads W with swapped strides (pretending W is (K,N)) + may still get suboptimal memory access and may miss the fastest matmul path. + Packing W^T once makes the physical memory layout match the kernel's + access order for tl.dot, improving locality and throughput. + pattern_before: | + # Weight stored as [N, K], passed directly + w = weight.to(device="xpu", dtype=torch.float16).contiguous() + + # Kernel "views" W as (K, N) by swapping strides: + w_bp = tl.make_block_ptr( + base=w_ptr, + shape=(K, N), + strides=(stride_wk, stride_wm), # swapped + offsets=(0, n_base), + block_shape=(BLOCK_K, BLOCK_N), + order=(1, 0), + ) + w_tile = tl.load(w_bp, boundary_check=(0, 1)) + pattern_after: | + # Pack once on XPU (inference-friendly; rebuild if weights change) + # Store packed transpose as [K, N] fp16 contiguous. + w16 = weight.to(device="xpu", dtype=torch.float16) + weight_t = w16.t().contiguous() # [K, N] contiguous + + # Kernel reads packed W^T directly with natural contiguous strides: + wt_bp = tl.make_block_ptr( + base=wt_ptr, + shape=(K, N), + strides=(stride_wtk, stride_wtn), # from weight_t.stride() + offsets=(0, n_base), + block_shape=(BLOCK_K, BLOCK_N), + order=(1, 0), + ) + wt_tile = tl.load(wt_bp, boundary_check=(0, 1)) + expected_speedup: "often significant for GEMM-heavy kernels; workload-dependent" + applies_to: [gemm, fused_gemm_epilogue, attention_linear] + notes: | + - This is "layout-for-speed" (XPU optimization), not memory optimization. + - Overhead: extra device memory for packed transpose (~K*N*2 bytes for fp16). + - Best for inference / fixed weights. If training, refresh weight_t when weight updates. + - Consider returning XPU tensor to avoid CPU transfer overhead in benchmarks. + + - id: xpu_cache_packed_transposes + name: "Cache packed weight transposes on XPU (W^T contiguous)" + stage: xpu_specific + description: "Build W1^T and W2^T once on XPU, store contiguous [K,N] tensors, reuse per forward" + rationale: | + Packing W^T makes the physical layout match tl.dot access order and avoids + expensive transpose+contiguous copies every iteration. This is especially + important in KernelBench-style loops where forward() runs many times. + pattern_before: | + # inside kernel_function / forward + W1 = weight1.to("xpu", torch.float16).contiguous() + W2 = weight2.to("xpu", torch.float16).contiguous() + W1_t = W1.t().contiguous() + W2_t = W2.t().contiguous() + pattern_after: | + # inside Model._move_params_once() (or a cache refresh path) + W1_xpu = weight1.to("xpu", torch.float16).contiguous() # [H, K] + W2_xpu = weight2.to("xpu", torch.float16).contiguous() # [O, H] + W1_t = W1_xpu.t().contiguous() # [K, H] + W2_t = W2_xpu.t().contiguous() # [H, O] + # forward reuses W1_t/W2_t without repacking + expected_speedup: "often large if repacking was on the timed path" + applies_to: [gemm, mlp, fused_linear, attention_linear] + notes: | + - This is layout-for-speed, not memory saving (extra memory for packed transposes). + - If weights change (training), rebuild the packed transpose after updates. + - In inference benchmarks, caching is usually always correct. + + - id: xpu_split_gemm2_and_lse_for_parallelism + name: "Split GEMM2 and LogSumExp to preserve parallelism" + stage: fusion + description: "Use a 2D GEMM kernel for GEMM2, then a separate row-wise LSE reduction kernel" + rationale: | + Fusing GEMM2 with row-wise LogSumExp by sweeping N tiles inside a single pid_m + reduces parallel work to ~cdiv(M,BLOCK_M) programs, which can severely underutilize XPU. + A 2D GEMM produces C2 tiles with high parallelism; a follow-up reduction kernel + computes LSE efficiently over rows. + pattern_before: | + # GEMM2+LSE fused with 1D grid over M: + pid_m = tl.program_id(0) + for n_start in range(0, N, BLOCK_N): + # compute tile [BM, BN] and update m_row/s_row + ... + y[m] = log(s_row) + m_row + pattern_after: | + # Step 1: GEMM2 (2D grid over M,N) -> C2 [M, N] (fp16 or fp32) + grid = (cdiv(M, BLOCK_M), cdiv(N, BLOCK_N)) + gemm2_kernel[grid](inter, W2_t, bias2, C2, ...) + + # Step 2: Row-wise LogSumExp (1D grid over M) -> Y [M] + grid = (cdiv(M, BLOCK_M),) + row_lse_kernel[grid](C2, Y, ...) + expected_speedup: "often dramatic vs serialized GEMM2+LSE on XPU" + applies_to: [mlp, attention_scores, logits_lse] + notes: | + - Materializing C2 costs memory (~M*N), but restores occupancy and throughput. + - If memory is tight, consider tiling/reduction strategies that still parallelize over N + (e.g., partial reductions + second-stage reduction), but avoid full serialization. + + - id: xpu_gemm2_2d_grid_required + name: "GEMM2 should use a 2D grid when producing a matrix output" + stage: xpu_specific + description: "Use pid_m/pid_n mapping for GEMM tiles; avoid 1D pid_m-only mapping" + rationale: | + GEMM kernels need parallelism across both output dimensions to saturate the device. + Restricting to pid_m-only often leaves too few programs in flight. + pattern_after: | + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + # compute tile [BLOCK_M, BLOCK_N] of C + expected_speedup: "prevents severe underutilization" + applies_to: [gemm, linear, matmul] + + - id: xpu_packed_transpose_kernel_signature + name: "Kernel signatures should accept packed transpose as [K, N] with its own strides" + stage: block_pointers + description: "Pass packed transpose pointer and its strides into tl.make_block_ptr" + rationale: | + Using packed transposes only helps if the kernel loads them with the correct + stride metadata (from the packed tensor), not reused from the original. + pattern_after: | + wt_bp = tl.make_block_ptr( + base=WT_ptr, + shape=(K, N), + strides=(stride_wtk, stride_wtn), # from WT.stride() + offsets=(k_start, n_base), + block_shape=(BLOCK_K, BLOCK_N), + order=(1, 0), + ) + expected_speedup: "enables fast-path memory access for RHS" + applies_to: [gemm, fused_gemm_epilogue] + + - id: xpu_sigmoid_use_exp2_not_exp_fp32 + name: "Sigmoid: avoid fp32 tl.exp; prefer tl.math.exp2-based sigmoid (XPU fast-path)" + severity: high + description: | + On Intel XPU, using fp32 `tl.exp()` in the sigmoid epilogue can be a throughput bottleneck. + Prefer an exp2-based formulation: + exp(x) = exp2(x / ln(2)) + so: + sigmoid(x) = 1 / (1 + exp(-x)) + = 1 / (1 + exp2((-x) / ln(2))) + + WRONG (often slower on XPU): + ```python + # acc is fp32 + sig = 1.0 / (1.0 + tl.exp(-acc)) + out = acc + 2.0 * sig + ``` + + CORRECT (typically faster / broader backend support): + ```python + @triton.jit + def sigmoid_exp2_fp32(x: tl.tensor): + inv_ln2 = 1.4426950408889634 # 1 / ln(2) + e = tl.math.exp2((-x) * inv_ln2) + return 1.0 / (1.0 + e) + + sig = sigmoid_exp2_fp32(acc) + out = acc + 2.0 * sig + ``` + + Notes: + - Keep `acc` in fp32 if you need accuracy; this rule is about the exp implementation. + - If your workload tolerates lower precision, consider computing sigmoid in fp16/bf16, + but only if correctness thresholds allow it. + + - id: xpu_warp_sweep_no_fixed_32 + name: "Warp count: do NOT force 32 warps; sweep 4/8/16 (32 only if autotune proves it)" + severity: high + description: | + On Intel XPU, hard-coding `num_warps=32` frequently hurts performance due to + higher GRF usage and reduced occupancy—especially when the kernel has a heavy + epilogue (sigmoid/bn/lse) or large tiles. + + WRONG (common regression pattern): + ```python + @triton.autotune( + configs=[triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64, 'grf_mode': '256'}, + num_warps=32, num_stages=4)], + key=['M','N','K'], + ) + ``` + or: + ```python + # "XPU guidance says 32 warps" applied unconditionally + num_warps = 32 + ``` + + CORRECT (treat warps as a tuned parameter): + ```python + cfgs = [ + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'grf_mode': '256'}, num_warps=8, num_stages=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'grf_mode': '256'}, num_warps=16, num_stages=3), + # 32 only if it actually wins in autotune: + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 256, 'BLOCK_K': 64, 'grf_mode': '256'}, num_warps=32, num_stages=3), + ] + @triton.autotune(configs=cfgs, key=['M','N','K']) + @triton.jit + def kernel(..., BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): + ... + ``` + + Practical guidance: + - If epilogue is heavy (sigmoid/bn/lse): start with 8 or 16 warps. + - If the kernel is pure GEMM and tiles are large: 16 may win; 32 is "maybe", not default. + - Always sweep GRF mode together with warps if you see register pressure. + + Exception for large pure GEMMs: + For large, square GEMMs (e.g., 3072x4096x3072) with 256x256 tiles, + 32 warps + grf_mode='256' is a proven strong default — Intel's own XPU + benchmark kernels ship with this as a single config. The "don't blindly + use 32" guidance is most critical for skinny/non-square shapes and + kernels with heavy epilogues. + + - id: xpu_hardware_subslice_query + name: "Query XPU subslice count for hardware-aware program counts" + stage: xpu_specific + description: "Use gpu_subslice_count for Stream K or persistent kernel program counts" + rationale: | + Stream K and persistent kernels need to know the hardware parallelism. + On Intel XPU, the relevant unit is the subslice (Xe core) count. + Do NOT hard-code this — it varies across GPU models. + pattern_after: | + num_xe_core = torch.xpu.get_device_capability(0)['gpu_subslice_count'] + streamk_programs = num_xe_core + expected_speedup: "enables optimal utilization" + applies_to: [stream_k, persistent_kernel, xpu_specific] + notes: | + - This is the XPU equivalent of CUDA's SM count + - For persistent kernels, consider autotuning NUM_PROGS around this value + + - id: xpu_batched_gemm_pattern + name: "Batched GEMM: use batch dimension on grid axis=1 with int64 offsets" + stage: xpu_specific + description: "Add batch dimension via tl.program_id(axis=1) and per-batch pointer offsets" + rationale: | + Batched GEMM (B x M x K) @ (B x K x N) → (B x M x N) extends the standard + tiled GEMM by adding a batch grid dimension. The (M, N) tile swizzling stays + on axis=0; the batch index goes on axis=1. Each batch computes an independent + GEMM by offsetting all pointers by bid * stride_z. + + Key details: + - Cast bid to int64 before multiplying by batch stride to avoid int32 overflow + on large tensors (see correctness.yaml: int64_cast_for_large_batch_offsets) + - Tile swizzling (GROUP_SIZE_M) works identically within each batch + - The K-loop, boundary masking, and accumulation are unchanged per batch + pattern_after: | + # Grid: axis=0 = flattened (M, N) tiles with swizzle, axis=1 = batch + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + B, + ) + + @triton.jit + def matmul_kernel_batched( + a_ptr, b_ptr, c_ptr, + B: tl.constexpr, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, + stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, + stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, + stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + bid = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + + # Standard tile swizzle on axis=0 (same as non-batched) + # ... pid_m, pid_n from GROUP_SIZE_M swizzle ... + + # Batch offsets — int64 to prevent overflow + offset_a = bid.to(tl.int64) * stride_az + offset_b = bid.to(tl.int64) * stride_bz + a_ptrs = a_ptr + offset_a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + offset_b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + # K-loop, accumulate, store (identical to non-batched) + # ... + offset_c = bid.to(tl.int64) * stride_cz + c_ptrs = c_ptr + offset_c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + expected_speedup: "enables batched workloads; parallelism scales with B" + applies_to: [gemm, batched_gemm, attention, multi_head] + notes: | + - Grid axis=0 stays 1D (swizzled M*N tiles), axis=1 = batch + - All swizzle/autotune patterns apply unchanged within each batch + - Block pointer / descriptor variants work identically — just add batch offset to base pointer + - For attention: heads dimension often maps to the same batch axis + + - id: xpu_warp_size_autotune + name: "Sweep warp_size (sub-group size) alongside num_warps on XPU" + stage: xpu_specific + description: "Include warp_size={16, 32} in autotune configs for Intel XPU kernels" + rationale: | + Intel XPU supports sub-group (warp) sizes of 16 and 32. The optimal + choice depends on the kernel's register pressure, reduction width, and + memory access pattern. This is a SEPARATE axis from num_warps — both + should be swept. + + For row-wise kernels (softmax, layernorm) with narrow rows, warp_size=16 + with more warps can outperform warp_size=32 because each sub-group + covers fewer columns, allowing more rows to run concurrently. + + KB already covers num_warps and grf_mode sweeping; warp_size is a third + independent axis that can significantly affect performance. + pattern_after: | + @triton.autotune( + configs=[ + # warp_size=32 variants + triton.Config({"warp_size": 32}, num_warps=32), + triton.Config({"warp_size": 32}, num_warps=16), + triton.Config({"warp_size": 32}, num_warps=8), + triton.Config({"warp_size": 32}, num_warps=4), + # warp_size=16 variants (more warps possible) + triton.Config({"warp_size": 16}, num_warps=64), + triton.Config({"warp_size": 16}, num_warps=32), + triton.Config({"warp_size": 16}, num_warps=16), + triton.Config({"warp_size": 16}, num_warps=8), + triton.Config({"warp_size": 16}, num_warps=4), + ], + key=[...], + ) + expected_speedup: "workload-dependent; can be significant for row-wise kernels" + applies_to: [softmax, layernorm, reduction, attention, row_wise] + notes: | + - warp_size=16 allows up to 64 warps (more parallelism per program) + - warp_size=32 is standard for GEMM; warp_size=16 often better for reductions + - Combine with num_warps and grf_mode for full autotune coverage + + - id: xpu_work_group_size_query + name: "Query max_work_group_size for multi-row tiling in row-wise kernels" + stage: xpu_specific + description: "Use device max_work_group_size to compute how many rows fit per program" + rationale: | + Row-wise kernels (softmax, layernorm, RMS norm) that fit one row per block + can underutilize the hardware when rows are narrow. Packing multiple rows + per program (BLOCK_SIZE_Y > 1) increases work per program and utilization. + + The maximum rows per program is bounded by the hardware work group size + divided by the column block size. Query this at runtime — do not hard-code. + pattern_after: | + from triton.runtime import driver + + device = torch.xpu.current_device() + properties = driver.active.utils.get_device_properties(device) + MAX_WORK_GROUP_SIZE = properties["max_work_group_size"] + + BLOCK_SIZE_X = triton.next_power_of_2(n_cols) + BLOCK_SIZE_Y = MAX_WORK_GROUP_SIZE // BLOCK_SIZE_X + BLOCK_SIZE_Y = BLOCK_SIZE_Y if BLOCK_SIZE_Y > 0 else 1 + + grid = (n_rows // BLOCK_SIZE_Y,) + softmax_kernel[grid](..., BLOCK_SIZE_X=BLOCK_SIZE_X, BLOCK_SIZE_Y=BLOCK_SIZE_Y) + expected_speedup: "significant for narrow rows where single-row programs underutilize" + applies_to: [softmax, layernorm, rms_norm, row_wise, reduction] + notes: | + - BLOCK_SIZE_X must be power-of-2 and >= n_cols (use triton.next_power_of_2) + - BLOCK_SIZE_Y = 1 is the fallback when rows are wide enough to fill the work group + - Grid shrinks by BLOCK_SIZE_Y (fewer programs, more rows each) + - This is complementary to xpu_hardware_subslice_query (which queries Xe core count) + + - id: xpu_power_of_2_row_block + name: "Use next_power_of_2 for row-wise kernels that process entire rows per block" + stage: xpu_specific + description: "Set BLOCK_SIZE to triton.next_power_of_2(n_cols) for single-block row processing" + rationale: | + Softmax, layernorm, and similar row-wise kernels often process an entire row + in a single block (no K-loop). Triton requires power-of-2 block sizes, so + the column block must be rounded up. This is computed at launch time, not + inside the kernel. + pattern_after: | + BLOCK_SIZE_X = triton.next_power_of_2(n_cols) + # Mask out-of-bounds columns in the kernel: + col_offsets = tl.arange(0, BLOCK_SIZE_X) + mask = col_offsets < n_cols + row = tl.load(row_ptr + col_offsets, mask=mask, other=-float("inf")) + expected_speedup: "N/A (required for correctness; enables single-block row processing)" + applies_to: [softmax, layernorm, rms_norm, row_wise] + + - id: xpu_module_level_constexpr + name: "Use tl.constexpr() for module-level compile-time constants" + stage: xpu_specific + description: "Define mathematical constants at module scope with tl.constexpr()" + rationale: | + Constants used across multiple kernels or JIT helpers (e.g., sqrt(2/pi) for + GeLU, 1/ln(2) for exp2-based sigmoid) can be defined once at module level + using tl.constexpr(). This ensures they are compile-time constants, avoids + recomputation, and makes them reusable across all @triton.jit functions in + the module. + pattern_after: | + import math + import triton.language as tl + + # Module-level compile-time constants + kAlpha = tl.constexpr(math.sqrt(2.0 / math.pi)) # GeLU constant + kInvLn2 = tl.constexpr(1.4426950408889634) # 1/ln(2) for exp2 + + @triton.jit + def gelu(x): + return 0.5 * x * (1 + tanh(kAlpha * (x + 0.044715 * x * x * x))) + expected_speedup: "N/A (code clarity and reuse pattern)" + applies_to: [all, activation, epilogue] + + - id: xpu_reusable_jit_activation_helpers + name: "Factor activation functions into reusable @triton.jit helpers" + stage: xpu_specific + description: "Define common activations (tanh, gelu, silu, etc.) as standalone JIT functions" + rationale: | + Complex activations like GeLU involve multiple operations (tanh, multiply, + add, cube). Defining them as standalone @triton.jit functions makes them: + - Reusable across multiple kernel epilogues + - Testable independently + - Easy to swap (e.g., replace GeLU with SiLU in an epilogue) + + These helpers compose: gelu() calls tanh(), which calls tl.sigmoid(). + Triton inlines @triton.jit calls, so there is no function call overhead. + pattern_after: | + @triton.jit + def tanh(x): + # sigmoid-based tanh: avoids separate tl.exp/tl.math.tanh + return 2 * tl.sigmoid(2 * x) - 1 + + @triton.jit + def gelu(x): + # Tanh-approximation GeLU (matches PyTorch nn.functional.gelu) + return 0.5 * x * (1 + tanh(kAlpha * (x + 0.044715 * x * x * x))) + + @triton.jit + def silu(x): + return x * tl.sigmoid(x) + + # Use in any GEMM epilogue: + @triton.jit + def gemm_gelu_kernel(...): + # ... K-loop ... + c = gelu(accumulator) + c_desc.store([...], c) + expected_speedup: "N/A (code organization pattern; no perf impact vs inline)" + applies_to: [gemm, mlp, activation, epilogue] + notes: | + - @triton.jit helpers are inlined at compile time — no call overhead + - Same pattern used for swizzle_tile helper (see xpu_tile_swizzling) + - Can define sigmoid_exp2, silu, gelu, tanh, etc. in a shared module + + - id: xpu_sigmoid_based_tanh + name: "Implement tanh via sigmoid: tanh(x) = 2*sigmoid(2x) - 1" + stage: xpu_specific + description: "Use tl.sigmoid to compute tanh, avoiding tl.exp or tl.math.tanh" + rationale: | + The identity tanh(x) = 2*sigmoid(2x) - 1 lets you implement tanh using + tl.sigmoid, which may have a fast hardware path. This avoids needing a + separate tl.math.tanh or tl.exp call. Used in Intel XPU benchmark kernels + for GeLU activation. + + Combine with exp2-based sigmoid (see xpu_sigmoid_use_exp2_not_exp_fp32) + for maximum throughput if tl.sigmoid's built-in path is slow: + sigmoid(x) = 1 / (1 + exp2(-x * 1/ln2)) + tanh(x) = 2 * sigmoid(2x) - 1 + pattern_after: | + @triton.jit + def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + expected_speedup: "workload-dependent; avoids separate tanh intrinsic" + applies_to: [activation, gelu, tanh, epilogue] + + - id: xpu_exp2_attention_scaling + name: "Use exp2/log2 with RCP_LN2 scaling for attention softmax" + stage: xpu_specific + description: "Scale QK scores by 1/ln(2) and use tl.math.exp2 instead of tl.exp throughout attention" + rationale: | + Flash Attention computes softmax(QK^T * scale). The standard approach uses + tl.exp, but on Intel XPU, tl.math.exp2 is faster (it maps to native hardware). + + The trick: multiply the QK scale factor by 1/ln(2) = 1.44269504, then use + exp2 everywhere (both in the forward online softmax and backward). This is + algebraically equivalent because exp(x) = exp2(x / ln(2)). + + This extends the exp2 pattern from sigmoid (xpu_sigmoid_use_exp2_not_exp_fp32) + to the full attention computation. + pattern_after: | + # Forward: scale by 1/ln(2) once, then use exp2 throughout + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) — converts to base-2 + + # In the attention inner loop: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) # NOT tl.exp! + alpha = tl.math.exp2(m_i - m_ij) # correction factor also uses exp2 + + # Epilogue: + m_i += tl.math.log2(l_i) # log2, not log! + + # Backward: pre-scale K by sm_scale * RCP_LN2 + RCP_LN2 = 1.4426950408889634 + arg_k = k * (sm_scale * RCP_LN2) + # Then exp2 in backward loops too: + pT = tl.math.exp2(qkT - m[None, :]) + expected_speedup: "significant on XPU (exp2 maps to native hardware)" + applies_to: [attention, flash_attention, softmax, online_softmax] + notes: | + - Both forward and backward must use consistent base-2 arithmetic + - The LN2 constant (0.6931471824645996) is used in backward for dQ scaling + - Pre-scaling K in backward avoids repeated multiplication inside the loop + + - id: xpu_tl_multiple_of_hint + name: "Use tl.multiple_of() to hint alignment to the compiler" + stage: xpu_specific + description: "Tell Triton a value is a multiple of a block size for better codegen" + rationale: | + tl.multiple_of(value, N) tells the Triton compiler that `value` is guaranteed + to be a multiple of N. This enables: + - Better vectorized memory access patterns + - Elimination of boundary checks + - More aggressive loop optimizations + + Use it when a loop variable or offset is known to be block-aligned but the + compiler cannot prove it statically. + pattern_after: | + # In Flash Attention causal masking: + if STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) # compiler knows lo is aligned + + # In a K-loop where start_n is always block-aligned: + for start_n in tl.range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # ... loads using start_n are now optimized + expected_speedup: "modest (enables compiler optimizations)" + applies_to: [attention, gemm, any_blocked_loop] + + - id: xpu_1d_tensor_descriptor + name: "Use 1D tensor descriptors for per-row metadata (logsumexp, max, sum)" + stage: memory_access + description: "Create 1D descriptors for storing/loading vector data like attention stats" + rationale: | + Tensor descriptors are not limited to 2D tiles. 1D descriptors work for + per-row metadata vectors (logsumexp M, delta D, row sums, etc.) that are + common in attention and normalization kernels. + + 1D descriptors handle boundary checking automatically, just like 2D ones. + pattern_after: | + # Store per-row logsumexp values + desc_m = tl.make_tensor_descriptor( + base=M + off_hz * N_CTX, # offset to correct batch/head + shape=[N_CTX], # 1D shape + strides=[1], # contiguous + block_shape=[BLOCK_M], # 1D block + ) + desc_m.store([start_m * BLOCK_M], m_i) + + # Load per-row values in backward + # (or use tl.load for simple 1D access when descriptors aren't needed) + expected_speedup: "N/A (API pattern; same perf as manual pointer + mask)" + applies_to: [attention, normalization, reduction, metadata] + + - id: xpu_descriptor_load_transpose + name: "Transpose after descriptor load: desc.load([...]).T" + stage: memory_access + description: "Load a tile via descriptor then transpose in-register with .T" + rationale: | + Some operations need data in transposed layout (e.g., attention loads K as + [BLOCK_N, HEAD_DIM] but needs [HEAD_DIM, BLOCK_N] for QK^T = Q @ K^T). + + Rather than creating a descriptor with swapped strides (which may not match + the physical layout), load in natural layout then transpose in-register: + k = desc_k.load([offset, 0]).T + + For backward passes, tl.trans(tensor) provides the same in-register transpose: + dpT = tl.dot(v, tl.trans(do)) + pattern_after: | + # Forward: load K as [BLOCK_N, HEAD_DIM], transpose for Q @ K^T + desc_k = tl.make_tensor_descriptor( + K, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_N, HEAD_DIM]) + k = desc_k.load([offsetk_y, 0]).T # now [HEAD_DIM, BLOCK_N] + qk = tl.dot(q, k) # [BLOCK_M, HEAD_DIM] @ [HEAD_DIM, BLOCK_N] + + # Backward: transpose with tl.trans() + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dk += tl.dot(dsT, tl.trans(qT)) + expected_speedup: "N/A (required for correct matrix orientation)" + applies_to: [attention, gemm, any_transpose] + notes: | + - .T on a load result is an in-register operation (no extra memory access) + - tl.trans() is equivalent to .T but can be applied to any tensor variable + - Both are used extensively in Flash Attention backward passes + + - id: xpu_adaptive_grid_layout + name: "Adapt grid layout based on problem size for occupancy" + stage: xpu_specific + description: "Use different grid dimension mappings for small vs large problems" + rationale: | + For attention or similar kernels, the grid needs batch (Z), head (H), and + tile dimensions. How these map to grid axes affects occupancy: + + - Large N_CTX (> 512): grid = (Z, H, num_blocks_m) + Each batch and head gets its own grid dimension for maximum parallelism. + + - Small N_CTX (<= 512): grid = (num_blocks_m, 1, Z*H) + Flatten batch×head into axis=2, put tiles on axis=0. This ensures enough + total programs even when num_blocks_m is small. + + The kernel must decode the program IDs differently based on the layout. + pattern_after: | + # Launch-side: + if n_ctx <= 512: + grid = lambda args: (triton.cdiv(n_ctx, args['BLOCK_M']), 1, Z * H) + else: + grid = lambda args: (Z, H, triton.cdiv(n_ctx, args['BLOCK_M'])) + + # Kernel-side: + if N_CTX <= 512: + start_m = tl.program_id(0) + off_hz = tl.program_id(2) + off_z = off_hz // H + off_h = off_hz % H + else: + off_z = tl.program_id(0) + off_h = tl.program_id(1) + start_m = tl.program_id(2) + expected_speedup: "significant for small sequence lengths (avoids underutilization)" + applies_to: [attention, flash_attention, variable_sequence_length] + + - id: xpu_blk_slice_factor_causal + name: "Use smaller blocks for causal-masked (diagonal) tiles" + stage: xpu_specific + description: "BLK_SLICE_FACTOR reduces block size for masked regions to avoid wasted compute" + rationale: | + In causal attention, diagonal tiles (where the causal mask is active) have + many masked-out elements. Using full-size blocks wastes compute on zeros. + + BLK_SLICE_FACTOR divides the block size for masked regions: + MASK_BLOCK_M1 = BLOCK_M1 // BLK_SLICE_FACTOR + + The kernel runs the masked loop with smaller blocks, then switches to full + blocks for the unmasked (fully below-diagonal) region. + pattern_after: | + BLOCK_M1, BLOCK_N1 = 32, 128 + BLK_SLICE_FACTOR = 2 + MASK_BLOCK_M1 = BLOCK_M1 // BLK_SLICE_FACTOR # = 16 for masked tiles + + # Masked (diagonal) pass with smaller blocks: + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _attn_bwd_dkdv(..., MASK_BLOCK_M1, BLOCK_N1, ..., MASK=True) + + # Unmasked (full) pass with normal blocks: + num_steps = (N_CTX - start_m) // BLOCK_M1 + dk, dv = _attn_bwd_dkdv(..., BLOCK_M1, BLOCK_N1, ..., MASK=False) + expected_speedup: "reduces wasted compute in causal backward pass" + applies_to: [attention, flash_attention, causal, backward] + + - id: xpu_flattened_4d_to_2d_descriptors + name: "Flatten 4D tensors to 2D for tensor descriptors with manual batch/head offset" + stage: memory_access + description: "Reshape [Z, H, N_CTX, D] as [Z*H*N_CTX, D] for descriptors, compute offsets manually" + rationale: | + Tensor descriptors support 2D shapes. For 4D attention tensors [Z, H, N_CTX, D], + flatten the first three dims into one and use a manual offset to select the + correct batch/head slice. This works because Q/K/V/O are contiguous with + stride pattern [H*N_CTX*D, N_CTX*D, D, 1]. + pattern_after: | + y_dim = Z * H * N_CTX + desc_q = tl.make_tensor_descriptor( + Q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], + block_shape=[BLOCK_M, HEAD_DIM]) + + # Compute offset into the flattened dimension: + offset_y = off_z * (N_CTX * H) + off_h * N_CTX + qo_offset_y = offset_y + start_m * BLOCK_M + + # Load from the correct batch/head/position: + q = desc_q.load([qo_offset_y, 0]) + expected_speedup: "N/A (enables descriptor use for multi-dim tensors)" + applies_to: [attention, flash_attention, multi_head] + notes: | + - Requires that Q/K/V are contiguous (assert in wrapper) + - The [HEAD_DIM, 1] strides mean the last dim is always contiguous + - K/V descriptors use the same flattened shape but may load with .T \ No newline at end of file diff --git a/kernel-builder/skills/xpu-kernels/scripts/analyze_kernel.py b/kernel-builder/skills/xpu-kernels/scripts/analyze_kernel.py new file mode 100755 index 00000000..08ba5369 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/analyze_kernel.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +Analyze PyTorch kernel to extract structure and guide Triton optimization. + +Usage: + python scripts/analyze_kernel.py +""" + +import ast +import re +import sys +from pathlib import Path +from typing import Dict, List, Set, Tuple + + +class KernelAnalyzer(ast.NodeVisitor): + """AST visitor to analyze PyTorch model operations.""" + + def __init__(self): + self.operations = [] + self.shapes = {} + self.dtypes = set() + self.has_matmul = False + self.has_linear = False + self.activations = [] + self.reductions = [] + self.elementwise = [] + + def visit_Call(self, node): + """Visit function calls to identify operations.""" + # torch.matmul + if isinstance(node.func, ast.Attribute): + if hasattr(node.func.value, "id") and node.func.value.id == "torch": + op_name = node.func.attr + self.operations.append(op_name) + + if op_name == "matmul": + self.has_matmul = True + elif op_name in ["sum", "mean", "max", "min"]: + self.reductions.append(op_name) + elif op_name in ["sigmoid", "tanh", "relu", "gelu", "silu"]: + self.activations.append(op_name) + elif op_name == "clamp": + self.elementwise.append("clamp") + + # torch.nn.functional + elif hasattr(node.func.value, "attr"): + if node.func.value.attr == "functional": + op_name = node.func.attr + self.operations.append(f"F.{op_name}") + if op_name in ["gelu", "relu", "silu", "softmax", "sigmoid"]: + self.activations.append(op_name) + + self.generic_visit(node) + + def visit_BinOp(self, node): + """Visit binary operations (*, /, +, -).""" + op_map = { + ast.Mult: "multiply", + ast.Div: "divide", + ast.Add: "add", + ast.Sub: "subtract", + } + op_type = type(node.op) + if op_type in op_map: + self.elementwise.append(op_map[op_type]) + self.generic_visit(node) + + def visit_Assign(self, node): + """Visit assignments to track nn.Linear.""" + if isinstance(node.value, ast.Call): + if hasattr(node.value.func, "attr") and node.value.func.attr == "Linear": + self.has_linear = True + self.generic_visit(node) + + +def analyze_pytorch_kernel(filepath: Path) -> Dict: + """Analyze PyTorch kernel file and extract optimization hints.""" + + with open(filepath, "r") as f: + source = f.read() + + tree = ast.parse(source) + analyzer = KernelAnalyzer() + analyzer.visit(tree) + + # Extract shape information from module-level variables + shapes = {} + for line in source.split("\n"): + if "=" in line and any( + dim in line + for dim in ["batch_size", "in_features", "out_features", "hidden_size", "input_size"] + ): + match = re.match(r"(\w+)\s*=\s*(\d+)", line.strip()) + if match: + shapes[match.group(1)] = int(match.group(2)) + + # Determine kernel type + kernel_type = "unknown" + if analyzer.has_matmul or analyzer.has_linear: + if analyzer.activations or analyzer.elementwise: + kernel_type = "gemm_epilogue" + elif analyzer.reductions: + kernel_type = "gemm_reduction" + else: + kernel_type = "gemm" + elif analyzer.reductions: + kernel_type = "reduction" + elif analyzer.elementwise: + kernel_type = "elementwise" + + # Fusion analysis + fusion_opportunities = [] + if analyzer.has_matmul or analyzer.has_linear: + if len(analyzer.activations) <= 2 and len(analyzer.elementwise) <= 3: + fusion_opportunities.append("Light epilogue fusion (GEMM + activation + elementwise)") + else: + fusion_opportunities.append("Heavy epilogue - consider partial fusion or split") + + if analyzer.reductions: + fusion_opportunities.append( + "WARNING: GEMM + reduction - use 2D GEMM then separate reduction kernel" + ) + + # Memory pattern recommendation + memory_pattern = "block_pointers" # default + if "Stream K" in str(analyzer.operations) or len(analyzer.reductions) > 1: + memory_pattern = "tensor_descriptors" + + return { + "kernel_type": kernel_type, + "operations": analyzer.operations, + "activations": analyzer.activations, + "reductions": analyzer.reductions, + "elementwise": analyzer.elementwise, + "shapes": shapes, + "fusion_opportunities": fusion_opportunities, + "memory_pattern": memory_pattern, + "has_gemm": analyzer.has_matmul or analyzer.has_linear, + } + + +def print_analysis(analysis: Dict, filepath: Path): + """Pretty print the analysis results.""" + + print(f"\n{'=' * 70}") + print(f"Analysis: {filepath.name}") + print(f"{'=' * 70}\n") + + print(f"Kernel Type: {analysis['kernel_type'].upper()}") + print(f"Memory Pattern: {analysis['memory_pattern']}") + print() + + if analysis["shapes"]: + print("Shapes:") + for key, val in analysis["shapes"].items(): + print(f" {key}: {val}") + print() + + print("Operations:") + print(f" Total: {len(analysis['operations'])}") + if analysis["has_gemm"]: + print(f" ✓ GEMM/Linear") + if analysis["activations"]: + print(f" ✓ Activations: {', '.join(set(analysis['activations']))}") + if analysis["reductions"]: + print(f" ✓ Reductions: {', '.join(set(analysis['reductions']))}") + if analysis["elementwise"]: + print(f" ✓ Elementwise: {', '.join(set(analysis['elementwise']))}") + print() + + if analysis["fusion_opportunities"]: + print("Fusion Opportunities:") + for opp in analysis["fusion_opportunities"]: + if "WARNING" in opp: + print(f" ⚠️ {opp}") + else: + print(f" → {opp}") + print() + + # Recommendations + print("Recommended Optimizations:") + + if analysis["has_gemm"]: + print(" 1. Use tensor descriptors (preferred on XPU) or block pointers") + print(" 2. Apply tile swizzling (GROUP_SIZE_M)") + + # Tile size recommendations based on shape + batch_size = analysis["shapes"].get("batch_size", 0) + if batch_size and batch_size < 256: + print(" 3. Use smaller BLOCK_M (32-64) for skinny M") + else: + print(" 3. Try large tiles (256x256) with autotune") + + print(" 4. Autotune: num_warps={4,8,16,32}, grf_mode='256'") + print(" 5. Mixed precision: bf16/fp16 inputs, fp32 accumulator") + print(" 6. Pre-pack weight transpose: weight_t = weight.t().contiguous()") + + if "sigmoid" in analysis["activations"]: + print(" → Use exp2-based sigmoid (faster on XPU)") + + if "tanh" in analysis["activations"]: + print(" → Implement tanh via sigmoid: tanh(x) = 2*sigmoid(2x) - 1") + + if "gelu" in analysis["activations"]: + print(" → Use tanh-approximation GeLU with JIT helper") + + if analysis["reductions"]: + if analysis["has_gemm"]: + print(" ⚠️ Split GEMM and reduction into separate kernels") + print(" (Don't serialize over N tiles inside one program)") + else: + print(" → Use multi-row tiling for reductions") + print(" → Query max_work_group_size for BLOCK_SIZE_Y") + + print() + + print("Relevant Reference Files:") + print(" • references/xpu_optimizations.yaml - Core XPU patterns") + if analysis["fusion_opportunities"]: + print(" • references/fusion_patterns.yaml - Fusion guidelines") + print(" • references/memory_patterns.yaml - Memory access patterns") + print(" • references/correctness.yaml - Critical constraints") + print() + + # Template suggestion + if analysis["kernel_type"] == "gemm": + print("Suggested Template: See GEMM pattern in references/implementation_reference.md") + elif analysis["kernel_type"] == "gemm_epilogue": + print("Suggested Template: See GEMM with epilogue pattern in references/implementation_reference.md") + elif analysis["kernel_type"] in ["reduction", "gemm_reduction"]: + print("Suggested Template: See reduction pattern in references/implementation_reference.md") + print() + + +def main(): + if len(sys.argv) != 2: + print("Usage: python scripts/analyze_kernel.py ") + sys.exit(1) + + filepath = Path(sys.argv[1]) + if not filepath.exists(): + print(f"Error: File not found: {filepath}") + sys.exit(1) + + analysis = analyze_pytorch_kernel(filepath) + print_analysis(analysis, filepath) + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/skills/xpu-kernels/scripts/benchmark.py b/kernel-builder/skills/xpu-kernels/scripts/benchmark.py new file mode 100755 index 00000000..41e4351b --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/benchmark.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python3 +""" +Benchmark Triton kernel against a baseline (PyTorch or Triton). + +Validates correctness and measures performance using ai-bench. + +Usage: + python scripts/benchmark.py [--spec ] [--device ] [--ci] [--triton-baseline] [--baseline-us 123.45] + +Examples: + python scripts/benchmark.py test_kernels/14_Gemm_Divide_Sum_Scaling_pytorch.py output/14_Gemm_Divide_Sum_Scaling_triton.py + python scripts/benchmark.py test_kernels/14_Gemm_Divide_Sum_Scaling_triton.py output/14_optimized_triton.py --triton-baseline + python scripts/benchmark.py baseline.py triton.py --baseline-us 123.45 # skip baseline perf, use cached value +""" + +import argparse +import importlib.util +import sys +from pathlib import Path +from types import SimpleNamespace + +PROJECT_ROOT = Path(__file__).resolve().parent.parent + + +def _load_module(filepath: Path, module_name: str): + """Dynamically load a Python module from file path.""" + spec = importlib.util.spec_from_file_location(module_name, filepath) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def _determine_spec_type(spec_file: Path, ci: bool): + """Determine which spec key to use for ai-bench.""" + from ai_bench.harness import core as ai_hc + + if ci: + return ai_hc.SpecKey.V_CI + + import yaml + + with open(spec_file) as f: + raw_spec = yaml.safe_load(f) + + # Prefer bench-xpu, fall back to bench-gpu, then ci + if "bench-xpu" in raw_spec: + spec_key = "bench-xpu" + elif str(ai_hc.SpecKey.V_BENCH_GPU) in raw_spec: + spec_key = str(ai_hc.SpecKey.V_BENCH_GPU) + else: + print(" No benchmark spec found (bench-xpu or bench-gpu). Falling back to CI mode.") + spec_key = str(ai_hc.SpecKey.V_CI) + + if spec_key == str(ai_hc.SpecKey.V_CI): + return ai_hc.SpecKey.V_CI + elif spec_key == str(ai_hc.SpecKey.V_BENCH_GPU): + return ai_hc.SpecKey.V_BENCH_GPU + else: + # bench-xpu — pass as string + return spec_key + + +# --------------------------------------------------------------------------- +# Correctness check via ai-bench +# --------------------------------------------------------------------------- +def run_correctness( + pytorch_file: Path, triton_file: Path, spec_file: Path | None, device_str: str +) -> bool: + """Validate numerical equivalence using ai-bench.""" + if spec_file and spec_file.exists(): + return _run_correctness_with_spec(pytorch_file, triton_file, spec_file, device_str) + else: + return _run_correctness_no_spec(pytorch_file, triton_file, device_str) + + +def _run_correctness_with_spec( + pytorch_file: Path, triton_file: Path, spec_file: Path, device_str: str +) -> bool: + """Correctness check using ai-bench spec infrastructure.""" + try: + import torch + from ai_bench.harness import core as ai_hc + from ai_bench.harness.runner.benchmark_compare import ( + check_correctness, + copy_model_weights, + set_all_seeds, + ) + from ai_bench.harness.runner.kernel_runner import KernelRunner + except ImportError as e: + print(f" Could not import ai_bench: {e}") + print(f" Install ai-bench: pip install -r scripts/requirements.txt") + return False + + device = torch.device(device_str) + + # Use the same spec type as performance (bench-xpu / bench-gpu), fall back to CI + spec_type = _determine_spec_type(spec_file, ci=False) + print(f" Using spec type: {spec_type}") + + runner = KernelRunner( + spec_type=spec_type, + device=device, + backend=ai_hc.Backend.PYTORCH, + ) + + # Load spec + spec = runner.load_spec(spec_file) + + if str(spec_type) not in spec: + print( + f" Spec type '{spec_type}' not in {spec_file.name}, falling back to no-spec correctness check" + ) + return _run_correctness_no_spec(pytorch_file, triton_file, device_str) + + variants = runner.get_spec_variants(spec) + + spec_inputs = runner.get_spec_inputs(spec) + spec_inits = runner.get_spec_inits(spec) + + # Load model classes + pytorch_model_cls = runner.load_model(pytorch_file) + triton_model_cls = runner.load_model(triton_file) + + if pytorch_model_cls is None: + print(f" Could not load PyTorch model from {pytorch_file}") + return False + if triton_model_cls is None: + print(f" Could not load Triton model from {triton_file}") + return False + + print(f" Found {len(variants)} variant(s)") + + all_correct = True + for i, variant in enumerate(variants): + set_all_seeds(123) + + # Log variant details + dims = variant.get("dims", {}) + dtype = variant.get("dtype", "unknown") + print(f" Variant {i}: dtype={dtype}, dims={dims}") + + rtol = ai_hc.get_rtol(variant) + atol = ai_hc.get_atol(variant) + has_explicit_tol = "atol" in variant or "rtol" in variant + # Only apply bf16 atol floor when no explicit spec tolerance was set + if not has_explicit_tol and atol < 1e-2: + atol = 1e-2 + + # Instantiate models with variant-specific init params (eval mode for deterministic BN) + pytorch_model = runner.init_model(pytorch_model_cls, variant, spec_inits).eval() + triton_model = runner.init_model(triton_model_cls, variant, spec_inits).eval() + + # Sync weights from reference to optimized + copy_model_weights(pytorch_model, triton_model) + + # Create inputs from spec + args = ai_hc.get_inputs(variant, spec_inputs, device=device) + + with torch.no_grad(): + pytorch_output = pytorch_model(*args) + triton_output = triton_model(*args) + + # Cast to common dtype (fp32) for comparison — models may use different output dtypes + if isinstance(pytorch_output, tuple): + pytorch_output = pytorch_output[0] + if isinstance(triton_output, tuple): + triton_output = triton_output[0] + pytorch_output = pytorch_output.float() + triton_output = triton_output.float() + + correct = check_correctness(pytorch_output, triton_output, rtol, atol) + status = "PASS" if correct else "FAIL" + print(f" Variant {i}: {status} (rtol={rtol:.1e}, atol={atol:.1e})") + + if not correct: + all_correct = False + + return all_correct + + +def _run_correctness_no_spec(pytorch_file: Path, triton_file: Path, device_str: str) -> bool: + """Fallback correctness check without spec file.""" + try: + import torch + from ai_bench.harness.runner.benchmark_compare import ( + check_correctness, + copy_model_weights, + set_all_seeds, + ) + except ImportError as e: + print(f" Could not import ai_bench: {e}") + print(f" Install ai-bench: pip install -r scripts/requirements.txt") + return False + + device = torch.device(device_str) + + # Default tolerances (bf16 accumulation typically needs atol >= 1e-2) + rtol, atol = 1e-2, 1e-2 + + # Load modules directly + pytorch_mod = _load_module(pytorch_file, "pytorch_ref") + triton_mod = _load_module(triton_file, "triton_kernel") + + set_all_seeds(123) + + init_inputs = pytorch_mod.get_init_inputs() + pytorch_model = pytorch_mod.Model(*init_inputs).to(device).eval() + triton_model = triton_mod.Model(*init_inputs).to(device).eval() + + # Sync weights + copy_model_weights(pytorch_model, triton_model) + + # Run both models + inputs = pytorch_mod.get_inputs() + inputs = [inp.to(device) if hasattr(inp, "to") else inp for inp in inputs] + + with torch.no_grad(): + pytorch_output = pytorch_model(*inputs) + triton_output = triton_model(*inputs) + + # Cast to common dtype (fp32) for comparison + if isinstance(pytorch_output, tuple): + pytorch_output = pytorch_output[0] + if isinstance(triton_output, tuple): + triton_output = triton_output[0] + pytorch_output = pytorch_output.float() + triton_output = triton_output.float() + + correct = check_correctness(pytorch_output, triton_output, rtol, atol) + print(f" Result: {'PASS' if correct else 'FAIL'} (rtol={rtol:.1e}, atol={atol:.1e})") + + return correct + + +# --------------------------------------------------------------------------- +# Performance benchmark via ai-bench +# --------------------------------------------------------------------------- +def find_spec_file(triton_file: Path) -> Path | None: + """Derive the YAML spec path from the same directory as the input file.""" + stem = triton_file.stem + for suffix in ("_triton", "_optimized", "_opt", "_pytorch"): + if stem.endswith(suffix): + stem = stem[: -len(suffix)] + break + + parent = triton_file.parent + for candidate_stem in (stem, stem + "_pytorch"): + candidate = parent / (candidate_stem + ".yaml") + if candidate.exists(): + return candidate + return None + + +def run_performance( + pytorch_file: Path, + triton_file: Path, + spec_file: Path, + device_str: str, + ci: bool, + triton_baseline: bool = False, + baseline_us: list[float] | None = None, +) -> bool: + """Benchmark baseline vs optimized Triton using ai-bench KernelRunner. + + If baseline_us is provided (list of floats, one per variant), the baseline + performance measurement is skipped and the cached values are used instead. + """ + try: + import ai_bench + import torch + from ai_bench.harness import core as ai_hc + from ai_bench.harness.runner.kernel_runner import KernelRunner + except ImportError as e: + print(f" Could not import ai_bench: {e}") + print(f" Install ai-bench: pip install -r scripts/requirements.txt") + return False + + device = torch.device(device_str) + spec_type = _determine_spec_type(spec_file, ci) + + baseline_backend = ai_hc.Backend.TRITON if triton_baseline else ai_hc.Backend.PYTORCH + baseline_label = "Triton baseline" if triton_baseline else "PyTorch baseline" + cached = baseline_us is not None + + # --- Run or cache baseline --- + if cached: + print(f"\n Using cached {baseline_label} times: {baseline_us}") + pytorch_stats = [SimpleNamespace(meas_us=val) for val in baseline_us] + else: + print(f"\n Running {baseline_label}...") + pytorch_runner = KernelRunner( + spec_type=spec_type, + device=device, + backend=baseline_backend, + ) + pytorch_stats = pytorch_runner.run_kernel_spec(pytorch_file, spec_file) + + if pytorch_stats is None: + print(f" {baseline_label}: spec type '{spec_type}' not found in {spec_file.name}") + return False + + # --- Run optimized Triton kernel --- + print(f" Running optimized Triton kernel...") + triton_runner = KernelRunner( + spec_type=spec_type, + device=device, + backend=ai_hc.Backend.TRITON, + ) + triton_stats = triton_runner.run_kernel_spec(triton_file, spec_file) + + if triton_stats is None: + print(f" Triton kernel: spec type '{spec_type}' not found in {spec_file.name}") + return False + + # Validate cached baseline count matches triton variant count + if cached and len(pytorch_stats) != len(triton_stats): + print( + f" Error: --baseline-us has {len(pytorch_stats)} value(s) but spec has " + f"{len(triton_stats)} variant(s). Re-run without --baseline-us." + ) + return False + + # --- CI mode: just validate both run without error --- + if ci or spec_type == ai_hc.SpecKey.V_CI: + print( + f" CI validation: PyTorch {'OK' if pytorch_stats is not None else 'FAILED'}, " + f"Triton {'OK' if triton_stats is not None else 'FAILED'}" + ) + return pytorch_stats is not None and triton_stats is not None + + # --- Report benchmark results --- + if not pytorch_stats or not triton_stats: + print(" No benchmark stats available.") + return False + + if triton_baseline: + baseline_col = "Triton BL (us)*" if cached else "Triton BL (us)" + else: + baseline_col = "PyTorch (us)*" if cached else "PyTorch (us)" + print(f"\n {'Variant':<8} {baseline_col:>16} {'Triton (us)':>14} {'Speedup':>10}") + print(f" {'-' * 52}") + + all_faster = True + for i, (pt_stat, tr_stat) in enumerate(zip(pytorch_stats, triton_stats)): + speedup = pt_stat.meas_us / tr_stat.meas_us if tr_stat.meas_us > 0 else 0 + marker = "+" if speedup >= 1.0 else "-" + if speedup < 1.0: + all_faster = False + print( + f" {i:<8} {pt_stat.meas_us:>16.2f} {tr_stat.meas_us:>14.2f} {speedup:>9.2f}x {marker}" + ) + + if cached: + print(f"\n * baseline cached from prior trial") + + return all_faster + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- +def main(): + parser = argparse.ArgumentParser( + description="Benchmark Triton kernel against a baseline (PyTorch or Triton)" + ) + parser.add_argument( + "pytorch_file", + type=Path, + help="Baseline file (PyTorch by default, or Triton with --triton-baseline)", + ) + parser.add_argument("triton_file", type=Path, help="Optimized Triton kernel implementation") + parser.add_argument( + "--spec", + type=Path, + default=None, + help="YAML spec file (auto-detected if omitted)", + ) + parser.add_argument( + "--device", + default="xpu", + help="Target device (default: xpu, always falls back to xpu)", + ) + parser.add_argument( + "--ci", + action="store_true", + help="CI mode: quick validation only, no benchmarking", + ) + parser.add_argument( + "--triton-baseline", + action="store_true", + help="Baseline file is a Triton kernel (use Backend.TRITON)", + ) + parser.add_argument( + "--baseline-us", + type=str, + default=None, + help="Comma-separated cached baseline time(s) in microseconds (skips baseline perf run)", + ) + args = parser.parse_args() + + # Parse --baseline-us into list of floats + baseline_us = None + if args.baseline_us is not None: + try: + baseline_us = [float(v.strip()) for v in args.baseline_us.split(",")] + except ValueError: + print(f"Error: --baseline-us must be comma-separated floats, got: {args.baseline_us}") + sys.exit(1) + + if not args.pytorch_file.exists(): + print(f"Error: Baseline file not found: {args.pytorch_file}") + sys.exit(1) + if not args.triton_file.exists(): + print(f"Error: Triton file not found: {args.triton_file}") + sys.exit(1) + + # Always fall back to XPU — this project only targets Intel XPU + if args.device != "xpu": + print( + f"Warning: device '{args.device}' requested, falling back to 'xpu' (only XPU is supported)" + ) + args.device = "xpu" + + spec_file = args.spec or find_spec_file(args.triton_file) or find_spec_file(args.pytorch_file) + + baseline_label = "Triton baseline" if args.triton_baseline else "PyTorch baseline" + + print(f"\n{'=' * 70}") + print(f"Benchmark Configuration") + print(f"{'=' * 70}") + print(f"{baseline_label}: {args.pytorch_file}") + print(f"Triton kernel: {args.triton_file}") + print(f"Spec file: {spec_file or '(none)'}") + print(f"Device: {args.device}") + print(f"Mode: {'CI' if args.ci else 'Benchmark'}") + + # --- Correctness --- + print(f"\n{'=' * 70}") + print(f"Correctness Check (ai-bench)") + print(f"{'=' * 70}") + correctness_passed = run_correctness( + args.pytorch_file, args.triton_file, spec_file, args.device + ) + print(f"\n Result: {'PASSED' if correctness_passed else 'FAILED'}") + + # --- Performance --- + performance_passed = None + if spec_file and spec_file.exists(): + print(f"\n{'=' * 70}") + print(f"Performance Benchmark (ai-bench)") + print(f"{'=' * 70}") + performance_passed = run_performance( + args.pytorch_file, + args.triton_file, + spec_file, + args.device, + args.ci, + triton_baseline=args.triton_baseline, + baseline_us=baseline_us, + ) + print(f"\n Result: {'PASSED' if performance_passed else 'FAILED'}") + else: + print(f"\n Skipping performance benchmark (no spec file found)") + + # --- Summary --- + print(f"\n{'=' * 70}") + print(f"Summary") + print(f"{'=' * 70}") + print(f"Correctness: {'PASSED' if correctness_passed else 'FAILED'}") + if performance_passed is not None: + print(f"Performance: {'PASSED' if performance_passed else 'FAILED'}") + else: + print(f"Performance: SKIPPED (no spec file)") + print() + + if correctness_passed and (performance_passed is not False): + print("All checks passed!") + sys.exit(0) + else: + print("Some checks FAILED - see output above for details") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/skills/xpu-kernels/scripts/benchmark_kernels.py b/kernel-builder/skills/xpu-kernels/scripts/benchmark_kernels.py new file mode 100644 index 00000000..66541f50 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/benchmark_kernels.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python3 +""" +Micro-benchmark for all 4 Triton kernels on XPU: RMSNorm, RoPE 3D, GEGLU, AdaLN. + +Measures: + 1. Correctness vs PyTorch reference + 2. Latency (custom vs baseline, warmup + averaged) + 3. Memory bandwidth utilization + +Usage: + python benchmark_kernels.py + python benchmark_kernels.py --kernel rmsnorm + python benchmark_kernels.py --kernel rope + python benchmark_kernels.py --kernel geglu + python benchmark_kernels.py --kernel adaln + python benchmark_kernels.py --dtype float16 + +Requirements: + python -m pip install -r scripts/requirements.txt +""" +import argparse +import time +from typing import Tuple + +import torch +import triton +import triton.language as tl + + +# ============================================================================ +# Kernel 1: RMSNorm +# ============================================================================ +# CRITICAL: BLOCK_D must be >= D (hidden dimension). +# Using autotune with fixed BLOCK_D configs is WRONG because autotune may +# pick BLOCK_D < D, causing only partial row processing. +# Fix: compute BLOCK_D = next_power_of_2(D) dynamically in the Python wrapper. + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, + eps, + HAS_WEIGHT: tl.constexpr, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_D) + mask = col_offsets < D + + x = tl.load(x_ptr + row * stride_x + col_offsets, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + if HAS_WEIGHT: + w = tl.load(weight_ptr + col_offsets, mask=mask, other=1.0).to(tl.float32) + result = x * rms_inv * w + else: + result = x * rms_inv + + tl.store(out_ptr + row * stride_x + col_offsets, result.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight=None, eps=1e-6): + orig_shape = x.shape + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + has_weight = weight is not None + if not has_weight: + weight = torch.empty(0, device=x.device) + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, + x_2d.stride(0), D, float(eps), has_weight, + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view(orig_shape) + + +def pytorch_rmsnorm(x, weight=None, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + out = x * torch.rsqrt(variance + eps) + if weight is not None: + out = out * weight + return out + + +# ============================================================================ +# Kernel 2: RoPE 3D +# ============================================================================ +# CRITICAL: cos/sin have shape [seq_len, head_dim], NOT [batch*seq_len, ...]. +# When grid is (batch * seq_len, num_heads), we must use pid_s % seq_len +# to index into cos/sin to avoid out-of-bounds access for batch > 1. + +@triton.jit +def rope_3d_fwd_kernel( + qk_ptr, cos_ptr, sin_ptr, out_ptr, + seq_len, num_heads, head_dim, + stride_s, stride_h, stride_d, + BLOCK_HD: tl.constexpr, +): + pid_s = tl.program_id(0) + pid_h = tl.program_id(1) + half_dim = head_dim // 2 + offs = tl.arange(0, BLOCK_HD) + mask = offs < half_dim + + base = pid_s * stride_s + pid_h * stride_h + x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32) + x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32) + + seq_idx = pid_s % seq_len + cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32) + sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32) + + out0 = x0 * cos_val - x1 * sin_val + out1 = x0 * sin_val + x1 * cos_val + + tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask) + tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask) + + +def triton_rope_3d(qk, cos, sin): + qk = qk.contiguous() + out = torch.empty_like(qk) + batch, seq_len, num_heads, head_dim = qk.shape + half_dim = head_dim // 2 + qk_flat = qk.view(batch * seq_len, num_heads, head_dim) + out_flat = out.view(batch * seq_len, num_heads, head_dim) + grid = (batch * seq_len, num_heads) + BLOCK_HD = triton.next_power_of_2(half_dim) + num_warps = 4 if BLOCK_HD <= 64 else 8 + rope_3d_fwd_kernel[grid]( + qk_flat, cos, sin, out_flat, + seq_len, num_heads, head_dim, + qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2), + BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2, + ) + return out + + +def pytorch_rope(qk, cos, sin): + half = qk.shape[-1] // 2 + x0, x1 = qk[..., :half], qk[..., half:] + cos_exp = cos.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + sin_exp = sin.unsqueeze(0).unsqueeze(2)[:, :qk.shape[1], :, :half] + out0 = x0 * cos_exp - x1 * sin_exp + out1 = x0 * sin_exp + x1 * cos_exp + return torch.cat([out0, out1], dim=-1) + + +# ============================================================================ +# Kernel 3: GEGLU +# ============================================================================ +# Same BLOCK_SIZE fix as RMSNorm: compute dynamically, do NOT autotune. + +@triton.jit +def geglu_fwd_kernel( + input_ptr, output_ptr, + stride_in, stride_out, hidden_size, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_H) + mask = offs < hidden_size + + gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32) + value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32) + + # GELU approx: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + # Manual tanh for portability + SQRT_2_OVER_PI = 0.7978845608028654 + tanh_arg = SQRT_2_OVER_PI * (gate + 0.044715 * gate * gate * gate) + e2x = tl.exp(2.0 * tanh_arg) + tanh_val = (e2x - 1.0) / (e2x + 1.0) + cdf = 0.5 * (1.0 + tanh_val) + gelu_gate = gate * cdf + result = gelu_gate * value + + tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask) + + +def triton_geglu(x): + x = x.contiguous() + *batch_dims, double_h = x.shape + hidden_size = double_h // 2 + x_2d = x.view(-1, double_h) + M = x_2d.shape[0] + out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype) + + BLOCK_H = triton.next_power_of_2(hidden_size) + num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16) + geglu_fwd_kernel[(M,)]( + x_2d, out, + x_2d.stride(0), out.stride(0), hidden_size, + BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2, + ) + return out.view(*batch_dims, hidden_size) + + +def pytorch_geglu(x): + hidden_size = x.shape[-1] // 2 + gate, value = x[..., :hidden_size], x[..., hidden_size:] + return torch.nn.functional.gelu(gate, approximate='tanh') * value + + +# ============================================================================ +# Kernel 4: AdaLN +# ============================================================================ +# Same BLOCK_D fix: compute dynamically. + +@triton.jit +def adaln_fwd_kernel( + x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr, + stride_x, stride_cond, D, + eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + x_norm = x * rms_inv + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32) + + out = x_norm * w * (1.0 + scale) + shift + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_adaln(x, weight, scale, shift, eps=1e-6): + x_flat = x.contiguous().view(-1, x.shape[-1]) + scale_flat = scale.contiguous().view(-1, x.shape[-1]) + shift_flat = shift.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_flat) + M, D = x_flat.shape + + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + adaln_fwd_kernel[(M,)]( + x_flat, weight, scale_flat, shift_flat, out, + x_flat.stride(0), scale_flat.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +def pytorch_adaln(x, weight, scale, shift, eps=1e-6): + variance = x.pow(2).mean(dim=-1, keepdim=True) + x_norm = x * torch.rsqrt(variance + eps) + return x_norm * weight * (1.0 + scale) + shift + + +# ============================================================================ +# Benchmark Utilities +# ============================================================================ + +def benchmark_fn(func, args, warmup=20, iterations=100) -> Tuple[float, float]: + for _ in range(warmup): + func(*args) + torch.xpu.synchronize() + + times = [] + for _ in range(iterations): + torch.xpu.synchronize() + start = time.perf_counter() + func(*args) + torch.xpu.synchronize() + end = time.perf_counter() + times.append((end - start) * 1000) + + return sum(times) / len(times), min(times) + + +def check_correctness(out, ref, name, dtype): + max_abs = (out.float() - ref.float()).abs().max().item() + max_rel = ((out.float() - ref.float()).abs() / (ref.float().abs() + 1e-8)).max().item() + + # BF16 has 7-bit mantissa; for values ~8-16 the ULP is 0.0625-0.125 + # FP16 has 10-bit mantissa; tighter but RoPE trig ops can accumulate 1-2 ULP error + atol = 0.15 if dtype == torch.bfloat16 else 0.02 + passed = max_abs < atol + status = "PASS" if passed else "FAIL" + print(f" [{status}] {name}: max_abs={max_abs:.6e}, max_rel={max_rel:.6e}") + return passed + + +# ============================================================================ +# Benchmark Runners +# ============================================================================ + +def benchmark_rmsnorm(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RMSNorm (168 instances in LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (4, 1024, 2048), + (1, 4096, 2048), + (2, 4096, 3072), + (1, 8192, 2048), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="xpu") + w = torch.ones(hidden, dtype=dtype, device="xpu") + + ref = pytorch_rmsnorm(x, w) + out = triton_rmsnorm(x, w) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + p_avg, _ = benchmark_fn(pytorch_rmsnorm, (x, w)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + # No-weight variant + print("\n -- No-weight variant (elementwise_affine=False) --") + x = torch.randn(2, 4096, 2048, dtype=dtype, device="xpu") + ref_nw = pytorch_rmsnorm(x, None) + out_nw = triton_rmsnorm(x, None) + check_correctness(out_nw, ref_nw, "no-weight [2x4096x2048]", dtype) + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + + # Bandwidth analysis + batch, seq, hidden = 4, 4096, 3072 + x = torch.randn(batch, seq, hidden, dtype=dtype, device="xpu") + w = torch.ones(hidden, dtype=dtype, device="xpu") + bytes_per_elem = 2 if dtype in (torch.float16, torch.bfloat16) else 4 + total_bytes = batch * seq * hidden * bytes_per_elem * 2 + hidden * bytes_per_elem + t_avg, _ = benchmark_fn(triton_rmsnorm, (x, w)) + bw_gbps = (total_bytes / 1e9) / (t_avg / 1000) + print(f"\n Bandwidth analysis [{batch}x{seq}x{hidden}]:") + print(f" Data moved: {total_bytes / 1e6:.2f} MB") + print(f" Achieved: {bw_gbps:.1f} GB/s") + + return all_correct, avg_speedup + + +def benchmark_rope(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: RoPE 3D (Video Position Encoding)") + print("=" * 70) + + configs = [ + (1, 1024, 16, 64), + (1, 4096, 16, 64), + (2, 4096, 16, 128), + (1, 8192, 32, 64), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, heads, hdim in configs: + qk = torch.randn(batch, seq, heads, hdim, dtype=dtype, device="xpu") + cos = torch.randn(seq, hdim, dtype=dtype, device="xpu") + sin = torch.randn(seq, hdim, dtype=dtype, device="xpu") + + ref = pytorch_rope(qk, cos, sin) + out = triton_rope_3d(qk, cos, sin) + if not check_correctness(out, ref, f"[{batch}x{seq}x{heads}x{hdim}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_rope_3d, (qk, cos, sin)) + p_avg, _ = benchmark_fn(pytorch_rope, (qk, cos, sin)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{heads}x{hdim}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_geglu(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: GEGLU (For SD3/FLUX, NOT LTX-Video)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 4096), + (2, 4096, 3072), + (4, 4096, 4096), + ] + + print(f"\n{'Config':<30} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 75) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden * 2, dtype=dtype, device="xpu") + + ref = pytorch_geglu(x) + out = triton_geglu(x) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden*2}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_geglu, (x,)) + p_avg, _ = benchmark_fn(pytorch_geglu, (x,)) + speedup = p_avg / t_avg + total_speedup += speedup + + cfg = f"[{batch}x{seq}x{hidden*2}->{hidden}]" + print(f" {cfg:<28} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +def benchmark_adaln(dtype): + print("\n" + "=" * 70) + print("BENCHMARK: AdaLN (Fused Norm + Conditioning for DiT)") + print("=" * 70) + + configs = [ + (1, 1024, 2048), + (2, 1024, 2048), + (2, 4096, 3072), + (4, 4096, 3072), + ] + + print(f"\n{'Config':<25} {'Triton (ms)':<15} {'PyTorch (ms)':<15} {'Speedup':<10}") + print("-" * 70) + + all_correct = True + total_speedup = 0 + + for batch, seq, hidden in configs: + x = torch.randn(batch, seq, hidden, dtype=dtype, device="xpu") + w = torch.ones(hidden, dtype=dtype, device="xpu") + scale = torch.randn(batch, seq, hidden, dtype=dtype, device="xpu") * 0.1 + shift = torch.randn(batch, seq, hidden, dtype=dtype, device="xpu") * 0.1 + + ref = pytorch_adaln(x, w, scale, shift) + out = triton_adaln(x, w, scale, shift) + if not check_correctness(out, ref, f"[{batch}x{seq}x{hidden}]", dtype): + all_correct = False + + t_avg, _ = benchmark_fn(triton_adaln, (x, w, scale, shift)) + p_avg, _ = benchmark_fn(pytorch_adaln, (x, w, scale, shift)) + speedup = p_avg / t_avg + total_speedup += speedup + + print(f" [{batch}x{seq}x{hidden}]{'':<13} {t_avg:>10.3f} {p_avg:>10.3f} {speedup:>7.2f}x") + + avg_speedup = total_speedup / len(configs) + print(f"\n Average speedup: {avg_speedup:.2f}x") + return all_correct, avg_speedup + + +# ============================================================================ +# Main +# ============================================================================ + +def main(): + parser = argparse.ArgumentParser(description="Benchmark Triton kernels on XPU") + parser.add_argument("--kernel", type=str, default="all", + choices=["all", "rmsnorm", "rope", "geglu", "adaln"]) + parser.add_argument("--dtype", type=str, default="bfloat16", + choices=["bfloat16", "float16"]) + args = parser.parse_args() + + dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 + + print("=" * 70) + print("XPU Triton Kernel Micro-Benchmark") + print("=" * 70) + print(f"Device: {torch.xpu.get_device_name(0)}") + print(f"Dtype: {dtype}") + + results = {} + runners = { + "rmsnorm": benchmark_rmsnorm, + "rope": benchmark_rope, + "geglu": benchmark_geglu, + "adaln": benchmark_adaln, + } + + if args.kernel == "all": + for name, runner in runners.items(): + correct, speedup = runner(dtype) + results[name] = {"correct": correct, "speedup": speedup} + else: + correct, speedup = runners[args.kernel](dtype) + results[args.kernel] = {"correct": correct, "speedup": speedup} + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + print(f"{'Kernel':<15} {'Correct':<12} {'Avg Speedup':<15}") + print("-" * 42) + for name, r in results.items(): + status = "PASS" if r["correct"] else "FAIL" + print(f"{name:<15} {status:<12} {r['speedup']:.2f}x") + + all_pass = all(r["correct"] for r in results.values()) + print(f"\nOverall: {'ALL PASS' if all_pass else 'SOME FAILED'}") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/skills/xpu-kernels/scripts/config.py b/kernel-builder/skills/xpu-kernels/scripts/config.py new file mode 100644 index 00000000..08c23ee4 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/config.py @@ -0,0 +1,22 @@ +"""Shared config loader — reads config.yaml from project root.""" + +from pathlib import Path + +import yaml + +_CONFIG_DIR = Path(__file__).resolve().parent +_DEFAULTS = { + "max_trials": 10, + "vtune_enabled": True, + "vtune_bin": "/bin64/vtune", +} + + +def load_config() -> dict: + """Load config.yaml, falling back to defaults for missing keys.""" + config_path = _CONFIG_DIR / "config.yaml" + cfg = {} + if config_path.exists(): + with open(config_path) as f: + cfg = yaml.safe_load(f) or {} + return {**_DEFAULTS, **cfg} diff --git a/kernel-builder/skills/xpu-kernels/scripts/config.yaml b/kernel-builder/skills/xpu-kernels/scripts/config.yaml new file mode 100644 index 00000000..ae8e087d --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/config.yaml @@ -0,0 +1,4 @@ +# Project configuration — edit these values to control optimization sessions. +max_trials: 10 # Maximum number of optimization trials (3-20) +vtune_enabled: true # Set to false to skip VTune profiling entirely +vtune_bin: "/bin64/vtune" # Path to VTune binary diff --git a/kernel-builder/skills/xpu-kernels/scripts/huggingface_kernels_example.py b/kernel-builder/skills/xpu-kernels/scripts/huggingface_kernels_example.py new file mode 100644 index 00000000..4315aefd --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/huggingface_kernels_example.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +""" +Example: Using HuggingFace Kernels library to load and use optimized kernels on XPU. + +This script demonstrates how to: +1. Load kernels from the HuggingFace Hub using get_kernel() +2. Check kernel availability with has_kernel() +3. Integrate Hub kernels with transformers/diffusers models +4. Fall back to local Triton kernels when Hub builds are unavailable + +Requirements: + python -m pip install -r scripts/requirements.txt + +Usage: + python scripts/huggingface_kernels_example.py +""" + +import os +import time +from typing import Optional + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================= +# Local Triton RMSNorm (fallback when Hub kernel unavailable) +# ============================================================================= + +EPS_DEFAULT = 1e-6 + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def local_triton_rmsnorm(x, weight, eps=EPS_DEFAULT): + """Local Triton RMSNorm — used as fallback when Hub kernel is unavailable.""" + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================= +# Part 1: Check Environment +# ============================================================================= + +def check_environment(): + """Print environment information for debugging.""" + print("=" * 60) + print("Environment") + print("=" * 60) + print(f"PyTorch: {torch.__version__}") + print(f"XPU available: {torch.xpu.is_available()}") + if torch.xpu.is_available(): + print(f"GPU: {torch.xpu.get_device_name()}") + print() + + +# ============================================================================= +# Part 2: Basic Kernel Loading from Hub +# ============================================================================= + +def demo_basic_kernel_loading(): + """Demonstrate basic kernel loading from Hub.""" + print("=" * 60) + print("Part 1: Basic Kernel Loading from Hub") + print("=" * 60) + + try: + from kernels import get_kernel, has_kernel + + repo_id = "kernels-community/triton-layer-norm" + + print(f"\n1. Checking kernel availability: {repo_id}") + if has_kernel(repo_id): + print(" Kernel is available for this XPU environment") + + print(f"\n2. Loading kernel from Hub...") + kernel = get_kernel(repo_id) + + print(f"\n3. Available functions:") + functions = [f for f in dir(kernel) if not f.startswith('_')] + for func in functions[:10]: + print(f" - {func}") + + print(f"\n4. Testing RMSNorm kernel...") + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="xpu") + w = torch.ones(2048, dtype=torch.bfloat16, device="xpu") + + rms_fn_name = None + for name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(kernel, name): + rms_fn_name = name + break + + if rms_fn_name: + rms_fn = getattr(kernel, rms_fn_name) + try: + out = rms_fn(x, w, eps=1e-6) + except TypeError: + # rms_norm_fn(x, weight, bias, ...) requires bias argument + out = rms_fn(x, w, None, eps=1e-6) + print(f" Using: kernel.{rms_fn_name}()") + print(f" Input: {x.shape}, Output: {out.shape}") + print(f" Success!") + else: + print(f" No RMSNorm function found. Available: {functions}") + + return kernel + else: + print(" No compatible build for this XPU environment") + print(" Will use local Triton kernel as fallback") + return None + + except ImportError: + print("\n kernels library not installed. Install with: pip install kernels") + return None + except Exception as e: + print(f"\n Error: {e}") + return None + + +# ============================================================================= +# Part 3: Benchmark Hub Kernel vs Local Triton vs PyTorch +# ============================================================================= + +def demo_benchmark(hub_kernel): + """Benchmark Hub kernel vs local Triton vs PyTorch.""" + print("\n" + "=" * 60) + print("Part 2: Benchmark Hub vs Local Triton vs PyTorch") + print("=" * 60) + + shapes = [(2, 1024, 2048), (4, 4096, 3072)] + warmup, iterations = 20, 100 + + for shape in shapes: + x = torch.randn(shape, dtype=torch.bfloat16, device="xpu") + w = torch.ones(shape[-1], dtype=torch.bfloat16, device="xpu") + + def _call_hub(fn, x, w, eps): + try: + return fn(x, w, eps=eps) + except TypeError: + return fn(x, w, None, eps=eps) + + hub_rms_fn_raw = None + if hub_kernel: + for fn_name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(hub_kernel, fn_name): + hub_rms_fn_raw = getattr(hub_kernel, fn_name) + break + + # Warmup all implementations + for _ in range(warmup): + local_triton_rmsnorm(x, w, eps=1e-6) + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + if hub_rms_fn_raw: + _call_hub(hub_rms_fn_raw, x, w, 1e-6) + torch.xpu.synchronize() + + # PyTorch baseline + start = time.perf_counter() + for _ in range(iterations): + variance = x.pow(2).mean(-1, keepdim=True) + _ = x * torch.rsqrt(variance + 1e-6) * w + torch.xpu.synchronize() + pt_ms = (time.perf_counter() - start) / iterations * 1000 + + # Local Triton + start = time.perf_counter() + for _ in range(iterations): + local_triton_rmsnorm(x, w, eps=1e-6) + torch.xpu.synchronize() + local_ms = (time.perf_counter() - start) / iterations * 1000 + + print(f"\n Shape {shape}:") + print(f" PyTorch: {pt_ms:.4f} ms") + print(f" Local Triton: {local_ms:.4f} ms (speedup: {pt_ms/local_ms:.2f}x)") + + if hub_rms_fn_raw: + start = time.perf_counter() + for _ in range(iterations): + _call_hub(hub_rms_fn_raw, x, w, 1e-6) + torch.xpu.synchronize() + hub_ms = (time.perf_counter() - start) / iterations * 1000 + print(f" Hub kernel: {hub_ms:.4f} ms (speedup: {pt_ms/hub_ms:.2f}x)") + + +# ============================================================================= +# Part 4: Model Integration with Fallback +# ============================================================================= + +def demo_model_integration(hub_kernel): + """Demonstrate integrating kernels with models, with fallback.""" + print("\n" + "=" * 60) + print("Part 3: Model Integration with Fallback") + print("=" * 60) + + class SimpleModel(nn.Module): + def __init__(self, hidden_size=2048): + super().__init__() + self.norm = nn.RMSNorm(hidden_size) + self.linear = nn.Linear(hidden_size, hidden_size) + + def forward(self, x): + return self.linear(self.norm(x)) + + model = SimpleModel().to("xpu").to(torch.bfloat16) + + # Decide which RMSNorm to use + hub_rms_fn = None + if hub_kernel: + for fn_name in ('rms_norm', 'rms_norm_fn', 'rmsnorm'): + if hasattr(hub_kernel, fn_name): + hub_rms_fn = getattr(hub_kernel, fn_name) + break + + if hub_rms_fn: + def _hub_rmsnorm(x, w, eps): + try: + return hub_rms_fn(x, w, eps=eps) + except TypeError: + return hub_rms_fn(x, w, None, eps=eps) + rmsnorm_fn = _hub_rmsnorm + source = "Hub kernel" + else: + rmsnorm_fn = local_triton_rmsnorm + source = "Local Triton" + + print(f"\n1. Using {source} for RMSNorm") + + # Patch model + for name, module in model.named_modules(): + if isinstance(module, nn.RMSNorm): + raw_eps = getattr(module, 'eps', None) + eps = float(raw_eps) if raw_eps is not None else 1e-6 + + def make_forward(mod, epsilon, fn): + def forward(x): + return fn(x, mod.weight, epsilon) + return forward + + module.forward = make_forward(module, eps, rmsnorm_fn) + print(f" Patched: {name} (eps={eps})") + + # Test + print(f"\n2. Testing forward pass...") + x = torch.randn(2, 1024, 2048, dtype=torch.bfloat16, device="xpu") + with torch.inference_mode(): + y = model(x) + print(f" Input: {x.shape} -> Output: {y.shape}") + print(f" Success!") + + +# ============================================================================= +# Part 5: Publishing Info +# ============================================================================= + +def demo_publishing_info(): + """Show information about publishing kernels to Hub.""" + print("\n" + "=" * 60) + print("Part 4: Publishing Triton Kernels to Hub") + print("=" * 60) + + print(""" + For Triton kernels (best XPU compatibility): + + 1. Create project structure: + my-triton-kernel/ + ├── build.toml + ├── kernel_src/ + │ └── rmsnorm.py # Triton kernel + └── torch-ext/ + ├── torch_binding.cpp + └── my_kernels/__init__.py + + 2. Configure build.toml with XPU support: + [general] + name = "my_kernels" + backends = ["cuda", "xpu"] + + 3. Build and publish: + $ pip install kernel-builder + $ kernel-builder build + $ huggingface-cli upload my-username/my-kernel ./dist + + See: https://huggingface.co/docs/kernels + """) + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + print("=" * 60) + print("HuggingFace Kernels Integration Example (XPU)") + print("=" * 60) + + check_environment() + + if not torch.xpu.is_available(): + print("GPU not available. This example requires an Intel GPU with XPU support.") + return + + hub_kernel = demo_basic_kernel_loading() + demo_benchmark(hub_kernel) + demo_model_integration(hub_kernel) + demo_publishing_info() + + print("\n" + "=" * 60) + print("Done!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/skills/xpu-kernels/scripts/requirements.txt b/kernel-builder/skills/xpu-kernels/scripts/requirements.txt new file mode 100644 index 00000000..05aba8cf --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/requirements.txt @@ -0,0 +1,5 @@ +transformers +safetensors +huggingface-hub +kernels +ai-bench[xpu] @ git+https://github.com/libxsmm/AI-bench.git diff --git a/kernel-builder/skills/xpu-kernels/scripts/transformers_injection_example.py b/kernel-builder/skills/xpu-kernels/scripts/transformers_injection_example.py new file mode 100644 index 00000000..6492e7ce --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/transformers_injection_example.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +""" +Minimal example: Inject custom Triton kernels into HuggingFace Transformers models on XPU. + +This script demonstrates the essential pattern for integrating custom Triton kernels +with transformers models like LLaMA, Mistral, and Qwen on Intel XPU GPUs. + +Key lessons: +1. Transformers RMSNorm modules always have weights (unlike some diffusers modules) +2. Use 'RMSNorm' substring match to catch LlamaRMSNorm, MistralRMSNorm, etc. +3. Check for 'variance_epsilon' (LLaMA) or 'eps' (others) for epsilon value +4. Use Flash Attention 2 for attention optimization instead of custom processors + +Usage: + python scripts/transformers_injection_example.py + +Requirements: + python -m pip install -r scripts/requirements.txt +""" + +import sys +import time + +import torch +import torch.nn as nn +import triton +import triton.language as tl + + +# ============================================================================= +# Triton RMSNorm Kernel +# ============================================================================= + +@triton.jit +def rmsnorm_fwd_kernel( + x_ptr, weight_ptr, out_ptr, + stride_x, D, eps, + BLOCK_D: tl.constexpr, +): + row = tl.program_id(0) + offs = tl.arange(0, BLOCK_D) + mask = offs < D + x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32) + + variance = tl.sum(x * x, axis=0) / D + rms_inv = tl.rsqrt(variance + eps) + + w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32) + out = x * rms_inv * w + + tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask) + + +def triton_rmsnorm(x, weight, eps=1e-6): + x_2d = x.contiguous().view(-1, x.shape[-1]) + out = torch.empty_like(x_2d) + M, D = x_2d.shape + BLOCK_D = triton.next_power_of_2(D) + num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16) + rmsnorm_fwd_kernel[(M,)]( + x_2d, weight, out, x_2d.stride(0), D, float(eps), + BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2, + ) + return out.view_as(x) + + +# ============================================================================= +# RMSNorm Module Patcher +# ============================================================================= + +def patch_rmsnorm_modules(model: nn.Module) -> int: + """ + Patch all RMSNorm modules to use Triton kernel on XPU. + + Works with LlamaRMSNorm, MistralRMSNorm, Qwen2RMSNorm, etc. + Unlike diffusers, transformers RMSNorm always has weights. + """ + patched_count = 0 + + for name, module in model.named_modules(): + class_name = type(module).__name__ + + if 'RMSNorm' in class_name: + eps = getattr(module, 'variance_epsilon', None) + if eps is None: + eps = getattr(module, 'eps', 1e-6) + + has_weight = hasattr(module, 'weight') and module.weight is not None + + if has_weight: + def make_patched_forward(mod, epsilon): + def patched_forward(hidden_states): + return triton_rmsnorm(hidden_states, mod.weight, eps=epsilon) + return patched_forward + module.forward = make_patched_forward(module, eps) + patched_count += 1 + else: + print(f"WARNING: {name} has no weight, skipping") + + return patched_count + + +def inject_optimized_kernels(model) -> dict: + """Inject custom Triton kernels into a transformers model.""" + stats = {'rmsnorm_modules': 0} + stats['rmsnorm_modules'] = patch_rmsnorm_modules(model) + return stats + + +# ============================================================================= +# Main +# ============================================================================= + +def main(): + from transformers import AutoModelForCausalLM, AutoTokenizer + + print("=" * 60) + print("Transformers Triton Kernel Injection (XPU)") + print("=" * 60) + + print(f"\nXPU available: {torch.xpu.is_available()}") + print(f"GPU: {torch.xpu.get_device_name()}") + + model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + + print(f"\n1. Loading model: {model_id}...") + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + device_map="xpu" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + + rmsnorm_count = sum(1 for _, m in model.named_modules() if 'RMSNorm' in type(m).__name__) + print(f" Found {rmsnorm_count} RMSNorm modules") + + print("\n2. Injecting optimized Triton kernels...") + stats = inject_optimized_kernels(model) + print(f" RMSNorm modules patched: {stats['rmsnorm_modules']}") + + print("\n3. Verifying injection...") + x = torch.randn(1, 10, model.config.hidden_size, device='xpu', dtype=torch.bfloat16) + for name, module in model.named_modules(): + if 'RMSNorm' in type(module).__name__: + out = module(x) + print(f" RMSNorm forward pass: {x.shape} -> {out.shape}") + break + + print("\n4. Running generation test...") + prompt = "The capital of France is" + inputs = tokenizer(prompt, return_tensors="pt").to("xpu") + + with torch.inference_mode(): + _ = model.generate(**inputs, max_new_tokens=5, do_sample=False) + + num_tokens = 50 + start_time = time.perf_counter() + with torch.inference_mode(): + outputs = model.generate( + **inputs, + max_new_tokens=num_tokens, + do_sample=False, + pad_token_id=tokenizer.eos_token_id + ) + end_time = time.perf_counter() + + elapsed = end_time - start_time + tokens_per_second = num_tokens / elapsed + + print(f" Prompt: {prompt}") + print(f" Output: {tokenizer.decode(outputs[0], skip_special_tokens=True)}") + print(f" Generated {num_tokens} tokens in {elapsed:.2f}s ({tokens_per_second:.1f} tokens/s)") + + print("\n" + "=" * 60) + print("Success! Custom Triton kernels are being used on XPU.") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/skills/xpu-kernels/scripts/trial_manager.py b/kernel-builder/skills/xpu-kernels/scripts/trial_manager.py new file mode 100644 index 00000000..c1237ee4 --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/trial_manager.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +"""Trial Tree State Manager for iterative kernel optimization. + +Manages a tree of optimization trials for each kernel, tracking parent-child +relationships, strategies, correctness, and speedup results. Supports +branching back to the best ancestor when a trial regresses. + +Usage: + python scripts/trial_manager.py init + python scripts/trial_manager.py save --parent --strategy "description" + python scripts/trial_manager.py result --validation --correctness --speedup --baseline_us --triton_us + python scripts/trial_manager.py status + python scripts/trial_manager.py best + python scripts/trial_manager.py finalize +""" + +import argparse +import json +import os +import shutil +import sys + +TRIALS_DIR = os.path.join(os.getcwd(), "trials") +OUTPUT_DIR = os.path.join(os.getcwd(), "output") + + +def _state_path(kernel_name): + return os.path.join(TRIALS_DIR, kernel_name, "state.json") + + +def _trial_dir(kernel_name): + return os.path.join(TRIALS_DIR, kernel_name) + + +def _load_state(kernel_name): + path = _state_path(kernel_name) + if not os.path.exists(path): + print(f"Error: No trial tree found for '{kernel_name}'. Run 'init' first.", file=sys.stderr) + sys.exit(1) + with open(path) as f: + state = json.load(f) + # Backward compat: old state files lack baseline_type + state.setdefault("baseline_type", "pytorch") + # Backward compat: rename pytorch_us -> baseline_us in trials + for trial in state.get("trials", {}).values(): + if "pytorch_us" in trial and "baseline_us" not in trial: + trial["baseline_us"] = trial.pop("pytorch_us") + return state + + +def _save_state(kernel_name, state): + path = _state_path(kernel_name) + with open(path, "w") as f: + json.dump(state, f, indent=2) + + +# ============================================================================ +# Commands +# ============================================================================ + + +def cmd_init(args): + """Initialize a new trial tree for a kernel.""" + kernel_name = args.kernel_name + pytorch_file = args.pytorch_file + + trial_dir = _trial_dir(kernel_name) + if os.path.exists(_state_path(kernel_name)): + print( + f"Warning: Trial tree for '{kernel_name}' already exists. Use a different name or delete trials/{kernel_name}/." + ) + # sys.exit(1) + else: + os.makedirs(trial_dir, exist_ok=True) + + baseline_type = "triton" if args.triton_baseline else "pytorch" + state = { + "kernel_name": kernel_name, + "pytorch_file": pytorch_file, + "baseline_type": baseline_type, + "trials": {}, + "best_trial": None, + "next_id": 0, + "baseline_us": None, + } + _save_state(kernel_name, state) + baseline_label = "Triton" if baseline_type == "triton" else "PyTorch" + print(f"Initialized trial tree for '{kernel_name}' in trials/{kernel_name}/") + print(f" Baseline ({baseline_label}): {pytorch_file}") + + +def cmd_save(args): + """Save a trial by copying the kernel file into the trial directory.""" + kernel_name = args.kernel_name + trial_file = args.trial_file + parent = args.parent # None or "t0", "t1", etc. + strategy = args.strategy or "" + + if not os.path.exists(trial_file): + print(f"Error: Trial file '{trial_file}' not found.", file=sys.stderr) + sys.exit(1) + + state = _load_state(kernel_name) + + # Validate parent exists (gracefully handle first trial with mistaken --parent) + if parent is not None and parent not in state["trials"]: + if state["next_id"] == 0: + print( + f"Warning: Ignoring --parent '{parent}' for first trial (no trials exist yet).", + file=sys.stderr, + ) + parent = None + else: + print( + f"Error: Parent trial '{parent}' not found. Available: {list(state['trials'].keys())}", + file=sys.stderr, + ) + sys.exit(1) + + trial_id = f"t{state['next_id']}" + state["next_id"] += 1 + + # Copy file + dest = os.path.join(_trial_dir(kernel_name), f"{trial_id}.py") + try: + shutil.copy2(trial_file, dest) + except Exception as e: + print(f"File already written", file=sys.stderr) + + state["trials"][trial_id] = { + "parent": parent, + "file": f"{trial_id}.py", + "strategy": strategy, + "validation": None, + "correctness": None, + "speedup": None, + "baseline_us": None, + "triton_us": None, + "status": "saved", + } + _save_state(kernel_name, state) + print(f"Saved trial {trial_id}: {strategy}") + print(f" Parent: {parent or 'root'}") + print(f" File: trials/{kernel_name}/{trial_id}.py") + + +def cmd_result(args): + """Record results for a trial.""" + kernel_name = args.kernel_name + trial_id = args.trial_id + + state = _load_state(kernel_name) + + if trial_id not in state["trials"]: + args.trial_file = os.path.join(_trial_dir(kernel_name), f"{trial_id}.py") + args.parent = None # consider changing + args.strategy = None + print(f"Error: Trial '{trial_id}' not found. Saving state", file=sys.stderr) + cmd_save(args) # Auto-save if trial doesn't exist + state = _load_state(kernel_name) # Reload state after saving + # print(f"Error: Trial '{trial_id}' not found. Available: {list(state['trials'].keys())}", file=sys.stderr) + # sys.exit(1) + + trial = state["trials"][trial_id] + + if args.validation: + trial["validation"] = args.validation + if args.correctness: + trial["correctness"] = args.correctness + if args.speedup is not None: + trial["speedup"] = args.speedup + if args.baseline_us is not None: + trial["baseline_us"] = args.baseline_us + if args.triton_us is not None: + trial["triton_us"] = args.triton_us + + # Cache baseline_us at kernel level on first recording (for --baseline-us skip on later trials) + if args.baseline_us is not None and state.get("baseline_us") is None: + state["baseline_us"] = [args.baseline_us] + + # Update status + if trial["validation"] == "fail" or trial["correctness"] == "fail": + trial["status"] = "failed" + elif trial["correctness"] == "pass" and trial["speedup"] is not None: + trial["status"] = "completed" + else: + trial["status"] = "partial" + + # Update best trial (highest speedup among correct trials) + best_speedup = -1.0 + best_id = None + for tid, t in state["trials"].items(): + if t.get("correctness") == "pass" and t.get("speedup") is not None: + if t["speedup"] > best_speedup: + best_speedup = t["speedup"] + best_id = tid + state["best_trial"] = best_id + + _save_state(kernel_name, state) + + status_icon = {"completed": "+", "failed": "X", "partial": "~", "saved": "?"} + icon = status_icon.get(trial["status"], "?") + runtime_str = "" + if trial.get("baseline_us") is not None and trial.get("triton_us") is not None: + runtime_str = f", baseline={trial['baseline_us']:.2f}us, triton={trial['triton_us']:.2f}us" + print( + f"[{icon}] {trial_id}: validation={trial['validation']}, correctness={trial['correctness']}, speedup={trial['speedup']}{runtime_str}" + ) + if state["best_trial"]: + best = state["trials"][state["best_trial"]] + best_runtime = "" + if best.get("baseline_us") is not None and best.get("triton_us") is not None: + best_runtime = ( + f", baseline={best['baseline_us']:.2f}us, triton={best['triton_us']:.2f}us" + ) + print(f" Best trial: {state['best_trial']} ({best['speedup']}x{best_runtime})") + + +def cmd_status(args): + """Show trial tree status as ASCII tree.""" + kernel_name = args.kernel_name + state = _load_state(kernel_name) + + baseline_label = "Triton" if state.get("baseline_type") == "triton" else "PyTorch" + print(f"Trial tree: {state['kernel_name']}") + print(f" Baseline ({baseline_label}): {state['pytorch_file']}") + print(f" Best: {state['best_trial'] or 'none'}") + print(f" Trials: {len(state['trials'])}") + print() + + if not state["trials"]: + print(" (no trials yet)") + return + + # Build children map + children = {} + roots = [] + for tid, t in state["trials"].items(): + parent = t["parent"] + if parent is None: + roots.append(tid) + else: + children.setdefault(parent, []).append(tid) + + # Sort by trial number + def sort_key(tid): + return int(tid[1:]) + + roots.sort(key=sort_key) + for k in children: + children[k].sort(key=sort_key) + + # Print tree recursively + def print_node(tid, prefix="", is_last=True): + trial = state["trials"][tid] + connector = "└── " if is_last else "├── " + + # Status indicators + is_best = tid == state["best_trial"] + status_icon = {"completed": "+", "failed": "X", "partial": "~", "saved": "?"} + icon = status_icon.get(trial["status"], "?") + + speedup_str = f"{trial['speedup']:.2f}x" if trial["speedup"] is not None else "---" + runtime_str = "" + if trial.get("baseline_us") is not None and trial.get("triton_us") is not None: + runtime_str = f" (bl={trial['baseline_us']:.0f}us, tr={trial['triton_us']:.0f}us)" + best_marker = " <<<< BEST" if is_best else "" + strategy_short = trial["strategy"][:60] if trial["strategy"] else "" + + print( + f"{prefix}{connector}[{icon}] {tid}: {speedup_str}{runtime_str} | {strategy_short}{best_marker}" + ) + + child_prefix = prefix + (" " if is_last else "│ ") + kids = children.get(tid, []) + for i, child in enumerate(kids): + print_node(child, child_prefix, i == len(kids) - 1) + + for i, root in enumerate(roots): + print_node(root, " ", i == len(roots) - 1) + + +def cmd_best(args): + """Get the best trial info.""" + kernel_name = args.kernel_name + state = _load_state(kernel_name) + + if state["best_trial"] is None: + print("No correct trials yet.") + sys.exit(1) + + best_id = state["best_trial"] + best = state["trials"][best_id] + best_file = os.path.join(_trial_dir(kernel_name), best["file"]) + + print(f"best_trial: {best_id}") + print(f"speedup: {best['speedup']}") + if best.get("baseline_us") is not None: + print(f"baseline_us: {best['baseline_us']}") + if best.get("triton_us") is not None: + print(f"triton_us: {best['triton_us']}") + print(f"strategy: {best['strategy']}") + print(f"file: {best_file}") + print(f"parent: {best['parent'] or 'root'}") + + +def cmd_baseline_us(args): + """Print cached baseline time(s) as comma-separated floats.""" + kernel_name = args.kernel_name + state = _load_state(kernel_name) + + baseline_us = state.get("baseline_us") + if baseline_us is None: + print( + "No baseline_us cached yet. Run benchmark and record result for t0 first.", + file=sys.stderr, + ) + sys.exit(1) + + print(",".join(f"{v:.2f}" for v in baseline_us)) + + +def cmd_finalize(args): + """Copy the best correct trial to the output path. + + If output_file has no directory component it is placed inside OUTPUT_DIR + (``output/`` at the project root) which is created automatically. + """ + kernel_name = args.kernel_name + output_file = args.output_file + + state = _load_state(kernel_name) + + if state["best_trial"] is None: + print("Error: No correct trials to finalize.", file=sys.stderr) + sys.exit(1) + + best_id = state["best_trial"] + best = state["trials"][best_id] + src = os.path.join(_trial_dir(kernel_name), best["file"]) + + # Default bare filenames into output/ + if os.path.dirname(output_file) == "": + os.makedirs(OUTPUT_DIR, exist_ok=True) + output_file = os.path.join(OUTPUT_DIR, output_file) + + shutil.copy2(src, output_file) + runtime_str = "" + if best.get("baseline_us") is not None and best.get("triton_us") is not None: + runtime_str = f", baseline={best['baseline_us']:.2f}us, triton={best['triton_us']:.2f}us" + print(f"Finalized {best_id} ({best['speedup']}x{runtime_str}) -> {output_file}") + print(f" Strategy: {best['strategy']}") + + +# ============================================================================ +# CLI +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser(description="Trial Tree State Manager") + subparsers = parser.add_subparsers(dest="command", required=True) + + # init + p_init = subparsers.add_parser("init", help="Initialize trial tree") + p_init.add_argument("kernel_name", help="Kernel identifier (e.g. 39_Gemm_Scale_BatchNorm)") + p_init.add_argument("pytorch_file", help="Path to baseline file (PyTorch or Triton)") + p_init.add_argument( + "--triton-baseline", + action="store_true", + help="Baseline is a Triton kernel (default: PyTorch)", + ) + + # save + p_save = subparsers.add_parser("save", help="Save a trial") + p_save.add_argument("kernel_name", help="Kernel identifier") + p_save.add_argument("trial_file", help="Path to the trial kernel file") + p_save.add_argument("--parent", default=None, help="Parent trial ID (e.g. t0)") + p_save.add_argument("--strategy", default="", help="Description of optimization strategy") + + # result + p_result = subparsers.add_parser("result", help="Record trial results") + p_result.add_argument("kernel_name", help="Kernel identifier") + p_result.add_argument("trial_id", help="Trial ID (e.g. t0)") + p_result.add_argument("--validation", choices=["pass", "fail"], help="Validation result") + p_result.add_argument("--correctness", choices=["pass", "fail"], help="Correctness result") + p_result.add_argument("--speedup", type=float, help="Speedup over baseline") + p_result.add_argument( + "--baseline_us", + type=float, + help="Baseline runtime in microseconds (PyTorch or Triton baseline)", + ) + p_result.add_argument("--triton_us", type=float, help="Triton kernel runtime in microseconds") + + # status + p_status = subparsers.add_parser("status", help="Show trial tree status") + p_status.add_argument("kernel_name", help="Kernel identifier") + + # best + p_best = subparsers.add_parser("best", help="Get best trial info") + p_best.add_argument("kernel_name", help="Kernel identifier") + + # baseline-us + p_baseline_us = subparsers.add_parser("baseline-us", help="Print cached baseline time(s)") + p_baseline_us.add_argument("kernel_name", help="Kernel identifier") + + # finalize + p_finalize = subparsers.add_parser("finalize", help="Copy best trial to output") + p_finalize.add_argument("kernel_name", help="Kernel identifier") + p_finalize.add_argument( + "output_file", help="Output file path (bare filename defaults to output/)" + ) + + args = parser.parse_args() + + commands = { + "init": cmd_init, + "save": cmd_save, + "result": cmd_result, + "status": cmd_status, + "best": cmd_best, + "baseline-us": cmd_baseline_us, + "finalize": cmd_finalize, + } + commands[args.command](args) + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/skills/xpu-kernels/scripts/validate_triton.py b/kernel-builder/skills/xpu-kernels/scripts/validate_triton.py new file mode 100755 index 00000000..53424c8a --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/validate_triton.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +""" +Validate Triton kernel for common XPU optimization issues. + +Usage: + python scripts/validate_triton.py +""" + +import re +import sys +from pathlib import Path +from typing import List, Tuple + + +class ValidationError: + def __init__(self, level: str, message: str, line_num: int = None): + self.level = level # 'ERROR', 'WARNING', 'INFO' + self.message = message + self.line_num = line_num + + def __str__(self): + prefix = {"ERROR": "❌", "WARNING": "⚠️", "INFO": "ℹ️"}[self.level] + loc = f" (line {self.line_num})" if self.line_num else "" + return f"{prefix} {self.level}: {self.message}{loc}" + + +def validate_triton_kernel(filepath: Path) -> List[ValidationError]: + """Validate Triton kernel against XPU optimization guidelines.""" + + with open(filepath, "r") as f: + source = f.read() + + lines = source.split("\n") + errors = [] + + # 1. Check for autotune parameter defaults (CRITICAL) + autotune_params = set() + in_autotune = False + for i, line in enumerate(lines): + if "@triton.autotune" in line: + in_autotune = True + if in_autotune and "Config" in line: + # Extract parameter names from Config dict + matches = re.findall(r"'(\w+)':", line) + autotune_params.update(matches) + if in_autotune and "@triton.jit" in line: + in_autotune = False + + # Check kernel signature for defaults on autotune params + in_kernel_sig = False + for i, line in enumerate(lines): + if "@triton.jit" in line: + in_kernel_sig = True + if in_kernel_sig and "def " in line and "(" in line: + in_kernel_sig = True + if in_kernel_sig: + for param in autotune_params: + if f"{param}:" in line and "=" in line: + errors.append( + ValidationError( + "ERROR", + f"Autotune parameter '{param}' has default value in kernel signature. " + f"This causes 'Conflicting meta-parameters' error. Remove the default.", + i + 1, + ) + ) + if ")" in line and in_kernel_sig: + break + + # 2. Check grid dimensionality with swizzling + has_swizzling = "GROUP_SIZE_M" in source or "swizzle" in source.lower() + has_2d_grid = False + for i, line in enumerate(lines): + if "grid" in line and "=" in line: + # Check for 2D grid pattern: (triton.cdiv(...), triton.cdiv(...)) + if line.count("triton.cdiv") >= 2 and line.count(",") >= 1: + # Try to detect if it's a tuple with 2+ elements + if "(" in line and ")" in line: + tuple_content = line[line.index("(") : line.rindex(")") + 1] + # Count commas outside of nested parens + paren_depth = 0 + comma_count = 0 + for char in tuple_content: + if char == "(": + paren_depth += 1 + elif char == ")": + paren_depth -= 1 + elif char == "," and paren_depth == 1: + comma_count += 1 + if comma_count >= 1: + has_2d_grid = True + if has_swizzling: + errors.append( + ValidationError( + "ERROR", + "Grid is 2D but tile swizzling is used. Grid must be 1D " + "when using GROUP_SIZE_M swizzling.", + i + 1, + ) + ) + + # 3. Check boundary_check format + for i, line in enumerate(lines): + if "boundary_check" in line: + # Check if it's using booleans instead of dimension indices + if "True" in line or "False" in line: + errors.append( + ValidationError( + "ERROR", + "boundary_check uses booleans. Use dimension indices (0, 1) instead.", + i + 1, + ) + ) + # Check if it's a descriptor load (descriptors don't support boundary_check) + if ".load(" in line and "boundary_check" in line: + # Check if this is a descriptor (desc.load pattern) + # Look backwards for descriptor creation + for j in range(max(0, i - 20), i): + if "make_tensor_descriptor" in lines[j]: + errors.append( + ValidationError( + "ERROR", + "Tensor descriptor .load() does NOT accept boundary_check parameter. " + "Remove it - descriptors handle boundaries internally.", + i + 1, + ) + ) + break + + # 4. Check for float64 usage (CRITICAL performance issue) + for i, line in enumerate(lines): + if "float64" in line.lower() or "tl.float64" in line: + errors.append( + ValidationError( + "WARNING", + "float64 detected. This is 5-10x slower on XPU. Use float32 unless absolutely required.", + i + 1, + ) + ) + + # 5. Check for int32 overflow in batch offset calculations + batch_offset_pattern = r"(program_id|pid|bid)\s*\*\s*stride" + for i, line in enumerate(lines): + if re.search(batch_offset_pattern, line): + if ".to(tl.int64)" not in line and "to(tl.int64)" not in line: + errors.append( + ValidationError( + "WARNING", + "Batch offset calculation may overflow int32. Cast program_id to int64: " + "offset = bid.to(tl.int64) * stride", + i + 1, + ) + ) + + # 6. Check for num_warps=32 without autotune + for i, line in enumerate(lines): + if "num_warps=32" in line or "num_warps = 32" in line: + # Check if it's in a single Config (not autotuned) + if "@triton.autotune" not in source or source.count("num_warps=32") == 1: + errors.append( + ValidationError( + "WARNING", + "num_warps=32 used without autotuning. This can hurt performance on " + "skinny-M or heavy-epilogue kernels. Sweep {4,8,16,32}.", + i + 1, + ) + ) + + # 7. Check for mixed block pointer and tensor descriptor APIs + has_block_ptr = "make_block_ptr" in source + has_tensor_desc = "make_tensor_descriptor" in source + if has_block_ptr and has_tensor_desc: + errors.append( + ValidationError( + "INFO", + "Both block pointers and tensor descriptors found. This is OK if used for " + "different operations (e.g., descriptors for loads, manual pointers for atomics). " + "But do NOT mix APIs for the same load/store operation.", + ) + ) + + # 8. Check for .item() or device-to-host sync in hot path + for i, line in enumerate(lines): + if ".item()" in line or "float(tensor" in line or "int(tensor" in line: + if "def forward" in "".join(lines[max(0, i - 20) : i]): + errors.append( + ValidationError( + "ERROR", + "Device-to-host sync (.item() or float(tensor)) detected in forward pass. " + "This forces synchronization and kills performance.", + i + 1, + ) + ) + + # 9. Check for weight transpose in forward() hot path + for i, line in enumerate(lines): + if ".t()" in line and ".contiguous()" in line: + # Check if inside forward() + in_forward = False + for j in range(max(0, i - 20), i): + if "def forward" in lines[j]: + in_forward = True + break + if in_forward: + errors.append( + ValidationError( + "WARNING", + "Weight transpose (.t().contiguous()) in forward() hot path. " + "Pre-pack once and cache to avoid per-iteration overhead.", + i + 1, + ) + ) + + # 10. Check for GEMM with reduction loop over N (serialization issue) + has_n_loop = False + for i, line in enumerate(lines): + if re.search(r"for.*range\(.*,\s*N\s*[,)]", line): + has_n_loop = True + if ( + "tl.dot" + in source[max(0, source.rfind("\n", 0, i) - 500) : source.find("\n", i) + 500] + ): + errors.append( + ValidationError( + "ERROR", + "GEMM kernel loops over N tiles inside one program. This serializes " + "parallelism. Use 2D grid (pid_m, pid_n) instead.", + i + 1, + ) + ) + + # 11. Check for tl.exp usage (prefer tl.math.exp2 on XPU) + for i, line in enumerate(lines): + if "tl.exp(" in line and "tl.math.exp2" not in source: + errors.append( + ValidationError( + "INFO", + "tl.exp() found. Consider using exp2-based implementation for better XPU performance: " + "exp(x) = exp2(x * 1.44269504)", + i + 1, + ) + ) + + # 12. Check for get_inputs / get_init_inputs (benchmark harness interface) + has_model_class = "class Model" in source + if has_model_class: + if "def get_inputs" not in source: + errors.append( + ValidationError( + "WARNING", + "Model class found but no get_inputs() function. This is required by the " + "benchmark harness (ai-bench).", + ) + ) + if "def get_init_inputs" not in source: + errors.append( + ValidationError( + "WARNING", + "Model class found but no get_init_inputs() function. This is required by the " + "benchmark harness (ai-bench).", + ) + ) + + # 15. Success indicators + if not errors: + errors.append(ValidationError("INFO", "No critical issues found! ✓")) + + # Positive feedback for good patterns + if has_block_ptr or has_tensor_desc: + errors.append( + ValidationError( + "INFO", "✓ Using modern memory access API (block pointers or tensor descriptors)" + ) + ) + if "bfloat16" in source or "float16" in source: + errors.append(ValidationError("INFO", "✓ Using reduced precision inputs (bf16/fp16)")) + if "float32" in source and "accumulator" in source.lower(): + errors.append(ValidationError("INFO", "✓ Using fp32 accumulator for numerical stability")) + return errors + + +def print_validation_results(errors: List[ValidationError], filepath: Path): + """Pretty print validation results.""" + + print(f"\n{'=' * 70}") + print(f"Validation: {filepath.name}") + print(f"{'=' * 70}\n") + + # Separate by level + error_list = [e for e in errors if e.level == "ERROR"] + warning_list = [e for e in errors if e.level == "WARNING"] + info_list = [e for e in errors if e.level == "INFO"] + + if error_list: + print("ERRORS (must fix):") + for err in error_list: + print(f" {err}") + print() + + if warning_list: + print("WARNINGS (should review):") + for err in warning_list: + print(f" {err}") + print() + + if info_list: + print("INFO:") + for err in info_list: + print(f" {err}") + print() + + # Summary + if error_list: + print(f"Status: ❌ FAILED ({len(error_list)} errors)") + return 1 + elif warning_list: + print(f"Status: ⚠️ PASSED with warnings ({len(warning_list)} warnings)") + return 0 + else: + print(f"Status: ✅ PASSED") + return 0 + + +def main(): + if len(sys.argv) != 2: + print("Usage: python scripts/validate_triton.py ") + sys.exit(1) + + filepath = Path(sys.argv[1]) + if not filepath.exists(): + print(f"Error: File not found: {filepath}") + sys.exit(1) + + errors = validate_triton_kernel(filepath) + exit_code = print_validation_results(errors, filepath) + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/skills/xpu-kernels/scripts/xpu_profiler.py b/kernel-builder/skills/xpu-kernels/scripts/xpu_profiler.py new file mode 100644 index 00000000..55e7df5c --- /dev/null +++ b/kernel-builder/skills/xpu-kernels/scripts/xpu_profiler.py @@ -0,0 +1,1095 @@ +#!/usr/bin/env python3 +""" +Profile a Triton kernel using Intel VTune on XPU hardware. + +Collects GPU hardware counters (OA metrics) and maps bottlenecks to +optimization patterns. Use when speedup plateaus or you need guidance on +which optimization level to try next. + +Usage: + python scripts/xpu_profiler.py [--warmup 5] [--iters 20] + +Examples: + python scripts/xpu_profiler.py test_kernels/39_Gemm_Scale_BatchNorm_triton.py + python scripts/xpu_profiler.py output/14_Gemm_Divide_Sum_Scaling_triton.py --iters 50 +""" + +import argparse +import csv +import os +import re +import shutil +import subprocess +import sys +import tempfile +import time +from pathlib import Path + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +sys.path.insert(0, str(Path(__file__).resolve().parent)) + +from config import load_config as _load_config + +_CFG = _load_config() +VTUNE_BIN = _CFG["vtune_bin"] + +# PyTorch runtime kernel patterns — these are NOT user compute kernels +_OVERHEAD_KERNEL_PATTERNS = [ + re.compile(r"VectorizedElementwiseKernel"), + re.compile(r"UnrolledElementwiseKernel"), + re.compile(r"zeCommandListAppendMemoryCopy"), + re.compile(r"ReduceKernelEmptyFunctor"), + re.compile(r"\[Outside any task\]"), +] + +# Columns to request from the hotspots report (the key OA hardware counters) +# VTune OA counter groups conflict when L3 BW and LSC BW are requested +# together. We use two report passes to get all columns. +_HOTSPOTS_COLUMNS_PASS1 = ",".join( + [ + "Computing Task:Total Time", + "Computing Task:Average Time", + "Computing Task:Instance Count", + "Computing Task:SIMD Width", + "Computing Task:Spill Memory Size", + "Work Size:Global", + "Work Size:Local", + # XVE execution breakdown + "XVE Array:Active", + "XVE Array:Stalled", + "XVE Array:Idle", + # Occupancy: peak (auto-includes Work Size/SLM/Barriers limiters) + # NOTE: cannot combine with "XVE Threads Occupancy" — VTune counter conflict + "Peak XVE Threads Occupancy", + # Memory bandwidth (GPU VRAM) + "GPU Memory Bandwidth, GB/sec:Read", + "GPU Memory Bandwidth, GB/sec:Write", + # L3 cache (includes BW — conflicts with LSC BW) + "GPU L3:Busy", + "GPU L3:Stalled", + "GPU L3:Miss Ratio", + "GPU L3:Average Bandwidth, GB/s:Read", + "GPU L3:Average Bandwidth, GB/s:Write", + "GPU L3:Input Available", + "GPU L3:Output Ready", + # Load/Store cache: ratios + pipeline (no BW — conflicts with L3 BW) + "GPU Load Store Cache:Miss Ratio", + "GPU Load Store Cache:L3 Miss Ratio", + "GPU Load Store Cache:Input Available", + "GPU Load Store Cache:Output Ready", + "GPU Load Store Cache:Partial Writes", + # Instruction cache + "GPU Instruction cache L3 Miss Ratio", + # SLM and misc + "GPU Shared Local Memory:Bank Conflicts", + "TLB Misses", + ] +) + +# Second pass: LSC bandwidth + XVE Threads Occupancy (measured) +# These conflict with Pass 1 columns. +_HOTSPOTS_COLUMNS_PASS2 = ",".join( + [ + "Computing Task:Total Time", + "XVE Threads Occupancy", + "GPU Load Store Cache:Average Bandwidth, GB/s:Read", + "GPU Load Store Cache:Average Bandwidth, GB/s:Write", + ] +) + + +def _is_overhead_kernel(name: str) -> bool: + """Return True if *name* matches a known PyTorch runtime kernel.""" + for pat in _OVERHEAD_KERNEL_PATTERNS: + if pat.search(name): + return True + return False + + +# --------------------------------------------------------------------------- +# Runner script generation +# --------------------------------------------------------------------------- + + +def generate_runner_script( + triton_file: Path, warmup: int, iters: int, vtune_bin: str = "", result_dir: str = "" +) -> str: + # When vtune_bin and result_dir are provided, use VTune CLI to + # resume/pause collection so only the profiled loop is captured. + if vtune_bin and result_dir: + resume_pause = f""" +import subprocess +_VTUNE_BIN = "{vtune_bin}" +_RESULT_DIR = "{result_dir}" +def _vtune_cmd(cmd): + subprocess.run([_VTUNE_BIN, "-command", cmd, "-r", _RESULT_DIR], + capture_output=True, timeout=30) +""" + resume_call = "_vtune_cmd('resume')" + else: + resume_pause = "" + resume_call = "pass # no VTune pause/resume" + + return f"""\ +import torch +import importlib.util +import sys +{resume_pause} +spec = importlib.util.spec_from_file_location("triton_kernel", "{triton_file.resolve()}") +mod = importlib.util.module_from_spec(spec) +spec.loader.exec_module(mod) + +device = torch.device("xpu") +init_inputs = mod.get_init_inputs() +model = mod.Model(*init_inputs).to(device).eval() + +inputs = mod.get_inputs() +inputs = [inp.to(device) if hasattr(inp, 'to') else inp for inp in inputs] + +# Warmup (collection paused) +with torch.no_grad(): + for _ in range({warmup}): + _ = model(*inputs) + torch.xpu.synchronize() + +# Resume collection for profiled iterations +{resume_call} + +with torch.no_grad(): + for _ in range({iters}): + _ = model(*inputs) + torch.xpu.synchronize() + +""" + + +# --------------------------------------------------------------------------- +# VTune CSV parsing +# --------------------------------------------------------------------------- + + +def parse_vtune_summary_csv(csv_path: Path) -> tuple[dict[str, str], list[dict], list[dict]]: + """Parse the VTune ``-R summary`` TSV report. + + Returns (scalar_metrics, gpu_tasks, host_tasks). + """ + scalar_metrics: dict[str, str] = {} + gpu_tasks: list[dict] = [] + host_tasks: list[dict] = [] + + if not csv_path.exists(): + return scalar_metrics, gpu_tasks, host_tasks + + with open(csv_path, newline="") as f: + rows = list(csv.reader(f, delimiter="\t")) + + if not rows: + return scalar_metrics, gpu_tasks, host_tasks + + TABLE_SECTIONS = { + "Top Hotspots when GPU was idle", + "Hottest Host Tasks", + "Hottest GPU Computing Tasks", + } + INFO_SECTION = "Collection and Platform Info" + RECO_SECTION = "Recommendations:" + + def _parse_table(start_idx: int) -> tuple[list[dict], int]: + header_row = rows[start_idx] + col_names = [c.strip() for c in header_row[1:]] + result = [] + idx = start_idx + 1 + while idx < len(rows): + row = rows[idx] + if len(row) < 2: + break + name = row[1].strip() + if ( + len(row) >= 3 + and not row[2].strip() + and name in (TABLE_SECTIONS | {INFO_SECTION, RECO_SECTION}) + ): + break + if len(row) == 2 or (len(row) >= 3 and not row[2].strip()): + if name and not any(c.strip() for c in row[2:]): + break + vals = [c.strip() for c in row[1:]] + entry = {} + for j, col in enumerate(col_names): + entry[col] = vals[j] if j < len(vals) else "" + result.append(entry) + idx += 1 + return result, idx + + idx = 1 # skip header + while idx < len(rows): + row = rows[idx] + if len(row) < 2: + idx += 1 + continue + + name = row[1].strip() + has_value = len(row) >= 3 and row[2].strip() + + if name in TABLE_SECTIONS and not has_value: + idx += 1 + if idx >= len(rows): + break + entries, idx = _parse_table(idx) + if name == "Hottest GPU Computing Tasks": + gpu_tasks = entries + elif name == "Hottest Host Tasks": + host_tasks = entries + continue + + if name == INFO_SECTION and not has_value: + idx += 1 + while idx < len(rows): + r = rows[idx] + if len(r) < 2: + idx += 1 + continue + rname = r[1].strip() + if rname == RECO_SECTION or rname.startswith("Recommendations"): + break + rval = r[2].strip() if len(r) >= 3 else "" + if rval: + scalar_metrics[rname] = rval + idx += 1 + continue + + if name == RECO_SECTION or name.startswith("Recommendations"): + # Parse VTune recommendations section + idx += 1 + while idx < len(rows): + r = rows[idx] + if len(r) < 2: + idx += 1 + continue + rname = r[1].strip() + rval = r[2].strip() if len(r) >= 3 else "" + if rval and rname: + scalar_metrics[f"_reco_{rname}"] = rval + idx += 1 + continue + + if has_value: + scalar_metrics[name] = row[2].strip() + + idx += 1 + + return scalar_metrics, gpu_tasks, host_tasks + + +def parse_hotspots_csv(csv_path: Path) -> list[dict]: + """Parse the VTune ``-R hotspots -group-by computing-task`` TSV report. + + Returns a list of per-kernel dicts with OA hardware counter columns. + """ + if not csv_path.exists(): + return [] + + with open(csv_path, newline="") as f: + rows = list(csv.reader(f, delimiter="\t")) + + if len(rows) < 2: + return [] + + # VTune may prepend warning lines (e.g. "war:Column filter is ON."). + # Find the actual header row — it starts with "Computing Task". + header_idx = 0 + for i, row in enumerate(rows): + if row and row[0].strip().startswith("Computing Task"): + header_idx = i + break + + headers = [h.strip() for h in rows[header_idx]] + result = [] + for row in rows[header_idx + 1 :]: + if not row or not row[0].strip(): + continue + entry = {} + for j, h in enumerate(headers): + entry[h] = row[j].strip() if j < len(row) else "" + result.append(entry) + return result + + +def _extract(value: str) -> float | None: + """Extract a numeric value, stripping %, units, commas.""" + value = value.strip().rstrip("%").replace(",", "").strip() + try: + return float(value) + except ValueError: + return None + + +# --------------------------------------------------------------------------- +# Primary kernel identification +# --------------------------------------------------------------------------- + + +def find_primary_kernel(gpu_tasks: list[dict]) -> dict | None: + """Find the primary compute kernel, skipping PyTorch overhead kernels. + + Among non-overhead kernels with the same name, pick the variant with the + highest total time (the autotune winner). If ALL kernels are overhead, + fall back to the one with the highest time. + """ + candidates = [] + fallback = None + fallback_time = 0.0 + + for task in gpu_tasks: + name = task.get("Computing Task", "") + t = _extract(task.get("Computing Task:Total Time", task.get("Total Time", ""))) + if t is None or name.startswith("["): + continue + + if t > fallback_time: + fallback_time = t + fallback = task + + if not _is_overhead_kernel(name): + candidates.append((t, task)) + + if candidates: + candidates.sort(key=lambda x: x[0], reverse=True) + return candidates[0][1] + return fallback + + +def aggregate_kernel_variants(gpu_tasks: list[dict]) -> list[dict]: + """Group rows by kernel name and sum times / average metrics. + + VTune reports the same kernel name multiple times with different SIMD + widths (autotune configurations). This helper aggregates them for the + summary display while keeping the per-variant details available. + """ + by_name: dict[str, list[dict]] = {} + for task in gpu_tasks: + name = task.get("Computing Task", "") + by_name.setdefault(name, []).append(task) + + result = [] + for name, variants in by_name.items(): + total_time = sum( + _extract(v.get("Computing Task:Total Time", v.get("Total Time", ""))) or 0 + for v in variants + ) + total_count = sum( + int(_extract(v.get("Computing Task:Instance Count", v.get("Instance Count", ""))) or 0) + for v in variants + ) + # Use the variant with the highest time for representative metrics + best = max( + variants, + key=lambda v: ( + _extract(v.get("Computing Task:Total Time", v.get("Total Time", ""))) or 0 + ), + ) + agg = dict(best) # copy best variant's metrics + agg["_total_time"] = total_time + agg["_total_count"] = total_count + agg["_num_variants"] = len(variants) + result.append(agg) + result.sort(key=lambda x: x["_total_time"], reverse=True) + return result + + +# --------------------------------------------------------------------------- +# Display functions +# --------------------------------------------------------------------------- + + +def print_host_tasks(host_tasks: list[dict]): + if not host_tasks: + return + print("\n Hottest Host Tasks:\n") + print(f" {'Host Task':<45} {'Time (s)':>10} {'% Elapsed':>10} {'Count':>6}") + print(f" {'-' * 76}") + for task in host_tasks: + name = task.get("Host Task", "") + ttime = task.get("Task Time", "") + pct = task.get("% of Elapsed Time(%)", "") + count = task.get("Task Count", "") + if len(name) > 45: + name = name[:42] + "..." + pct_str = f"{pct}%" if pct else "" + print(f" {name:<45} {ttime:>10} {pct_str:>10} {count:>6}") + + +def _truncate(name: str, maxlen: int = 55) -> str: + if name.startswith("["): + return name + return name if len(name) <= maxlen else name[: maxlen - 3] + "..." + + +def print_gpu_tasks_summary(gpu_tasks: list[dict], has_oa: bool): + """Print a compact GPU computing tasks table.""" + if not gpu_tasks: + return + + agg = aggregate_kernel_variants(gpu_tasks) + N = 55 + + print("\n GPU Computing Tasks (by kernel name):\n") + if has_oa: + print( + f" {'Kernel':<{N}} {'Time':>7} {'Cnt':>5} {'Active':>7} {'Stall':>7} {'Idle':>7} {'Occ%':>6} {'MemR':>7} {'MemW':>7}" + ) + print(f" {'-' * (N + 62)}") + else: + print(f" {'Kernel':<{N}} {'Time':>7} {'Cnt':>5} {'SIMD':>5} {'Occ%':>7} {'Util%':>7}") + print(f" {'-' * (N + 37)}") + + def _fmt_pct(val: str) -> str: + v = _extract(val) + return f"{v:.1f}%" if v is not None else "" + + def _fmt_f(val: str) -> str: + v = _extract(val) + return f"{v:.1f}" if v is not None else "" + + for a in agg: + name = a.get("Computing Task", "") + tag = " *" if _is_overhead_kernel(name) else "" + dname = _truncate(name, N - len(tag)) + tag + tt = f"{a['_total_time']:.4f}" + cnt = str(a["_total_count"]) + + if has_oa: + active_s = _fmt_pct(a.get("XVE Array:Active(%)", "")) + stall_s = _fmt_pct(a.get("XVE Array:Stalled(%)", "")) + idle_s = _fmt_pct(a.get("XVE Array:Idle(%)", "")) + occ_s = _fmt_pct( + a.get("XVE Threads Occupancy(%)", "") or a.get("Peak XVE Threads Occupancy(%)", "") + ) + memr_s = _fmt_f(a.get("GPU Memory Bandwidth, GB/sec:Read", "")) + memw_s = _fmt_f(a.get("GPU Memory Bandwidth, GB/sec:Write", "")) + print( + f" {dname:<{N}} {tt:>7} {cnt:>5} {active_s:>7} {stall_s:>7} {idle_s:>7} {occ_s:>6} {memr_s:>7} {memw_s:>7}" + ) + else: + simd = a.get("Computing Task:SIMD Width", a.get("SIMD Width", "")) + occ = a.get("Peak XVE Threads Occupancy(%)", "") + util = a.get("SIMD Utilization(%)", "") + occ_s = f"{_extract(occ):.1f}%" if _extract(occ) is not None else occ + util_s = f"{_extract(util):.1f}%" if _extract(util) is not None else util + vn = a["_num_variants"] + simd_s = simd if vn == 1 else f"{simd}({vn})" + print(f" {dname:<{N}} {tt:>7} {cnt:>5} {simd_s:>5} {occ_s:>7} {util_s:>7}") + + print("\n (* = PyTorch runtime overhead kernel, not user compute)") + + +def print_primary_kernel_detail(primary: dict, has_oa: bool): + """Print detailed metrics for the primary compute kernel.""" + name = primary.get("Computing Task", "") + print(f"\n{'=' * 70}") + print(f"Primary Kernel Analysis: {_truncate(name, 50)}") + print(f"{'=' * 70}") + + def _row(label, key, unit="", fmt=".4f"): + val = primary.get(key, "") + v = _extract(val) + if v is not None: + print(f" {label:<42} {v:{fmt}}{unit}") + elif val: + print(f" {label:<42} {val}") + + _row("Total Time", "Computing Task:Total Time", "s") + _row("Average Time", "Computing Task:Average Time", "s", ".6f") + _row("Instance Count", "Computing Task:Instance Count", "", ".0f") + _row("SIMD Width", "Computing Task:SIMD Width", "", ".0f") + _row("Spill Memory Size", "Computing Task:Spill Memory Size", " bytes", ".0f") + + gs = primary.get("Work Size:Global", "") + ls = primary.get("Work Size:Local", "") + if gs or ls: + print(f" {'Work Size (Global / Local)':<42} {gs} / {ls}") + + if has_oa: + print() + _row("XVE Active", "XVE Array:Active(%)", "%", ".1f") + _row("XVE Stalled", "XVE Array:Stalled(%)", "%", ".1f") + _row("XVE Idle", "XVE Array:Idle(%)", "%", ".1f") + _row("XVE Threads Occupancy (measured)", "XVE Threads Occupancy(%)", "%", ".1f") + _row("Peak XVE Threads Occupancy", "Peak XVE Threads Occupancy(%)", "%", ".1f") + # Occupancy limiters — show what's capping peak occupancy + ws_lim = _extract(primary.get("Peak XVE Threads Occupancy:Work Size Limit(%)", "")) + slm_lim = _extract(primary.get("Peak XVE Threads Occupancy:SLM Use Limit(%)", "")) + bar_lim = _extract(primary.get("Peak XVE Threads Occupancy:Barriers Use Limit(%)", "")) + if ws_lim is not None: + limiter = min( + ("Work Size (grid too small)", ws_lim), + ("SLM Usage", slm_lim if slm_lim is not None else 100), + ("Barriers", bar_lim if bar_lim is not None else 100), + key=lambda x: x[1], + ) + print( + f" {'Occupancy Limiters':<42} WorkSize={ws_lim:.0f}% SLM={slm_lim:.0f}% Barriers={bar_lim:.0f}%" + ) + if limiter[1] < 100: + print(f" {' -> Bottleneck':<42} {limiter[0]}") + print() + _row("GPU Memory BW Read", "GPU Memory Bandwidth, GB/sec:Read", " GB/s", ".1f") + _row("GPU Memory BW Write", "GPU Memory Bandwidth, GB/sec:Write", " GB/s", ".1f") + print() + _row("L3 Busy", "GPU L3:Busy(%)", "%", ".1f") + _row("L3 Stalled", "GPU L3:Stalled(%)", "%", ".1f") + _row("L3 Cache Miss Ratio", "GPU L3:Miss Ratio(%)", "%", ".1f") + _row("L3 BW Read", "GPU L3:Average Bandwidth, GB/s:Read", " GB/s", ".1f") + _row("L3 BW Write", "GPU L3:Average Bandwidth, GB/s:Write", " GB/s", ".1f") + _row("L3 Input Available", "GPU L3:Input Available(%)", "%", ".1f") + _row("L3 Output Ready", "GPU L3:Output Ready(%)", "%", ".1f") + print() + _row("LSC Miss Ratio", "GPU Load Store Cache:Miss Ratio(%)", "%", ".1f") + _row("LSC -> L3 Miss Ratio", "GPU Load Store Cache:L3 Miss Ratio(%)", "%", ".1f") + _row("LSC BW Read", "GPU Load Store Cache:Average Bandwidth, GB/s:Read", " GB/s", ".1f") + _row("LSC BW Write", "GPU Load Store Cache:Average Bandwidth, GB/s:Write", " GB/s", ".1f") + _row("LSC Input Available", "GPU Load Store Cache:Input Available(%)", "%", ".1f") + _row("LSC Output Ready", "GPU Load Store Cache:Output Ready(%)", "%", ".1f") + _row("LSC Partial Writes", "GPU Load Store Cache:Partial Writes", "", ".0f") + print() + _row("Instruction Cache L3 Miss", "GPU Instruction cache L3 Miss Ratio(%)", "%", ".1f") + _row("SLM Bank Conflicts", "GPU Shared Local Memory:Bank Conflicts", "", ".0f") + _row("TLB Misses", "TLB Misses", "", ".0f") + + +def print_recommendations( + primary: dict | None, + gpu_tasks: list[dict], + host_tasks: list[dict], + scalar_metrics: dict[str, str], + has_oa: bool, +): + """Generate actionable recommendations grounded in KB optimization patterns. + + Every recommendation references a specific KB file and pattern ID. + Thresholds are based on hardware counter semantics, not arbitrary cutoffs. + """ + recommendations = [] + + # --- Host overhead check --- + # Grounded in: references/memory_patterns.yaml (no_device_to_host_scalar_sync) + total_host_time = sum(_extract(t.get("Task Time", "")) or 0 for t in host_tasks) + total_gpu_time = sum( + _extract(t.get("Computing Task:Total Time", t.get("Total Time", ""))) or 0 + for t in gpu_tasks + if not t.get("Computing Task", "").startswith("[") + ) + + if total_host_time > 0 and total_gpu_time > 0 and total_host_time > total_gpu_time * 2: + recommendations.append( + ( + f"Host overhead ({total_host_time:.3f}s) >> GPU compute ({total_gpu_time:.3f}s). " + "CPU-side dominates: check for .item()/.cpu() in hot path, " + "ensure weight packing runs at init time (not in forward()).", + "references/memory_patterns.yaml (no_device_to_host_scalar_sync)", + ) + ) + + # --- Overhead kernel dominance --- + # Grounded in: references/optimization_levels.yaml (level_2_bandwidth_reduction) + # When PyTorch Fill/Copy/Cast ops dominate, it means data type conversion + # is happening at runtime instead of at pack time. + overhead_time = sum( + _extract(t.get("Computing Task:Total Time", t.get("Total Time", ""))) or 0 + for t in gpu_tasks + if _is_overhead_kernel(t.get("Computing Task", "")) + and not t.get("Computing Task", "").startswith("[") + ) + if total_gpu_time > 0 and overhead_time / total_gpu_time > 0.30: + recommendations.append( + ( + f"Overhead kernels consume {overhead_time / total_gpu_time * 100:.0f}% of GPU time " + f"({overhead_time:.4f}s). These are PyTorch Fill/Copy/Cast ops. " + "Pre-pack weights AND inputs to bf16 at init time to eliminate them.", + "references/optimization_levels.yaml (level_2_bandwidth_reduction)", + ) + ) + + if primary is None: + _print_reco_section(recommendations) + return + + if has_oa: + active = _extract(primary.get("XVE Array:Active(%)", "")) + stalled = _extract(primary.get("XVE Array:Stalled(%)", "")) + idle = _extract(primary.get("XVE Array:Idle(%)", "")) + occupancy = _extract(primary.get("XVE Threads Occupancy(%)", "")) or _extract( + primary.get("Peak XVE Threads Occupancy(%)", "") + ) + l3_miss = _extract(primary.get("GPU L3:Miss Ratio(%)", "")) + lsc_l3_miss = _extract(primary.get("GPU Load Store Cache:L3 Miss Ratio(%)", "")) + spill = _extract(primary.get("Computing Task:Spill Memory Size", "")) + + # Occupancy limiters + ws_lim = _extract(primary.get("Peak XVE Threads Occupancy:Work Size Limit(%)", "")) + slm_lim = _extract(primary.get("Peak XVE Threads Occupancy:SLM Use Limit(%)", "")) + bar_lim = _extract(primary.get("Peak XVE Threads Occupancy:Barriers Use Limit(%)", "")) + + # --- XVE Stall-dominated: memory/dependency bound --- + # Grounded in: references/optimization_levels.yaml (level_2_bandwidth_reduction) + # and references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern) + # When XVE spends more time stalled than active, the execution units + # are waiting on memory. This maps to Level 2 bandwidth reduction. + if stalled is not None and active is not None and stalled > active: + recommendations.append( + ( + f"XVE Stalled ({stalled:.0f}%) > Active ({active:.0f}%): memory/dependency bound. " + "Use tensor descriptors for better address codegen, pre-pack to bf16 to halve bandwidth, " + "try tile swizzling for better L3 locality.", + "references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern, xpu_tile_swizzling) + " + "references/optimization_levels.yaml (level_2_bandwidth_reduction)", + ) + ) + + # --- XVE Idle-dominated: underutilization --- + # Grounded in: references/persistent_kernel_patterns.yaml (persistent_kernel_basic_tile_loop) + # High idle means the GPU has execution units with no work scheduled. + # This happens when the grid is too small or there aren't enough warps. + if idle is not None and active is not None and idle > 30 and active < 30: + recommendations.append( + ( + f"XVE Idle ({idle:.0f}%) with Active only ({active:.0f}%): GPU underutilized. " + "Grid may be too small — increase tile count or use persistent kernel pattern.", + "references/persistent_kernel_patterns.yaml (persistent_kernel_basic_tile_loop)", + ) + ) + + # --- Occupancy limiters --- + # Grounded in: references/xpu_optimizations.yaml (xpu_grf_mode, xpu_tile_swizzling) + # Check BOTH measured occupancy AND peak occupancy limiters. + # Peak < 100% means hardware CAN'T give full occupancy (structural limit). + # Measured < peak means kernel isn't filling available slots (launch config). + peak_occ = _extract(primary.get("Peak XVE Threads Occupancy(%)", "")) + occ_limited = (occupancy is not None and occupancy < 50) or ( + peak_occ is not None and peak_occ < 100 + ) + + if occ_limited: + if ( + ws_lim is not None + and ws_lim < 100 + and (slm_lim is None or ws_lim <= (slm_lim or 100)) + ): + recommendations.append( + ( + f"Peak Occupancy capped at {ws_lim:.0f}% by Work Size (grid too small). " + "Increase grid dimensions, use tile swizzling with GROUP_SIZE_M, " + "or try persistent kernel pattern to keep all XVEs busy.", + "references/xpu_optimizations.yaml (xpu_tile_swizzling) + " + "references/persistent_kernel_patterns.yaml", + ) + ) + elif slm_lim is not None and slm_lim < 100: + recommendations.append( + ( + f"Peak Occupancy capped at {slm_lim:.0f}% by SLM Usage. " + "Kernel uses too much shared local memory per work group. " + "Reduce tile sizes or try grf_mode='large' to trade SLM for registers.", + "references/xpu_optimizations.yaml (xpu_grf_mode)", + ) + ) + elif bar_lim is not None and bar_lim < 100: + recommendations.append( + ( + f"Peak Occupancy capped at {bar_lim:.0f}% by Barriers. " + "Too many barrier synchronizations. Reduce num_warps or restructure " + "the kernel to use fewer synchronization points.", + "references/xpu_optimizations.yaml (xpu_warp_count)", + ) + ) + elif occupancy is not None and occupancy < 50: + recommendations.append( + ( + f"XVE Occupancy {occupancy:.0f}%: low thread count on GPU. " + "Try larger tiles, more warps, or grf_mode='large' (256 registers).", + "references/xpu_optimizations.yaml (xpu_grf_mode)", + ) + ) + + # --- High L3 miss: data streaming from VRAM --- + # Grounded in: references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern, xpu_tile_swizzling) + # Tensor descriptors produce better address codegen; tile swizzling improves L3 reuse + # across neighboring work groups. + if l3_miss is not None and l3_miss > 50: + recommendations.append( + ( + f"L3 Miss Ratio {l3_miss:.0f}%: data is streaming from VRAM with poor reuse. " + "Use tensor descriptors (better address codegen) and tile swizzling " + "(improves L3 locality across neighboring work groups).", + "references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern, xpu_tile_swizzling)", + ) + ) + + # --- High LSC->L3 miss: L1 thrashing --- + # Grounded in: references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern) + # and references/memory_patterns.yaml (mem_block_pointers) + if lsc_l3_miss is not None and lsc_l3_miss > 50: + recommendations.append( + ( + f"LSC->L3 Miss Ratio {lsc_l3_miss:.0f}%: L1 cache not capturing data. " + "Use tensor descriptors for structured access patterns. " + "Ensure coalesced access within work groups.", + "references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern) + " + "references/memory_patterns.yaml (mem_block_pointers)", + ) + ) + + # --- Register spill --- + # Grounded in: references/memory_patterns.yaml (reduce_liveness_sink_load_and_prefetch) + # and references/xpu_optimizations.yaml (xpu_grf_mode) + # Spill means the kernel exceeds the register file. Two mitigations: + # 1. Reduce live variable count (sink loads closer to use, prefetch early) + # 2. Use large GRF mode (256 registers) to increase register budget + if spill is not None and spill > 0: + recommendations.append( + ( + f"Register Spill: {spill:.0f} bytes. Kernel exceeds register file capacity. " + "Reduce variable liveness: sink loads closer to tl.dot(), use tl.prefetch() " + "to warm cache without holding registers. Also try grf_mode='large' (256 regs).", + "references/memory_patterns.yaml (reduce_liveness_sink_load_and_prefetch) + " + "references/xpu_optimizations.yaml (xpu_grf_mode)", + ) + ) + + # --- Instruction cache misses --- + # High instruction cache L3 miss ratio means the compiled kernel binary + # is too large to stay in instruction cache. This typically happens with + # heavily unrolled or very large tile kernels. + icache_miss = _extract(primary.get("GPU Instruction cache L3 Miss Ratio(%)", "")) + if icache_miss is not None and icache_miss > 30: + recommendations.append( + ( + f"Instruction Cache L3 Miss {icache_miss:.0f}%: compiled kernel is too large. " + "Reduce tile sizes or unrolling to shrink kernel binary. " + "Smaller BLOCK_K or fewer autotune configs can help.", + "references/xpu_optimizations.yaml (xpu_descriptor_gemm_pattern)", + ) + ) + + else: + # Fallback: no OA data, use summary metrics only + occupancy = _extract(primary.get("Peak XVE Threads Occupancy(%)", "")) + + if occupancy is not None and occupancy < 50: + recommendations.append( + ( + f"Peak XVE Occupancy {occupancy:.0f}%: low. " + "Try larger tiles, more warps, or grf_mode='large'.", + "references/xpu_optimizations.yaml (xpu_grf_mode, xpu_tile_swizzling)", + ) + ) + + _print_reco_section(recommendations) + + +def _print_reco_section(recommendations): + print(f"\n{'=' * 70}") + print("Optimization Recommendations") + print(f"{'=' * 70}") + if recommendations: + for msg, ref in recommendations: + print(f"\n >> {msg}") + print(f" Reference: {ref}") + else: + print("\n No bottlenecks detected -- kernel looks well-optimized!") + print(" If speedup is still below target, consider:") + print(" - Level 3: Algebraic fusion (fold BN/scale into weights)") + print(" - Level 4: Stream K / persistent kernels") + + +# --------------------------------------------------------------------------- +# VTune execution +# --------------------------------------------------------------------------- + + +def run_vtune_collection( + vtune_bin: str, + triton_file: Path, + result_dir: str, + summary_csv: Path, + warmup: int, + iters: int, + timeout: int, +) -> subprocess.CompletedProcess: + """Run VTune gpu-offload collection and produce summary CSV.""" + runner = tempfile.NamedTemporaryFile( + mode="w", suffix=".py", prefix="vtune_runner_", delete=False + ) + runner.write( + generate_runner_script( + triton_file, warmup, iters, vtune_bin=vtune_bin, result_dir=result_dir + ) + ) + runner.close() + + try: + # Step 1: Collect with -start-paused (runner resumes/pauses around profiled loop) + collect_cmd = [ + vtune_bin, + "-collect", + "gpu-offload", + "-start-paused", + "-r", + result_dir, + "--", + sys.executable, + runner.name, + ] + result = subprocess.run(collect_cmd, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0: + return result + + # Step 2: Generate summary report from collected data + report_cmd = [ + vtune_bin, + "-R", + "summary", + "-r", + result_dir, + "-format", + "csv", + "-csv-delimiter", + "tab", + "-report-output", + str(summary_csv), + ] + subprocess.run(report_cmd, capture_output=True, text=True, timeout=120) + return result + finally: + try: + os.unlink(runner.name) + except OSError: + pass + + +def _run_single_hotspots_report( + vtune_bin: str, result_dir: str, csv_path: Path, columns: str +) -> bool: + """Run one VTune hotspots report pass. Returns True on success.""" + cmd = [ + vtune_bin, + "-R", + "hotspots", + "-r", + result_dir, + "-group-by", + "computing-task", + "-format", + "csv", + "-csv-delimiter", + "tab", + "-column", + columns, + "-report-output", + str(csv_path), + ] + result = subprocess.run(cmd, capture_output=True, text=True, timeout=120) + if result.returncode != 0: + return False + return csv_path.exists() and csv_path.stat().st_size > 100 + + +def run_hotspots_report(vtune_bin: str, result_dir: str, hotspots_csv: Path) -> bool: + """Run VTune hotspots reports to extract per-kernel OA hardware counters. + + Uses two passes because some OA counter groups conflict (e.g. L3 BW + and LSC BW cannot be collected simultaneously). Pass 2 results are + merged into Pass 1 by kernel name. + + Returns True if at least Pass 1 produced useful data. + """ + ok = _run_single_hotspots_report(vtune_bin, result_dir, hotspots_csv, _HOTSPOTS_COLUMNS_PASS1) + if not ok: + return False + + # Pass 2: supplementary columns (LSC BW, measured occupancy) + pass2_csv = hotspots_csv.with_name(hotspots_csv.stem + "_pass2.csv") + ok2 = _run_single_hotspots_report(vtune_bin, result_dir, pass2_csv, _HOTSPOTS_COLUMNS_PASS2) + if ok2: + _merge_pass2(hotspots_csv, pass2_csv) + try: + pass2_csv.unlink() + except OSError: + pass + + return True + + +def _merge_pass2(main_csv: Path, pass2_csv: Path): + """Merge Pass 2 columns into the main hotspots CSV by kernel name.""" + pass2_tasks = parse_hotspots_csv(pass2_csv) + if not pass2_tasks: + return + + # Build lookup by kernel name + p2_by_name: dict[str, dict] = {} + for task in pass2_tasks: + name = task.get("Computing Task", "") + if name and name not in p2_by_name: + p2_by_name[name] = task + + # Re-read main CSV, add pass2 columns, rewrite + main_tasks = parse_hotspots_csv(main_csv) + if not main_tasks: + return + + # Find new columns from pass2 (skip "Computing Task" and duplicates) + main_keys = set(main_tasks[0].keys()) if main_tasks else set() + new_cols = [k for k in pass2_tasks[0].keys() if k not in main_keys and k != "Computing Task"] + + if not new_cols: + return + + for task in main_tasks: + name = task.get("Computing Task", "") + p2 = p2_by_name.get(name, {}) + for col in new_cols: + task[col] = p2.get(col, "") + + # Rewrite main CSV + all_cols = list(main_tasks[0].keys()) + with open(main_csv, "w", newline="") as f: + writer = csv.writer(f, delimiter="\t") + writer.writerow(all_cols) + for task in main_tasks: + writer.writerow([task.get(c, "") for c in all_cols]) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Profile Triton kernel with Intel VTune on XPU") + parser.add_argument("triton_file", type=Path, help="Triton kernel implementation") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations") + parser.add_argument("--iters", type=int, default=20, help="Profiled iterations") + parser.add_argument("--timeout", type=int, default=300, help="VTune timeout (s)") + args = parser.parse_args() + + if not args.triton_file.exists(): + print(f"Error: Triton file not found: {args.triton_file}") + sys.exit(1) + + vtune_bin = os.environ.get("VTUNE_BIN", VTUNE_BIN) + if not shutil.which(vtune_bin) and not Path(vtune_bin).is_file(): + print(f"Error: VTune binary not found at: {vtune_bin}") + sys.exit(1) + + kernel_name = args.triton_file.stem + timestamp = int(time.time()) + result_dir = f"/tmp/vtune_result_{kernel_name}_{timestamp}" + summary_csv = Path(f"/tmp/vtune_{kernel_name}_{timestamp}_summary.csv") + hotspots_csv = Path(f"/tmp/vtune_{kernel_name}_{timestamp}_hotspots.csv") + + print(f"\n{'=' * 70}") + print("VTune Profiling Configuration") + print(f"{'=' * 70}") + print(f" Triton kernel: {args.triton_file}") + print(" Collection: gpu-offload (with OA hardware counters)") + print(f" Warmup iters: {args.warmup}") + print(f" Profiled iters: {args.iters}") + print(f" Timeout: {args.timeout}s") + + # --- Step 1: Run VTune collection --- + print(f"\n{'=' * 70}") + print("Running VTune Collection...") + print(f"{'=' * 70}") + + try: + result = run_vtune_collection( + vtune_bin, + args.triton_file, + result_dir, + summary_csv, + args.warmup, + args.iters, + args.timeout, + ) + except subprocess.TimeoutExpired: + print(f"\nError: VTune collection timed out after {args.timeout}s") + sys.exit(1) + + if result.returncode != 0: + print(f"\nError: VTune exited with code {result.returncode}") + if result.stderr: + print(f" stderr: {result.stderr[:2000]}") + sys.exit(1) + + # --- Step 2: Parse summary report --- + scalar_metrics, gpu_tasks, host_tasks = parse_vtune_summary_csv(summary_csv) + + # --- Step 3: Run hotspots report for OA hardware counters --- + has_oa = False + hotspot_tasks: list[dict] = [] + if os.path.isdir(result_dir): + ok = run_hotspots_report(vtune_bin, result_dir, hotspots_csv) + if ok: + hotspot_tasks = parse_hotspots_csv(hotspots_csv) + # Check if OA columns are actually populated + if hotspot_tasks: + sample = hotspot_tasks[0] + has_oa = bool(_extract(sample.get("XVE Array:Active(%)", ""))) + + # Use hotspot_tasks if available (richer data), else fall back to summary gpu_tasks + display_tasks = hotspot_tasks if hotspot_tasks else gpu_tasks + + # --- Step 4: Display results --- + print(f"\n{'=' * 70}") + print("VTune Profiling Results") + print(f"{'=' * 70}") + + # Platform info (compact) + gpu_name = scalar_metrics.get("Name", "") + xve_count = scalar_metrics.get("XVE Count", "") + max_freq = scalar_metrics.get("Max Core Frequency", "") + if gpu_name: + freq_ghz = f"{int(max_freq) / 1e9:.1f} GHz" if _extract(max_freq) else "" + print(f"\n GPU: {gpu_name} XVEs: {xve_count} Max Freq: {freq_ghz}") + + elapsed = scalar_metrics.get("Elapsed Time", "") + gpu_pct = scalar_metrics.get("GPU Time, % of Elapsed time", "") + if elapsed: + print(f" Elapsed: {elapsed}s GPU Time: {gpu_pct}% of elapsed") + + # VTune's own recommendations (XVE Stalled/Idle) + xve_stall_reco = scalar_metrics.get("_reco_XVE Array Stalled/Idle", "") + if xve_stall_reco: + # Extract the percentage + match = re.search(r"([\d.]+)", xve_stall_reco) + if match: + print(f" XVE Array Stalled/Idle: {match.group(1)}% of GPU busy time") + + print_host_tasks(host_tasks) + print_gpu_tasks_summary(display_tasks, has_oa) + + # --- Step 5: Primary kernel detail --- + primary = find_primary_kernel(display_tasks) + if primary: + print_primary_kernel_detail(primary, has_oa) + + # --- Step 6: Recommendations --- + print_recommendations(primary, display_tasks, host_tasks, scalar_metrics, has_oa) + + print(f"\n{'=' * 70}") + print("Profiling Complete") + print(f"{'=' * 70}") + print(f" Summary CSV: {summary_csv}") + if hotspot_tasks: + print(f" Hotspots CSV: {hotspots_csv}") + if os.path.isdir(result_dir): + print(f" VTune result: {result_dir}") + if not has_oa: + print("\n Note: OA hardware counters not available.") + print(" To enable: echo 0 | sudo tee /proc/sys/dev/xe/observation_paranoid") + print() + + +if __name__ == "__main__": + main() diff --git a/kernel-builder/src/skills.rs b/kernel-builder/src/skills.rs index d0e7cd52..d1753452 100644 --- a/kernel-builder/src/skills.rs +++ b/kernel-builder/src/skills.rs @@ -10,10 +10,12 @@ const GITHUB_RAW_BASE_TEMPLATE: &str = "https://raw.githubusercontent.com/huggingface/kernels/main/kernel-builder/skills"; #[derive(Clone, Debug, Default, ValueEnum)] +#[allow(clippy::enum_variant_names)] pub enum SkillId { #[default] CudaKernels, RocmKernels, + XpuKernels, } impl SkillId { @@ -21,6 +23,7 @@ impl SkillId { match self { SkillId::CudaKernels => "cuda-kernels", SkillId::RocmKernels => "rocm-kernels", + SkillId::XpuKernels => "xpu-kernels", } } }