diff --git a/kernels/custom_all_reduce.py b/kernels/custom_all_reduce.py new file mode 100644 index 000000000..3658aed6f --- /dev/null +++ b/kernels/custom_all_reduce.py @@ -0,0 +1,878 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""Custom all-reduce kernel + Python-facing shim. + +Provides FlyDSL-generated allreduce kernels with cross-GPU signal +protocol for multi-GPU communication on ROCm. +""" + +from contextlib import contextmanager +import torch + +_KMAXBLOCKS = 80 +_DEFAULT_MAX_SIZE = 8192 * 1024 * 8 * 2 # 128 MB + + +def meta_size() -> int: + """Return meta buffer size (for API compatibility).""" + return 0 + + +def _is_weak_contiguous(t) -> bool: + """Check if tensor occupies a single dense range in storage.""" + try: + if t.is_contiguous(): + return True + storage = t.untyped_storage() + return int(storage.nbytes()) - int(t.storage_offset()) * int(t.element_size()) == int(t.numel()) * int(t.element_size()) + except Exception: + return False + + +_FLYDSL_AITER_GLOO_GROUP = None + + +def init_custom_ar(meta, rank_data, handles, offsets, rank: int, full_nvlink: bool=True, out=None, max_size: int = _DEFAULT_MAX_SIZE, backend=None): + """Initialize allreduce backend. + + Backend controlled by env var FLYDSL_AITER_IMPL: + - "flydsl" (default): use FlyDSL kernel + - "aiter": use aiter kernel (requires aiter package) + """ + import os + import torch.distributed as dist + + _ = meta + world_size = len(offsets) + if world_size > 8: + raise ValueError("world size > 8 is not supported") + if world_size > 1 and (world_size % 2 != 0): + raise ValueError("Odd num gpus is not supported for now") + if world_size != len(handles): + raise ValueError("handles length should equal to offsets length") + if rank < 0 or rank >= world_size: + raise ValueError("invalid rank passed in") + + impl = str(os.environ.get("FLYDSL_AITER_IMPL", "flydsl")).strip().lower() + if impl not in {"aiter", "flydsl"}: + raise ValueError(f"unsupported FLYDSL_AITER_IMPL={impl!r}") + + if not dist.is_initialized(): + raise RuntimeError("torch.distributed must be initialized") + + global _FLYDSL_AITER_GLOO_GROUP + if _FLYDSL_AITER_GLOO_GROUP is None: + try: + _FLYDSL_AITER_GLOO_GROUP = dist.new_group(backend="gloo") + except Exception: + _FLYDSL_AITER_GLOO_GROUP = dist.group.WORLD + + dev = getattr(rank_data, "device", None) or torch.device(f"cuda:{rank}") + + if backend is not None: + return backend( + group=_FLYDSL_AITER_GLOO_GROUP, + device=dev, + max_size=max_size, + world_size=world_size, + rank=rank, + full_nvlink=bool(full_nvlink), + ) + + if impl == "flydsl": + return FlyDSLAllreduce( + group=_FLYDSL_AITER_GLOO_GROUP, + device=dev, + max_size=max_size, + world_size=world_size, + rank=rank, + full_nvlink=bool(full_nvlink), + ) + + try: + from aiter.dist.device_communicators.custom_all_reduce import CustomAllreduce as AIterCustomAllreduce + except ModuleNotFoundError: + try: + from aiter.dist.custom_all_reduce import CustomAllreduce as AIterCustomAllreduce + except ModuleNotFoundError as e: + raise ModuleNotFoundError("Cannot import AIter CustomAllreduce") from e + + aiter_ar = AIterCustomAllreduce(_FLYDSL_AITER_GLOO_GROUP, dev, max_size=max_size) + try: + if hasattr(rank_data, "is_cuda") and bool(rank_data.is_cuda): + aiter_ar.register_input_buffer(rank_data) + if out is not None and hasattr(out, "is_cuda") and bool(out.is_cuda): + aiter_ar.register_output_buffer(out) + except Exception: + pass + return aiter_ar + + +class FlyDSLAllreduce: + """FlyDSL allreduce kernels with cross-GPU signal protocol on ROCm.""" + + _HIP_IPC_HANDLE_BYTES = 64 + _HIP_IPC_MEM_LAZY_ENABLE_PEER_ACCESS = 0x1 + _HIP_DEVICE_MALLOC_UNCACHED = 0x3 + _hip = None + _hipIpcMemHandle_t = None + _gpu_arch = None + + # Signal struct layout (each field alignas(128)): + # uint32_t start[_MAX_BLOCKS][8] -> _MAX_BLOCKS * 8 * 4 + # uint32_t end[_MAX_BLOCKS][8] -> _MAX_BLOCKS * 8 * 4 + # uint32_t _flag[_MAX_BLOCKS] -> _MAX_BLOCKS * 4 + # Struct size padded to 128-byte alignment. + _SIGNAL_SIZE = ((_KMAXBLOCKS * 8 * 4) * 2 + _KMAXBLOCKS * 4 + 127) & ~127 + + @classmethod + def _get_gpu_arch(cls) -> str: + """Return current GPU architecture name (cached). + + Uses arch name (e.g. 'gfx942') to decide write-mode eligibility. + """ + if cls._gpu_arch is not None: + return cls._gpu_arch + arch = "" + try: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + arch = getattr(props, "gcnArchName", "") or "" + except Exception: + pass + if not arch: + try: + import subprocess + r = subprocess.run(["rocminfo"], capture_output=True, text=True, timeout=10) + for line in r.stdout.splitlines(): + if "Name:" in line and "gfx" in line.lower(): + arch = line.split(":")[-1].strip() + break + except Exception: + pass + cls._gpu_arch = arch + return arch + + + @classmethod + def _load_hip(cls): + if cls._hip is not None: + return cls._hip + import ctypes + for name in ("libamdhip64.so", "libamdhip64.so.6", "libamdhip64.so.5"): + try: + cls._hip = ctypes.CDLL(name) + break + except OSError: + continue + if cls._hip is None: + raise RuntimeError("Failed to load HIP runtime library") + + class hipIpcMemHandle_t(ctypes.Structure): + _fields_ = [("reserved", ctypes.c_byte * cls._HIP_IPC_HANDLE_BYTES)] + cls._hipIpcMemHandle_t = hipIpcMemHandle_t + + cls._hip.hipIpcGetMemHandle.restype = ctypes.c_int + cls._hip.hipIpcGetMemHandle.argtypes = [ctypes.POINTER(hipIpcMemHandle_t), ctypes.c_void_p] + cls._hip.hipIpcOpenMemHandle.restype = ctypes.c_int + cls._hip.hipIpcOpenMemHandle.argtypes = [ctypes.POINTER(ctypes.c_void_p), hipIpcMemHandle_t, ctypes.c_uint] + cls._hip.hipIpcCloseMemHandle.restype = ctypes.c_int + cls._hip.hipIpcCloseMemHandle.argtypes = [ctypes.c_void_p] + cls._hip.hipGetErrorString.restype = ctypes.c_char_p + cls._hip.hipGetErrorString.argtypes = [ctypes.c_int] + cls._hip.hipExtMallocWithFlags.restype = ctypes.c_int + cls._hip.hipExtMallocWithFlags.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t, ctypes.c_uint] + cls._hip.hipFree.restype = ctypes.c_int + cls._hip.hipFree.argtypes = [ctypes.c_void_p] + cls._hip.hipMemset.restype = ctypes.c_int + cls._hip.hipMemset.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t] + return cls._hip + + @classmethod + def _hip_check(cls, err: int, *, what: str): + if int(err) == 0: + return + hip = cls._load_hip() + try: + s = hip.hipGetErrorString(int(err)) + msg = s.decode("utf-8", errors="replace") if s else f"hipError({err})" + except Exception: + msg = f"hipError({err})" + raise RuntimeError(f"{what} failed: {msg}") + + @classmethod + def _get_mem_handle_bytes(cls, base_ptr: int) -> bytes: + import ctypes + hip = cls._load_hip() + h = cls._hipIpcMemHandle_t() + err = hip.hipIpcGetMemHandle(ctypes.byref(h), ctypes.c_void_p(int(base_ptr))) + cls._hip_check(err, what="hipIpcGetMemHandle") + return bytes(ctypes.string_at(ctypes.byref(h), cls._HIP_IPC_HANDLE_BYTES)) + + @classmethod + def _open_mem_handle(cls, handle_bytes: bytes) -> int: + import ctypes + if len(handle_bytes) != cls._HIP_IPC_HANDLE_BYTES: + raise ValueError(f"Expected {cls._HIP_IPC_HANDLE_BYTES}B handle") + hip = cls._load_hip() + h = cls._hipIpcMemHandle_t() + ctypes.memmove(ctypes.byref(h), bytes(handle_bytes), cls._HIP_IPC_HANDLE_BYTES) + out_ptr = ctypes.c_void_p() + err = hip.hipIpcOpenMemHandle(ctypes.byref(out_ptr), h, ctypes.c_uint(int(cls._HIP_IPC_MEM_LAZY_ENABLE_PEER_ACCESS))) + cls._hip_check(err, what="hipIpcOpenMemHandle") + return int(out_ptr.value) + + @classmethod + def _close_mem_handle(cls, base_ptr: int) -> None: + import ctypes + hip = cls._load_hip() + err = hip.hipIpcCloseMemHandle(ctypes.c_void_p(int(base_ptr))) + cls._hip_check(err, what="hipIpcCloseMemHandle") + + @classmethod + def _alloc_uncached(cls, size: int) -> int: + """Allocate zero-initialised uncached device memory (hipDeviceMallocUncached). + + Returns the raw device pointer as int. + """ + import ctypes + hip = cls._load_hip() + buf = ctypes.c_void_p() + err = hip.hipExtMallocWithFlags(ctypes.byref(buf), ctypes.c_size_t(size), + ctypes.c_uint(cls._HIP_DEVICE_MALLOC_UNCACHED)) + cls._hip_check(err, what="hipExtMallocWithFlags") + err = hip.hipMemset(buf, 0, ctypes.c_size_t(size)) + cls._hip_check(err, what="hipMemset") + return int(buf.value) + + @classmethod + def _free_device_mem(cls, ptr: int) -> None: + import ctypes + hip = cls._load_hip() + err = hip.hipFree(ctypes.c_void_p(ptr)) + cls._hip_check(err, what="hipFree") + + @staticmethod + def _gather_object_list_via_broadcast(group, shard_data): + import torch.distributed as dist + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + all_data = [[None] for _ in range(world_size)] + all_data[rank][0] = shard_data + ranks = sorted(dist.get_process_group_ranks(group=group)) + for i, r in enumerate(ranks): + dist.broadcast_object_list(all_data[i], src=r, group=group, device="cpu") + return [all_data[i][0] for i in range(world_size)] + + def __init__(self, *, group, device, max_size: int, world_size: int, rank: int, full_nvlink: bool): + import os + import torch.distributed as dist + + self.group = group + self.device = device + self.max_size = int(max_size) + self.world_size = int(world_size) + self.rank = int(rank) + self.full_nvlink = bool(full_nvlink) + + if not dist.is_initialized(): + raise RuntimeError("torch.distributed must be initialized") + if self.world_size <= 1: + raise ValueError("world_size must be > 1") + + alloc_size = self._SIGNAL_SIZE + int(self.max_size) + self._meta_ptr = self._alloc_uncached(alloc_size) + + my_meta_bytes = self._get_mem_handle_bytes(self._meta_ptr) + all_meta = self._gather_object_list_via_broadcast(self.group, (my_meta_bytes, 0)) + + self._meta_bases = [None] * self.world_size + self._sg_ptrs = [0] * 8 + self._tmp_ptrs = [0] * 8 + for r in range(self.world_size): + hb, off = all_meta[r] + base_ptr = self._meta_ptr if r == self.rank else int(self._open_mem_handle(bytes(hb))) + if r != self.rank: + self._meta_bases[r] = base_ptr + sg_ptr = base_ptr + off + tmp_ptr = sg_ptr + self._SIGNAL_SIZE + if r < 8: + self._sg_ptrs[r] = sg_ptr + self._tmp_ptrs[r] = tmp_ptr + for i in range(self.world_size, 8): + self._sg_ptrs[i] = self._sg_ptrs[0] + self._tmp_ptrs[i] = self._tmp_ptrs[0] + self._self_sg = self._sg_ptrs[self.rank] + self._gpu_sg_ptrs_array = torch.tensor(self._sg_ptrs[:8], dtype=torch.int64, device=self.device) + + self.input_buffer = torch.empty(self.max_size, dtype=torch.uint8, device=self.device) + self.output_buffer = torch.empty(self.max_size, dtype=torch.uint8, device=self.device) + + inp_buf_base = int(self.input_buffer.untyped_storage().data_ptr()) + inp_buf_off = int(self.input_buffer.data_ptr()) - inp_buf_base + my_inp_buf_h = self._get_mem_handle_bytes(inp_buf_base) + all_inp_buf = self._gather_object_list_via_broadcast(self.group, (my_inp_buf_h, inp_buf_off)) + self._input_buffer_bases = [None] * self.world_size + self._input_buffer_ptrs = [0] * 8 + for r in range(self.world_size): + hb, off = all_inp_buf[r] + if r == self.rank: + self._input_buffer_ptrs[r] = int(self.input_buffer.data_ptr()) + else: + peer_base = int(self._open_mem_handle(bytes(hb))) + self._input_buffer_bases[r] = peer_base + self._input_buffer_ptrs[r] = peer_base + off + for i in range(self.world_size, 8): + self._input_buffer_ptrs[i] = self._input_buffer_ptrs[0] + + ws, rk = self.world_size, self.rank + rotated_input_buf_ptrs = [self._input_buffer_ptrs[(rk + i) % ws] for i in range(8)] + self._gpu_input_buffer_ptrs_array = torch.tensor(rotated_input_buf_ptrs, dtype=torch.int64, device=self.device) + + rotated_tmp_ptrs = [self._tmp_ptrs[(rk + i) % ws] for i in range(8)] + self._gpu_tmp_ptrs_array = torch.tensor(rotated_tmp_ptrs, dtype=torch.int64, device=self.device) + + out_buf_base = int(self.output_buffer.untyped_storage().data_ptr()) + out_buf_off = int(self.output_buffer.data_ptr()) - out_buf_base + my_out_buf_h = self._get_mem_handle_bytes(out_buf_base) + all_out_buf = self._gather_object_list_via_broadcast(self.group, (my_out_buf_h, out_buf_off)) + self._output_buffer_bases = [None] * self.world_size + self._output_buffer_ptrs = [0] * 8 + for r in range(self.world_size): + hb, off = all_out_buf[r] + if r == self.rank: + self._output_buffer_ptrs[r] = int(self.output_buffer.data_ptr()) + else: + peer_base = int(self._open_mem_handle(bytes(hb))) + self._output_buffer_bases[r] = peer_base + self._output_buffer_ptrs[r] = peer_base + off + for i in range(self.world_size, 8): + self._output_buffer_ptrs[i] = self._output_buffer_ptrs[0] + + self._gpu_output_buffer_ptrs_array = torch.tensor(self._output_buffer_ptrs[:8], dtype=torch.int64, device=self.device) + self._gpu_tmp_ptrs_nonrotated_array = torch.tensor(self._tmp_ptrs[:8], dtype=torch.int64, device=self.device) + + self._IS_CAPTURING = False + self._graph_inp = None + self._graph_out = None + self._graph_use_write_mode = False + self._gpu_graph_in_ptrs_array = torch.tensor(rotated_input_buf_ptrs, dtype=torch.int64, device=self.device) + self._graph_in_bases = [] # flat list of opened peer IPC base ptrs (for cleanup) + self._gpu_graph_out_ptrs_array = torch.tensor(self._output_buffer_ptrs[:8], dtype=torch.int64, device=self.device) + self._graph_out_bases = [] + # List-based cudagraph registration: [(tensor, per_call_ptrs, rotated), ...] + # rotated=True → inp, rotate by rank before writing ptrs + # rotated=False → out (write-mode), use rank-order ptrs + # ONE collective registers all entries at once after capture. + self._pending_graph_entries: list = [] + # Per-capture cache: data_ptr -> per_call_ptrs tensor already queued. + # Prevents duplicate pending entries when the same tensor appears in + # multiple allreduce calls within one graph capture. + self._graph_ptrs_cache: dict = {} + # Cache for eagerly-registered user output IPC ptrs (key: data_ptr int) + self._out_ptrs_cache: dict | None = None + + self._exe_cache = {} + self._threads = 512 + self._grid_x_cache = {} + + self._reuse_out_default = str(os.environ.get("FLYDSL_AITER_REUSE_OUT", "0")).strip().lower() in {"1", "true", "yes", "y"} + self._cached_out = None + + def close(self): + """Release IPC memory handles for peer GPU buffers.""" + for bases in [self._meta_bases, self._input_buffer_bases, self._output_buffer_bases, self._graph_out_bases]: + for b in bases: + if b is not None: + self._close_mem_handle(int(b)) + # _graph_in_bases is a flat list of opened peer IPC bases + for b in self._graph_in_bases: + if b is not None: + self._close_mem_handle(int(b)) + # eager write-mode out-ptrs cache + if self._out_ptrs_cache: + for b in self._out_ptrs_cache.get('bases', []): + try: + self._close_mem_handle(int(b)) + except Exception: + pass + self._out_ptrs_cache = None + self._meta_bases = [] + self._input_buffer_bases = [] + self._output_buffer_bases = [] + self._graph_in_bases = [] + self._graph_out_bases = [] + if getattr(self, '_meta_ptr', None): + try: + self._free_device_mem(self._meta_ptr) + except Exception: + pass + self._meta_ptr = None + + @contextmanager + def capture(self): + """Context manager for CUDA graph capture.""" + try: + self._IS_CAPTURING = True + self._graph_inp = None + self._graph_out = None + self._graph_use_write_mode = False + self._pending_graph_entries = [] # reset per-capture list + self._graph_ptrs_cache = {} # reset per-capture ptrs cache + yield + finally: + self._IS_CAPTURING = False + # List-based batch registration: one collective for all captured tensors. + # Covers BOTH write-mode (out entries, rotated=False) and + # non-write-mode (inp entries, rotated=True). + if self._pending_graph_entries: + self._register_graph_tensors() + + @classmethod + def _get_alloc_base_ptr(cls, dev_ptr: int) -> int: + """Get the hipMalloc allocation base for a device pointer.""" + import ctypes + hip = cls._load_hip() + base = ctypes.c_void_p() + _RANGE_START_ADDR = 11 + if not hasattr(hip, '_pga_setup'): + hip.hipPointerGetAttribute.restype = ctypes.c_int + hip.hipPointerGetAttribute.argtypes = [ctypes.c_void_p, ctypes.c_int, ctypes.c_void_p] + hip._pga_setup = True + err = hip.hipPointerGetAttribute( + ctypes.byref(base), + ctypes.c_int(_RANGE_START_ADDR), + ctypes.c_void_p(int(dev_ptr)), + ) + cls._hip_check(err, what="hipPointerGetAttribute(RANGE_START_ADDR)") + return int(base.value) + + def _exchange_out_ptrs(self, out: "torch.Tensor") -> "torch.Tensor": + """Register user output tensor via IPC and return gpu_out_ptrs_array. + + Result is in rank-order (NOT rotated), matching write-mode kernel expectation. + Cached by data_ptr so repeated eager calls with the same buffer are free. + """ + ptr = int(out.data_ptr()) + if self._out_ptrs_cache is not None and self._out_ptrs_cache.get("ptr") == ptr: + return self._out_ptrs_cache["arr"] + + ws, rk = self.world_size, self.rank + alloc_base = self._get_alloc_base_ptr(ptr) + off = ptr - alloc_base + handle = self._get_mem_handle_bytes(alloc_base) + all_out = self._gather_object_list_via_broadcast(self.group, (handle, off)) + + out_ptrs = [0] * 8 + new_bases: list = [] + for r in range(ws): + hb, o = all_out[r] + if r == rk: + out_ptrs[r] = ptr + else: + peer_base = int(self._open_mem_handle(bytes(hb))) + new_bases.append(peer_base) + out_ptrs[r] = peer_base + o + for i in range(ws, 8): + out_ptrs[i] = out_ptrs[0] + + arr = torch.tensor(out_ptrs[:8], dtype=torch.int64, device=self.device) + + # Release old cached bases before replacing + if self._out_ptrs_cache: + for b in self._out_ptrs_cache.get("bases", []): + try: + self._close_mem_handle(int(b)) + except Exception: + pass + self._out_ptrs_cache = {"ptr": ptr, "arr": arr, "bases": new_bases} + return arr + + def _get_or_create_graph_ptrs(self, tensor, rotated: bool): + """Return per-call ptrs tensor for cudagraph recording. + + Checks two caches in priority order: + 1. _out_ptrs_cache (write-mode only): IPC-registered real ptrs from + warmup; if the out address is already known, use it immediately + without queuing any deferred registration. + 2. _graph_ptrs_cache: per-call placeholder tensors already queued this + capture; reuse instead of creating a duplicate pending entry. + If neither hits, allocate a new placeholder, enqueue in + _pending_graph_entries, and store in _graph_ptrs_cache. + + Args: + tensor: inp tensor (rotated=True) or out tensor (rotated=False). + rotated: True -> rotate ptrs by rank (inp, non-write-mode). + False -> rank-order ptrs (out, write-mode). + """ + ptr = int(tensor.data_ptr()) + + # Write-mode out: check IPC registration cache first. + if not rotated: + _ipc = self._out_ptrs_cache + if _ipc is not None and _ipc.get("ptr") == ptr: + return _ipc["arr"] + + # Check per-capture graph ptrs cache. + cached = self._graph_ptrs_cache.get(ptr) + if cached is not None: + return cached + + # First occurrence: allocate placeholder and queue for batch registration. + per_call_ptrs = torch.empty(8, dtype=torch.int64, device=self.device) + self._pending_graph_entries.append((tensor, per_call_ptrs, rotated)) + self._graph_ptrs_cache[ptr] = per_call_ptrs + return per_call_ptrs + + def _register_graph_tensors(self): + """Batch-register IPC handles for all captured input tensors in ONE collective. + + Compared to the old two-collective approach (one for inp, one for out), + this collects all (handle, offset) pairs into a single list and calls + _gather_object_list_via_broadcast once, reducing inter-rank synchronisation. + + Each entry in self._pending_graph_entries is (inp, per_call_in_ptrs) where + per_call_in_ptrs is the per-call GPU tensor that was passed to _run_kernel + during graph recording and whose values are updated here for replay. + """ + ws, rk = self.world_size, self.rank + entries = self._pending_graph_entries + if not entries: + return + + # 1. Collect handle+offset for EVERY captured inp into ONE list + my_handle_list = [] + for tensor, _, _rotated in entries: + alloc_base = self._get_alloc_base_ptr(int(tensor.data_ptr())) + off = int(tensor.data_ptr()) - alloc_base + handle = self._get_mem_handle_bytes(alloc_base) + my_handle_list.append((handle, off)) + + # 2. ONE collective — each rank sends its full list, receives all others' + all_ranks_handles = self._gather_object_list_via_broadcast( + self.group, my_handle_list + ) + + # 3. For each entry, build pointer array and update in-place + # rotated=True → inp: rotate by rank (read from peer GPU inputs) + # rotated=False → out: rank-order (write-mode broadcasts to all outs) + self._graph_in_bases = [] # flat list of opened peer bases (for cleanup) + for entry_idx, (tensor, per_call_ptrs, rotated) in enumerate(entries): + ptrs = [0] * 8 + for r in range(ws): + hb, o = all_ranks_handles[r][entry_idx] + if r == rk: + ptrs[r] = int(tensor.data_ptr()) + else: + peer_base = int(self._open_mem_handle(bytes(hb))) + self._graph_in_bases.append(peer_base) + ptrs[r] = peer_base + o + for i in range(ws, 8): + ptrs[i] = ptrs[0] + if rotated: + final = [ptrs[(rk + i) % ws] for i in range(8)] + else: + final = ptrs[:8] + per_call_ptrs.copy_( + torch.tensor(final, dtype=torch.int64, device=self.device) + ) + + + def __del__(self): + try: + self.close() + except Exception: + pass + + _SUPPORTED_WORLD_SIZES = {2, 4, 8} + _SUPPORTED_DTYPES = {torch.float32, torch.float16, torch.bfloat16} + + def should_custom_ar(self, inp, *, open_fp8_quant: bool = False) -> bool: + """Check whether the custom allreduce kernel can handle this input. + + Returns False (caller should fall back to NCCL) when any of these + conditions is violated: + 1. world_size ∈ {2, 4, 8} + 2. inp byte-size is a multiple of 16 + 3. dtype ∈ {float32, float16, bfloat16} + 4. inp byte-size ≤ max_size / 2 (2-stage write-mode uses 2× tmp) + 5. fp8 quantisation is not requested + 6. full_nvlink (fully_connected) is True, or world_size == 2 + """ + from flydsl.utils import log + + if self.world_size not in self._SUPPORTED_WORLD_SIZES: + log().error("custom allreduce unsupported: world_size=%d, " + "expected one of %s", self.world_size, + sorted(self._SUPPORTED_WORLD_SIZES)) + return False + + inp_size = int(inp.numel()) * int(inp.element_size()) + if inp_size % 16 != 0: + log().error("custom allreduce unsupported: inp_size=%d " + "is not a multiple of 16", inp_size) + return False + + if inp.dtype not in self._SUPPORTED_DTYPES: + log().error("custom allreduce unsupported: dtype=%s, " + "expected one of {%s}", inp.dtype, + ", ".join(str(d) for d in sorted(self._SUPPORTED_DTYPES, key=str))) + return False + + if inp_size > self.max_size // 2: + log().error("custom allreduce unsupported: inp_size=%d " + "exceeds max_size/2=%d", inp_size, self.max_size // 2) + return False + + if open_fp8_quant: + log().error("custom allreduce unsupported: fp8 quantisation " + "is not supported") + return False + + if self.world_size > 2 and not self.full_nvlink: + log().error("custom allreduce unsupported: fully_connected=false " + "is not supported for world_size=%d", self.world_size) + return False + + return True + + _DTYPE_STR_CACHE = {} + + def _dtype_str(self, t) -> str: + dtype = getattr(t, "dtype", None) + if dtype in self._DTYPE_STR_CACHE: + return self._DTYPE_STR_CACHE[dtype] + name = str(dtype) + if "bfloat16" in name: + result = "bf16" + elif "float16" in name: + result = "f16" + elif "float32" in name: + result = "f32" + else: + raise ValueError(f"unsupported dtype: {name}") + self._DTYPE_STR_CACHE[dtype] = result + return result + + def _compile(self, *, N: int, dtype_str: str): + from kernels.custom_all_reduce_kernel import make_allreduce_kernels + + key = (N, dtype_str, self.world_size) + fns = self._exe_cache.get(key) + if fns is not None: + return fns + fns = make_allreduce_kernels( + N=N, + dtype_str=dtype_str, + world_size=self.world_size, + threads=self._threads, + ) + self._exe_cache[key] = fns + return fns + + def _run_kernel( + self, + N: int, + dtype_str: str, + *, + gpu_in_ptrs_array=None, + gpu_out_ptrs_array=None, + inp_ptr: int = 0, + out_ptr: int = 0, + use_write_mode: bool = False, + stream_ptr: int | None = None, + ): + """Launch allreduce kernel (auto-selects 1-stage or 2-stage by data size).""" + from flydsl.expr.typing import Int32, Int64, Stream + + # Auto-select stage by data size: + # world_size == 2 → always 1-stage + # world_size <= 4, bytes < 160KB → 1-stage + # world_size <= 8, bytes < 80KB → 1-stage + # otherwise → 2-stage + elem_bytes = 2 if dtype_str in ("f16", "bf16") else 4 + bytes_n = N * elem_bytes + if self.world_size == 2: + _stage = "1" + elif (self.world_size <= 4 and bytes_n < 160 * 1024) or bytes_n < 80 * 1024: + _stage = "1" + else: + _stage = "2" + + try: + grid_x = self._grid_x_cache[(int(N), str(dtype_str), _stage)] + except Exception: + pack_elems = 8 if dtype_str in ("f16", "bf16") else 4 + num_packs = int(N) // int(pack_elems) + if _stage == "1": + # 1-stage: tnum_gpu threads per warp handle one pack each + tnum_gpu = self._threads // self.world_size + grid_x = int(max(1, min(_KMAXBLOCKS, (num_packs + tnum_gpu - 1) // tnum_gpu))) + else: + part_p = int(num_packs) // int(self.world_size) + tnum_gpu = self._threads // self.world_size + grid_x = int(max(1, min(_KMAXBLOCKS, (max(1, part_p) + tnum_gpu - 1) // tnum_gpu))) + self._grid_x_cache[(int(N), str(dtype_str), _stage)] = int(grid_x) + + if stream_ptr is None: + stream_obj = torch.cuda.current_stream() + else: + stream_obj = torch.cuda.ExternalStream(stream_ptr) + + fns = self._compile(N=N, dtype_str=dtype_str) + + if _stage == "1" and not use_write_mode: + fns["run_1stage_arr"]( + Int32(self.rank), + Int32(grid_x), + Int64(self._self_sg), + Int64(int(self._gpu_sg_ptrs_array.data_ptr())), + Int64(int(gpu_in_ptrs_array.data_ptr())), + Int64(int(out_ptr)), + stream=stream_obj, + ) + elif use_write_mode: + fns["run_2stage_write_mode"]( + Int32(self.rank), + Int32(grid_x), + Int64(self._self_sg), + Int64(int(self._gpu_sg_ptrs_array.data_ptr())), + Int64(int(inp_ptr)), + Int64(int(gpu_out_ptrs_array.data_ptr())), + Int64(int(self._gpu_tmp_ptrs_nonrotated_array.data_ptr())), + stream=stream_obj, + ) + else: + fns["run_2stage_arr"]( + Int32(self.rank), + Int32(grid_x), + Int64(self._self_sg), + Int64(int(self._gpu_sg_ptrs_array.data_ptr())), + Int64(int(gpu_in_ptrs_array.data_ptr())), + Int64(int(self._gpu_tmp_ptrs_array.data_ptr())), + Int64(int(out_ptr)), + stream=stream_obj, + ) + + def custom_all_reduce( + self, + inp, + *, + out=None, + use_new: bool = True, + open_fp8_quant: bool = False, + validate: bool = True, + stream_ptr: int | None = None, + ): + """Unified all-reduce (eager and cudagraph). + + Returns None when the input is not supported by the custom kernel + (caller should fall back to NCCL). + Selects write_mode kernel when N > 512*4096 and world_size == 8. + """ + if not self.should_custom_ar(inp, open_fp8_quant=open_fp8_quant): + return None + + if out is None: + if self._reuse_out_default and (self._cached_out is not None) and self._cached_out.shape == inp.shape and self._cached_out.dtype == inp.dtype and self._cached_out.device == inp.device: + out = self._cached_out + else: + out = torch.empty_like(inp) + if self._reuse_out_default: + self._cached_out = out + + if validate: + if int(inp.numel()) != int(out.numel()): + raise ValueError("inp.numel must equal out.numel") + if not _is_weak_contiguous(out): + raise ValueError("output tensor must be weak-contiguous") + dtype_str = self._dtype_str(inp) + if dtype_str != self._dtype_str(out): + raise ValueError("inp/out dtype mismatch") + bytes_n = int(inp.numel()) * int(inp.element_size()) + if bytes_n % 16 != 0: + raise ValueError("byte size must be multiple of 16") + if bytes_n > self.max_size: + raise ValueError(f"input bytes {bytes_n} exceed max_size {self.max_size}") + else: + dtype_str = self._dtype_str(inp) + bytes_n = int(inp.numel()) * int(inp.element_size()) + N = int(out.numel()) + + # Write-mode only on CDNA3 (gfx942), ws=8, large tensors + use_write_mode = ( + bytes_n > 512 * 4096 * 2 + and self.world_size == 8 + and "gfx942" in self._get_gpu_arch() + ) + + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + self._graph_inp = inp + self._graph_out = out + self._graph_bytes_n = bytes_n + + if use_write_mode: + self._graph_use_write_mode = True + self._run_kernel( + N, dtype_str, + gpu_out_ptrs_array=self._get_or_create_graph_ptrs(out, False), + inp_ptr=int(inp.data_ptr()), + use_write_mode=True, + stream_ptr=stream_ptr, + ) + else: + self._graph_use_write_mode = False + self._run_kernel( + N, dtype_str, + gpu_in_ptrs_array=self._get_or_create_graph_ptrs(inp, True), + out_ptr=int(out.data_ptr()), + use_write_mode=False, + stream_ptr=stream_ptr, + ) + return out + else: + # IS_CAPTURING=True but stream is not recording: warmup-inside-capture + # is not a supported usage path. Return zeros to keep all ranks in sync + # without issuing any kernel or collective. + out.zero_() + return out + + if use_write_mode: + self._run_kernel( + N, dtype_str, + gpu_out_ptrs_array=self._gpu_output_buffer_ptrs_array, + inp_ptr=int(inp.data_ptr()), + use_write_mode=True, + stream_ptr=stream_ptr, + ) + out.view(torch.uint8)[:bytes_n].copy_(self.output_buffer[:bytes_n]) + else: + self.input_buffer[:bytes_n].copy_(inp.view(torch.uint8)) + self._run_kernel( + N, dtype_str, + gpu_in_ptrs_array=self._gpu_input_buffer_ptrs_array, + out_ptr=int(out.data_ptr()), + use_write_mode=False, + stream_ptr=stream_ptr, + ) + return out + + def all_reduce_reg(self, inp, out, open_fp8_quant: bool = False): + if isinstance(inp, (list, tuple)): + import functools + result = functools.reduce(torch.add, inp) + out.copy_(result) + return out + return self.custom_all_reduce(inp, out=out, open_fp8_quant=open_fp8_quant) + + def all_gather_reg(self, inp, out): + if isinstance(inp, (list, tuple)): + stacked = torch.stack(list(inp), dim=0) + out.copy_(stacked) + elif self.world_size == 1: + out.copy_(inp) + else: + import torch.distributed as dist + dist.all_gather_into_tensor(out, inp, group=self.group) + return out diff --git a/kernels/custom_all_reduce_kernel.py b/kernels/custom_all_reduce_kernel.py new file mode 100644 index 000000000..811d0347f --- /dev/null +++ b/kernels/custom_all_reduce_kernel.py @@ -0,0 +1,999 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""FlyDSL all-reduce kernels using signal protocol for multi-GPU communication. + +Implements 1-stage and 2-stage (reduce-scatter + all-gather) kernels. +Signal buffers are hipDeviceMallocUncached (bypasses L1/TCP cache). +Memory ordering uses GFX942 inline assembly for XGMI/HBM visibility. +""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.expr.typing import T +from flydsl.expr.meta import traced_op +from flydsl.expr import arith as ea, gpu, range_constexpr +from flydsl._mlir.dialects import gpu as _raw_gpu, llvm, rocdl, arith as _mlir_arith +from flydsl.expr.typing import T, Int32, Int64, Stream +from flydsl.expr.buffer_ops import _unwrap_value +from flydsl._mlir import ir +from flydsl.expr.arith import ArithValue + + +def extui(self, target_type, *, loc=None): + """Zero-extend integer to wider type (e.g. i32 → i64).""" + return ea.ExtUIOp(target_type, self, loc=loc).result + + +def extsi(self, target_type, *, loc=None): + """Sign-extend integer to wider type (e.g. i32 → i64).""" + return ea.ExtSIOp(target_type, self, loc=loc).result + + +def trunci(self, target_type, *, loc=None): + """Truncate integer to narrower type (e.g. i64 → i32).""" + return ea.TruncIOp(target_type, self, loc=loc).result + + +ArithValue.extui = extui +ArithValue.extsi = extsi +ArithValue.trunci = trunci + + +@traced_op +def select_by_index(index_val, values): + """Select one of *values* by integer *index_val* via chained ``arith.select``. + + Equivalent to a compile-time switch: returns ``values[index_val]``. + + Args: + index_val: Integer index (i32 ``ir.Value``). + values: List of ``ir.Value`` to select from. + + Returns: + The selected ``ir.Value``. + """ + out = values[0] + for i in range(1, len(values)): + pred = _mlir_arith.CmpIOp( + _mlir_arith.CmpIPredicate.eq, index_val, ea.constant(i, type=index_val.type) + ).result + out = _mlir_arith.SelectOp(pred, values[i], out).result + return out + + +ea.select_by_index = select_by_index + + +# --------------------------------------------------------------------------- +# Uncached i32 operations (system-scope coherent, for signal buffers) +# --------------------------------------------------------------------------- + +@traced_op +def load_i32_uncached(addr_i64): + """Load i32 from global address, bypassing L1 cache (system-scope). + + Emits ``global_load_dword ... sc1`` on GFX942. + Typically used to poll cross-GPU signal buffers. + """ + v = llvm.InlineAsmOp( + T.i32, [addr_i64], + "global_load_dword $0, $1, off sc1", "=v,v", + has_side_effects=True, + ).result + rocdl.s_waitcnt(0) + rocdl.sched_barrier(0) + return v + + +@traced_op +def store_i32_uncached_flush(addr_i64, val_i32): + """Store i32 with L2 flush + system-scope coherence for XGMI visibility. + + Emits ``buffer_wbl2 sc0 sc1`` followed by ``global_store_dword ... sc0 sc1``. + Use after cached data stores (``store_v4i32``) to ensure L2 dirty lines + reach HBM before the signal becomes visible to peer GPUs. + """ + llvm.InlineAsmOp(None, [], "buffer_wbl2 sc0 sc1", "", has_side_effects=True) + llvm.InlineAsmOp( + None, [addr_i64, val_i32], + "global_store_dword $0, $1, off sc0 sc1", "v,v", + has_side_effects=True, + ) + rocdl.s_waitcnt(0) + rocdl.sched_barrier(0) + + +@traced_op +def store_i32_uncached(addr_i64, val_i32): + """Store i32 with system-scope coherence (no L2 flush). + + Emits ``global_store_dword ... sc0 sc1``. + Use after nontemporal data stores (``store_v4i32_nt``) which already + bypass L2 — no ``buffer_wbl2`` is needed. + """ + llvm.InlineAsmOp( + None, [addr_i64, val_i32], + "global_store_dword $0, $1, off sc0 sc1", "v,v", + has_side_effects=True, + ) + rocdl.s_waitcnt(0) + rocdl.sched_barrier(0) + + +@traced_op +def store_i32(addr_i64, val_i32): + """Store i32 to global address (normal cached store). + + Emits ``global_store_dword ... off`` with no cache coherence flags. + Use for writes visible only to the local GPU (e.g. updating own signal). + """ + llvm.InlineAsmOp( + None, [addr_i64, val_i32], + "global_store_dword $0, $1, off", "v,v", + has_side_effects=True, + ) + rocdl.s_waitcnt(0) + rocdl.sched_barrier(0) + + +# --------------------------------------------------------------------------- +# v4i32 (16-byte) vector operations +# --------------------------------------------------------------------------- + +@traced_op +def load_v4i32(addr_i64): + """Load 16 bytes (``vector<4xi32>``) from global address. + + Emits ``flat_load_dwordx4``. + """ + v = llvm.InlineAsmOp( + T.i32x4, [addr_i64], + "flat_load_dwordx4 $0, $1", "=v,v", + has_side_effects=True, + ).result + rocdl.s_waitcnt(0) + rocdl.sched_barrier(0) + return v + + +@traced_op +def store_v4i32(addr_i64, v4i32_val): + """Store 16 bytes (``vector<4xi32>``) to global address. + + Emits ``global_store_dwordx4 ... off``. + """ + llvm.InlineAsmOp( + None, [addr_i64, v4i32_val], + "global_store_dwordx4 $0, $1, off", "v,v", + has_side_effects=True, + ) + rocdl.s_waitcnt(0) + rocdl.sched_barrier(0) + + +@traced_op +def store_v4i32_nt(addr_i64, v4i32_val): + """Store 16 bytes with nontemporal hint, bypassing L1/L2 cache. + + Emits ``flat_store_dwordx4 ... nt``. + Suitable for large data writes across XGMI — works on any memory type + (regular ``hipMalloc``, IPC-mapped coarse-grained memory). + """ + llvm.InlineAsmOp( + None, [addr_i64, v4i32_val], + "flat_store_dwordx4 $0, $1 nt", "v,v", + has_side_effects=True, + ) + rocdl.s_waitcnt(0) + rocdl.sched_barrier(0) + + +# --------------------------------------------------------------------------- +# Pointer helpers +# --------------------------------------------------------------------------- + +@traced_op +def load_device_ptr(array_base_i64, index): + """Load an i64 pointer from a device-side pointer array. + + Computes ``base + index * 8``, casts to ``!llvm.ptr``, and loads i64. + + Args: + array_base_i64: Base address of the pointer array (i64). + index: Array index (i32 or i64). + """ + + i64 = T.i64 + if hasattr(index, 'type') and str(index.type) == 'i32': + index = _mlir_arith.ExtUIOp(i64, index).result + elem_addr = array_base_i64 + index * ea.constant(8, type=i64) + ptr = llvm.IntToPtrOp(ir.Type.parse("!llvm.ptr"), elem_addr).result + return llvm.LoadOp(i64, ptr).result + + +@traced_op +def invalidate_l1(): + """Invalidate L1 scalar cache (``buffer_inv sc1``). + + Use inside a polling loop after a remote-visible load to discard stale + L1 cache lines so the next iteration sees fresh data from L2/HBM. + """ + llvm.InlineAsmOp(None, [], "buffer_inv sc1", "", has_side_effects=True) + + +import flydsl.compiler as flyc +from flydsl.expr import arith as ea, gpu, range_constexpr, vector as ev +from flydsl.expr.typing import T, Int32, Int64, Stream +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + + +# Signal buffer layout offsets (bytes) within the per-rank signal buffer +_SG_START_OFF_B = 0 +_SG_END_OFF_B = 2560 +_SG_FLAG_OFF_B = 5120 + +_MAX_BLOCKS = 80 + + +# --------------------------------------------------------------------------- +# Element type helpers +# --------------------------------------------------------------------------- + +def _elem_type(dtype_str: str) -> ir.Type: + d = (dtype_str or "").strip().lower() + if d in {"f16", "fp16"}: + return T.f16 + if d in {"bf16"}: + return T.bf16 + if d in {"f32", "fp32"}: + return T.f32 + raise ValueError(f"unsupported dtype_str: {dtype_str!r}") + + +def _pack_elems(dtype_str: str) -> int: + d = (dtype_str or "").strip().lower() + if d in {"f32", "fp32"}: + return 4 + if d in {"f16", "fp16", "bf16"}: + return 8 + raise ValueError(f"unsupported dtype_str: {dtype_str!r}") + + +def _u(v): + """Tag ArithValue as unsigned for //, %, <, <=, >, >=, >> ops.""" + return v.with_signedness(False) + + +# --------------------------------------------------------------------------- +# Signal synchronization primitives +# --------------------------------------------------------------------------- + +def _signal_start_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, ngpus: int): + """Start-sync: write start flag to all peers, wait for all to arrive.""" + + + i32, i64 = T.i32, T.i64 + + flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + + bid_i32.extui(i64) * ea.constant(4, type=i64)) + flag = load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + + bid8 = bid_i32 * ea.constant(8, type=i32) + lin_lane = bid8 + lane_i32 + start_wait_addr = (self_sg_i64 + ea.constant(_SG_START_OFF_B, type=i64) + + lin_lane.extui(i64) * ea.constant(4, type=i64)) + lin_rank = bid8 + rank_i32 + start_rank_off = (ea.constant(_SG_START_OFF_B, type=i64) + + lin_rank.extui(i64) * ea.constant(4, type=i64)) + + is_lane = _u(lane_i32) < ea.constant(ngpus, type=i32) + if_op = scf.IfOp(is_lane, results_=[], has_else=False) + with ir.InsertionPoint(if_op.then_block): + peer_sg = ea.select_by_index(lane_i32, sgs_i64) + store_i32_uncached_flush(peer_sg + start_rank_off, flag) + init_cur = load_i32_uncached(start_wait_addr) + w = scf.WhileOp([i32], [init_cur]) + wb = ir.Block.create_at_start(w.before, [i32]) + wa = ir.Block.create_at_start(w.after, [i32]) + with ir.InsertionPoint(wb): + cur = wb.arguments[0] + need_wait = _u(cur) < flag + scf.ConditionOp(need_wait, [cur]) + with ir.InsertionPoint(wa): + scf.YieldOp([load_i32_uncached(start_wait_addr)]) + scf.YieldOp([]) + + gpu.barrier() + is_t0 = lane_i32 == ea.constant(0, type=i32) + if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) + with ir.InsertionPoint(if_t0.then_block): + store_i32(flag_addr, flag) + scf.YieldOp([]) + return flag_addr + + +def _signal_end_sync(*, lane_i32, rank_i32, bid_i32, self_sg_i64, sgs_i64, + ngpus: int, need_wbl2: bool = False): + """End-sync: write end flag to all peers, wait for all to finish. + + Args: + need_wbl2: True → use st_xgmi_u32 (buffer_wbl2 + signal store). + Required after cached stores (st_global_16b) so + that L2 dirty lines reach HBM before the signal. + False → use st_signal_u32 (signal store only, no wbl2). + For nt data stores (st_nt_16b) which already bypass + L2; uses ATOMIC_RELAXED + MEMORY_SCOPE_SYSTEM. + """ + + + i32, i64 = T.i32, T.i64 + + gpu.barrier() + flag_addr = (self_sg_i64 + ea.constant(_SG_FLAG_OFF_B, type=i64) + + bid_i32.extui(i64) * ea.constant(4, type=i64)) + flag = load_i32_uncached(flag_addr) + ea.constant(1, type=i32) + + bid8 = bid_i32 * ea.constant(8, type=i32) + lin_lane = bid8 + lane_i32 + end_wait_addr = (self_sg_i64 + ea.constant(_SG_END_OFF_B, type=i64) + + lin_lane.extui(i64) * ea.constant(4, type=i64)) + lin_rank = bid8 + rank_i32 + end_rank_off = (ea.constant(_SG_END_OFF_B, type=i64) + + lin_rank.extui(i64) * ea.constant(4, type=i64)) + + is_lane = _u(lane_i32) < ea.constant(ngpus, type=i32) + if_op = scf.IfOp(is_lane, results_=[], has_else=False) + with ir.InsertionPoint(if_op.then_block): + peer_sg = ea.select_by_index(lane_i32, sgs_i64) + if need_wbl2: + store_i32_uncached_flush(peer_sg + end_rank_off, flag) + else: + store_i32_uncached(peer_sg + end_rank_off, flag) + init_cur = load_i32_uncached(end_wait_addr) + w = scf.WhileOp([i32], [init_cur]) + wb = ir.Block.create_at_start(w.before, [i32]) + wa = ir.Block.create_at_start(w.after, [i32]) + with ir.InsertionPoint(wb): + cur = wb.arguments[0] + need_wait = _u(cur) < flag + scf.ConditionOp(need_wait, [cur]) + with ir.InsertionPoint(wa): + nxt = load_i32_uncached(end_wait_addr) + invalidate_l1() + scf.YieldOp([nxt]) + scf.YieldOp([]) + + gpu.barrier() + is_t0 = lane_i32 == ea.constant(0, type=i32) + if_t0 = scf.IfOp(is_t0, results_=[], has_else=False) + with ir.InsertionPoint(if_t0.then_block): + store_i32(flag_addr, flag) + scf.YieldOp([]) + + +# --------------------------------------------------------------------------- +# Kernel work group size attribute helper +# --------------------------------------------------------------------------- + +def _set_workgroup_size(threads: int): + """Set rocdl work group size attributes on the enclosing gpu.func.""" + entry_block = ir.InsertionPoint.current.block + gpu_func_op = entry_block.owner + gpu_func_op.operation.attributes["rocdl.reqd_work_group_size"] = ir.DenseI32ArrayAttr.get([threads, 1, 1]) + gpu_func_op.operation.attributes["rocdl.flat_work_group_size"] = ir.StringAttr.get(f"{threads},{threads}") + return gpu_func_op + + +# --------------------------------------------------------------------------- +# Main entry point +# --------------------------------------------------------------------------- + +def make_allreduce_kernels(*, N: int, dtype_str: str, world_size: int, threads: int = 512): + """Build and return compiled allreduce launcher functions. + + Captures compile-time constants as closures, returns a dict with: + "run_1stage_arr" -- CUDAGraph-compatible 1-stage allreduce (small N) + "run_2stage_arr" -- CUDAGraph-compatible 2-stage allreduce + "run_2stage_write_mode" -- Large-tensor 2-stage allreduce (N > 512*4096, ws=8) + + Args: + N: Total number of elements to reduce. + dtype_str: "f16" or "f32". + world_size: Number of GPUs (2, 4, 6, or 8). + threads: Threads per block (must be divisible by world_size). + """ + if world_size not in {2, 4, 6, 8}: + raise ValueError(f"world_size must be one of {{2,4,6,8}}, got {world_size}") + if threads <= 0 or threads % world_size != 0: + raise ValueError(f"threads={threads} must be > 0 and divisible by world_size={world_size}") + + pack_elems = _pack_elems(dtype_str) + if N <= 0 or N % pack_elems != 0: + raise ValueError(f"N={N} must be > 0 and a multiple of pack_elems={pack_elems}") + + # Compile-time constants captured by closures + num_packs = N // pack_elems + part_p = num_packs // world_size + largest_part_p = part_p + (num_packs % world_size) + tnum_gpu = threads // world_size + is_f32 = dtype_str.lower().strip() in {"f32", "fp32"} + is_bf16 = dtype_str.lower().strip() in {"bf16"} + # Vectorized gather path: requires perfect partition + no world_size=6 + vec_ok = (num_packs % world_size == 0) and (world_size != 6) + + # Adaptive LDS buffer strategy for 2-stage Stage 1: + # Single buffer (8KB, 2 barriers/iter): halves LDS usage, doubles block + # occupancy per CU, improves latency-hiding for many-iteration workloads. + # Double buffer (16KB, 1 barrier/iter): saves 1 barrier per iteration, + # better for small tensors where the kernel runs only 1-2 iterations and + # occupancy is already saturated by register usage rather than LDS. + # Threshold: use single buffer when estimated iterations per block >= 3. + _est_iters_2stage = max(1, (max(1, part_p) + _MAX_BLOCKS * tnum_gpu - 1) + // (_MAX_BLOCKS * tnum_gpu)) + _use_single_buf_2stage = (_est_iters_2stage >= 3) + + # ----------------------------------------------------------------------- + # GPU Kernel: 1-stage arr (full allreduce in one pass, CUDAGraph-compatible) + # ----------------------------------------------------------------------- + @flyc.kernel + def allreduce_1stage_arr( + rank: Int32, + self_sg: Int64, + sg_ptrs: Int64, + in_ptrs: Int64, + out_ptr: Int64, + ): + """1-stage allreduce using shared memory. + + Each warp loads data from one rank into shared memory, then warp 0 + reduces across all warps and writes the result to global memory. + """ + + + i32, i64 = T.i32, T.i64 + idx = ir.IndexType.get() + v4i32 = T.i32x4 + if is_f32: + v4f32 = T.f32x4 + else: + v8half = T.bf16x8 if is_bf16 else T.f16x8 + v8f32 = T.vec(8, T.f32) + + gpu_func_op = _set_workgroup_size(threads) + + lane_i32 = ea.index_cast(i32, gpu.thread_id("x")) + bid_i32 = ea.index_cast(i32, gpu.block_id("x")) + rank_i32 = rank.ir_value() + self_sg_i64 = self_sg.ir_value() + sg_ptrs_i64 = sg_ptrs.ir_value() + in_ptrs_i64 = in_ptrs.ir_value() + out_ptr_i64 = out_ptr.ir_value() + + sgs = [load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + in_ptrs_arr = [load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + + smem_sym = f"allreduce_1s_smem_ws{world_size}_t{threads}" + n_smem = 2 * threads + allocator = SmemAllocator(None, global_sym_name=smem_sym) + smem_off = allocator._align(allocator.ptr, 16) + allocator.ptr = smem_off + n_smem * 16 + with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): + allocator.finalize() + smem_ptr = SmemPtr(allocator.get_base(), smem_off, v4i32, shape=(n_smem,)) + smem_ptr.get() + + tnum_gpu_i32 = ea.constant(tnum_gpu, type=i32) + warp_id = _u(lane_i32) // tnum_gpu_i32 + lane_id = _u(lane_i32) % tnum_gpu_i32 + + _signal_start_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, + self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + # Grid-stride loop: each warp loads from its assigned rank, + # then warp 0 reduces and writes output. + tid_pack = bid_i32 * tnum_gpu_i32 + lane_id + stride_pack = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 + + loop = scf.WhileOp([i32, i32], [tid_pack, ea.constant(0, type=i32)]) + bfor = ir.Block.create_at_start(loop.before, [i32, i32]) + afor = ir.Block.create_at_start(loop.after, [i32, i32]) + with ir.InsertionPoint(bfor): + p = bfor.arguments[0] + cond = _u(p) < ea.constant(num_packs, type=i32) + scf.ConditionOp(cond, [p, bfor.arguments[1]]) + with ir.InsertionPoint(afor): + p = afor.arguments[0] + parity = afor.arguments[1] + + # Each warp loads data from its rank into shared memory + in_base = ea.select_by_index(warp_id, in_ptrs_arr) + off16 = p.extui(i64) * ea.constant(16, type=i64) + raw = load_v4i32(in_base + off16) + sm_base = parity * ea.constant(threads, type=i32) + sm_idx = ea.index_cast(idx, sm_base + lane_i32) + smem_ptr.store(raw, [sm_idx]) + gpu.barrier() + + # Warp 0 reduces across all warps and writes to output + is_w0 = warp_id == ea.constant(0, type=i32) + ifw0 = scf.IfOp(is_w0, results_=[], has_else=False) + with ir.InsertionPoint(ifw0.then_block): + acc = None + for wi in range_constexpr(world_size): + sm_i_idx = ea.index_cast( + idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + sm_base) + raw_i = smem_ptr.load([sm_i_idx]) + if is_f32: + vf = raw_i.bitcast(v4f32) + acc = vf if acc is None else acc + vf + else: + v16 = ev.bitcast(v8half, raw_i) + v32 = v16.extf(v8f32) + acc = v32 if acc is None else acc + v32 + if is_f32: + out_bits = acc.bitcast(v4i32) + else: + out_bits = ev.bitcast(v4i32, acc.truncf(v8half)) + dst_off = p.extui(i64) * ea.constant(16, type=i64) + store_v4i32(out_ptr_i64 + dst_off, out_bits) + scf.YieldOp([]) + + # No barrier 2 needed: parity double-buffer ensures next iteration + # writes to the opposite smem half, so warp-0 reads from parity_N half + # are disjoint from all-warp writes to (1-parity_N) half in the next + # iteration. The barrier at the top of the next iteration guarantees + # warp-0 finishes before any thread reads the new data. + scf.YieldOp([p + stride_pack, ea.constant(1, type=i32) - parity]) + + # 1-stage does not use end_sync to avoid hangs. + + # ----------------------------------------------------------------------- + # GPU Kernel: 2-stage arr (reduce-scatter + all-gather) + # ----------------------------------------------------------------------- + @flyc.kernel + def allreduce_2stage_arr( + rank: Int32, + self_sg: Int64, + sg_ptrs: Int64, + in_ptrs: Int64, + tmp_ptrs: Int64, + out_ptr: Int64, + ): + + + i32, i64 = T.i32, T.i64 + idx = ir.IndexType.get() + v4i32 = T.i32x4 + if is_f32: + v4f32 = T.f32x4 + else: + v8half = T.bf16x8 if is_bf16 else T.f16x8 + v8f32 = T.vec(8, T.f32) + + gpu_func_op = _set_workgroup_size(threads) + + lane_i32 = ea.index_cast(i32, gpu.thread_id("x")) + bid_i32 = ea.index_cast(i32, gpu.block_id("x")) + rank_i32 = rank.ir_value() + self_sg_i64 = self_sg.ir_value() + sg_ptrs_i64 = sg_ptrs.ir_value() + in_ptrs_i64 = in_ptrs.ir_value() + tmp_ptrs_i64 = tmp_ptrs.ir_value() + out_ptr_i64 = out_ptr.ir_value() + + sgs = [load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + in_ptrs_arr = [load_device_ptr(in_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + tmp_ptrs_arr = [load_device_ptr(tmp_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + + # Compute pack range for this rank's reduce-scatter partition + start_p = rank_i32 * ea.constant(part_p, type=i32) + is_last = rank_i32 == ea.constant(world_size - 1, type=i32) + end_p = ea.select(is_last, ea.constant(num_packs, type=i32), + start_p + ea.constant(part_p, type=i32)) + + _signal_start_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, + self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + tnum_gpu_i32 = ea.constant(tnum_gpu, type=i32) + warp_id = _u(lane_i32) // tnum_gpu_i32 + lane_id = _u(lane_i32) % tnum_gpu_i32 + tid_pack = bid_i32 * tnum_gpu_i32 + lane_id + stride_pack = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 + + _buf_tag = "1b" if _use_single_buf_2stage else "2b" + smem_sym = f"allreduce_smem_ws{world_size}_t{threads}_{_buf_tag}" + smem_slots = threads if _use_single_buf_2stage else 2 * threads + allocator = SmemAllocator(None, global_sym_name=smem_sym) + smem_off = allocator._align(allocator.ptr, 16) + allocator.ptr = smem_off + smem_slots * 16 + with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): + allocator.finalize() + smem_ptr = SmemPtr(allocator.get_base(), smem_off, v4i32, shape=(smem_slots,)) + smem_ptr.get() + tmp_out_i64 = tmp_ptrs_arr[0] + + # ---- Stage 1: reduce-scatter ---- + # Two implementations selected at compile time via _use_single_buf_2stage: + # Single-buffer (large tensor): 8KB LDS, 2 barriers/iter, higher occupancy. + # Double-buffer (small tensor): 16KB LDS, 1 barrier/iter (parity trick). + + def _build_reduce_body(cur, smem_base_expr=None): + """Emit reduce body: load → smem → barrier1 → warp0 reduce → [barrier2].""" + in_base = ea.select_by_index(warp_id, in_ptrs_arr) + raw = load_v4i32(in_base + cur.extui(i64) * ea.constant(16, type=i64)) + if smem_base_expr is None: + sm_idx = ea.index_cast(idx, lane_i32) + else: + sm_idx = ea.index_cast(idx, smem_base_expr + lane_i32) + smem_ptr.store(raw, [sm_idx]) + gpu.barrier() # barrier 1: all warps have written smem + + is_w0 = warp_id == ea.constant(0, type=i32) + ifw0 = scf.IfOp(is_w0, results_=[], has_else=False) + with ir.InsertionPoint(ifw0.then_block): + acc = None + for wi in range_constexpr(world_size): + if smem_base_expr is None: + sm_r_idx = ea.index_cast(idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id) + else: + sm_r_idx = ea.index_cast(idx, ea.constant(wi, type=i32) * tnum_gpu_i32 + lane_id + smem_base_expr) + raw_i = smem_ptr.load([sm_r_idx]) + if is_f32: + vf = raw_i.bitcast(v4f32) + acc = vf if acc is None else acc + vf + else: + v16 = ev.bitcast(v8half, raw_i) + v32 = v16.extf(v8f32) + acc = v32 if acc is None else acc + v32 + if is_f32: + out_raw = acc.bitcast(v4i32) + else: + out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) + rel_p = cur - start_p + store_v4i32(tmp_out_i64 + rel_p.extui(i64) * ea.constant(16, type=i64), + out_raw) + scf.YieldOp([]) + + idx_p = start_p + tid_pack + if _use_single_buf_2stage: + # Single buffer: 8KB LDS, 2 barriers per iteration. + loop1 = scf.WhileOp([i32], [idx_p]) + b1 = ir.Block.create_at_start(loop1.before, [i32]) + a1 = ir.Block.create_at_start(loop1.after, [i32]) + with ir.InsertionPoint(b1): + cur = b1.arguments[0] + cond = _u(cur) < end_p + scf.ConditionOp(cond, [cur]) + with ir.InsertionPoint(a1): + cur = a1.arguments[0] + _build_reduce_body(cur, smem_base_expr=None) + gpu.barrier() # barrier 2: protect smem before next iter's writes + scf.YieldOp([cur + stride_pack]) + else: + # Double buffer: 16KB LDS, 1 barrier per iteration (parity trick). + # The parity alternates between the two smem halves so warp-0 reads + # from half-A while all warps write the next pack to half-B. + loop1 = scf.WhileOp([i32, i32], [idx_p, ea.constant(0, type=i32)]) + b1 = ir.Block.create_at_start(loop1.before, [i32, i32]) + a1 = ir.Block.create_at_start(loop1.after, [i32, i32]) + with ir.InsertionPoint(b1): + cur = b1.arguments[0] + cond = _u(cur) < end_p + scf.ConditionOp(cond, [cur, b1.arguments[1]]) + with ir.InsertionPoint(a1): + cur = a1.arguments[0] + parity = a1.arguments[1] + sm_base = parity * ea.constant(threads, type=i32) + _build_reduce_body(cur, smem_base_expr=sm_base) + # No barrier 2: parity ensures next iteration writes to opposite + # smem half, so warp-0 reads and all-warp writes are disjoint. + scf.YieldOp([cur + stride_pack, ea.constant(1, type=i32) - parity]) + + _signal_end_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, + self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + # ---- Stage 2: all-gather ---- + if vec_ok: + tid_pack2 = bid_i32 * tnum_gpu_i32 + lane_id + stride_pack2 = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 + + loop2 = scf.WhileOp([i32], [tid_pack2]) + b2 = ir.Block.create_at_start(loop2.before, [i32]) + a2 = ir.Block.create_at_start(loop2.after, [i32]) + with ir.InsertionPoint(b2): + cur = b2.arguments[0] + cond = _u(cur) < ea.constant(part_p, type=i32) + scf.ConditionOp(cond, [cur]) + with ir.InsertionPoint(a2): + cur = a2.arguments[0] + sum_rw = rank_i32 + warp_id + if world_size in {2, 4, 8}: + dst_rank = sum_rw & ea.constant(world_size - 1, type=i32) + else: + dst_rank = _u(sum_rw) % ea.constant(world_size, type=i32) + tmp_base = ea.select_by_index(warp_id, tmp_ptrs_arr) + raw = load_v4i32(tmp_base + cur.extui(i64) * ea.constant(16, type=i64)) + dst_pack = dst_rank * ea.constant(part_p, type=i32) + cur + store_v4i32(out_ptr_i64 + dst_pack.extui(i64) * ea.constant(16, type=i64), + raw) + scf.YieldOp([cur + stride_pack2]) + else: + # Non-vectorized fallback (world_size=6 or num_packs % world_size != 0) + tid_i32 = bid_i32 * ea.constant(threads, type=i32) + lane_i32 + stride_i32 = gpu.grid_dim.x.ir_value() * ea.constant(threads, type=i32) + + loop2 = scf.WhileOp([i32], [tid_i32]) + b2 = ir.Block.create_at_start(loop2.before, [i32]) + a2 = ir.Block.create_at_start(loop2.after, [i32]) + with ir.InsertionPoint(b2): + cur = b2.arguments[0] + cond = _u(cur) < ea.constant(largest_part_p, type=i32) + scf.ConditionOp(cond, [cur]) + with ir.InsertionPoint(a2): + cur = a2.arguments[0] + for p in range_constexpr(world_size): + if p == world_size - 1: + ok = ea.constant(1, type=T.bool()) + else: + ok = _u(cur) < ea.constant(part_p, type=i32) + ifp = scf.IfOp(ok, results_=[], has_else=False) + with ir.InsertionPoint(ifp.then_block): + src_off = cur.extui(i64) * ea.constant(16, type=i64) + raw = load_v4i32(tmp_ptrs_arr[p] + src_off) + dst_pack_idx = ea.constant(p * part_p, type=i32) + cur + dst_off = dst_pack_idx.extui(i64) * ea.constant(16, type=i64) + store_v4i32(out_ptr_i64 + dst_off, raw) + scf.YieldOp([]) + scf.YieldOp([cur + stride_i32]) + + # ----------------------------------------------------------------------- + # GPU Kernel: 2-stage write-mode (large tensors, writes reduced result + # directly to REMOTE output buffers via XGMI) + # ----------------------------------------------------------------------- + @flyc.kernel + def allreduce_2stage_write_mode( + rank: Int32, + self_sg: Int64, + sg_ptrs: Int64, + inp_ptr: Int64, + out_ptrs: Int64, + tmp_ptrs: Int64, + ): + import math + + + i32, i64 = T.i32, T.i64 + idx = ir.IndexType.get() + v4i32 = T.i32x4 + if is_f32: + v4f32 = T.f32x4 + else: + v8half = T.bf16x8 if is_bf16 else T.f16x8 + v8f32 = T.vec(8, T.f32) + + gpu_func_op = _set_workgroup_size(threads) + + lane_i32 = ea.index_cast(i32, gpu.thread_id("x")) + bid_i32 = ea.index_cast(i32, gpu.block_id("x")) + rank_i32 = rank.ir_value() + self_sg_i64 = self_sg.ir_value() + sg_ptrs_i64 = sg_ptrs.ir_value() + inp_ptr_i64 = inp_ptr.ir_value() + out_ptrs_i64 = out_ptrs.ir_value() + tmp_ptrs_i64 = tmp_ptrs.ir_value() + + sgs = [load_device_ptr(sg_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + out_ptrs_arr = [load_device_ptr(out_ptrs_i64, ea.constant(i, type=i32)) for i in range(8)] + + tnum_gpu_i32 = ea.constant(tnum_gpu, type=i32) + log2_tnum = int(math.log2(tnum_gpu)) + warp_id = _u(lane_i32) >> ea.constant(log2_tnum, type=i32) + warp_base = warp_id * tnum_gpu_i32 + lane_id = lane_i32 - warp_base + tid_pack = bid_i32 * tnum_gpu_i32 + lane_id + stride_pack = gpu.grid_dim.x.ir_value() * tnum_gpu_i32 + + smem_sym_wm = f"allreduce_smem_wm_ws{world_size}_t{threads}" + n_smem_wm = 2 * threads + allocator_wm = SmemAllocator(None, global_sym_name=smem_sym_wm) + smem_wm_off = allocator_wm._align(allocator_wm.ptr, 16) + allocator_wm.ptr = smem_wm_off + n_smem_wm * 16 + with ir.InsertionPoint.at_block_begin(gpu_func_op.operation.block): + allocator_wm.finalize() + smem_ptr = SmemPtr(allocator_wm.get_base(), smem_wm_off, v4i32, shape=(n_smem_wm,)) + smem_ptr.get() + tmp_out_i64 = load_device_ptr(tmp_ptrs_i64, rank_i32) + + # ---- Stage 1: scatter local input to REMOTE tmp buffers ---- + start_w = warp_id * ea.constant(part_p, type=i32) + is_last_w = warp_id == ea.constant(world_size - 1, type=i32) + end_w_if = scf.IfOp(is_last_w, results_=[i32], has_else=True) + with ir.InsertionPoint(end_w_if.then_block): + scf.YieldOp([ea.constant(num_packs, type=i32)]) + with ir.InsertionPoint(end_w_if.else_block): + scf.YieldOp([start_w + ea.constant(part_p, type=i32)]) + end_w = end_w_if.results[0] + + idx_s1 = start_w + tid_pack + loop_s1 = scf.WhileOp([i32, i32], [idx_s1, stride_pack]) + bs1 = ir.Block.create_at_start(loop_s1.before, [i32, i32]) + as1 = ir.Block.create_at_start(loop_s1.after, [i32, i32]) + with ir.InsertionPoint(bs1): + cur = bs1.arguments[0] + cond = _u(cur) < end_w + scf.ConditionOp(cond, [cur, bs1.arguments[1]]) + with ir.InsertionPoint(as1): + cur = as1.arguments[0] + stride_s1 = as1.arguments[1] + raw = load_v4i32(inp_ptr_i64 + cur.extui(i64) * ea.constant(16, type=i64)) + rel_idx = cur - start_w + dst_off = rank_i32 * ea.constant(part_p, type=i32) + rel_idx + dst_tmp = load_device_ptr(tmp_ptrs_i64, warp_id) + tmp_addr = dst_tmp + dst_off.extui(i64) * ea.constant(16, type=i64) + is_tmp_null = dst_tmp == ea.constant(0, type=i64) + tmp_low4 = tmp_addr & ea.constant(0xF, type=i64) + is_tmp_misaligned = tmp_low4 != ea.constant(0, type=i64) + bad_tmp_addr = is_tmp_null | is_tmp_misaligned + if_tmp_ok = scf.IfOp(bad_tmp_addr, results_=[], has_else=True) + with ir.InsertionPoint(if_tmp_ok.then_block): + scf.YieldOp([]) + with ir.InsertionPoint(if_tmp_ok.else_block): + store_v4i32(tmp_addr, raw) + scf.YieldOp([]) + scf.YieldOp([cur + stride_s1, stride_s1]) + + # Signal all ranks that stage 1 is complete + _signal_start_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, + self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + # ---- Stage 2: reduce local tmp and write to REMOTE outputs ---- + part_p_i32 = ea.constant(part_p, type=i32) + loop_s2 = scf.WhileOp([i32, i32], [tid_pack, stride_pack]) + bs2 = ir.Block.create_at_start(loop_s2.before, [i32, i32]) + as2 = ir.Block.create_at_start(loop_s2.after, [i32, i32]) + with ir.InsertionPoint(bs2): + cur = bs2.arguments[0] + cond = _u(cur) < part_p_i32 + scf.ConditionOp(cond, [cur, bs2.arguments[1]]) + with ir.InsertionPoint(as2): + cur = as2.arguments[0] + stride_s2 = as2.arguments[1] + + src_off = warp_id * ea.constant(part_p, type=i32) + cur + load_addr = tmp_out_i64 + src_off.extui(i64) * ea.constant(16, type=i64) + is_tmpout_null = tmp_out_i64 == ea.constant(0, type=i64) + load_low4 = load_addr & ea.constant(0xF, type=i64) + is_load_misaligned = load_low4 != ea.constant(0, type=i64) + bad_load_addr = is_tmpout_null | is_load_misaligned + raw_if = scf.IfOp(bad_load_addr, results_=[v4i32], has_else=True) + with ir.InsertionPoint(raw_if.then_block): + scf.YieldOp([ea.constant_vector(0, v4i32)]) + with ir.InsertionPoint(raw_if.else_block): + scf.YieldOp([load_v4i32(load_addr)]) + raw = raw_if.results[0] + + sm_idx = ea.index_cast(idx, lane_i32) + smem_ptr.store(raw, [sm_idx]) + gpu.barrier() + + warp_id_local = _u(lane_i32) >> ea.constant(log2_tnum, type=i32) + lane_id_local = lane_i32 - warp_id_local * ea.constant(tnum_gpu, type=i32) + + raw_vals = [] + for wi in range_constexpr(world_size): + sm_i_idx = ea.index_cast(idx, ea.constant(wi * tnum_gpu, type=i32) + lane_id_local) + raw_vals.append(smem_ptr.load([sm_i_idx])) + + acc = None + for wi in range_constexpr(world_size): + raw_i = raw_vals[wi] + if is_f32: + vf = raw_i.bitcast(v4f32) + acc = vf if acc is None else acc + vf + else: + v16 = ev.bitcast(v8half, raw_i) + v32 = v16.extf(v8f32) + acc = v32 if acc is None else acc + v32 + if is_f32: + out_raw = acc.bitcast(v4i32) + else: + out_raw = ev.bitcast(v4i32, acc.truncf(v8half)) + + dst_out_off = rank_i32 * ea.constant(part_p, type=i32) + cur + dst_byte_off = dst_out_off.extui(i64) * ea.constant(16, type=i64) + + # Each warp writes its reduced partition directly to the target + # output via flat_store_dwordx4 nt. The nt hint bypasses L1/L2 + # and works for all memory types (including IPC-mapped addresses). + dst_ptr = out_ptrs_arr[0] + for w in range_constexpr(1, world_size): + is_warp_w = warp_id_local == ea.constant(w, type=i32) + dst_ptr = ea.select(is_warp_w, out_ptrs_arr[w], dst_ptr) + out_addr = dst_ptr + dst_byte_off + is_out_null = dst_ptr == ea.constant(0, type=i64) + out_low4 = out_addr & ea.constant(0xF, type=i64) + is_out_misaligned = out_low4 != ea.constant(0, type=i64) + bad_out_addr = is_out_null | is_out_misaligned + if_out_ok = scf.IfOp(bad_out_addr, results_=[], has_else=True) + with ir.InsertionPoint(if_out_ok.then_block): + scf.YieldOp([]) + with ir.InsertionPoint(if_out_ok.else_block): + store_v4i32_nt(out_addr, out_raw) + scf.YieldOp([]) + + scf.YieldOp([cur + stride_s2, stride_s2]) + + _signal_end_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, + self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + # ----------------------------------------------------------------------- + # Host launchers (@flyc.jit) + # ----------------------------------------------------------------------- + + @flyc.jit + def run_1stage_arr( + rank: Int32, + grid_x: Int32, + self_sg: Int64, + sg_ptrs: Int64, + in_ptrs: Int64, + out_ptr: Int64, + stream: Stream = Stream(None), + ): + allreduce_1stage_arr(rank, self_sg, sg_ptrs, in_ptrs, out_ptr).launch( + grid=(grid_x, 1, 1), + block=(threads, 1, 1), + stream=stream, + ) + + @flyc.jit + def run_2stage_arr( + rank: Int32, + grid_x: Int32, + self_sg: Int64, + sg_ptrs: Int64, + in_ptrs: Int64, + tmp_ptrs: Int64, + out_ptr: Int64, + stream: Stream = Stream(None), + ): + """Launch 2-stage allreduce (arr variant, CUDAGraph-compatible).""" + allreduce_2stage_arr(rank, self_sg, sg_ptrs, in_ptrs, tmp_ptrs, out_ptr).launch( + grid=(grid_x, 1, 1), + block=(threads, 1, 1), + stream=stream, + ) + + @flyc.jit + def run_2stage_write_mode( + rank: Int32, + grid_x: Int32, + self_sg: Int64, + sg_ptrs: Int64, + inp_ptr: Int64, + out_ptrs: Int64, + tmp_ptrs: Int64, + stream: Stream = Stream(None), + ): + """Launch 2-stage write-mode allreduce (large tensors).""" + allreduce_2stage_write_mode(rank, self_sg, sg_ptrs, inp_ptr, out_ptrs, tmp_ptrs).launch( + grid=(grid_x, 1, 1), + block=(threads, 1, 1), + stream=stream, + ) + + # Unique function names per (N, dtype_str, world_size, threads) to prevent + # file-cache collisions (N is baked into kernel body, not the cache key). + _suffix = f"_N{N}_{dtype_str}_ws{world_size}_t{threads}" + run_1stage_arr.func.__name__ = f"run_1stage_arr{_suffix}" + run_2stage_arr.func.__name__ = f"run_2stage_arr{_suffix}" + run_2stage_write_mode.func.__name__ = f"run_2stage_write_mode{_suffix}" + + return { + "run_1stage_arr": run_1stage_arr, + "run_2stage_arr": run_2stage_arr, + "run_2stage_write_mode": run_2stage_write_mode, + } diff --git a/kernels/hgemm_ar.py b/kernels/hgemm_ar.py new file mode 100644 index 000000000..c252ae838 --- /dev/null +++ b/kernels/hgemm_ar.py @@ -0,0 +1,985 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +import torch +import functools +import numpy as np +import torch.nn.functional as F +import torch.distributed as dist +from abc import ABC, abstractmethod + +import flydsl +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.expr.typing import T, Int32, Int64, Stream +from flydsl.expr import range_constexpr, arith, vector, gpu, rocdl +from flydsl._mlir import ir +from flydsl.runtime.device import get_rocm_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr +from flydsl.compiler.kernel_function import CompilationContext +from flydsl._mlir.dialects import llvm, fly, memref, scf +from flydsl.compiler.protocol import fly_values +from flydsl.expr.buffer_ops import _unwrap_value + +from kernels.custom_all_reduce import init_custom_ar, FlyDSLAllreduce +from kernels.custom_all_reduce_kernel import _signal_start_sync, _signal_end_sync, load_device_ptr, select_by_index, load_v4i32, store_v4i32, store_v4i32_nt +from kernels.tensor_shim import get_dtype_in_kernel, GTensor, STensor, _to_raw +fm_fast = arith.FastMathFlags.fast + + +SPLIT_K_COUNTER_MAX_LEN = 128 + + +def swizzle_xor16(row, col_in_bytes, k_blocks16): + return col_in_bytes ^ ((row % k_blocks16) * 16) + + +class WmmaHalfBase(ABC): + @abstractmethod + def __init__(self, dtype: str): + pass + + @abstractmethod + def __call__(self, a_frag, b_frag, c_frag): + pass + + +class WmmaHalf_m16n16k16(WmmaHalfBase): + WMMA_M = 16 + WMMA_N = 16 + WMMA_K = 16 + WMMA_A_FRAG_VALUES = 4 + WMMA_B_FRAG_VALUES = 4 + WMMA_C_FRAG_VALUES = 4 + + def __init__(self, dtype: str): + self.dtype = dtype + + def __call__(self, a_frag, b_frag, c_frag): + if self.dtype == 'bf16': + a_frag_vi16 = vector.bitcast(T.vec(self.WMMA_A_FRAG_VALUES, T.i16), a_frag) + b_frag_vi16 = vector.bitcast(T.vec(self.WMMA_B_FRAG_VALUES, T.i16), b_frag) + c_frag_new = rocdl.mfma_f32_16x16x16bf16_1k(T.f32x4, [a_frag_vi16, b_frag_vi16, c_frag, 0, 0, 0]) + return c_frag_new + else: + c_frag_new = rocdl.mfma_f32_16x16x16f16(T.vec(self.WMMA_C_FRAG_VALUES, T.f32), [a_frag, b_frag, c_frag, 0, 0, 0]) + return c_frag_new + + +class WmmaHalf_m16n16k32(WmmaHalfBase): + WMMA_M = 16 + WMMA_N = 16 + WMMA_K = 32 + WMMA_A_FRAG_VALUES = 8 + WMMA_B_FRAG_VALUES = 8 + WMMA_C_FRAG_VALUES = 4 + + def __init__(self, dtype: str): + self.dtype = dtype + + def __call__(self, a_frag, b_frag, c_frag): + if self.dtype == 'bf16': + c_frag_new = rocdl.mfma_f32_16x16x32_bf16(T.vec(self.WMMA_C_FRAG_VALUES, T.f32), a_frag, b_frag, c_frag, 0, 0, 0).res + return c_frag_new + else: + c_frag_new = rocdl.mfma_f32_16x16x32_f16(T.vec(self.WMMA_C_FRAG_VALUES, T.f32), a_frag, b_frag, c_frag, 0, 0, 0).res + return c_frag_new + + +class OnlineScheduler: + def __init__(self, total_signals: int, init_count: int = 0): + self.total_signals = total_signals + self.current_signal_id = init_count + self.remaining = init_count + + def release(self, count: int): + count = min(count, self.total_signals - self.current_signal_id) + self.current_signal_id += count + self.remaining += count + + def consume(self, count: int): + count = min(count, self.remaining) + self.remaining -= count + return count + + +@functools.lru_cache(maxsize=1024) +def compile_hgemm_ar_kernel( + world_size: int, + dtype: str, + n: int, + k: int, + TILE_M: int = 128, + TILE_N: int = 128, + TILE_K: int = 64, + SPLIT_K: int = 1, + BLOCK_M_WARPS: int = 1, + BLOCK_N_WARPS: int = 4, + B_PRE_SHUFFLE: bool = False, + B_TO_LDS: bool = False, + USE_CROSS_DEVICE_ATOMIC: bool = False, +): + IS_SPLIT_K = SPLIT_K > 1 + BLOCK_K = TILE_K + assert (k % SPLIT_K == 0) and (k // SPLIT_K >= 1) + ks = k // SPLIT_K + assert (ks % BLOCK_K == 0) and (ks // BLOCK_K >= 1) + assert BLOCK_K >= 32 + if B_PRE_SHUFFLE == True: + B_TO_LDS = False + assert B_TO_LDS == False + GPU_ARCH = get_rocm_arch() + if GPU_ARCH == 'gfx942': + WMMA_IMPL = WmmaHalf_m16n16k16(dtype) + DMA_BYTES = 4 + MFMA_PER_WARP_K = 2 + ASYNC_COPY = False + else: + WMMA_IMPL = WmmaHalf_m16n16k32(dtype) + DMA_BYTES = 16 + MFMA_PER_WARP_K = 1 + ASYNC_COPY = True + + # Fixed parameters: + WARP_SIZE = 64 + DTYPE_BYTES = 2 + LDG_VEC_SIZE = 8 + STAGES = 2 + + # Propagated parameters: + WMMA_M = WMMA_IMPL.WMMA_M + WMMA_N = WMMA_IMPL.WMMA_N + WMMA_K = WMMA_IMPL.WMMA_K + WMMA_A_FRAG_VALUES = WMMA_IMPL.WMMA_A_FRAG_VALUES + WMMA_B_FRAG_VALUES = WMMA_IMPL.WMMA_B_FRAG_VALUES + WMMA_C_FRAG_VALUES = WMMA_IMPL.WMMA_C_FRAG_VALUES + WARP_ATOM_M = WMMA_M + WARP_ATOM_N = WMMA_N + WARP_ATOM_K = WMMA_K * MFMA_PER_WARP_K + BLOCK_K_LOOPS = ks // BLOCK_K + WARP_K_STEPS = BLOCK_K // WARP_ATOM_K + assert (BLOCK_K % WARP_ATOM_K == 0) and (WARP_K_STEPS >= 1) + BLOCK_THREADS = BLOCK_M_WARPS * BLOCK_N_WARPS * WARP_SIZE + WARP_M_STEPS = TILE_M // BLOCK_M_WARPS // WARP_ATOM_M + WARP_N_STEPS = TILE_N // BLOCK_N_WARPS // WARP_ATOM_N + assert (WARP_M_STEPS >= 1) and (WARP_N_STEPS >= 1) + assert TILE_M % (BLOCK_M_WARPS * WARP_ATOM_M) == 0 + assert TILE_N % (BLOCK_N_WARPS * WARP_ATOM_N) == 0 + WARP_M = WARP_M_STEPS * WARP_ATOM_M + WARP_N = WARP_N_STEPS * WARP_ATOM_N + BLOCK_M = BLOCK_M_WARPS * WARP_M + BLOCK_N = BLOCK_N_WARPS * WARP_N + assert (n >= BLOCK_N) and (n % BLOCK_N == 0) + BLOCK_MK_SIZE = BLOCK_M * BLOCK_K + BLOCK_NK_SIZE = BLOCK_N * BLOCK_K + BLOCK_MN_SIZE = BLOCK_M * BLOCK_N + LDG_A_X_THREADS = BLOCK_K // LDG_VEC_SIZE + LDG_B_X_THREADS = BLOCK_K // LDG_VEC_SIZE + LDG_C_X_THREADS = BLOCK_N // LDG_VEC_SIZE + BLOCK_VECS = LDG_VEC_SIZE * BLOCK_THREADS + LDG_REG_A_COUNT = BLOCK_MK_SIZE // BLOCK_VECS + LDG_REG_B_COUNT = BLOCK_NK_SIZE // BLOCK_VECS + LDG_REG_C_COUNT = BLOCK_MN_SIZE // BLOCK_VECS + assert (LDG_REG_A_COUNT >= 1) and (LDG_REG_B_COUNT >= 1) and (LDG_REG_C_COUNT >= 1) + assert (BLOCK_MK_SIZE % BLOCK_VECS == 0) + assert (BLOCK_NK_SIZE % BLOCK_VECS == 0) + assert (BLOCK_MN_SIZE % BLOCK_VECS == 0) + BLOCK_K_BYTES = BLOCK_K * DTYPE_BYTES + + # LDS parameters: + allocator = SmemAllocator(None, arch=GPU_ARCH, global_sym_name="smem") + smem_a_offset = allocator._align(allocator.ptr, 16) + AS_BYTES = STAGES * BLOCK_M * BLOCK_K * DTYPE_BYTES + AS_BYTES = max(AS_BYTES, BLOCK_M * BLOCK_N * DTYPE_BYTES) + allocator.ptr = smem_a_offset + AS_BYTES + SMEM_USE = AS_BYTES + if B_TO_LDS: + smem_b_offset = allocator._align(allocator.ptr, 16) + allocator.ptr = smem_b_offset + STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES + SMEM_USE += STAGES * BLOCK_N * BLOCK_K * DTYPE_BYTES + assert SMEM_USE <= 163840 + LDG_ASYNC_VEC_SIZE = DMA_BYTES // DTYPE_BYTES + LDG_A_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE + LDG_REG_A_COUNT_AS = BLOCK_MK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS + LDG_B_X_THREADS_AS = BLOCK_K // LDG_ASYNC_VEC_SIZE + LDG_REG_B_COUNT_AS = BLOCK_NK_SIZE // LDG_ASYNC_VEC_SIZE // BLOCK_THREADS + + KERNEL_NAME = f"hgemm_ar_{dtype}_{BLOCK_M}x{BLOCK_N}x{BLOCK_K}_S{STAGES}TN" + KERNEL_NAME += "_NA" if not ASYNC_COPY else "_AS" + if B_PRE_SHUFFLE: + KERNEL_NAME += "_BP" + if IS_SPLIT_K: + KERNEL_NAME += f"_SPK{SPLIT_K}" + USE_CROSS_DEVICE_ATOMIC = True + if B_TO_LDS: + KERNEL_NAME += f"_BS" + if USE_CROSS_DEVICE_ATOMIC: + KERNEL_NAME += f"_CATOM" + + @flyc.kernel + def hgemm_ar_kernel( + rank: Int32, + self_sg: Int64, + sg_ptrs: Int64, + tmp_ptrs: Int64, + out_ptrs: Int64, + C: fx.Tensor, + A: fx.Tensor, + B: fx.Tensor, + m: fx.Int32, + COUNTER: fx.Tensor, + signal_state: fx.Int32, + ): + dtype_ = get_dtype_in_kernel(dtype) + _ptr_type = ir.Type.parse("!llvm.ptr<1>") + _i64_type = T.i64 + c_zero_d = arith.constant(0.0, type=dtype_) + acc_init = arith.constant_vector(0.0, T.vec(WMMA_C_FRAG_VALUES, T.f32)) + + A_ = GTensor(A, dtype=dtype_, shape=(-1, k)) + B_ = GTensor(B, dtype=dtype_, shape=(n, k)) + C_ = GTensor(C, dtype=dtype_, shape=(-1, n)) + base_ptr = allocator.get_base() + smem_a_ptr = SmemPtr(base_ptr, smem_a_offset, dtype_, shape=(STAGES * BLOCK_M * BLOCK_K,)) + as_ = STensor(smem_a_ptr, dtype_, shape=(STAGES, BLOCK_M, BLOCK_K)) + if B_TO_LDS: + smem_b_ptr = SmemPtr(base_ptr, smem_b_offset, dtype_, shape=(STAGES * BLOCK_N * BLOCK_K,)) + bs_ = STensor(smem_b_ptr, dtype_, shape=(STAGES, BLOCK_N, BLOCK_K)) + smem_c_ptr = SmemPtr(base_ptr, smem_a_offset, dtype_, shape=(BLOCK_M * BLOCK_N,)) + cs_ = STensor(smem_c_ptr, dtype_, shape=(BLOCK_M, BLOCK_N)) + if B_PRE_SHUFFLE: + # origin: n // WARP_ATOM_N, WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, LDG_VEC_SIZE + SHUFFLED_B_ = GTensor(B, dtype=dtype_, shape=( + n // WARP_ATOM_N, k // WARP_ATOM_K, WARP_ATOM_K // LDG_VEC_SIZE, WARP_ATOM_N, LDG_VEC_SIZE)) + if IS_SPLIT_K: + COUNTER_ = GTensor(COUNTER, dtype=T.i32, shape=(-1,)) + + tid = fx.Int32(fx.thread_idx.x) + wid = tid // WARP_SIZE + w_tid = tid % WARP_SIZE + block_m_idx = fx.block_idx.x + block_n_idx = fx.block_idx.y + ks_idx = fx.Index(fx.block_idx.z) + ks_begin = arith.index_cast(T.i32, ks_idx * ks) + counter_idx = fx.Int32(signal_state * SPLIT_K_COUNTER_MAX_LEN) + fx.block_idx.x * fx.Int32(n // BLOCK_N) + fx.block_idx.y + + m_offset = fx.Index(block_m_idx * BLOCK_M) + n_offset = fx.Index(block_n_idx * BLOCK_N) + k_blocks16 = fx.Int32(BLOCK_K_BYTES // 16) + + warp_m_idx = wid // BLOCK_N_WARPS * WARP_M + warp_n_idx = wid % BLOCK_N_WARPS * WARP_N + ldmatrix_a_m_idx = w_tid % WMMA_M + ldmatrix_a_k_vec_idx = w_tid // WMMA_M * WMMA_A_FRAG_VALUES * MFMA_PER_WARP_K + ldmatrix_b_n_idx = w_tid % WMMA_N + ldmatrix_b_k_vec_idx = w_tid // WMMA_N * WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K + A_FRAGS_LEN = WARP_K_STEPS * WARP_M_STEPS + B_FRAGS_LEN = WARP_K_STEPS * WARP_N_STEPS + C_FRAGS_LEN = WARP_M_STEPS * WARP_N_STEPS + c_frags = [acc_init] * C_FRAGS_LEN + + # communication vars + bid_linear = (fx.block_idx.x * (n // BLOCK_N) + fx.block_idx.y) * SPLIT_K + fx.block_idx.z + rank_i32 = _unwrap_value(rank) + self_sg_i64 = _unwrap_value(self_sg) + sg_ptrs_i64 = _unwrap_value(sg_ptrs) + tmp_ptrs_i64 = _unwrap_value(tmp_ptrs) + out_ptrs_i64 = _unwrap_value(out_ptrs) + bid_i32 = arith.index_cast(T.i32, fx.Index(bid_linear)) + lane_i32 = arith.index_cast(T.i32, fx.Index(fx.thread_idx.x)) + sgs = [load_device_ptr(sg_ptrs_i64, arith.constant(i, type=T.i32)) for i in range(8)] + tmp_ptrs_arr = [load_device_ptr(tmp_ptrs_i64, arith.constant(i, type=T.i32)) for i in range(8)] + out_ptrs_arr = [load_device_ptr(out_ptrs_i64, arith.constant(i, type=T.i32)) for i in range(8)] + self_tmp_ptr = select_by_index(arith.constant(0, type=T.i32), tmp_ptrs_arr) + self_out_ptr = select_by_index(rank_i32, out_ptrs_arr) + + def zero_c(): + # zero c + cond_ks0 = arith.cmpi(arith.CmpIPredicate.eq, ks_idx, fx.Index(0)) + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + zero_vec = vector.broadcast(T.vec(LDG_VEC_SIZE, dtype_), c_zero_d) + vec_i32x4 = vector.bitcast(T.i32x4, zero_vec) + for i in range_constexpr(LDG_REG_C_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_C_X_THREADS + n_local_idx = global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE + row_idx = m_offset + fx.Index(m_local_idx) + cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(m)) + cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + with ir.InsertionPoint(cond_boundary_if.then_block): + linear_byte_offset = C_.linear_offset((row_idx, n_offset + n_local_idx)) * DTYPE_BYTES + byte_offset_i64 = arith.index_cast(T.i64, linear_byte_offset) + if USE_CROSS_DEVICE_ATOMIC: + store_v4i32_nt(self_out_ptr + byte_offset_i64, vec_i32x4) + else: + store_v4i32_nt(self_tmp_ptr + byte_offset_i64, vec_i32x4) + # C_.vec_store((row_idx, n_offset + n_local_idx), zero_vec, LDG_VEC_SIZE) + scf.YieldOp([]) + scf.YieldOp([]) + if IS_SPLIT_K: + rocdl.sched_barrier(0) + gpu.barrier() + # write flag + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + is_t0_cond = arith.cmpi(arith.CmpIPredicate.eq, fx.Index(tid), fx.Index(0)) + is_t0_cond_if = scf.IfOp(is_t0_cond, results_=[], has_else=False) + with ir.InsertionPoint(is_t0_cond_if.then_block): + counter_base_ptr = fly.extract_aligned_pointer_as_index(_ptr_type, fly_values(COUNTER)[0]) + counter_base_ptr = llvm.PtrToIntOp(_i64_type, counter_base_ptr).result + counter_byte_offset = arith.index_cast(T.i64, fx.Index(counter_idx) * fx.Index(4)) + counter_ptr = llvm.AddOp(counter_base_ptr, counter_byte_offset, llvm.IntegerOverflowFlags(0)).result + counter_ptr = llvm.IntToPtrOp(_ptr_type, counter_ptr).result + counter_ptr_v = counter_ptr._value if hasattr(counter_ptr, "_value") else counter_ptr + llvm.InlineAsmOp(None, [], "buffer_wbl2 sc0 sc1", "", has_side_effects=True) + llvm.InlineAsmOp( + None, [counter_ptr_v, arith.constant(1, type=T.i32)], + "global_store_dword $0, $1, off sc0 sc1", "v,v", + has_side_effects=True, + ) + rocdl.s_waitcnt(0) + scf.YieldOp([]) + scf.YieldOp([]) + rocdl.sched_barrier(0) + gpu.barrier() + # zero signal + cond_ks0_if = scf.IfOp(cond_ks0, results_=[], has_else=False) + with ir.InsertionPoint(cond_ks0_if.then_block): + clean_cond = arith.cmpi(arith.CmpIPredicate.ult, fx.Index(tid), fx.Index(SPLIT_K_COUNTER_MAX_LEN)) + clean_cond_if = scf.IfOp(clean_cond, results_=[], has_else=False) + with ir.InsertionPoint(clean_cond_if.then_block): + clean_counter_idx = fx.Int32(((signal_state + 2) % 3) * SPLIT_K_COUNTER_MAX_LEN) + fx.Index(tid) + COUNTER_[fx.Index(clean_counter_idx)] = arith.constant(0, type=T.i32) + scf.YieldOp([]) + scf.YieldOp([]) + rocdl.sched_barrier(0) + gpu.barrier() + + def split_k_barrier(): + if True: + init_cur = arith.constant(0, type=T.i32) + w = scf.WhileOp([T.i32], [init_cur]) + before = ir.Block.create_at_start(w.before, [T.i32]) + after = ir.Block.create_at_start(w.after, [T.i32]) + with ir.InsertionPoint(before): + cur = before.arguments[0] + need_wait = arith.CmpIOp(arith.CmpIPredicate.eq, cur, arith.constant(0, type=T.i32)).result + scf.ConditionOp(need_wait, [cur]) + with ir.InsertionPoint(after): + counter_base_ptr = fly.extract_aligned_pointer_as_index(_ptr_type, fly_values(COUNTER)[0]) + counter_base_ptr = llvm.PtrToIntOp(_i64_type, counter_base_ptr).result + counter_byte_offset = arith.index_cast(T.i64, fx.Index(counter_idx) * fx.Index(4)) + counter_ptr = llvm.AddOp(counter_base_ptr, counter_byte_offset, llvm.IntegerOverflowFlags(0)).result + counter_ptr = llvm.IntToPtrOp(_ptr_type, counter_ptr).result + counter_ptr_v = counter_ptr._value if hasattr(counter_ptr, "_value") else counter_ptr + data = llvm.InlineAsmOp( + T.i32, [counter_ptr_v], + "global_load_dword $0, $1, off sc1", "=v,v", + has_side_effects=True, + ).result + rocdl.s_waitcnt(0) + scf.YieldOp([data]) + gpu.barrier() + + def ldg_a(k_offset): + vecs = [] + for i in range_constexpr(LDG_REG_A_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_A_X_THREADS + k_local_idx = global_tid % LDG_A_X_THREADS * LDG_VEC_SIZE + row_idx = m_offset + fx.Index(m_local_idx) + safe_row_idx = arith.select( + arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(m)), + row_idx, + fx.Index(0), + ) + col_idx = fx.Index(k_offset + k_local_idx) + vec = A_.vec_load((safe_row_idx, col_idx), LDG_VEC_SIZE) + vecs.append(vec) + return vecs + + def sts_a(vecs, lds_stage): + for i in range_constexpr(LDG_REG_A_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_A_X_THREADS + k_local_idx = global_tid % LDG_A_X_THREADS * LDG_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(m_local_idx, col_in_bytes, k_blocks16) + as_.vec_store((fx.Index(lds_stage), m_local_idx, col_in_bytes // DTYPE_BYTES), vecs[i], LDG_VEC_SIZE) + + def ldg_b(k_offset): + vecs = [] + for i in range_constexpr(LDG_REG_B_COUNT): + global_tid = BLOCK_THREADS * i + tid + n_local_idx = global_tid // LDG_B_X_THREADS + k_local_idx = global_tid % LDG_B_X_THREADS * LDG_VEC_SIZE + row_idx = n_offset + fx.Index(n_local_idx) + safe_row_idx = arith.select( + arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(n)), + row_idx, + fx.Index(0), + ) + col_idx = fx.Index(k_offset + k_local_idx) + vec = B_.vec_load((safe_row_idx, col_idx), LDG_VEC_SIZE) + vecs.append(vec) + return vecs + + def sts_b(vecs, lds_stage): + for i in range_constexpr(LDG_REG_B_COUNT): + global_tid = BLOCK_THREADS * i + tid + n_local_idx = global_tid // LDG_B_X_THREADS + k_local_idx = global_tid % LDG_B_X_THREADS * LDG_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(n_local_idx, col_in_bytes, k_blocks16) + bs_.vec_store((fx.Index(lds_stage), n_local_idx, col_in_bytes // DTYPE_BYTES), vecs[i], LDG_VEC_SIZE) + + def ldg_sts_a_async(k_offset, lds_stage): + for i in range_constexpr(LDG_REG_A_COUNT_AS): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = global_tid // LDG_A_X_THREADS_AS + k_local_idx = global_tid % LDG_A_X_THREADS_AS * LDG_ASYNC_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(m_local_idx, col_in_bytes, k_blocks16) + row_idx = m_offset + fx.Index(m_local_idx) + safe_row_idx = arith.select( + arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(m)), + row_idx, + fx.Index(0), + ) + col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) + # get offset + global_offset = A_.linear_offset((safe_row_idx, col_idx)) * DTYPE_BYTES + global_offset = arith.index_cast(T.i32, global_offset) + lds_offset = as_.linear_offset((fx.Index(lds_stage), m_local_idx, k_local_idx)) * DTYPE_BYTES + # get lds ptr + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_addr = memref.extract_aligned_pointer_as_index(as_.memptr) + lds_offset + lds_addr_ = rocdl.readfirstlane(T.i64, arith.index_cast(T.i64, lds_addr)) + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) + # dma copy + rocdl.raw_ptr_buffer_load_lds( + A_.rsrc, + lds_ptr, + arith.constant(DMA_BYTES, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(1, type=T.i32), + ) + + def ldg_sts_b_async(k_offset, lds_stage): + for i in range_constexpr(LDG_REG_B_COUNT_AS): + global_tid = BLOCK_THREADS * i + tid + n_local_idx = global_tid // LDG_B_X_THREADS_AS + k_local_idx = global_tid % LDG_B_X_THREADS_AS * LDG_ASYNC_VEC_SIZE + col_in_bytes = k_local_idx * DTYPE_BYTES + col_in_bytes = swizzle_xor16(n_local_idx, col_in_bytes, k_blocks16) + row_idx = n_offset + fx.Index(n_local_idx) + safe_row_idx = arith.select( + arith.cmpi(arith.CmpIPredicate.ult, row_idx, fx.Index(n)), + row_idx, + fx.Index(0), + ) + col_idx = fx.Index(k_offset + col_in_bytes // DTYPE_BYTES) + # get offset + global_offset = B_.linear_offset((safe_row_idx, col_idx)) * DTYPE_BYTES + global_offset = arith.index_cast(T.i32, global_offset) + lds_offset = bs_.linear_offset((fx.Index(lds_stage), n_local_idx, k_local_idx)) * DTYPE_BYTES + # get lds ptr + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_addr = memref.extract_aligned_pointer_as_index(bs_.memptr) + lds_offset + lds_addr_ = rocdl.readfirstlane(T.i64, arith.index_cast(T.i64, lds_addr)) + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_addr_) + # dma copy + rocdl.raw_ptr_buffer_load_lds( + B_.rsrc, + lds_ptr, + arith.constant(DMA_BYTES, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(1, type=T.i32), + ) + + def lds_matrix_a(lds_stage): + s = fx.Index(lds_stage) + a_frags = [0] * (WARP_K_STEPS * WARP_M_STEPS) + for ii in range_constexpr(WARP_M_STEPS): + warp_atom_m_idx = warp_m_idx + ii * WARP_ATOM_M + for kk in range_constexpr(WARP_K_STEPS): + warp_atom_k_idx = kk * WARP_ATOM_K + row = warp_atom_m_idx + ldmatrix_a_m_idx + col_in_bytes = (warp_atom_k_idx + ldmatrix_a_k_vec_idx) * DTYPE_BYTES + col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) + vec = as_.vec_load((s, row, col_in_bytes // DTYPE_BYTES), WMMA_A_FRAG_VALUES * MFMA_PER_WARP_K) + a_frags[kk * WARP_M_STEPS + ii] = vec + return a_frags + + def lds_matrix_b(lds_stage): + s = fx.Index(lds_stage) + b_frags = [0] * (WARP_K_STEPS * WARP_N_STEPS) + for ii in range_constexpr(WARP_N_STEPS): + warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N + for kk in range_constexpr(WARP_K_STEPS): + warp_atom_k_idx = kk * WARP_ATOM_K + row = warp_atom_n_idx + ldmatrix_b_n_idx + col_in_bytes = (warp_atom_k_idx + ldmatrix_b_k_vec_idx) * DTYPE_BYTES + col_in_bytes = swizzle_xor16(row, col_in_bytes, k_blocks16) + vec = bs_.vec_load((s, row, col_in_bytes // DTYPE_BYTES), WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K) + b_frags[kk * WARP_N_STEPS + ii] = vec + return b_frags + + def ldg_matrix_b(k_offset): + vecs = [] + b_n_intra_base = ldmatrix_b_n_idx + b_k_intra_vec = ldmatrix_b_k_vec_idx // LDG_VEC_SIZE + b_n0_base = n_offset // WARP_ATOM_N + warp_n_idx // WARP_ATOM_N + b_k0_base = k_offset // WARP_ATOM_K + for kk in range_constexpr(WARP_K_STEPS): + b_k0 = b_k0_base + kk + for ii in range_constexpr(WARP_N_STEPS): + b_n0 = b_n0_base + ii + if not B_PRE_SHUFFLE: + warp_atom_n_idx = warp_n_idx + ii * WARP_ATOM_N + warp_atom_k_idx = kk * WARP_ATOM_K + n_idx = n_offset + warp_atom_n_idx + ldmatrix_b_n_idx + k_idx = k_offset + warp_atom_k_idx + ldmatrix_b_k_vec_idx + vec = B_.vec_load((n_idx, k_idx), WMMA_B_FRAG_VALUES * MFMA_PER_WARP_K) + vecs.append(vec) + else: + b_n_intra = b_n_intra_base # idx_1 + vec = SHUFFLED_B_.vec_load((b_n0, b_k0, b_k_intra_vec, b_n_intra, 0), LDG_VEC_SIZE) + vecs.append(vec) + return vecs + + def block_mma_sync(a_frags, b_frags, c_frags): + # wmma + c_frags_new = [cx for cx in c_frags] + for kk in range_constexpr(WARP_K_STEPS): + for ii in range_constexpr(WARP_M_STEPS): + a_frag = a_frags[kk * WARP_M_STEPS + ii] + for jj in range_constexpr(WARP_N_STEPS): + b_frag = b_frags[kk * WARP_N_STEPS + jj] + if MFMA_PER_WARP_K == 2: + # split a + a_i64x2 = vector.bitcast(T.i64x2, a_frag) + a0_i64 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) + a1_i64 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) + a_v0 = vector.bitcast(T.f16x4, vector.from_elements(T.vec(1, T.i64), [a0_i64])) + a_v1 = vector.bitcast(T.f16x4, vector.from_elements(T.vec(1, T.i64), [a1_i64])) + # split b + b_i64x2 = vector.bitcast(T.i64x2, b_frag) + b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) + b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) + b_v0 = vector.bitcast(T.f16x4, vector.from_elements(T.vec(1, T.i64), [b0_i64])) + b_v1 = vector.bitcast(T.f16x4, vector.from_elements(T.vec(1, T.i64), [b1_i64])) + # wmma + c_idx = ii * WARP_N_STEPS + jj + acc_in = c_frags_new[c_idx] + acc_mid = WMMA_IMPL(a_v0, b_v0, acc_in) + c_frags_new[c_idx] = WMMA_IMPL(a_v1, b_v1, acc_mid) + elif MFMA_PER_WARP_K == 1: + c_idx = ii * WARP_N_STEPS + jj + c_frags_new[c_idx] = WMMA_IMPL(a_frag, b_frag, c_frags_new[c_idx]) + else: + raise NotImplementedError(f"MFMA_PER_WARP_K={MFMA_PER_WARP_K} not supported") + return c_frags_new + + zero_c() + + if B_TO_LDS: + + ldg_sts_a_async(ks_begin, 0) + ldg_sts_b_async(ks_begin, 0) + gpu.barrier() + def hot_loop_scheduler(): + MFMA_TOTAL = WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K + LDG_REG_A_COUNT_ = LDG_REG_A_COUNT_AS if ASYNC_COPY else LDG_REG_A_COUNT + LDG_REG_B_COUNT_ = LDG_REG_B_COUNT_AS if ASYNC_COPY else LDG_REG_B_COUNT + LDG_TOTAL = LDG_REG_A_COUNT_ + LDG_REG_B_COUNT_ + LDS_TOTAL = WARP_K_STEPS * (WARP_M_STEPS + WARP_N_STEPS) + LD_TOTAL = LDG_TOTAL + LDS_TOTAL + # ================ Ordered ================ + for i in range_constexpr(WARP_K_STEPS * WARP_M_STEPS): + rocdl.sched_dsrd(1) # lds_matrix_a current + for i in range_constexpr(WARP_K_STEPS * WARP_N_STEPS): + rocdl.sched_dsrd(1) # lds_matrix_b current + for i in range_constexpr(LDG_REG_A_COUNT_): + rocdl.sched_vmem(1) # ldg_sts_a_async next + rocdl.sched_mfma(2) + for i in range_constexpr(LDG_REG_B_COUNT_): + rocdl.sched_vmem(1) # ldg_sts_b_async next + rocdl.sched_mfma(2) + REMAINING = max(MFMA_TOTAL - (LDG_REG_A_COUNT_ + LDG_REG_B_COUNT_) * 2, 0) + for i in range_constexpr(REMAINING): + rocdl.sched_mfma(1) + # ================ Reordered ================ + rocdl.sched_barrier(0) + UNROLL = 8 + init_state = [ks_begin, arith.constant(0, index=True)] + c_frags + for bki, state in range(0, BLOCK_K_LOOPS - 1, UNROLL, init=init_state): + k_offset = state[0] + current_stage = fx.Index(state[1]) + c_frags = state[2 : 2 + C_FRAGS_LEN] + for unroll_i in range_constexpr(UNROLL): + cond = arith.cmpi(arith.CmpIPredicate.ult, fx.Index(bki + unroll_i), fx.Index(BLOCK_K_LOOPS - 1)) + cond_if = scf.IfOp(cond, results_=[T.vec(WMMA_C_FRAG_VALUES, T.f32)] * C_FRAGS_LEN + [T.index, T.i32], has_else=True) + with ir.InsertionPoint(cond_if.then_block): + next_stage = 1 - current_stage + a_frags = lds_matrix_a(current_stage) + b_frags = lds_matrix_b(current_stage) + ldg_sts_a_async(k_offset + BLOCK_K, next_stage) + ldg_sts_b_async(k_offset + BLOCK_K, next_stage) + c_frags_new = block_mma_sync(a_frags, b_frags, c_frags) + hot_loop_scheduler() + gpu.barrier() + k_offset_next = k_offset + fx.Int32(BLOCK_K) + current_stage_next = 1 - current_stage + scf.YieldOp(c_frags_new + [_to_raw(current_stage_next), k_offset_next]) + with ir.InsertionPoint(cond_if.else_block): + scf.YieldOp(c_frags + [_to_raw(current_stage), k_offset]) + c_frags = [cond_if.results[i] for i in range(C_FRAGS_LEN)] + current_stage = cond_if.results[C_FRAGS_LEN] + k_offset = cond_if.results[C_FRAGS_LEN + 1] + results = yield [k_offset, current_stage] + c_frags + current_stage = results[1] + c_frags = results[2 : 2 + C_FRAGS_LEN] + a_frags = lds_matrix_a(current_stage) + b_frags = lds_matrix_b(current_stage) + c_frags = block_mma_sync(a_frags, b_frags, c_frags) + + else: + + sts_a(ldg_a(ks_begin), 0) + gpu.barrier() + a_frags = lds_matrix_a(0) + b_frags = ldg_matrix_b(ks_begin) + rocdl.sched_barrier(0) + def hot_loop_scheduler(): + MFMA_TOTAL = WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K + LDG_REG_A_COUNT_ = LDG_REG_A_COUNT_AS if ASYNC_COPY else LDG_REG_A_COUNT + LDG_TOTAL = LDG_REG_A_COUNT_ + WARP_K_STEPS * WARP_N_STEPS + mfma_ = OnlineScheduler(MFMA_TOTAL, MFMA_TOTAL) + ldg_ = OnlineScheduler(LDG_TOTAL, LDG_TOTAL) + # ================ Ordered ================ + # for i in range_constexpr(LDG_REG_A_COUNT_AS or LDG_REG_A_COUNT): + # rocdl.sched_vmem(1) # ldg_sts_a_async next + # for i in range_constexpr(WARP_K_STEPS * WARP_N_STEPS): + # rocdl.sched_vmem(1) # ldg_matrix_b next + # for i in range_constexpr(WARP_K_STEPS * WARP_M_STEPS * WARP_N_STEPS * MFMA_PER_WARP_K): + # rocdl.sched_mfma(1) + # ================ Reordered ================ + if ASYNC_COPY: + AVG_MFMA_COUNT = (MFMA_TOTAL + LDG_TOTAL - 1) // LDG_TOTAL + for i in range_constexpr(LDG_TOTAL): + rocdl.sched_vmem(ldg_.consume(1)) + rocdl.sched_mfma(mfma_.consume(AVG_MFMA_COUNT)) + else: + LDG_STS_TOTAL = LDG_TOTAL + LDG_REG_A_COUNT_ + AVG_MFMA_COUNT = (MFMA_TOTAL + LDG_STS_TOTAL - 1) // LDG_STS_TOTAL + for i in range_constexpr(LDG_TOTAL): + rocdl.sched_vmem(ldg_.consume(1)) + rocdl.sched_mfma(mfma_.consume(AVG_MFMA_COUNT)) + for i in range_constexpr(LDG_REG_A_COUNT_): + rocdl.sched_dswr(1) + rocdl.sched_mfma(mfma_.consume(AVG_MFMA_COUNT)) + rocdl.sched_barrier(0) + init_state = [ks_begin, arith.constant(0, index=True)] + c_frags + a_frags + b_frags + for bki, state in range(1, BLOCK_K_LOOPS, init=init_state): + k_offset = state[0] + current_stage = fx.Index(state[1]) + next_stage = 1 - current_stage + c_frags = state[2 : 2 + C_FRAGS_LEN] + a_frags = state[2 + C_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN] + b_frags = state[2 + C_FRAGS_LEN + A_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN + B_FRAGS_LEN] + if ASYNC_COPY: + ldg_sts_a_async(k_offset + BLOCK_K, next_stage) + else: + a_regs_next = ldg_a(k_offset + BLOCK_K) + b_frags_next = ldg_matrix_b(k_offset + BLOCK_K) + c_frags = block_mma_sync(a_frags, b_frags, c_frags) + if not ASYNC_COPY: + sts_a(a_regs_next, next_stage) + hot_loop_scheduler() + gpu.barrier() + a_frags_next = lds_matrix_a(next_stage) + k_offset = k_offset + fx.Int32(BLOCK_K) + rocdl.sched_barrier(0) + results = yield [k_offset, next_stage] + c_frags + a_frags_next + b_frags_next + c_frags = results[2 : 2 + C_FRAGS_LEN] + a_frags = results[2 + C_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN] + b_frags = results[2 + C_FRAGS_LEN + A_FRAGS_LEN : 2 + C_FRAGS_LEN + A_FRAGS_LEN + B_FRAGS_LEN] + c_frags = block_mma_sync(a_frags, b_frags, c_frags) + + # write to lds + stmatrix_c_m_vec_idx = w_tid // WMMA_N * WMMA_C_FRAG_VALUES + stmatrix_c_n_idx = w_tid % WMMA_N + gpu.barrier() + for ii in range_constexpr(WARP_M_STEPS): + warp_atom_m_idx = warp_m_idx + ii * WARP_ATOM_M + for jj in range_constexpr(WARP_N_STEPS): + warp_atom_n_idx = warp_n_idx + jj * WARP_ATOM_N + for kk in range_constexpr(WMMA_C_FRAG_VALUES): + lds_m_idx = fx.Index(warp_atom_m_idx + stmatrix_c_m_vec_idx + kk) + lds_n_idx = fx.Index(warp_atom_n_idx + stmatrix_c_n_idx) + val = vector.extract(c_frags[ii * WARP_N_STEPS + jj], static_position=[kk], dynamic_position=[]) + cs_[lds_m_idx, lds_n_idx] = val.truncf(dtype_) + + # write back to global + + if IS_SPLIT_K: + split_k_barrier() + else: + gpu.barrier() + + if USE_CROSS_DEVICE_ATOMIC: + + # NOTE: Low performance for atomic impl + + _signal_start_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + for i in range_constexpr(LDG_REG_C_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) + n_local_idx = fx.Index(global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE) + m_global_idx = m_offset + m_local_idx + cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, m_global_idx, fx.Index(m)) + cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + with ir.InsertionPoint(cond_boundary_if.then_block): + for wi in range_constexpr(world_size): + out_memref = select_by_index(arith.constant(wi, type=T.i32), out_ptrs_arr) + linear_bytes_offset = C_.linear_offset((m_global_idx, n_offset + n_local_idx)) * DTYPE_BYTES + pk_val = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + # split to vec2s + vec2_ty = T.vec(2, dtype_) + for vec_idx in range_constexpr(LDG_VEC_SIZE // 2): + e0 = vector.extract(pk_val, static_position=[vec_idx * 2], dynamic_position=[]) + e1 = vector.extract(pk_val, static_position=[vec_idx * 2 + 1], dynamic_position=[]) + pair = vector.from_elements(vec2_ty, [e0, e1]) + pair_byte_offset = arith.index_cast(T.i64, linear_bytes_offset + fx.Index(vec_idx * 2 * DTYPE_BYTES)) + pair_addr_i64 = llvm.AddOp(out_memref, pair_byte_offset, llvm.IntegerOverflowFlags(0)).result + pair_ptr = llvm.IntToPtrOp(_ptr_type, pair_addr_i64).result + pair_ptr_v = pair_ptr._value if hasattr(pair_ptr, "_value") else pair_ptr + pair_v = pair._value if hasattr(pair, "_value") else pair + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + pair_ptr_v, + pair_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=4, + ) + # C_.vec_store((m_global_idx, n_offset + n_local_idx), vec, LDG_VEC_SIZE) + scf.YieldOp([]) + + else: + + # FIXME: For some reasons splitk doesn't work + + for i in range_constexpr(LDG_REG_C_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) + n_local_idx = fx.Index(global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE) + m_global_idx = m_offset + m_local_idx + n_global_idx = n_offset + n_local_idx + cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, m_global_idx, fx.Index(m)) + cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + with ir.InsertionPoint(cond_boundary_if.then_block): + vec = cs_.vec_load((m_local_idx, n_local_idx), LDG_VEC_SIZE) + vec_i32x4 = vector.bitcast(T.i32x4, vec) + linear_bytes_offset = C_.linear_offset((m_global_idx, n_global_idx)) * DTYPE_BYTES + store_v4i32_nt(self_tmp_ptr + arith.index_cast(T.i64, linear_bytes_offset), vec_i32x4) + scf.YieldOp([]) + + _signal_start_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + for i in range_constexpr(LDG_REG_C_COUNT): + global_tid = BLOCK_THREADS * i + tid + m_local_idx = fx.Index(global_tid // LDG_C_X_THREADS) + n_local_idx = fx.Index(global_tid % LDG_C_X_THREADS * LDG_VEC_SIZE) + m_global_idx = m_offset + m_local_idx + cond_boundary = arith.cmpi(arith.CmpIPredicate.ult, m_global_idx, fx.Index(m)) + cond_boundary_if = scf.IfOp(cond_boundary, results_=[], has_else=False) + with ir.InsertionPoint(cond_boundary_if.then_block): + linear_bytes_offset = C_.linear_offset((m_global_idx, n_offset + n_local_idx)) * DTYPE_BYTES + final_vec = arith.constant_vector(0.0, T.vec(LDG_VEC_SIZE, T.f32)) + for wi in range_constexpr(world_size): + vec_v4i32 = load_v4i32(select_by_index(arith.constant(wi, type=T.i32), tmp_ptrs_arr) + arith.index_cast(T.i64, linear_bytes_offset)) + vec = vector.bitcast(T.vec(LDG_VEC_SIZE, dtype_), vec_v4i32) + final_vec = final_vec + vec.extf(T.vec(LDG_VEC_SIZE, T.f32)) + final_vec_i32x4 = vector.bitcast(T.i32x4, final_vec.truncf(T.vec(LDG_VEC_SIZE, dtype_))) + store_v4i32(self_out_ptr + arith.index_cast(T.i64, linear_bytes_offset), final_vec_i32x4) + scf.YieldOp([]) + + # _signal_end_sync(lane_i32=lane_i32, rank_i32=rank_i32, bid_i32=bid_i32, self_sg_i64=self_sg_i64, sgs_i64=sgs, ngpus=world_size) + + return + + @flyc.jit + def launch_hgemm_ar_kernel( + rank: Int32, + self_sg: Int64, + sg_ptrs: Int64, + tmp_ptrs: Int64, + out_ptrs: Int64, + C: fx.Tensor, + A: fx.Tensor, + B: fx.Tensor, + m: fx.Int32, + COUNTER: fx.Tensor, + signal_state: fx.Int32, + stream: fx.Stream = fx.Stream(None), + ): + allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator.finalize() + + bm = (m + BLOCK_M - 1) // BLOCK_M + bn = n // BLOCK_N + hgemm_ar_kernel._func.__name__ = KERNEL_NAME + hgemm_ar_kernel( + rank, self_sg, sg_ptrs, tmp_ptrs, out_ptrs, + C, A, B, m, COUNTER, signal_state).launch(grid=(bm, bn, SPLIT_K), block=(BLOCK_THREADS, 1, 1), stream=stream) + + return launch_hgemm_ar_kernel + + +def hgemm_shuffle_b(x, layout=(16, 16), k_steps=2): + x_shape = x.shape + VEC_SIZE = 16 // x.element_size() + BN = layout[0] + BK = layout[1] * k_steps + assert x.shape[-2] % BN == 0, f"{x.shape[-2]} % {BN} == {x.shape[-2] % BN }" + assert x.shape[-1] % BK == 0, f"{x.shape[-1]} % {BK} == {x.shape[-1] % BK }" + x = x.view(-1, x.shape[-2] // BN, BN, x.shape[-1] // BK, BK // VEC_SIZE, VEC_SIZE) + x = x.permute(0, 1, 3, 4, 2, 5).contiguous() + x = x.view(*x_shape) + x.is_shuffled = True + return x + + +def get_default_kwargs(m, n, k): + kwargs = { + 'TILE_M': 128, + 'TILE_N': 128, + 'TILE_K': 64, + 'SPLIT_K': 1, + 'BLOCK_M_WARPS': 2, + 'BLOCK_N_WARPS': 2, + 'B_PRE_SHUFFLE': False, + 'B_TO_LDS': True, + 'USE_CROSS_DEVICE_ATOMIC': False, + } + if m == 2048 and n == 2048 and k == 2048: + kwargs['TILE_M'] = 128 + kwargs['TILE_N'] = 64 + kwargs['TILE_K'] = 64 + kwargs['SPLIT_K'] = 1 + elif m <= 32 and n == 384 and k == 7168: + kwargs['TILE_M'] = 32 + kwargs['TILE_N'] = 64 + kwargs['TILE_K'] = 64 + kwargs['SPLIT_K'] = 8 + elif m <= 32 and n == 7168 and k == 2048: + kwargs['TILE_M'] = 32 + kwargs['TILE_N'] = 64 + kwargs['TILE_K'] = 128 + kwargs['SPLIT_K'] = 1 + elif m <= 32 and n == 384 and k == 16384: + kwargs['TILE_M'] = 32 + kwargs['TILE_N'] = 64 + kwargs['TILE_K'] = 128 + kwargs['SPLIT_K'] = 16 + return kwargs + + +selections = { + 'TILE_M': [16, 32, 48, 64, 96, 128], + 'TILE_N': [64, 128, 256], + 'TILE_K': [64, 128], + 'SPLIT_K': [1, 2, 4, 8, 16], +} + + +SPLIT_K_GLOBAL_SEMAPHORE = {} +SPLIT_K_GLOBAL_SEMAPHORE_STATE = {} +def hgemm_ar_( + world_size: int, + rank: Int32, + self_sg: Int64, + sg_ptrs: Int64, + tmp_ptrs: Int64, + out_ptrs: Int64, + c: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + shuffle_b: bool = False, + hgemm_kwargs: dict = {}, + stream: torch.cuda.Stream = torch.cuda.current_stream(), +): + global SPLIT_K_COUNTER_MAX_LEN + global SPLIT_K_GLOBAL_SEMAPHORE + global SPLIT_K_GLOBAL_SEMAPHORE_STATE + if SPLIT_K_GLOBAL_SEMAPHORE.get(stream, None) is None: + SPLIT_K_GLOBAL_SEMAPHORE[stream] = torch.zeros( + (3 * SPLIT_K_COUNTER_MAX_LEN,), dtype=torch.int32, device=a.device) + SPLIT_K_GLOBAL_SEMAPHORE_STATE[stream] = int(0) + signal_state = SPLIT_K_GLOBAL_SEMAPHORE_STATE[stream] + k = a.shape[-1] + a = a.view(-1, k) + m = a.shape[0] + n = b.shape[0] + assert b.shape[1] == k + c = c.view(-1, n) + assert c.shape[0] == m + kwargs = get_default_kwargs(m, n, k) + kwargs.update(hgemm_kwargs) + if a.dtype == torch.half: + exe = compile_hgemm_ar_kernel(world_size, 'f16', n, k, **kwargs) + elif a.dtype == torch.bfloat16: + exe = compile_hgemm_ar_kernel(world_size, 'bf16', n, k, **kwargs) + else: + raise NotImplementedError() + if kwargs['B_PRE_SHUFFLE'] and shuffle_b: + b = hgemm_shuffle_b(b) + semaphore = SPLIT_K_GLOBAL_SEMAPHORE[stream] + bm = (m + kwargs['TILE_M'] - 1) // kwargs['TILE_M'] + bn = n // kwargs['TILE_N'] + assert bm * bn * kwargs['SPLIT_K'] <= min(80, SPLIT_K_COUNTER_MAX_LEN) + exe(rank, self_sg, sg_ptrs, tmp_ptrs, out_ptrs, c, a, b, m, semaphore, signal_state, stream) + if kwargs['SPLIT_K'] > 1: + SPLIT_K_GLOBAL_SEMAPHORE_STATE[stream] = (signal_state + 1) % 3 + + +class GEMMARBackend(FlyDSLAllreduce): + def hgemm_ar_fusion(self, a, b, c, kwargs: dict={}): + world_size = self.world_size + m, k = a.shape + n = b.shape[0] + bytes_mn = m * n * 2 + assert bytes_mn <= self.max_size, f"Output {bytes_mn}B exceeds max_size {fa.max_size}B" + rank = Int32(self.rank) + self_sg = Int64(self._self_sg) + sg_ptrs = Int64(int(self._gpu_sg_ptrs_array.data_ptr())) + tmp_ptrs = Int64(int(self._gpu_tmp_ptrs_array.data_ptr())) + self._graph_use_write_mode = False + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + self._graph_inp = None + self._graph_out = c.view(-1) + self._graph_bytes_n = bytes_mn + out_ptrs = Int64(int(self._gpu_graph_out_ptrs_array.data_ptr())) + hgemm_ar_(world_size, rank, self_sg, sg_ptrs, tmp_ptrs, out_ptrs, c, a, b, hgemm_kwargs=kwargs) + return c + else: + out_ptrs = Int64(int(self._gpu_output_buffer_ptrs_array.data_ptr())) + hgemm_ar_(world_size, rank, self_sg, sg_ptrs, tmp_ptrs, out_ptrs, c, a, b, hgemm_kwargs=kwargs) + c.view(-1).view(torch.uint8)[:bytes_mn].copy_(self.output_buffer[:bytes_mn]) + return c + else: + out_ptrs = Int64(int(self._gpu_output_buffer_ptrs_array.data_ptr())) + hgemm_ar_(world_size, rank, self_sg, sg_ptrs, tmp_ptrs, out_ptrs, c, a, b, hgemm_kwargs=kwargs) + c.view(-1).view(torch.uint8)[:bytes_mn].copy_(self.output_buffer[:bytes_mn]) + return c diff --git a/tests/kernels/test_hgemm_ar.py b/tests/kernels/test_hgemm_ar.py new file mode 100644 index 000000000..9bf174c30 --- /dev/null +++ b/tests/kernels/test_hgemm_ar.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python3 + +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +import os +import sys +import logging +import flydsl.compiler as flyc + +import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import torch.nn.functional as F +import pytest +import pandas as pd + +from dataclasses import dataclass +from torch.profiler import profile, ProfilerActivity +from kernels.custom_all_reduce import init_custom_ar + + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +_PYFLYDSL_SRC = os.path.join(_REPO_ROOT, "flydsl", "src") +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) +if _PYFLYDSL_SRC not in sys.path: + sys.path.insert(0, _PYFLYDSL_SRC) + +from kernels.hgemm_ar import hgemm_ar_, hgemm_shuffle_b, GEMMARBackend +from tests.test_common import run_perftest, verify_output +from flydsl.runtime.device import get_rocm_arch + +logging.basicConfig(level=logging.INFO) +ARCH = str(get_rocm_arch()) + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +try: + import aiter + HAS_AITER = True +except Exception: + HAS_AITER = False + + +DEFAULT_BENCH_ITERS = 50 +DEFAULT_BENCH_WARMUP = 3 + + +@dataclass +class Args: + dtype: torch.dtype + m: int + n: int + k: int + tile_m: int + tile_n: int + tile_k: int + split_k: int + num_devices: int + parts: int + nsamples: int + + +def init_world(device_id, num_devices, parts, port=24327): + torch.cuda.set_device(device_id) + dist.init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{port}", + rank=device_id, + world_size=num_devices, + device_id=device_id, + ) + group_size = num_devices // parts + group_id = device_id // group_size + group_ranks = list(range(group_id * group_size, (group_id + 1) * group_size)) + group = dist.new_group(ranks=group_ranks) + print(f"[init_world] device_id:{device_id}, group_ranks:{group_ranks}", flush=True) + return group + + +def create_inputs(args): + group_size = args.num_devices // args.parts + inputs = [] + for part in range(args.parts): + for rank in range(group_size): + device_id = part * group_size + rank + for i in range(args.nsamples): + a = torch.empty((args.m, args.k), dtype=args.dtype, device=f'cuda:{device_id}') + a.uniform_(-1, 1) + b = torch.empty((args.n, args.k), dtype=args.dtype, device=f'cuda:{device_id}') + b.uniform_(-1, 1) + inputs.append([a, b]) + return inputs + + +def create_outputs(args): + group_size = args.num_devices // args.parts + outputs = [] + for part in range(args.parts): + for rank in range(group_size): + device_id = part * group_size + rank + for i in range(args.nsamples): + c = torch.randn((args.m, args.n), dtype=args.dtype, device=f"cuda:{device_id}") + outputs.append(c) + return outputs + + +def ref_worker(device_id, num_devices, parts, nsamples, inputs, outputs): + warmup_iter = 4 + group = init_world(device_id, num_devices, parts) + for i in range(warmup_iter): + input = inputs[device_id * nsamples + i] + output = outputs[device_id * nsamples + i] + F.linear(input[0], input[1], out=output) + dist.all_reduce(output, group=group) + torch.cuda.synchronize() + dist.barrier(group=group) + with profile( + activities=[ProfilerActivity.CUDA], + profile_memory=False, + with_stack=True, + with_modules=True + ) as prof: + for i in range(warmup_iter, nsamples): + input = inputs[device_id * nsamples + i] + output = outputs[device_id * nsamples + i] + F.linear(input[0], input[1], out=output) + dist.all_reduce(output, group=group) + torch.cuda.synchronize() + table = prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1) + print(table) + dist.barrier(group=group) + dist.destroy_process_group() + + +def ref_func(args, inputs, outputs): + mp.spawn( + ref_worker, + args=(args.num_devices, args.parts, args.nsamples, inputs, outputs), + nprocs=args.num_devices, + join=True + ) + + +def worker(device_id, num_devices, parts, nsamples, inputs, outputs, kwargs): + warmup_iter = 4 + group = init_world(device_id, num_devices, parts) + rank = dist.get_rank(group=group) + world_size = dist.get_world_size(group=group) + meta = torch.empty((0,), device=device_id, dtype=torch.int8) + rank_data = inputs[device_id * nsamples] + handles = [torch.empty((1,), device="cpu", dtype=torch.uint8) for _ in range(world_size)] + offsets = [0 for _ in range(world_size)] + fa = init_custom_ar(meta, rank_data, handles, offsets, rank=rank, backend=GEMMARBackend) + for i in range(warmup_iter): + input = inputs[device_id * nsamples + i] + output = outputs[device_id * nsamples + i] + fa.hgemm_ar_fusion(input[0], input[1], output, kwargs) + torch.cuda.synchronize() + dist.barrier(group=group) + with profile( + activities=[ProfilerActivity.CUDA], + profile_memory=False, + with_stack=True, + with_modules=True + ) as prof: + for i in range(warmup_iter, nsamples): + input = inputs[device_id * nsamples + i] + output = outputs[device_id * nsamples + i] + fa.hgemm_ar_fusion(input[0], input[1], output, kwargs) + torch.cuda.synchronize() + table = prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1) + print(table) + dist.barrier(group=group) + dist.destroy_process_group() + + +def func(args, inputs, outputs, kwargs): + mp.spawn( + worker, + args=(args.num_devices, args.parts, args.nsamples, inputs, outputs, kwargs), + nprocs=args.num_devices, + join=True + ) + + +# @pytest.mark.parametrize("dtype", ["fp16", "bf16"]) +@pytest.mark.parametrize("dtype", ["bf16"]) +@pytest.mark.parametrize( + "m, n, k, TILE_M, TILE_N, TILE_K, SPLIT_K, world_size", + [ + (32, 7168, 2048, 32, 128, 128, 1, 4), + # (32, 384, 7168, 32, 64, 64, 8, 4), + # (4, 384, 7168, 32, 64, 64, 8, 8), + # (65, 1024, 8192, 64, 64, 128, 2, 2), + ] +) +# @pytest.mark.parametrize("test_graph", [ +# pytest.param(False, id="eager"), +# pytest.param(True, id="graph"), +# ]) +@pytest.mark.parametrize("test_graph", [ + pytest.param(False, id="eager"), +]) +def test_mfma_flyc_hgemm_ar( + dtype, + m, n, k, + TILE_M, TILE_N, TILE_K, SPLIT_K, world_size, + *, + test_graph, + bench_iters: int = DEFAULT_BENCH_ITERS, + bench_warmup: int = DEFAULT_BENCH_WARMUP, +): + global ARCH + if not (ARCH in ["gfx950", "gfx942"]): + pytest.skip(f"Skip hgemm test: ARCH={ARCH}") + + print("=" * 80) + print( + f"[flyc] MFMA {dtype.upper()} HGEMM+Allreduce Test" + ) + print("=" * 80) + + bench_iters = max(2, int(bench_iters)) + bench_warmup = int(bench_warmup) + + torch_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16 + + args = Args( + dtype=torch_dtype, + m=m, + n=n, + k=k, + tile_m=TILE_M, + tile_n=TILE_N, + tile_k=TILE_K, + split_k=SPLIT_K, + num_devices=world_size, + parts=1, + nsamples=50, + ) + kwargs = { + 'TILE_M': args.tile_m, + 'TILE_N': args.tile_n, + 'TILE_K': args.tile_k, + 'SPLIT_K': args.split_k, + } + + inputs = create_inputs(args) + outputs = create_outputs(args) + ref_outputs = create_outputs(args) + func(args, inputs, outputs, kwargs) + ref_func(args, inputs, ref_outputs) + max_diff_global = float(-1) + for output, ref_output in zip(outputs, ref_outputs): + is_allclose = torch.allclose(output, ref_output) + # assert is_allclose == True + maxdiff_out = (output - ref_output).abs().max().item() + max_diff_global = max(max_diff_global, maxdiff_out) + print(f"max_diff_global:{max_diff_global}") + assert max_diff_global < 1e-3 * args.k + + print("===================== [REF] =====================") + ref_func(args, inputs, ref_outputs) + + print("===================== [FLYDSL] =====================") + func(args, inputs, outputs, kwargs)