diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index cf5f9a0db..48399d6da 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -12,61 +12,48 @@ Iris is a Triton-based framework for Remote Memory Access (RMA) operations on AM **FOLLOW THESE INSTRUCTIONS EXACTLY. Reference these instructions first before using search or bash commands.** -## Prerequisites +## GPU Environment -- **GPU**: AMD GPUs with ROCm compatibility (tested on MI300X, MI350X & MI355X) - > **Note**: See below for instructions on development without AMD GPU access -- **ROCm/HIP Toolkit**: Required for building C++/HIP components -- **Docker/Apptainer**: Recommended for containerized development +**You have AMD GPU access.** Your runner has 8x MI325X GPUs (gfx942) with ROCm installed. Do NOT skip GPU operations or say you lack GPU access. You can and should: -## Build +- Run benchmarks and post results (charts, tables) in PR comments +- Run tests with `torchrun --nproc_per_node=N` for multi-GPU tests +- Use `rocm-smi` to verify GPU status +- Run `python -c "import torch; print(torch.cuda.device_count())"` to confirm GPU count -### Docker Development Environment (Recommended) +When asked to run a benchmark, **run it and post the output**. Do not say you cannot. + +### Running multi-GPU tests and benchmarks + +Multi-GPU tests require `torch.distributed` initialization before pytest: ```bash -# Build and start development container (takes 45-60 minutes - NEVER CANCEL) -docker compose up --build -d +# Single GPU +pytest tests/unittests/ -v --tb=short -# Attach to running container -docker attach iris-dev +# Multi-GPU (N = number of GPUs) +torchrun --nproc_per_node=N -m pytest tests/ -v --tb=short -# Install Iris in development mode -cd iris && pip install -e ".[dev]" +# Benchmarks use iris.bench framework +torchrun --nproc_per_node=8 benchmark/ops/bench_.py ``` -### Alternative Docker Setup -```bash -# Build Docker image manually -./docker/build.sh # Takes 45-60 minutes +### iris.bench framework -# Run container -./docker/run.sh +Benchmarks use the declarative `iris.bench` framework. See existing `benchmark/ops/bench_*.py` files for examples. Output includes latency, throughput, and bandwidth tables. When posting benchmark results in PR comments, format as markdown tables. -# Install Iris -cd iris && pip install -e ".[dev]" -``` +## Prerequisites -### Apptainer Setup -```bash -# Build and run Apptainer image -./apptainer/build.sh -./apptainer/run.sh +- **GPU**: AMD GPUs with ROCm compatibility (tested on MI300X, MI325X, MI350X & MI355X) +- **ROCm/HIP Toolkit**: Required for building C++/HIP components +- **Docker/Apptainer**: Recommended for containerized development -# Install Iris -pip install -e ".[dev]" -``` +## Build -### Local Development (Not Recommended) +iris is already installed in your environment via `pip install -e .` in the setup steps. You do not need to build or install anything. If you need to reinstall after modifying `setup.py` or C extensions: ```bash -# Requires ROCm/HIP toolkit installation pip install -e ".[dev]" ``` -### Development Without AMD GPU -If you don't have access to AMD GPUs, you can still contribute to the project: -- **Code Editing**: Start editing code directly in your local environment -- **CI Testing**: The project has comprehensive CI pipelines that will test your changes automatically. You can check the CI logs if your changes fail to understand what went wrong. -- **Local Validation**: Run linting and formatting locally: `ruff check . --fix && ruff format .` - ## Run ### Testing diff --git a/.gitignore b/.gitignore index d8f9754f7..0bc6bbc55 100644 --- a/.gitignore +++ b/.gitignore @@ -28,6 +28,8 @@ omni*.pdf slurm*.out *.egg-info +*.backup +*.with_chunked examples/gemm/results/* asm/ @@ -57,4 +59,8 @@ gpucore.* logs/ *.cap hsakmt_counters.csv -core \ No newline at end of file +core +.intellikit/ +.github/agents/docs/benchmark-results/ +.github/agents/ +docs/benchmark-results/*.png diff --git a/benchmark/ops/all_gather_matmul/__init__.py b/benchmark/ops/all_gather_matmul/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmark/ops/all_gather_matmul/auto_config.py b/benchmark/ops/all_gather_matmul/auto_config.py new file mode 100644 index 000000000..0e8990886 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/auto_config.py @@ -0,0 +1,582 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Auto-selection mechanism for fused AG+MM kernel configurations. + +Given problem dimensions (M, N, K), transpose mode, world_size, and GPU +architecture, this module selects the best known configuration or returns +a sensible default. For world sizes where iris AG+MM is known to lose +against PyTorch (ws<8), the default disables iris and signals fallback. + +Config files live under: + configs/{arch}/{transpose}/ws{N}.json + +Each config file contains: + - Per-shape champion configs with all kernel parameters in a flat "params" dict + - A "default_params" dict with architecture-appropriate defaults + - Params include FusedConfig fields (block_size_m, etc.) and HBM buffer + kernel params (k_per_flag, num_fetch_sms, num_warps, num_stages, etc.) + +Transpose coverage: + The iris AG+MM kernel (`_fused_all_gather_matmul_kernel`) uses stride-based + addressing (`stride_am, stride_ak, stride_bk, stride_bn`), so transpose + layouts are handled implicitly by tensor strides. Config files exist for + all four layouts (NN, TN, NT, TT) under each architecture directory. + Only NN has per-shape champion configs from benchmarking (3,489 trials). + TN/NT/TT files contain heuristic defaults only (empty shapes dict) and are + marked enabled at ws>=8 to allow heuristic fallback. All transposes at ws<8 + are disabled (NO-GO based on NN benchmarks). + +Usage: + >>> from auto_config import select_ag_mm_config + >>> result = select_ag_mm_config(M=131072, N=16384, K=16384, world_size=8) + >>> if result.enabled: + ... config = result.to_fused_config() + ... hbm_params = result.hbm_buffer_params # k_per_flag, num_fetch_sms, etc. + ... shmem.ops.all_gather_matmul(output, A, B, config=config) + ... else: + ... # Fallback to PyTorch all_gather + matmul + ... ... + + >>> # List all regression test sizes + >>> from auto_config import load_regression_sizes + >>> sizes = load_regression_sizes() +""" + +import json +import os +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +from iris.ops.config import FusedConfig + +# Config files live alongside this module +_CONFIGS_DIR = Path(__file__).parent / "configs" + +# FusedConfig field names — everything else in "params" is an HBM buffer param +_FUSED_CONFIG_FIELDS = {f.name for f in FusedConfig.__dataclass_fields__.values()} + +# HBM buffer param names (kernel launch params, not FusedConfig fields) +_HBM_BUFFER_FIELDS = { + "k_per_flag", + "num_fetch_sms", + "num_fetch_stages", + "first_stage_fetch_sms", + "fetch_block_m", + "fetch_block_k", + "num_warps", + "num_stages", +} + +# In-memory cache: (arch, transpose, world_size) -> loaded JSON data +_config_cache: Dict[Tuple[str, str, int], dict] = {} + +# Cached GPU architecture detection result +_detected_arch: Optional[str] = None + +# Supported transpose modes. The AG+MM kernel only supports NN layout. +# TN/NT/TT would require kernel-level changes to permute strides. +SUPPORTED_TRANSPOSES = ("NN",) + +# Supported GPU architectures with tuned configs +SUPPORTED_ARCHITECTURES = ("mi300x", "mi355x") + +# Map gfx target IDs to architecture names used in config paths +_GFX_TO_ARCH = { + "gfx942": "mi300x", # MI300X, MI300A + "gfx950": "mi355x", # MI355X +} + + +def detect_gpu_arch() -> str: + """Auto-detect GPU architecture from the current system. + + Detection order: + 1. IRIS_GPU_ARCH environment variable (override) + 2. rocm-smi --showproductname parsing + 3. rocminfo gfx target parsing + 4. Falls back to "mi300x" (most common deployment target) + + Returns: + Architecture string (e.g., "mi300x") suitable for config lookup. + """ + global _detected_arch + if _detected_arch is not None: + return _detected_arch + + # 1. Environment variable override + env_arch = os.environ.get("IRIS_GPU_ARCH", "").strip().lower() + if env_arch: + _detected_arch = env_arch + return _detected_arch + + # 2. Try rocminfo for gfx target + try: + result = subprocess.run( + ["rocminfo"], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + for line in result.stdout.splitlines(): + line_stripped = line.strip().lower() + if "name:" in line_stripped and "gfx" in line_stripped: + for gfx_id, arch_name in _GFX_TO_ARCH.items(): + if gfx_id in line_stripped: + _detected_arch = arch_name + return _detected_arch + except (FileNotFoundError, subprocess.TimeoutExpired, OSError): + pass + + # 3. Fallback to MI300X (most common deployment target) + _detected_arch = "mi300x" + return _detected_arch + + +@dataclass +class AutoConfigResult: + """Result of auto-config lookup. + + Attributes: + enabled: If False, iris AG+MM should NOT be used; fallback to PyTorch. + config_params: Dict of FusedConfig parameters (only valid if enabled=True). + hbm_buffer_params: Dict of HBM buffer-specific kernel params + (k_per_flag, num_fetch_sms, num_fetch_stages, first_stage_fetch_sms). + source: Human-readable description of where this config came from. + shape_key: The MxNxK key that matched (None if heuristic/default). + expected_iris_ms: Expected kernel time in ms on target GPU (None if unknown). + """ + + enabled: bool = False + config_params: Dict = field(default_factory=dict) + hbm_buffer_params: Dict = field(default_factory=dict) + source: str = "default" + shape_key: Optional[str] = None + expected_iris_ms: Optional[float] = None + + def to_fused_config(self) -> FusedConfig: + """Convert to FusedConfig for use with iris.ops functions. + + Raises: + RuntimeError: If this config is disabled (enabled=False). + """ + if not self.enabled: + raise RuntimeError( + f"Cannot create FusedConfig: iris AG+MM is disabled for this " + f"configuration. Reason: {self.source}. " + f"Use PyTorch all_gather + matmul instead." + ) + # Filter to only fields FusedConfig accepts + valid_fields = {f.name for f in FusedConfig.__dataclass_fields__.values()} + filtered = {k: v for k, v in self.config_params.items() if k in valid_fields} + return FusedConfig(**filtered) + + +def _split_params(params: Dict) -> Tuple[Dict, Dict]: + """Split a flat params dict into (config_params, hbm_buffer_params). + + FusedConfig fields go into config_params. + Everything else (num_warps, num_stages, k_per_flag, etc.) goes into hbm_buffer_params. + """ + config_params = {} + hbm_params = {} + for k, v in params.items(): + if k in _FUSED_CONFIG_FIELDS: + config_params[k] = v + else: + hbm_params[k] = v + return config_params, hbm_params + + +def _extract_shape_params(shape_data: Dict) -> Tuple[Dict, Dict]: + """Extract config_params and hbm_buffer_params from shape data. + + Supports both the new flat "params" format and the legacy split + "config" + "hbm_buffer_params" format for backward compatibility. + """ + if "params" in shape_data: + return _split_params(shape_data["params"]) + return shape_data.get("config", {}), shape_data.get("hbm_buffer_params", {}) + + +def _extract_default_params(data: Dict) -> Optional[Tuple[Dict, Dict]]: + """Extract default config_params and hbm_buffer_params from file-level defaults. + + Supports both "default_params" (flat) and legacy "default_config" + "default_hbm_buffer_params". + Returns None if no defaults are available. + """ + if "default_params" in data and data["default_params"] is not None: + return _split_params(data["default_params"]) + default_config = data.get("default_config") + if default_config: + return default_config, data.get("default_hbm_buffer_params", {}) + return None + + +def _load_config_file(arch: str, transpose: str, world_size: int) -> Optional[dict]: + """Load and cache a config JSON file. + + Args: + arch: GPU architecture identifier (e.g., "mi300x"). + transpose: Transpose mode (e.g., "NN", "NT", "TN", "TT"). + world_size: Number of ranks. + + Returns: + Parsed JSON dict, or None if file doesn't exist. + """ + cache_key = (arch, transpose, world_size) + if cache_key in _config_cache: + return _config_cache[cache_key] + + config_path = _CONFIGS_DIR / arch / transpose / f"ws{world_size}.json" + if not config_path.exists(): + _config_cache[cache_key] = None + return None + + with open(config_path, "r") as f: + data = json.load(f) + + _config_cache[cache_key] = data + return data + + +def _load_default_config() -> dict: + """Load the global default config.""" + default_path = _CONFIGS_DIR / "default_config.json" + if default_path.exists(): + with open(default_path, "r") as f: + return json.load(f) + return {} + + +def _find_nearest_shape(M: int, N: int, K: int, shapes: dict, tolerance: float = 0.15) -> Optional[str]: + """Find the nearest matching shape in the config database. + + Uses log-space geometric distance to find shapes that are structurally + similar (within `tolerance` ratio per dimension). This avoids falling + back to heuristic when the user's problem is close to a champion shape. + + Args: + M, N, K: Target dimensions. + shapes: Dict of shape_key -> shape_data from the config file. + tolerance: Max fractional distance per dimension (default 15%). + + Returns: + The shape_key of the nearest match, or None if no shape is close enough. + """ + import math + + best_key = None + best_dist = float("inf") + + for shape_key, shape_data in shapes.items(): + sm, sn, sk = shape_data["M"], shape_data["N"], shape_data["K"] + + # Check per-dimension ratio tolerance + if sm == 0 or sn == 0 or sk == 0: + continue + rm = abs(M - sm) / sm + rn = abs(N - sn) / sn + rk = abs(K - sk) / sk + + if rm > tolerance or rn > tolerance or rk > tolerance: + continue + + # Geometric distance in log space + dist = math.sqrt( + math.log(max(M, 1) / max(sm, 1)) ** 2 + + math.log(max(N, 1) / max(sn, 1)) ** 2 + + math.log(max(K, 1) / max(sk, 1)) ** 2 + ) + if dist < best_dist: + best_dist = dist + best_key = shape_key + + return best_key + + +def _apply_heuristic(M: int, N: int, K: int, arch: str = "mi300x") -> Tuple[Dict, Dict]: + """Apply heuristic rules to generate config + HBM buffer params. + + Based on optimization data: + - MI300X: 3,489 measured trials + - MI355X: Optuna TPE + broad sweep + + Args: + M: Rows dimension. + N: Columns dimension. + K: Reduction dimension. + arch: GPU architecture for arch-specific heuristics. + + Returns: + Tuple of (config_params dict, hbm_buffer_params dict). + """ + bk = 64 + num_k_blocks = K // bk + + if arch == "mi355x": + bm = 256 + num_m_tiles = M // bm + gm = 4 if M <= 32768 else 8 + config_params = { + "block_size_m": bm, + "block_size_n": 256, + "block_size_k": bk, + "group_size_m": gm, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": True, + } + kpf = 8 if num_k_blocks <= 512 else 16 + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + hbm_params = { + "k_per_flag": kpf, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52, + } + return config_params, hbm_params + + # MI300X heuristics + if M <= 16384: + bm = 128 + else: + bm = 256 + + num_m_tiles = M // bm + + if M <= 8192: + gm = 8 + elif M <= 16384: + gm = 16 + else: + gm = 24 + + config_params = { + "block_size_m": bm, + "block_size_n": 256, + "block_size_k": bk, + "group_size_m": gm, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": True, + } + + if num_k_blocks >= 512: + kpf = 64 + elif num_k_blocks >= 128: + kpf = 16 + elif num_k_blocks >= 64: + kpf = 8 + else: + kpf = 4 + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + + if num_m_tiles <= 8: + fs = 4 + elif num_m_tiles <= 32: + fs = 16 + elif num_m_tiles <= 128: + fs = 32 + else: + fs = 52 + + if num_m_tiles >= 512: + nfs = 4 + elif num_m_tiles >= 64: + nfs = 2 + else: + nfs = 1 + + hbm_params = { + "k_per_flag": kpf, + "num_fetch_sms": fs, + "num_fetch_stages": nfs, + "first_stage_fetch_sms": 64, + } + + return config_params, hbm_params + + +def select_ag_mm_config( + M: int, + N: int, + K: int, + world_size: int, + transpose: str = "NN", + arch: str = "auto", +) -> AutoConfigResult: + """Select the best AG+MM config for the given problem. + + Lookup order: + 1. Exact shape match in benchmark/ops/all_gather_matmul/configs/{arch}/{transpose}/ws{world_size}.json + 2. Heuristic-based config from the same file's defaults + 3. Global default from benchmark/ops/all_gather_matmul/configs/default_config.json + + For world sizes where iris is known to lose (ws<8 on MI300X), returns + a disabled result signaling fallback to PyTorch. + + Args: + M: Number of rows (or M_local * world_size for AG+MM). + N: Number of columns. + K: Reduction dimension. + world_size: Number of ranks in the communicator. + transpose: Transpose mode ("NN", "NT", "TN", "TT"). Default "NN". + arch: GPU architecture ("mi300x", etc.) or "auto" to auto-detect. + Default "auto". Set IRIS_GPU_ARCH env var to override. + + Returns: + AutoConfigResult with .enabled indicating whether to use iris, + .to_fused_config() to get the FusedConfig if enabled, and + .hbm_buffer_params with kernel-specific parameters. + + Example: + >>> result = select_ag_mm_config(131072, 16384, 16384, world_size=8) + >>> result.enabled + True + >>> config = result.to_fused_config() + >>> result.hbm_buffer_params + {'k_per_flag': 32, 'num_fetch_sms': 4, 'num_fetch_stages': 64, 'first_stage_fetch_sms': 52} + + >>> result = select_ag_mm_config(4096, 4096, 4096, world_size=2) + >>> result.enabled + False + """ + transpose = transpose.upper() + if arch == "auto": + arch = detect_gpu_arch() + else: + arch = arch.lower() + + # Step 1: Try to load the specific config file + data = _load_config_file(arch, transpose, world_size) + + if data is not None: + # Check if this world_size is enabled + if not data.get("enabled", True): + return AutoConfigResult( + enabled=False, + source=f"Disabled by config: {arch}/{transpose}/ws{world_size}.json — {data.get('reason', 'no reason given')}", + ) + + # Look for exact shape match + shape_key = f"{M}x{N}x{K}" + shapes = data.get("shapes", {}) + if shape_key in shapes: + shape_data = shapes[shape_key] + cfg, hbm = _extract_shape_params(shape_data) + return AutoConfigResult( + enabled=True, + config_params=cfg, + hbm_buffer_params=hbm, + source=f"Exact match: {arch}/{transpose}/ws{world_size}.json [{shape_data.get('label', shape_key)}]", + shape_key=shape_key, + expected_iris_ms=shape_data.get("expected_iris_ms"), + ) + + # No exact match — try nearest champion shape (within 15% per dim) + nearest_key = _find_nearest_shape(M, N, K, shapes) + if nearest_key is not None: + nearest_data = shapes[nearest_key] + cfg, hbm = _extract_shape_params(nearest_data) + return AutoConfigResult( + enabled=True, + config_params=cfg, + hbm_buffer_params=hbm, + source=f"Nearest match: {arch}/{transpose}/ws{world_size}.json [{nearest_data.get('label', nearest_key)}] (target {M}x{N}x{K} ≈ {nearest_key})", + shape_key=nearest_key, + expected_iris_ms=nearest_data.get("expected_iris_ms"), + ) + + # No nearby match — use heuristic + file defaults + defaults = _extract_default_params(data) + if defaults is not None: + file_default_config, file_default_hbm = defaults + heuristic_config, heuristic_hbm = _apply_heuristic(M, N, K, arch=arch) + merged_config = {**file_default_config, **heuristic_config} + merged_hbm = {**file_default_hbm, **heuristic_hbm} + return AutoConfigResult( + enabled=True, + config_params=merged_config, + hbm_buffer_params=merged_hbm, + source=f"Heuristic (no exact shape match in {arch}/{transpose}/ws{world_size}.json)", + ) + + # Step 2: No config file found — check global default + default_data = _load_default_config() + ws_gate = default_data.get("world_size_gate", {}) + min_ws = ws_gate.get("min_world_size", 8) + + if world_size < min_ws: + return AutoConfigResult( + enabled=False, + source=f"world_size={world_size} < min_world_size={min_ws} (global default). {ws_gate.get('reason', '')}", + ) + + # World size OK but no specific config — apply heuristic + heuristic_config, heuristic_hbm = _apply_heuristic(M, N, K, arch=arch) + return AutoConfigResult( + enabled=True, + config_params=heuristic_config, + hbm_buffer_params=heuristic_hbm, + source=f"Heuristic fallback (no config file for {arch}/{transpose}/ws{world_size})", + ) + + +def list_known_shapes( + world_size: int, + transpose: str = "NN", + arch: str = "mi300x", +) -> list: + """List all known shape configurations for a given world_size/transpose/arch. + + Returns: + List of dicts with keys: shape_key, label, M, N, K, expected_iris_ms. + """ + data = _load_config_file(arch, transpose.upper(), world_size) + if data is None or not data.get("enabled", True): + return [] + + result = [] + for shape_key, shape_data in data.get("shapes", {}).items(): + result.append( + { + "shape_key": shape_key, + "label": shape_data.get("label", ""), + "M": shape_data["M"], + "N": shape_data["N"], + "K": shape_data["K"], + "expected_iris_ms": shape_data.get("expected_iris_ms"), + } + ) + + result.sort(key=lambda x: x.get("expected_iris_ms") or float("inf")) + return result + + +def load_regression_sizes() -> List[Dict]: + """Load regression test sizes from the JSON config file. + + Returns: + List of regression size dicts, each with: name, M, N, K, tier, + description, world_sizes, expected, regression_threshold_pct. + """ + reg_path = _CONFIGS_DIR / "regression_sizes.json" + if not reg_path.exists(): + return [] + with open(reg_path, "r") as f: + data = json.load(f) + return data.get("sizes", []) + + +def clear_config_cache(): + """Clear the in-memory config cache. Useful after modifying config files.""" + _config_cache.clear() diff --git a/benchmark/ops/all_gather_matmul/configs/default_config.json b/benchmark/ops/all_gather_matmul/configs/default_config.json new file mode 100644 index 000000000..ff96ac1f0 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/default_config.json @@ -0,0 +1,27 @@ +{ + "_meta": { + "description": "Global default fallback config for AG+MM operations. Disables iris AG+MM for ws<8 (fallback to PyTorch).", + "source": "benchmarking on MI300X (gfx942), 3489 measured trials", + "date": "2026-04-13" + }, + "world_size_gate": { + "min_world_size": 8, + "reason": "ws=2 best 0.89x, ws=4 best 0.86x vs PyTorch. Only ws>=8 is production-ready." + }, + "config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws2.json new file mode 100644 index 000000000..f5f0fd87e --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws2.json @@ -0,0 +1,159 @@ +{ + "_meta": { + "description": "AG+MM ws=2 on MI300X \u2014 DISABLED (loses vs PyTorch on all shapes)", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 AG transfers from 1 peer only. GEMM dominates latency. Fetch SM overhead exceeds overlap benefit. LDS overflow forces ns=1, imposing 15-35% penalty.", + "shapes": { + "8192x8192x262144": { + "label": "g5", + "M": 8192, + "N": 8192, + "K": 262144, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 186.062 + }, + "16384x16384x131072": { + "label": "g1", + "M": 16384, + "N": 16384, + "K": 131072, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 8, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 153.042 + }, + "4096x14336x4096": { + "label": "mixtral_gate", + "M": 4096, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 2.334 + }, + "4096x11008x4096": { + "label": "llama7b_gate", + "M": 4096, + "N": 11008, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 1.784 + }, + "4096x4096x4096": { + "label": "pow2_4k", + "M": 4096, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 1.109 + }, + "5120x13824x5120": { + "label": "llama13b_gate", + "M": 5120, + "N": 13824, + "K": 5120, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 4.144 + }, + "4096x4096x11008": { + "label": "llama7b_down", + "M": 4096, + "N": 4096, + "K": 11008, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 2.477 + } + }, + "default_params": null +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws4.json new file mode 100644 index 000000000..30b7a6bef --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws4.json @@ -0,0 +1,201 @@ +{ + "_meta": { + "description": "AG+MM ws=4 on MI300X \u2014 DISABLED (loses vs PyTorch on all shapes)", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses on all tested shapes. K=4096 shapes crash at ns=2 due to LDS overflow (65540>65536). ns=1 workaround constrains pipelining depth below break-even.", + "shapes": { + "262144x8192x8192": { + "label": "g6", + "M": 262144, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 64, + "num_fetch_sms": 52, + "num_fetch_stages": 4, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 161.027 + }, + "8192x8192x262144": { + "label": "g5", + "M": 8192, + "N": 8192, + "K": 262144, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 167.944 + }, + "131072x16384x16384": { + "label": "g2", + "M": 131072, + "N": 16384, + "K": 16384, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 64, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 209.556 + }, + "16384x16384x131072": { + "label": "g1", + "M": 16384, + "N": 16384, + "K": 131072, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 16, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 239.757 + }, + "4096x14336x4096": { + "label": "mixtral_gate", + "M": 4096, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 2.192 + }, + "4096x11008x4096": { + "label": "llama7b_gate", + "M": 4096, + "N": 11008, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 2.163 + }, + "4096x4096x4096": { + "label": "pow2_4k", + "M": 4096, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 1.494 + }, + "5120x13824x5120": { + "label": "llama13b_gate", + "M": 5120, + "N": 13824, + "K": 5120, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 3.257 + }, + "4096x4096x11008": { + "label": "llama7b_down", + "M": 4096, + "N": 4096, + "K": 11008, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 16 + }, + "expected_iris_ms": 2.578 + } + }, + "default_params": null +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws8.json new file mode 100644 index 000000000..2c518f1df --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NN/ws8.json @@ -0,0 +1,290 @@ +{ + "_meta": { + "description": "Champion configs for HBM buffer AG+MM ws=8 on MI300X (gfx942)", + "source": "sweep (3489 trials), optimize-loop iter3", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM" + }, + "enabled": true, + "shapes": { + "262144x8192x8192": { + "label": "g6", + "description": "Llama-70B MLP hidden x hidden \u2014 M-dominant", + "M": 262144, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 64, + "num_fetch_sms": 52, + "num_fetch_stages": 4, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 139.069 + }, + "131072x16384x16384": { + "label": "g2", + "description": "Llama MLP variant \u2014 balanced large", + "M": 131072, + "N": 16384, + "K": 16384, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 64, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 167.345 + }, + "147456x28672x4096": { + "label": "g14", + "description": "Llama-70B up-projection medium batch", + "M": 147456, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 59, + "num_fetch_stages": 36, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 74.244 + }, + "229376x28672x4096": { + "label": "g16", + "description": "Llama-70B up-projection mid batch", + "M": 229376, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 56, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 114.265 + }, + "327680x28672x4096": { + "label": "g15", + "description": "Llama-70B up-projection large batch", + "M": 327680, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 32, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 162.136 + }, + "8192x8192x262144": { + "label": "g5", + "description": "K-dominant square", + "M": 8192, + "N": 8192, + "K": 262144, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 4, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 217.725 + }, + "16384x16384x131072": { + "label": "g1", + "description": "K-dominant large", + "M": 16384, + "N": 16384, + "K": 131072, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 16, + "num_fetch_stages": 8, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 223.748 + }, + "196608x18432x16384": { + "label": "g9", + "description": "Large balanced shape", + "M": 196608, + "N": 18432, + "K": 16384, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 32, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 32 + }, + "expected_iris_ms": 266.608 + }, + "262144x28672x8192": { + "label": "g8", + "description": "Large wide shape", + "M": 262144, + "N": 28672, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 128, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 32 + }, + "expected_iris_ms": 278.546 + }, + "4096x14336x4096": { + "label": "mixtral_gate", + "description": "Mixtral gate projection", + "M": 4096, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "expected_iris_ms": 1.933 + }, + "4096x11008x4096": { + "label": "llama7b_gate", + "description": "Llama-7B gate projection", + "M": 4096, + "N": 11008, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 128, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "expected_iris_ms": 1.946 + }, + "4096x4096x4096": { + "label": "pow2_4k", + "description": "Small power-of-2 square shape", + "M": 4096, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 128, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 8, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "expected_iris_ms": 1.512 + } + }, + "default_params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws2.json new file mode 100644 index 000000000..897be3f2c --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws2.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=2 NT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 loses vs PyTorch on all tested shapes. LDS overflow forces ns=1, imposing 15-35% perf penalty." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws4.json new file mode 100644 index 000000000..cc1f9d297 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws4.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=4 NT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses vs PyTorch on all tested shapes. Best measured: 0.856x. LDS overflow at K=4096." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws8.json new file mode 100644 index 000000000..873cb76e1 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/NT/ws8.json @@ -0,0 +1,31 @@ +{ + "_meta": { + "description": "AG+MM ws=8 NT transpose on MI300X — heuristic defaults (no per-shape benchmarks yet)", + "source": "heuristic extrapolation from NN transpose champion data", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "data_tag": "heuristic", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM, B transposed (K×N → N×K)" + }, + "enabled": true, + "shapes": {}, + "default_config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "default_hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "heuristic_rules": { + "note": "Uses same heuristic as NN transpose. Shape-specific tuning pending." + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws2.json new file mode 100644 index 000000000..2fe67e154 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws2.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=2 TN transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 loses vs PyTorch on all tested shapes. LDS overflow forces ns=1, imposing 15-35% perf penalty." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws4.json new file mode 100644 index 000000000..c8977d5f0 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws4.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=4 TN transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses vs PyTorch on all tested shapes. Best measured: 0.856x. LDS overflow at K=4096." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws8.json new file mode 100644 index 000000000..df9a5b3f9 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TN/ws8.json @@ -0,0 +1,31 @@ +{ + "_meta": { + "description": "AG+MM ws=8 TN transpose on MI300X — heuristic defaults (no per-shape benchmarks yet)", + "source": "heuristic extrapolation from NN transpose champion data", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "data_tag": "heuristic", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM, A transposed (M×K → K×M)" + }, + "enabled": true, + "shapes": {}, + "default_config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "default_hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "heuristic_rules": { + "note": "Uses same heuristic as NN transpose. Shape-specific tuning pending." + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws2.json new file mode 100644 index 000000000..cc2c2497c --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws2.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=2 TT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=2 loses vs PyTorch on all tested shapes. LDS overflow forces ns=1, imposing 15-35% perf penalty." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws4.json new file mode 100644 index 000000000..55ee5f423 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws4.json @@ -0,0 +1,10 @@ +{ + "_meta": { + "description": "AG+MM ws=4 TT transpose on MI300X — DISABLED", + "source": "ws<8 is NO-GO across all transposes", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13" + }, + "enabled": false, + "reason": "ws=4 loses vs PyTorch on all tested shapes. Best measured: 0.856x. LDS overflow at K=4096." +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws8.json new file mode 100644 index 000000000..a184b41a4 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi300x/TT/ws8.json @@ -0,0 +1,31 @@ +{ + "_meta": { + "description": "AG+MM ws=8 TT transpose on MI300X — heuristic defaults (no per-shape benchmarks yet)", + "source": "heuristic extrapolation from NN transpose champion data", + "gpu": "AMD Instinct MI300X (gfx942)", + "date": "2026-04-13", + "data_tag": "heuristic", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM, both A and B transposed" + }, + "enabled": true, + "shapes": {}, + "default_config": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 24, + "num_warps": 8, + "num_stages": 2, + "num_xcds": 8, + "allow_tf32": true + }, + "default_hbm_buffer_params": { + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 64 + }, + "heuristic_rules": { + "note": "Uses same heuristic as NN transpose. Shape-specific tuning pending." + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws2.json b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws2.json new file mode 100644 index 000000000..9c07592f1 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws2.json @@ -0,0 +1,17 @@ +{ + "_meta": { + "description": "AG+MM ws=2 on MI355X (gfx950) — defaults only, needs tuning", + "gpu": "AMD Instinct MI355X (gfx950)", + "date": "2026-04-15", + "validated": "unvalidated — no shape-specific tuning yet" + }, + "enabled": true, + "shapes": {}, + "default_params": { + "block_size_m": 256, "block_size_n": 256, "block_size_k": 64, + "group_size_m": 4, "num_xcds": 8, "allow_tf32": true, + "num_warps": 8, "num_stages": 2, + "k_per_flag": 16, "num_fetch_sms": 4, + "num_fetch_stages": 1, "first_stage_fetch_sms": 52 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws4.json b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws4.json new file mode 100644 index 000000000..3d64610a2 --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws4.json @@ -0,0 +1,17 @@ +{ + "_meta": { + "description": "AG+MM ws=4 on MI355X (gfx950) — defaults only, needs tuning", + "gpu": "AMD Instinct MI355X (gfx950)", + "date": "2026-04-15", + "validated": "unvalidated — no shape-specific tuning yet" + }, + "enabled": true, + "shapes": {}, + "default_params": { + "block_size_m": 256, "block_size_n": 256, "block_size_k": 64, + "group_size_m": 4, "num_xcds": 8, "allow_tf32": true, + "num_warps": 8, "num_stages": 2, + "k_per_flag": 16, "num_fetch_sms": 4, + "num_fetch_stages": 1, "first_stage_fetch_sms": 52 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws8.json b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws8.json new file mode 100644 index 000000000..17fa7051a --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/mi355x/NN/ws8.json @@ -0,0 +1,246 @@ +{ + "_meta": { + "description": "Champion configs for HBM buffer AG+MM ws=8 on MI355X (gfx950)", + "source": "Optuna TPE + broad sweep", + "gpu": "AMD Instinct MI355X (gfx950)", + "date": "2026-04-15", + "convention": "Shapes are (M, N, K) for col-parallel (M-sharded) AG+MM" + }, + "enabled": true, + "shapes": { + "262144x8192x8192": { + "label": "g6", + "description": "Llama-70B MLP hidden x hidden \u2014 M-dominant", + "M": 262144, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 37.558 + }, + "131072x4096x4096": { + "label": "a3", + "description": "Output proj, 128K seq, 4K hidden", + "M": 131072, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 16, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 6.304 + }, + "65536x8192x28672": { + "label": "f4", + "description": "Llama 70B down, 64K seq", + "M": 65536, + "N": 8192, + "K": 28672, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 35.959 + }, + "32768x8192x8192": { + "label": "l1", + "description": "Training batch 32K", + "M": 32768, + "N": 8192, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 5.445 + }, + "65536x8192x4096": { + "label": "a2", + "description": "QKV proj, 64K seq, 8K hidden", + "M": 65536, + "N": 8192, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 32, + "num_fetch_stages": 2, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 5.775 + }, + "32768x4096x14336": { + "label": "f2", + "description": "Llama 7B down, 32K seq", + "M": 32768, + "N": 4096, + "K": 14336, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 1, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 5.961 + }, + "32768x4096x4096": { + "label": "a1", + "description": "QKV proj, 32K seq, 4K hidden", + "M": 32768, + "N": 4096, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 4, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 16, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 256 + }, + "expected_iris_ms": 1.864 + }, + "65536x28672x8192": { + "label": "f3", + "description": "Llama 70B gate/up, 64K seq", + "M": 65536, + "N": 28672, + "K": 8192, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 32.303 + }, + "16384x28672x4096": { + "label": "m2", + "description": "Large FFN up (Llama 70B-like)", + "M": 16384, + "N": 28672, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 2, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 8, + "num_fetch_sms": 2, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 256 + }, + "expected_iris_ms": 4.658 + }, + "32768x14336x4096": { + "label": "f1", + "description": "Llama 7B gate/up, 32K seq", + "M": 32768, + "N": 14336, + "K": 4096, + "params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 8, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + }, + "expected_iris_ms": 4.86 + } + }, + "default_params": { + "block_size_m": 256, + "block_size_n": 256, + "block_size_k": 64, + "group_size_m": 4, + "num_xcds": 8, + "allow_tf32": true, + "num_warps": 8, + "num_stages": 2, + "k_per_flag": 16, + "num_fetch_sms": 4, + "num_fetch_stages": 1, + "first_stage_fetch_sms": 52 + } +} diff --git a/benchmark/ops/all_gather_matmul/configs/regression_sizes.json b/benchmark/ops/all_gather_matmul/configs/regression_sizes.json new file mode 100644 index 000000000..40a497b0a --- /dev/null +++ b/benchmark/ops/all_gather_matmul/configs/regression_sizes.json @@ -0,0 +1,101 @@ +{ + "_meta": { + "description": "Regression test sizes for HBM buffer AG+MM kernel across ws=2/4/8", + "source": "28 measured shapes (12 ws=8, 9 ws=4, 7 ws=2) from 3489 trials", + "gpu_target": "MI300X (gfx942)", + "date": "2026-04-13", + "usage": "from iris.ops import load_regression_sizes" + }, + "sizes": [ + { + "name": "g2_ws8", + "label": "g2", + "description": "Llama MLP variant — balanced large (highest speedup)", + "M": 131072, "N": 16384, "K": 16384, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.343, "ws8_tflops": 420.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g15_ws8", + "label": "g15", + "description": "Llama-70B up-projection large batch — highest TFLOPS", + "M": 327680, "N": 28672, "K": 4096, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.284, "ws8_tflops": 474.7}, + "regression_threshold_pct": 10 + }, + { + "name": "g14_ws8", + "label": "g14", + "description": "Llama-70B up-projection medium batch", + "M": 147456, "N": 28672, "K": 4096, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.288, "ws8_tflops": 466.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g16_ws8", + "label": "g16", + "description": "Llama-70B up-projection mid batch", + "M": 229376, "N": 28672, "K": 4096, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.277, "ws8_tflops": 471.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g5_ws8", + "label": "g5", + "description": "K-dominant square — M-small, needs bm=128", + "M": 8192, "N": 8192, "K": 262144, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.224, "ws8_tflops": 161.6}, + "regression_threshold_pct": 10 + }, + { + "name": "g6_ws8", + "label": "g6", + "description": "Llama-70B MLP hidden x hidden — M-dominant", + "M": 262144, "N": 8192, "K": 8192, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.200, "ws8_tflops": 253.0}, + "regression_threshold_pct": 10 + }, + { + "name": "g1_ws8", + "label": "g1", + "description": "K-dominant large — parity shape", + "M": 16384, "N": 16384, "K": 131072, + "tier": "champion", + "world_sizes": [8], + "expected": {"ws8_speedup": 1.136, "ws8_tflops": 314.5}, + "regression_threshold_pct": 10 + }, + { + "name": "g5_ws2_disabled", + "label": "g5", + "description": "Best ws=2 shape — still loses vs PyTorch (0.887x). Verifies fallback.", + "M": 8192, "N": 8192, "K": 262144, + "tier": "disabled", + "world_sizes": [2], + "expected": {"ws2_speedup": 0.887, "ws2_disabled": true}, + "regression_threshold_pct": null + }, + { + "name": "g6_ws4_disabled", + "label": "g6", + "description": "Best ws=4 shape — still loses vs PyTorch (0.856x). Verifies fallback.", + "M": 262144, "N": 8192, "K": 8192, + "tier": "disabled", + "world_sizes": [4], + "expected": {"ws4_speedup": 0.856, "ws4_disabled": true}, + "regression_threshold_pct": null + } + ] +} diff --git a/benchmark/ops/bench_all_gather_matmul.py b/benchmark/ops/bench_all_gather_matmul.py index 9a50d3180..daf98492e 100644 --- a/benchmark/ops/bench_all_gather_matmul.py +++ b/benchmark/ops/bench_all_gather_matmul.py @@ -2,11 +2,26 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. -"""Benchmark for fused all-gather + GEMM (iris.ops).""" +"""Benchmark for all-gather + GEMM: RCCL baseline vs iris HBM-buffer prefetch. + +The HBM-buffer benchmark automatically loads tuned kernel parameters from +configs/{arch}/{transpose}/ws{N}.json when available. Run with --list-configs +to see which shapes have tuned configs for the current GPU. +""" + +import sys +import os import torch +import torch.distributed as dist import iris.bench as bench -from iris.ops import FusedConfig, all_gather_matmul_preamble +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer as _hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "all_gather_matmul")) +from auto_config import select_ag_mm_config @bench.register @@ -15,25 +30,88 @@ @bench.axis("N", [3584]) @bench.axis("K", [8192]) @bench.axis("dtype", [torch.float16]) -def all_gather_matmul(state, ctx): +def rccl_all_gather_matmul(state, ctx): + """PyTorch/RCCL baseline: all_gather (along K) + torch.mm""" + M, N, K = state["M"], state["N"], state["K"] + dtype = state["dtype"] + world_size = dist.get_world_size() + rank = dist.get_rank() + K_local = K // world_size + + # Per-rank seed for A (each rank holds different shards); shared seed for B + A_sharded = torch.randn( + (M, K_local), device="cuda", dtype=dtype, generator=torch.Generator("cuda").manual_seed(42 + rank) + ) + B = torch.randn((K, N), device="cuda", dtype=dtype, generator=torch.Generator("cuda").manual_seed(123)) + A_gathered_list = [torch.empty((M, K_local), device="cuda", dtype=dtype) for _ in range(world_size)] + C = torch.empty((M, N), device="cuda", dtype=dtype) + + state.set_flops(2 * M * N * K) + state.set_bytes((world_size - 1) * M * K_local * A_sharded.element_size()) + + def _run(): + dist.all_gather(A_gathered_list, A_sharded) + A_gathered = torch.cat(A_gathered_list, dim=1) + torch.mm(A_gathered, B, out=C) + + state.exec(_run) + + +@bench.register +@bench.axis("num_ranks", [2, 4, 8]) +@bench.axis("M", [1024, 4096, 16384]) +@bench.axis("N", [3584]) +@bench.axis("K", [8192]) +@bench.axis("dtype", [torch.float16]) +def all_gather_matmul_hbm_buffer(state, ctx): + """Iris HBM-buffer AG+MM with auto-tuned config from configs/ JSON files.""" M, N, K = state["M"], state["N"], state["K"] dtype = state["dtype"] world_size = ctx.get_num_ranks() K_local = K // world_size - A_sharded = ctx.zeros((M, K_local), dtype=dtype) - A_sharded.fill_(1.0) - B = torch.randn((K, N), device="cuda", dtype=dtype) - C = torch.zeros((M, N), device="cuda", dtype=dtype) + result = select_ag_mm_config(M, N, K, world_size=world_size) + if not result.enabled: + state.skip(f"iris disabled for ws={world_size}: {result.source}") + return + config = result.to_fused_config() + hbm = result.hbm_buffer_params - config = FusedConfig() - workspace = all_gather_matmul_preamble(ctx, A_sharded, B, config) + rank = ctx.get_rank() + # Per-rank seed for A (each rank holds different shards); shared seed for B + A_sharded = ctx.randn((M, K_local), dtype=dtype, generator=torch.Generator("cuda").manual_seed(42 + rank)) + B = torch.randn((K, N), device="cuda", dtype=dtype, generator=torch.Generator("cuda").manual_seed(123)) + C = ctx.zeros((M, N), dtype=dtype) + + workspace = all_gather_matmul_hbm_buffer_preamble( + ctx, + A_sharded, + B, + config, + k_per_flag=hbm.get("k_per_flag", 8), + ) state.set_flops(2 * M * N * K) state.set_bytes((world_size - 1) * M * K_local * A_sharded.element_size()) state.exec( - lambda: ctx.ops.all_gather_matmul(C, A_sharded, B, config=config, workspace=workspace), + lambda: _hbm_buffer( + ctx, + C, + A_sharded, + B, + config=config, + workspace=workspace, + num_fetch_sms=hbm.get("num_fetch_sms", 16), + k_per_flag=hbm.get("k_per_flag", 8), + fetch_block_m=hbm.get("fetch_block_m"), + fetch_block_k=hbm.get("fetch_block_k"), + num_warps=hbm.get("num_warps", 8), + num_stages=hbm.get("num_stages", 2), + num_fetch_stages=hbm.get("num_fetch_stages"), + first_stage_fetch_sms=hbm.get("first_stage_fetch_sms"), + ), + preamble_fn=lambda: (C.zero_(), workspace.locks.zero_()), ) diff --git a/docs/benchmark-results/bar_chart_mi300x_ws8.png b/docs/benchmark-results/bar_chart_mi300x_ws8.png new file mode 100644 index 000000000..eb3e89e4d Binary files /dev/null and b/docs/benchmark-results/bar_chart_mi300x_ws8.png differ diff --git a/docs/benchmark-results/bar_chart_ws8_corrected_rccl.png b/docs/benchmark-results/bar_chart_ws8_corrected_rccl.png new file mode 100644 index 000000000..358cd2321 Binary files /dev/null and b/docs/benchmark-results/bar_chart_ws8_corrected_rccl.png differ diff --git a/docs/benchmark-results/latency_comparison.png b/docs/benchmark-results/latency_comparison.png new file mode 100644 index 000000000..288fad091 Binary files /dev/null and b/docs/benchmark-results/latency_comparison.png differ diff --git a/docs/benchmark-results/tflops_comparison.png b/docs/benchmark-results/tflops_comparison.png new file mode 100644 index 000000000..c33582ec7 Binary files /dev/null and b/docs/benchmark-results/tflops_comparison.png differ diff --git a/iris/fd_passing.py b/iris/fd_passing.py index 4e8c13f44..1f71290fa 100644 --- a/iris/fd_passing.py +++ b/iris/fd_passing.py @@ -140,6 +140,55 @@ def setup_fd_mesh(rank: int, world_size: int, all_paths: Dict[int, str]) -> Dict return conns +def _allgather_paths_tensor(my_path: str, num_ranks: int): + """ + Exchange socket paths across ranks using a fixed-size tensor all_gather. + + Uses ``dist.all_gather`` with a fixed-size int8 tensor instead of + ``dist.all_gather_object`` to avoid injecting extra NCCL collective + calls (``all_gather_object`` internally issues two NCCL all_gathers for + size+data). At ws<8 the additional collectives can interleave with + data-plane ``all_gather_into_tensor`` calls on the same process group, + causing a rank-asymmetric collective ordering deadlock. + + AF_UNIX paths are at most 108 bytes; we use a 256-byte buffer for safety. + """ + import torch + import torch.distributed as dist + + _PATH_BUF_LEN = 256 + path_bytes = my_path.encode("utf-8") + if len(path_bytes) >= _PATH_BUF_LEN: + raise ValueError(f"Socket path too long ({len(path_bytes)} bytes, max {_PATH_BUF_LEN - 1}): {my_path}") + + # Encode into a fixed-size uint8 tensor (CPU for gloo, GPU for nccl). + # uint8 matches the [0,255] byte range; NCCL supports it natively. + buf = torch.zeros(_PATH_BUF_LEN, dtype=torch.uint8) + for i, b in enumerate(path_bytes): + buf[i] = b + + backend = str(dist.get_backend()).lower() + if backend == "nccl" and torch.cuda.is_available(): + device = torch.device("cuda", torch.cuda.current_device()) + buf = buf.to(device) + # else: keep on CPU (gloo) + + gathered = [torch.zeros_like(buf) for _ in range(num_ranks)] + dist.all_gather(gathered, buf) + + all_paths = {} + for r in range(num_ranks): + raw = gathered[r].cpu().tolist() + # Find null terminator (first 0) + try: + end = raw.index(0) + except ValueError: + end = _PATH_BUF_LEN + all_paths[r] = bytes(raw[:end]).decode("utf-8") + + return all_paths + + def setup_fd_infrastructure(cur_rank: int, num_ranks: int): """ Setup FD passing infrastructure for multi-rank communication. @@ -156,15 +205,17 @@ def setup_fd_infrastructure(cur_rank: int, num_ranks: int): if num_ranks <= 1: return None - import torch.distributed as dist from iris._distributed_helpers import distributed_barrier # Setup socket mesh for FD passing prefix = "iris-dmabuf" my_path = make_rank_sock_path(prefix, cur_rank) - obj_list = [None for _ in range(num_ranks)] - dist.all_gather_object(obj_list, my_path) - all_paths = {r: obj_list[r] for r in range(num_ranks)} + + # Use tensor-based all_gather instead of all_gather_object to avoid + # injecting extra NCCL collectives that can deadlock with data-plane + # all_gather_into_tensor at ws<8 (see _allgather_paths_tensor docstring). + all_paths = _allgather_paths_tensor(my_path, num_ranks) + distributed_barrier() fd_conns = setup_fd_mesh(cur_rank, num_ranks, all_paths) distributed_barrier() diff --git a/iris/iris.py b/iris/iris.py index 8c750ba67..52b91293c 100644 --- a/iris/iris.py +++ b/iris/iris.py @@ -1581,6 +1581,7 @@ def store(self, pointer, value, to_rank, mask=None, cache_modifier=None, hint: t value (Block): The tensor of elements to be stored. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not store the data at address pointer[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. cache_modifier (str, optional): Controls cache behavior of the store. Supported values are: - None: *(default)* — Same as ".wb". Uses write-back caching at all levels (CU, L2, LLC) with LRU policy. @@ -1623,6 +1624,7 @@ def get( to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer to local memory in current rank where the data will be written. from_rank (int): The rank ID from which to read the data. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. @@ -1672,6 +1674,7 @@ def put( to_ptr (triton.PointerType, or block of dtype=triton.PointerType): Pointer in the current rank's address space that references memory in `to_rank`. to_rank (int): The rank ID to which the data will be written. mask (Block of triton.int1, optional): If mask[idx] is false, do not load from from_ptr[idx] and do not store to to_ptr[idx]. Defaults to None. + hint (int or tuple, optional): Vectorization hint passed to tl.multiple_of / tl.max_contiguous on the translated pointer. Defaults to None. other (Block, optional): Value to return for masked-out elements during the load operation. If not provided, the result for masked-out elements is undefined. Defaults to None. load_cache_modifier (str, optional): Controls cache behavior of the load. Supported values are: - None: *(default)* — Same as ".ca". Uses cache at all levels (CU, L2, LLC) with LRU policy. diff --git a/iris/ops/__init__.py b/iris/ops/__init__.py index e0d12ba51..647ff91de 100644 --- a/iris/ops/__init__.py +++ b/iris/ops/__init__.py @@ -36,6 +36,7 @@ # from .matmul import matmul # Simple single-GPU GEMM - TODO: implement from .matmul_all_reduce import matmul_all_reduce, matmul_all_reduce_preamble from .all_gather_matmul import all_gather_matmul, all_gather_matmul_preamble +from .all_gather_matmul_hbm_buffer import all_gather_matmul_hbm_buffer, all_gather_matmul_hbm_buffer_preamble from .matmul_all_gather import matmul_all_gather from .matmul_reduce_scatter import matmul_reduce_scatter, matmul_reduce_scatter_preamble @@ -180,6 +181,8 @@ def matmul_reduce_scatter(self, output_tensor, A, B, bias=None, async_op=False, "matmul_all_reduce_preamble", "all_gather_matmul", "all_gather_matmul_preamble", + "all_gather_matmul_hbm_buffer", + "all_gather_matmul_hbm_buffer_preamble", "matmul_all_gather", "matmul_reduce_scatter", "matmul_reduce_scatter_preamble", diff --git a/iris/ops/all_gather_matmul.py b/iris/ops/all_gather_matmul.py index 5d700206c..041ff6a0c 100644 --- a/iris/ops/all_gather_matmul.py +++ b/iris/ops/all_gather_matmul.py @@ -164,7 +164,7 @@ def all_gather_matmul_preamble( B: torch.Tensor, config: Optional[FusedConfig] = None, ) -> FusedWorkspace: - """Allocate workspace for all_gather_matmul (none needed for pull pattern).""" + """Allocate workspace for all_gather_matmul.""" if config is None: config = FusedConfig() @@ -175,7 +175,7 @@ def all_gather_matmul_preamble( expected_K = world_size * K_local assert K == expected_K, f"K ({K}) must equal world_size ({world_size}) * K_local ({K_local})" - return FusedWorkspace( + ws = FusedWorkspace( operation="all_gather_matmul", shape=(M, N, K), dtype=A_sharded.dtype, @@ -183,6 +183,8 @@ def all_gather_matmul_preamble( prepared=True, ) + return ws + def all_gather_matmul( shmem, @@ -208,17 +210,6 @@ def all_gather_matmul( assert output_tensor.shape == (M, N), f"Output must be ({M}, {N}), got {output_tensor.shape}" # Validate problem size against block sizes - assert M >= config.block_size_m, ( - f"M ({M}) must be >= block_size_m ({config.block_size_m}). Use smaller block sizes for small problems." - ) - assert K_local >= config.block_size_k, ( - f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k}). " - f"Use smaller block sizes for small problems." - ) - assert N >= config.block_size_n, ( - f"N ({N}) must be >= block_size_n ({config.block_size_n}). Use smaller block sizes for small problems." - ) - if workspace is None: workspace = all_gather_matmul_preamble(shmem, A_sharded, B, config) @@ -245,7 +236,8 @@ def all_gather_matmul( even_k = K_local % config.block_size_k == 0 num_k_blocks_local = (K_local + config.block_size_k - 1) // config.block_size_k - # Launch single fused kernel + num_tiles_m = (M + config.block_size_m - 1) // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n grid = (num_sms,) _fused_all_gather_matmul_kernel[grid]( A_sharded, diff --git a/iris/ops/all_gather_matmul_hbm_buffer.py b/iris/ops/all_gather_matmul_hbm_buffer.py new file mode 100644 index 000000000..37fe99ea2 --- /dev/null +++ b/iris/ops/all_gather_matmul_hbm_buffer.py @@ -0,0 +1,719 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. + +""" +Fused All-Gather + GEMM using a local HBM staging buffer with dedicated +fetcher and GEMM workgroups, launched data-parallel. + +Supports configurable staged_a buffer layout (M-contiguous or K-contiguous) +and B layout to match optimal tritonblas conventions (TN, TT, NT, NN). +""" + +from typing import Optional +import torch +import triton +import triton.language as tl +import iris +import iris.x + +from iris.tracing.events import TraceEvent +from .config import FusedConfig +from .workspace import FusedWorkspace + + +# ────────────────────────────────────────────────────────────────────── +# Auto-config: shape-adaptive parameter selection for HBM buffer kernel +# Source: K-021 sweep data (1076+ trials, 7 verified champion shapes) +# ────────────────────────────────────────────────────────────────────── + +# Verified champion configs from IRIS-0018/0019 sweeps + optimize-loop iter3. +# Key: (M, N, K) -> dict of kernel params that beat PyTorch. +_CHAMPION_CONFIGS = { + (262144, 8192, 8192): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=64, + fs=52, + nfs=128, + fsf=304, + ), + (131072, 16384, 16384): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=32, + fs=4, + nfs=64, + fsf=52, + ), + (147456, 28672, 4096): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=16, + fs=59, + nfs=36, + fsf=52, + ), + (229376, 28672, 4096): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=16, + fs=4, + nfs=56, + fsf=52, + ), + (327680, 28672, 4096): dict( + bm=256, + bn=256, + bk=64, + gm=24, + kpf=16, + fs=4, + nfs=32, + fsf=52, + ), + (8192, 8192, 262144): dict( + bm=128, + bn=256, + bk=64, + gm=8, + kpf=32, + fs=4, + nfs=8, + fsf=52, + ), + (16384, 16384, 131072): dict( + bm=128, + bn=256, + bk=64, + gm=16, + kpf=16, + fs=16, + nfs=8, + fsf=52, + ), +} + + +def _auto_config(M: int, N: int, K: int, world_size: int = 8): + """ + Select optimal HBM buffer kernel parameters for a given shape. + + Returns (FusedConfig, k_per_flag, num_fetch_sms, num_fetch_stages, + first_stage_fetch_sms) — ready to pass to the kernel. + + Priority order: + 1. Exact match in champion configs (verified 1.12-1.44x vs PyTorch) + 2. Shape-heuristic derivation from 1076-trial sweep principles + + Heuristics (from K-021 sweep analysis): + - k_per_flag is the #1 knob (52% of perf range). Maximize it. + - bm=256 for M%256==0 and M>=8K; bm=128 otherwise + - bn=256 always (bn=128 is 15-35% worse) + - bk=64 always (bk=128 exceeds 64KB LDS on MI300X) + - num_stages=2 always (num_stages=3 crashes — 98KB LDS needed) + - num_warps=8 always (fewer warps = 22% worse) + - group_size_m: 1 for small M, 24 for large M (L2 locality) + """ + key = (M, N, K) + if key in _CHAMPION_CONFIGS: + c = _CHAMPION_CONFIGS[key] + # Validate kpf for this world_size + num_k_blocks = K // c["bk"] + kpf = c["kpf"] + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + config = FusedConfig( + block_size_m=c["bm"], + block_size_n=c["bn"], + block_size_k=c["bk"], + group_size_m=c["gm"], + ) + return config, kpf, c["fs"], c["nfs"], c["fsf"] + + # Derive from heuristics + num_k_blocks = K // 64 + + # Block sizes + bm = 256 if (M % 256 == 0 and M >= 8192) else 128 + num_m_tiles = M // bm + + # k_per_flag: maximize for throughput + if num_k_blocks >= 512: + kpf = 64 + elif num_k_blocks >= 128: + kpf = 16 + elif num_k_blocks >= 64: + kpf = 8 + else: + kpf = 4 + while num_k_blocks % kpf != 0 and kpf > 1: + kpf //= 2 + + # num_fetch_sms: scale with M-tiles (more tiles → more fetchers) + if num_m_tiles <= 8: + fs = 4 + elif num_m_tiles <= 32: + fs = 16 + elif num_m_tiles <= 128: + fs = 32 + else: + fs = 52 + + # num_fetch_stages + if num_m_tiles >= 512: + nfs = 4 + elif num_m_tiles >= 64: + nfs = 2 + else: + nfs = 1 + + # group_size_m + gm = 24 if num_m_tiles >= 64 else (8 if num_m_tiles >= 16 else 1) + + config = FusedConfig( + block_size_m=bm, + block_size_n=256, + block_size_k=64, + group_size_m=gm, + ) + return config, kpf, fs, nfs, 64 + + +@triton.jit +def _hbm_buffer_all_gather_matmul_kernel( + A_sharded, + B, + C, + bias_ptr, + staged_a, + flags_ptr, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_sa_m, # staged_a stride in M dim + stride_sa_k, # staged_a stride in K dim + stride_bias, + context_tensor: tl.tensor, + cur_rank: tl.constexpr, + world_size: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_FETCH_SMS: tl.constexpr, + NUM_M_TILES: tl.constexpr, + NUM_TILES_N: tl.constexpr, + NUM_K_BLOCKS: tl.constexpr, + NUM_K_BLOCKS_LOCAL: tl.constexpr, + K_PER_FLAG: tl.constexpr, + NUM_FLAG_GROUPS_K: tl.constexpr, + TOTAL_GATHER_TILES: tl.constexpr, + BIAS: tl.constexpr, + ALLOW_TF32: tl.constexpr, + NUM_FETCH_STAGES: tl.constexpr, + GEMM_TILES_PER_STAGE: tl.constexpr, + FIRST_STAGE_FETCH_SMS: tl.constexpr, + TRACE: tl.constexpr, +): + pid = tl.program_id(0) + acc_dtype = tl.int32 if C.type.element_ty == tl.int8 else tl.float32 + zero = tl.program_id(0) * 0 + + ctx = iris.DeviceContext.initialize(context_tensor, cur_rank, world_size, tracing=TRACE) + + # Interleaved layout with asymmetric first stage: + # [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... + # P = FIRST_STAGE_FETCH_SMS, F = NUM_FETCH_SMS, G = GEMM_TILES_PER_STAGE + FIRST_STAGE_SIZE: tl.constexpr = FIRST_STAGE_FETCH_SMS + GEMM_TILES_PER_STAGE + REST_STAGE_SIZE: tl.constexpr = NUM_FETCH_SMS + GEMM_TILES_PER_STAGE + M_PER_STAGE: tl.constexpr = (NUM_M_TILES + NUM_FETCH_STAGES - 1) // NUM_FETCH_STAGES + + # Two-phase decode: stage 0 has a different size than subsequent stages + if pid < FIRST_STAGE_SIZE: + my_stage = zero + local_pid = pid + fetch_threshold = zero + FIRST_STAGE_FETCH_SMS + else: + adjusted = pid - FIRST_STAGE_SIZE + my_stage = 1 + adjusted // REST_STAGE_SIZE + local_pid = adjusted % REST_STAGE_SIZE + fetch_threshold = zero + NUM_FETCH_SMS + + if local_pid < fetch_threshold: + # ============================================================== + # FETCHER — stage 0 uses FIRST_STAGE_FETCH_SMS WGs, + # later stages use NUM_FETCH_SMS WGs + # ============================================================== + stage_pid = local_pid + + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().fetch, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) + + src_view = iris.x.make_tensor_view(A_sharded, M, K_local, stride_am, stride_ak) + + tiles_per_m_group = NUM_FLAG_GROUPS_K * GROUP_SIZE_M + + for const_stage in range(NUM_FETCH_STAGES): + if my_stage == const_stage: + stage_fetch_sms = FIRST_STAGE_FETCH_SMS if const_stage == 0 else NUM_FETCH_SMS + stage_m_start = const_stage * M_PER_STAGE + stage_m_count = min(M_PER_STAGE, NUM_M_TILES - stage_m_start) + total_fg_stage = NUM_FLAG_GROUPS_K * stage_m_count + + for fg_idx in range(stage_pid, total_fg_stage, stage_fetch_sms): + m_group = fg_idx // tiles_per_m_group + within_group = fg_idx % tiles_per_m_group + k_flag_group = within_group // GROUP_SIZE_M + m_in_group = within_group % GROUP_SIZE_M + m_tile = stage_m_start + m_group * GROUP_SIZE_M + m_in_group + m_tile = min(m_tile, NUM_M_TILES - 1) + k_block_start = k_flag_group * K_PER_FLAG + + rm = m_tile * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + + for k_off in range(K_PER_FLAG): + k_block_global = k_block_start + k_off + + src_rank_idx = k_block_global // NUM_K_BLOCKS_LOCAL + k_block_local = k_block_global % NUM_K_BLOCKS_LOCAL + + pid_m_t = zero + m_tile + tile_k_t = zero + k_block_local + k_tile = iris.x.TileView(pid_m_t, tile_k_t, BLOCK_SIZE_M, BLOCK_SIZE_K) + + rk = k_block_global * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + staged_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + + for compile_rank in range(world_size): + if src_rank_idx == compile_rank: + a_tile = iris.x.gather(k_tile, src_view, compile_rank, ctx, hint=(1, BLOCK_SIZE_K)) + tl.store(staged_ptrs, a_tile, cache_modifier=".cg") + + flag_idx = m_tile * NUM_FLAG_GROUPS_K + k_flag_group + tl.debug_barrier() # ensure all per-block stores are visible before setting the flag + tl.atomic_xchg(flags_ptr + flag_idx, 1, sem="release", scope="gpu") + + if TRACE: + ctx.tracing.record_event_end(_trace_handle) + + else: + # ============================================================== + # GEMM — gemm_local_id indexes into this stage's M-tile range + # ============================================================== + gemm_local_id = local_pid - fetch_threshold + stage_m_start = my_stage * M_PER_STAGE + + num_pid_in_group = GROUP_SIZE_M * NUM_TILES_N + group_id = gemm_local_id // num_pid_in_group + first_pid_m = stage_m_start + group_id * GROUP_SIZE_M + first_pid_m = min(first_pid_m, NUM_M_TILES - 1) + group_sz = min(NUM_M_TILES - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((gemm_local_id % num_pid_in_group) % group_sz) + pid_n = (gemm_local_id % num_pid_in_group) // group_sz + pid_m = min(pid_m, NUM_M_TILES - 1) + + rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + rm = tl.max_contiguous(tl.multiple_of(rm, BLOCK_SIZE_M), BLOCK_SIZE_M) + rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + rn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_SIZE_N), BLOCK_SIZE_N) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) + + if TRACE: + _trace_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().compute, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=my_stage, + ) + + for k_fg in range(NUM_FLAG_GROUPS_K): + if TRACE: + _wait_handle = ctx.tracing.record_event_start( + event_id=TraceEvent().wait, + target_rank=cur_rank, + address=flags_ptr + tl.arange(0, 1), + pid_m=pid, + pid_n=k_fg, + ) + + flag_idx = pid_m * NUM_FLAG_GROUPS_K + k_fg + while tl.atomic_add(flags_ptr + flag_idx, 0, sem="acquire", scope="gpu") == 0: + pass + + if TRACE: + ctx.tracing.record_event_end(_wait_handle) + + k_block_base = k_fg * K_PER_FLAG + for k_off in range(K_PER_FLAG): + k_block = k_block_base + k_off + rk = k_block * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + rk = tl.max_contiguous(tl.multiple_of(rk, BLOCK_SIZE_K), BLOCK_SIZE_K) + + a_ptrs = staged_a + rm.to(tl.int64)[:, None] * stride_sa_m + rk[None, :] * stride_sa_k + a = tl.load(a_ptrs) + + B_ptrs = B + rk[:, None] * stride_bk + rn[None, :] * stride_bn + b = tl.load(B_ptrs) + + if ALLOW_TF32: + acc = tl.dot(a, b, acc, allow_tf32=True) + else: + acc += tl.dot(a, b, allow_tf32=False) + + if BIAS: + bias_val = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0) + acc = acc + bias_val[:, None] + + c = acc.to(C.type.element_ty) + C_ptrs = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn + c_mask = (rm[:, None] < M) & (rn[None, :] < N) + tl.store(C_ptrs, c, mask=c_mask, cache_modifier=".wt") + + if TRACE: + ctx.tracing.record_event_end(_trace_handle) + + +# ========================================================================== +# Python API +# ========================================================================== + + +def all_gather_matmul_hbm_buffer_preamble( + ctx, + A_sharded: torch.Tensor, + B: torch.Tensor, + config: Optional[FusedConfig] = None, + k_per_flag: Optional[int] = None, + staged_a_layout: str = "k_contiguous", +) -> FusedWorkspace: + """ + Allocate workspace. + + Args: + staged_a_layout: "k_contiguous" (default, row-major (M,K)) or + "m_contiguous" (col-major, stored as (K,M) transposed). + """ + M, K_local = A_sharded.shape + K, N = B.shape + world_size = ctx.get_num_ranks() + + if config is None: + auto_cfg, auto_kpf, _, _, _ = _auto_config(M, N, K, world_size) + config = auto_cfg + if k_per_flag is None: + k_per_flag = auto_kpf + if k_per_flag is None: + k_per_flag = 8 # Safety default; see K-021 best_configs.json for peak perf + + assert world_size * K_local == K + assert K_local % config.block_size_k == 0 + assert K % config.block_size_k == 0 + assert M % config.block_size_m == 0 + + num_m_tiles = M // config.block_size_m + num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 + num_flag_groups_k = num_k_blocks // k_per_flag + + ws = FusedWorkspace( + operation="all_gather_matmul_hbm_buffer", + shape=(M, N, K), + dtype=A_sharded.dtype, + world_size=world_size, + variant=f"hbm_buffer_{staged_a_layout}", + prepared=True, + ) + + if staged_a_layout == "m_contiguous": + # Allocate (K, M) row-major, .T gives (M, K) with stride_m=1, stride_k=M + storage = ctx.zeros((K, M), dtype=A_sharded.dtype) + ws.aux_buffer = storage.T # (M, K) view, M-contiguous + else: + # Default: (M, K) row-major, stride_m=K, stride_k=1 + ws.aux_buffer = ctx.zeros((M, K), dtype=A_sharded.dtype) + + ws.locks = ctx.zeros((num_m_tiles * num_flag_groups_k,), dtype=torch.int32) + + buffer_mb = M * K * A_sharded.element_size() / (1024**2) + sa_stride_m, sa_stride_k = ws.aux_buffer.stride() + ctx.info( + f"HBM buffer: staged_a=({M},{K}) [{buffer_mb:.1f} MB] " + f"layout={staged_a_layout} strides=({sa_stride_m},{sa_stride_k}), " + f"flags={num_m_tiles}x{num_flag_groups_k}, k_per_flag={k_per_flag}" + ) + + ctx.barrier() + return ws + + +_EID_FETCH = 1024 # TraceEvent().fetch +_EID_COMPUTE = 2048 # TraceEvent().compute +_EID_WAIT = 3072 # TraceEvent().wait + + +def _extract_wg_trace(ctx, grid_size, **metadata): + """Reconstruct per-workgroup trace arrays from DeviceTracing events.""" + import numpy as np + + bufs = ctx.tracing.trace_buffers + n = min(ctx.tracing.trace_counter.item(), ctx.tracing.max_events) + + event_ids = bufs["event_id"][:n].cpu().numpy() + pids = bufs["pid"][:n].cpu().numpy() + timestamps = bufs["timestamp"][:n].cpu().numpy().astype(np.int64) + # Note: despite the field name, "duration_cycles" stores the absolute end timestamp + # (set by record_event_end). The actual duration is end_ts - start_ts. + end_timestamps = bufs["duration_cycles"][:n].cpu().numpy().astype(np.int64) + xcc_ids = bufs["xcc_id"][:n].cpu().numpy().astype(np.int32) + + starts = torch.zeros(grid_size, dtype=torch.int64) + ends = torch.zeros(grid_size, dtype=torch.int64) + waits = torch.zeros(grid_size, dtype=torch.int64) + xcds = torch.zeros(grid_size, dtype=torch.int32) + + for i in range(n): + eid = int(event_ids[i]) + wg = int(pids[i]) + if wg >= grid_size: + continue + if eid == _EID_FETCH or eid == _EID_COMPUTE: + starts[wg] = int(timestamps[i]) + ends[wg] = int(end_timestamps[i]) + xcds[wg] = int(xcc_ids[i]) + elif eid == _EID_WAIT: + waits[wg] += int(end_timestamps[i]) - int(timestamps[i]) + + return {"start": starts, "end": ends, "wait": waits, "xcd": xcds, "grid_size": grid_size, **metadata} + + +def all_gather_matmul_hbm_buffer( + ctx, + output_tensor: torch.Tensor, + A_sharded: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + async_op: bool = False, + config: Optional[FusedConfig] = None, + workspace: Optional[FusedWorkspace] = None, + num_fetch_sms: Optional[int] = None, + k_per_flag: Optional[int] = None, + fetch_block_m: Optional[int] = None, + fetch_block_k: Optional[int] = None, + staged_a_layout: str = "k_contiguous", + num_warps: Optional[int] = 8, + num_stages: Optional[int] = 2, + num_fetch_stages: Optional[int] = None, + first_stage_fetch_sms: Optional[int] = None, + trace: bool = False, +) -> FusedWorkspace: + """ + All-gather + matmul with dedicated fetcher/GEMM workgroups. + + When ``config`` is None, uses ``_auto_config()`` to select shape-optimal + parameters from verified sweep data (K-021). This gives up to 1.44× + speedup over PyTorch on champion shapes without any manual tuning. + + Args: + staged_a_layout: Buffer layout for gathered A. + "k_contiguous" — (M,K) row-major, K is fast dim. Matches NN convention. + "m_contiguous" — (M,K) with M as fast dim. Matches TN convention (best for tritonblas). + """ + M, K_local = A_sharded.shape + K, N = B.shape + world_size = ctx.get_num_ranks() + + if config is None: + # Shape-adaptive auto-config from K-021 sweep data + auto_cfg, auto_kpf, auto_fs, auto_nfs, auto_fsf = _auto_config(M, N, K, world_size) + config = auto_cfg + if k_per_flag is None: + k_per_flag = auto_kpf + if num_fetch_sms is None: + num_fetch_sms = auto_fs + if num_fetch_stages is None: + num_fetch_stages = auto_nfs + if first_stage_fetch_sms is None: + first_stage_fetch_sms = auto_fsf + + # Apply defaults for any remaining None values (when config is explicit + # but some params are left at None). + # kpf=8 is the safety default: +4.3% vs kpf=16 on g6 (IRIS-0018, 934 trials) + # and avoids kpf=16 validation failures on 2/8 ranks at M=262144. + # For peak performance on known shapes, use best_configs.json from K-021. + if k_per_flag is None: + k_per_flag = 8 + if num_fetch_sms is None: + num_fetch_sms = 32 + if num_fetch_stages is None: + num_fetch_stages = 1 + if first_stage_fetch_sms is None: + first_stage_fetch_sms = 256 + + rank = ctx.get_rank() + + assert world_size * K_local == K + assert output_tensor.shape == (M, N) + assert M % config.block_size_m == 0 + assert K % config.block_size_k == 0 + assert K_local % config.block_size_k == 0 + + if fetch_block_m is None: + fetch_block_m = config.block_size_m + if fetch_block_k is None: + fetch_block_k = config.block_size_k + + num_k_blocks = K // config.block_size_k + assert num_k_blocks % k_per_flag == 0 + + if workspace is None: + workspace = all_gather_matmul_hbm_buffer_preamble(ctx, A_sharded, B, config, k_per_flag, staged_a_layout) + + workspace.locks.zero_() + + stride_am, stride_ak = A_sharded.stride() + stride_bk, stride_bn = B.stride() + stride_cm, stride_cn = output_tensor.stride() + stride_sa_m, stride_sa_k = workspace.aux_buffer.stride() + + if bias is not None: + assert bias.shape[0] == M + bias_ptr = bias + stride_bias = bias.stride()[0] if bias.dim() > 0 else 1 + use_bias = True + else: + bias_ptr = output_tensor + stride_bias = 1 + use_bias = False + + device = A_sharded.device + num_sms = config.num_sms + if num_sms is None: + props = torch.cuda.get_device_properties(device) + num_sms = props.multi_processor_count + + num_m_tiles = M // config.block_size_m + num_tiles_n = (N + config.block_size_n - 1) // config.block_size_n + total_gemm_tiles = num_m_tiles * num_tiles_n + num_k_blocks_local = K_local // config.block_size_k + num_flag_groups_k = num_k_blocks // k_per_flag + total_gather_tiles = num_m_tiles * num_k_blocks + + if num_fetch_sms is None: + num_fetch_sms = max(1, num_sms // 10) + assert 0 < num_fetch_sms + assert num_fetch_stages >= 1 + + # First stage can use more fetcher WGs to fill the first GPU wave + if first_stage_fetch_sms is None: + first_stage_fetch_sms = num_fetch_sms + + # Interleaved layout: [fetch0 (P)] [gemm0 (G)] [fetch1 (F)] [gemm1 (G)] ... + m_per_stage = (num_m_tiles + num_fetch_stages - 1) // num_fetch_stages + gemm_tiles_per_stage = m_per_stage * num_tiles_n + first_stage_size = first_stage_fetch_sms + gemm_tiles_per_stage + rest_stage_size = num_fetch_sms + gemm_tiles_per_stage + total_fetch_wgs = first_stage_fetch_sms + num_fetch_sms * max(0, num_fetch_stages - 1) + grid_size = first_stage_size + rest_stage_size * max(0, num_fetch_stages - 1) + + if trace: + max_trace_events = grid_size * 4 + if not ctx.tracing.enabled: + ctx.tracing.enable(max_events=max_trace_events) + else: + ctx.tracing.reset() + + launch_kwargs = {"matrix_instr_nonkdim": 16} + if num_warps is not None: + launch_kwargs["num_warps"] = num_warps + if num_stages is not None: + launch_kwargs["num_stages"] = num_stages + + _hbm_buffer_all_gather_matmul_kernel[(grid_size,)]( + A_sharded, + B, + output_tensor, + bias_ptr, + workspace.aux_buffer, + workspace.locks, + M, + N, + K, + K_local, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_sa_m, + stride_sa_k, + stride_bias, + ctx.get_device_context(), + rank, + world_size, + config.block_size_m, + config.block_size_n, + config.block_size_k, + config.group_size_m, + num_fetch_sms, + num_m_tiles, + num_tiles_n, + num_k_blocks, + num_k_blocks_local, + k_per_flag, + num_flag_groups_k, + total_gather_tiles, + use_bias, + config.allow_tf32, + num_fetch_stages, + gemm_tiles_per_stage, + first_stage_fetch_sms, + trace, + **launch_kwargs, + ) + + if not async_op: + ctx.barrier() + + if trace: + torch.cuda.synchronize() + workspace.trace_data = _extract_wg_trace( + ctx, + grid_size, + num_fetch_sms=num_fetch_sms, + num_fetch_stages=num_fetch_stages, + total_fetch_wgs=total_fetch_wgs, + num_m_tiles=num_m_tiles, + num_tiles_n=num_tiles_n, + first_stage_fetch_sms=first_stage_fetch_sms, + first_stage_size=first_stage_size, + rest_stage_size=rest_stage_size, + gemm_tiles_per_stage=gemm_tiles_per_stage, + ) + + return workspace diff --git a/iris/ops/config.py b/iris/ops/config.py index 3ca085c31..530df7816 100644 --- a/iris/ops/config.py +++ b/iris/ops/config.py @@ -32,7 +32,7 @@ class FusedConfig: CCL Parameters (for operations that need collective communication): all_reduce_variant: All-reduce algorithm variant. Options: "atomic", "ring", - "one_shot", "two_shot", "spinlock". Default: "one_shot". + "one_shot", "two_shot", "spinlock". Default: "two_shot". all_reduce_num_rings: Number of concurrent rings (for ring variant). Default: 1. Example: diff --git a/iris/tracing/core.py b/iris/tracing/core.py index 317fc0bbf..57c007625 100644 --- a/iris/tracing/core.py +++ b/iris/tracing/core.py @@ -208,7 +208,7 @@ def export(self, filename="trace.json", merge=False): "traceEvents": trace_events, "displayTimeUnit": "ns", "metadata": { - "schema_version": "1.1", + "schema_version": "1.2", "num_events": num_events, "rank": self.iris.cur_rank, "world_size": self.iris.num_ranks, @@ -285,7 +285,7 @@ def export(self, filename="trace.json", merge=False): "traceEvents": all_events, "displayTimeUnit": "ns", "metadata": { - "schema_version": "1.1", + "schema_version": "1.2", "total_events": len(all_events), "max_events": self.max_events, "time_unit": "cycles (s_memrealtime @ 100MHz)", diff --git a/iris/tracing/events.py b/iris/tracing/events.py index 4838c09d6..b8adb8bf1 100644 --- a/iris/tracing/events.py +++ b/iris/tracing/events.py @@ -2,6 +2,12 @@ Trace event type IDs and Triton-side enumeration. EVENT_NAMES and TraceEvent must stay in sync: same IDs for the same operations. + +Event ID ranges: + 0–1023 iris ops (data movement, atomics) + 1024–2047 user data movement (fetch/prefetch) + 2048–3071 user compute (compute, reduce) + 3072–4095 synchronization (wait, barrier) """ import triton @@ -12,6 +18,7 @@ # Event type IDs to names mapping (used for export / display). # Keep in sync with TraceEvent below. EVENT_NAMES = { + # iris ops (0–1023) 0: "load", 1: "store", 2: "get", @@ -26,45 +33,58 @@ 11: "atomic_or", 12: "atomic_min", 13: "atomic_max", + # User data movement (1024–2047) + 1024: "fetch", + # User compute (2048–3071) + 2048: "compute", + 2049: "reduce", + # Synchronization (3072–4095) + 3072: "wait", + 3073: "barrier", } @aggregate class TraceEvent: """ - Trace event type enumeration for iris remote memory operations. + Trace event type enumeration for iris operations and kernel instrumentation. + + Event ID ranges: + 0–1023 iris ops (data movement, atomics) + 1024–2047 user data movement (fetch/prefetch) + 2048–3071 user compute (compute, reduce) + 3072–4095 synchronization (wait, barrier) Usage: >>> ctx.record_event(event_id=TraceEvent().put, target_rank=1, address=ptr) Available event types: - Data Movement: + iris ops (0–1023): - load (0): Remote load operation - store (1): Remote store operation - get (2): Remote read (pull from remote to local) - put (3): Remote write (push from local to remote) - copy (4): Peer-to-peer copy between ranks + - atomic_add (5) .. atomic_max (13): Atomic operations + + User data movement (1024–2047): + - fetch (1024): Prefetching / staging data - Atomic Operations: - - atomic_add (5): Atomic addition - - atomic_sub (6): Atomic subtraction - - atomic_cas (7): Atomic compare-and-swap - - atomic_xchg (8): Atomic exchange - - atomic_xor (9): Atomic XOR - - atomic_and (10): Atomic AND - - atomic_or (11): Atomic OR - - atomic_min (12): Atomic minimum - - atomic_max (13): Atomic maximum + User compute (2048–3071): + - compute (2048): Kernel compute phase (GEMM, FFT, etc.) + - reduce (2049): Reduction operation + + Synchronization (3072–4095): + - wait (3072): Stalled on a dependency + - barrier (3073): Synchronization point """ - # Data movement operations + # iris ops (0–1023) load: tl.constexpr store: tl.constexpr get: tl.constexpr put: tl.constexpr copy: tl.constexpr - - # Atomic operations atomic_add: tl.constexpr atomic_sub: tl.constexpr atomic_cas: tl.constexpr @@ -75,16 +95,25 @@ class TraceEvent: atomic_min: tl.constexpr atomic_max: tl.constexpr + # User data movement (1024–2047) + fetch: tl.constexpr + + # User compute (2048–3071) + compute: tl.constexpr + reduce: tl.constexpr + + # Synchronization (3072–4095) + wait: tl.constexpr + barrier: tl.constexpr + @triton.constexpr_function def __init__(self): - # Data movement + # iris ops (0–1023) self.load = tl.constexpr(0) self.store = tl.constexpr(1) self.get = tl.constexpr(2) self.put = tl.constexpr(3) self.copy = tl.constexpr(4) - - # Atomics self.atomic_add = tl.constexpr(5) self.atomic_sub = tl.constexpr(6) self.atomic_cas = tl.constexpr(7) @@ -94,3 +123,14 @@ def __init__(self): self.atomic_or = tl.constexpr(11) self.atomic_min = tl.constexpr(12) self.atomic_max = tl.constexpr(13) + + # User data movement (1024–2047) + self.fetch = tl.constexpr(1024) + + # User compute (2048–3071) + self.compute = tl.constexpr(2048) + self.reduce = tl.constexpr(2049) + + # Synchronization (3072–4095) + self.wait = tl.constexpr(3072) + self.barrier = tl.constexpr(3073) diff --git a/iris/x/gather.py b/iris/x/gather.py index ca8bd4f9c..4e2b10cc9 100644 --- a/iris/x/gather.py +++ b/iris/x/gather.py @@ -24,6 +24,7 @@ def gather( src_view: TensorView, source_rank: tl.constexpr, ctx: DeviceContext, + hint: tl.constexpr = None, ): """ Tile-level gather from a specific rank. @@ -37,6 +38,9 @@ def gather( src_view: TensorView for source tensor on source_rank. source_rank: Specific rank to load from (constexpr). ctx: DeviceContext with rank, world_size, and heap_bases. + hint: Vectorization hint passed to tl.multiple_of / tl.max_contiguous on + the translated pointer. Use a scalar (e.g. 16) or a tuple + (e.g. (1, 16)) to indicate alignment. Defaults to None (no hint). Returns: Loaded tile data as a tensor. @@ -61,6 +65,7 @@ def gather( source_rank, # from_rank (source rank) ctx.heap_bases, mask=mask, + hint=hint, ) return tile_data diff --git a/tests/ops/test_all_gather_matmul.py b/tests/ops/test_all_gather_matmul.py index 193505011..afe503728 100644 --- a/tests/ops/test_all_gather_matmul.py +++ b/tests/ops/test_all_gather_matmul.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: MIT -# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026 Advanced Micro Devices, Inc. All rights reserved. """ -Tests for fused all_gather + matmul operation. +Tests for fused all_gather + matmul operations. Each rank has A_sharded (M x K_local), B is replicated. The operation gathers A from all ranks and computes C = A_gathered @ B. +Covers both the baseline pull kernel and the HBM-buffered kernel. """ import pytest @@ -13,6 +14,30 @@ import torch.distributed as dist import iris +from iris.ops.all_gather_matmul_hbm_buffer import ( + all_gather_matmul_hbm_buffer, + all_gather_matmul_hbm_buffer_preamble, +) +from iris.ops.config import FusedConfig + + +def _make_reference(rank, world_size, M, K_local, N, dtype): + """Build a torch reference output for all_gather + matmul.""" + device = f"cuda:{rank}" + K = K_local * world_size + + torch.manual_seed(42 + rank) + A_sharded = torch.randn(M, K_local, dtype=dtype, device=device) + + torch.manual_seed(123) + B = torch.randn(K, N, dtype=dtype, device=device) + + A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=device) for _ in range(world_size)] + dist.all_gather(A_gathered_list, A_sharded) + A_gathered_ref = torch.cat(A_gathered_list, dim=1) + ref_output = torch.matmul(A_gathered_ref, B) + torch.cuda.synchronize() + return A_sharded, B, ref_output @pytest.mark.parametrize( @@ -28,81 +53,200 @@ (256, 64, 128), ], ) -def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): - """Test all_gather_matmul against torch all_gather + matmul.""" +def test_all_gather_matmul_baseline(dtype, atol, rtol, M, K_local, N): + """Test baseline all_gather_matmul against torch all_gather + matmul.""" if not dist.is_initialized(): pytest.skip("torch.distributed not initialized") heap_size = 2**33 - shmem = iris.iris(heap_size) - rank = shmem.get_rank() - world_size = shmem.get_num_ranks() - - K = K_local * world_size # Full K dimension - - # Skip if problem size is too small for world_size or block sizes - # With default or custom configs, we need at least one tile - min_block_size = 32 # Smallest block size we use - if M < min_block_size: - pytest.skip(f"M={M} too small (need >= {min_block_size})") - if K_local < min_block_size: - pytest.skip(f"K_local={K_local} too small (need >= {min_block_size})") - if N < min_block_size: - pytest.skip(f"N={N} too small (need >= {min_block_size})") - - # Seed for reproducibility - different seed per rank for A_sharded - torch.manual_seed(42 + rank) - A_sharded = torch.randn(M, K_local, dtype=dtype, device=f"cuda:{rank}") + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() - # B must be identical on all ranks - torch.manual_seed(123) - B = torch.randn(K, N, dtype=dtype, device=f"cuda:{rank}") + K = K_local * world_size + + min_block_size = 32 + if M < min_block_size or K_local < min_block_size or N < min_block_size: + pytest.skip(f"Problem too small for min block size {min_block_size}") + + A_sharded, B, ref_output = _make_reference(rank, world_size, M, K_local, N, dtype) + device = f"cuda:{rank}" + + A_sharded_shmem = ctx.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = ctx.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + output = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + + config = ( + FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + if M <= 256 or K_local <= 64 or N <= 128 + else FusedConfig() + ) + + assert M >= config.block_size_m + assert K_local >= config.block_size_k + assert N >= config.block_size_n + + ctx.ops.all_gather_matmul(output, A_sharded_shmem, B_shmem, config=config) - # Reference: torch all_gather + matmul - A_gathered_list = [torch.zeros(M, K_local, dtype=dtype, device=f"cuda:{rank}") for _ in range(world_size)] - dist.all_gather(A_gathered_list, A_sharded) - A_gathered_ref = torch.cat(A_gathered_list, dim=1) # (M, K) - ref_output = torch.matmul(A_gathered_ref, B) torch.cuda.synchronize() + ctx.barrier() + + max_diff = (output - ref_output).abs().max().item() + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol}" + ) - # Create shmem tensors directly - A_sharded_shmem = shmem.zeros((M, K_local), dtype=dtype) + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + (256, 64, 128), + ], +) +@pytest.mark.parametrize( + "staged_a_layout", + [ + "k_contiguous", + "m_contiguous", + ], +) +def test_all_gather_matmul_hbm_buffer(dtype, atol, rtol, M, K_local, N, staged_a_layout): + """Test all_gather_matmul_hbm_buffer against torch all_gather + matmul.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + K = K_local * world_size + + A_sharded, B, ref_output = _make_reference(rank, world_size, M, K_local, N, dtype) + + A_sharded_shmem = ctx.zeros((M, K_local), dtype=dtype) A_sharded_shmem.copy_(A_sharded) - B_shmem = shmem.zeros((K, N), dtype=dtype) + B_shmem = ctx.zeros((K, N), dtype=dtype) B_shmem.copy_(B) - output = shmem.zeros((M, N), dtype=dtype) + output = ctx.zeros((M, N), dtype=dtype) - shmem.barrier() + ctx.barrier() - # Run fused all_gather + matmul using shmem.ops API - from iris.ops.config import FusedConfig + config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) - # Use appropriate block sizes based on problem size - # For small problems, use smaller blocks - if M <= 256 or K_local <= 64 or N <= 128: - config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) - else: - config = FusedConfig() + # k_per_flag must divide num_k_blocks = K // block_size_k; use 1 for small shapes + num_k_blocks = K // config.block_size_k + k_per_flag = 1 + while k_per_flag * 2 <= 8 and num_k_blocks % (k_per_flag * 2) == 0: + k_per_flag *= 2 - # Validate config against problem size - assert M >= config.block_size_m, f"M ({M}) must be >= block_size_m ({config.block_size_m})" - assert K_local >= config.block_size_k, f"K_local ({K_local}) must be >= block_size_k ({config.block_size_k})" - assert N >= config.block_size_n, f"N ({N}) must be >= block_size_n ({config.block_size_n})" + workspace = all_gather_matmul_hbm_buffer_preamble( + ctx, A_sharded_shmem, B_shmem, config=config, staged_a_layout=staged_a_layout, k_per_flag=k_per_flag + ) - shmem.ops.all_gather_matmul(output, A_sharded_shmem, B_shmem, config=config) + all_gather_matmul_hbm_buffer( + ctx, + output, + A_sharded_shmem, + B_shmem, + config=config, + workspace=workspace, + k_per_flag=k_per_flag, + staged_a_layout=staged_a_layout, + trace=False, + ) torch.cuda.synchronize() - shmem.barrier() + ctx.barrier() max_diff = (output - ref_output).abs().max().item() + assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( + f"Rank {rank}: Max diff {max_diff}, expected < {atol} " + f"(staged_a_layout={staged_a_layout}, M={M}, K_local={K_local}, N={N})" + ) + + +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float16, 1e-2, 1e-2), + ], +) +@pytest.mark.parametrize( + "M,K_local,N", + [ + (128, 32, 64), + ], +) +def test_all_gather_matmul_hbm_buffer_with_bias(dtype, atol, rtol, M, K_local, N): + """Test all_gather_matmul_hbm_buffer with a bias vector.""" + if not dist.is_initialized(): + pytest.skip("torch.distributed not initialized") + + heap_size = 2**33 + ctx = iris.iris(heap_size) + rank = ctx.get_rank() + world_size = ctx.get_num_ranks() + + K = K_local * world_size + + A_sharded, B, ref_output_no_bias = _make_reference(rank, world_size, M, K_local, N, dtype) + device = f"cuda:{rank}" + torch.manual_seed(77) + bias = torch.randn(M, dtype=dtype, device=device) + ref_output = ref_output_no_bias + bias[:, None] + + A_sharded_shmem = ctx.zeros((M, K_local), dtype=dtype) + A_sharded_shmem.copy_(A_sharded) + B_shmem = ctx.zeros((K, N), dtype=dtype) + B_shmem.copy_(B) + bias_shmem = ctx.zeros((M,), dtype=dtype) + bias_shmem.copy_(bias) + output = ctx.zeros((M, N), dtype=dtype) + + ctx.barrier() + + config = FusedConfig(block_size_m=64, block_size_n=64, block_size_k=32) + + # k_per_flag must divide num_k_blocks = K // block_size_k; use 1 for small shapes + num_k_blocks = K // config.block_size_k + k_per_flag = 1 + while k_per_flag * 2 <= 8 and num_k_blocks % (k_per_flag * 2) == 0: + k_per_flag *= 2 + + all_gather_matmul_hbm_buffer( + ctx, + output, + A_sharded_shmem, + B_shmem, + bias=bias_shmem, + config=config, + k_per_flag=k_per_flag, + trace=False, + ) + + torch.cuda.synchronize() + ctx.barrier() + + max_diff = (output - ref_output).abs().max().item() assert torch.allclose(output, ref_output, atol=atol, rtol=rtol), ( - f"Rank {rank}: Max diff {max_diff}, expected < {atol}" + f"Rank {rank}: Max diff {max_diff}, expected < {atol} (with bias)" ) if __name__ == "__main__": - # For quick debugging import sys if not dist.is_initialized(): @@ -111,7 +255,4 @@ def test_all_gather_matmul(dtype, atol, rtol, M, K_local, N): rank = dist.get_rank() torch.cuda.set_device(rank) - - print(f"[Rank {rank}] Testing all_gather_matmul...") - test_all_gather_matmul(torch.float16, 128, 32, 64) - print(f"[Rank {rank}] ✓ Test passed!") + print(f"[Rank {rank}] Tests in this file require pytest + torchrun. See tests/run_tests_distributed.py")