diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel deleted file mode 160000 index 345a56c55e..0000000000 --- a/3rdparty/composable_kernel +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 345a56c55ed2a1bd25618c3d2a3994cd73460581 diff --git a/aiter/ops/triton/_gluon_kernels/quant/__init__.py b/aiter/ops/triton/_gluon_kernels/quant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aiter/ops/triton/_gluon_kernels/quant/fuse_mxfp4_quant.py b/aiter/ops/triton/_gluon_kernels/quant/fuse_mxfp4_quant.py new file mode 100644 index 0000000000..e00c736eba --- /dev/null +++ b/aiter/ops/triton/_gluon_kernels/quant/fuse_mxfp4_quant.py @@ -0,0 +1,302 @@ +import triton +from triton.experimental import gluon +from aiter.ops.triton._triton_kernels.quant.fused_mxfp4_quant import _mxfp4_quant_op +from triton.experimental.gluon import language as gl + + +@gluon.jit +def _rmsnorm_op( + row, + weights, + n_cols, + epsilon, +): + + row_norm = row * row + row_norm = gl.sum(row_norm, axis=-1, keep_dims=True) + norm_factor = gl.rsqrt((row_norm / n_cols) + epsilon) + + rms_norm = row * norm_factor * weights + return rms_norm + + +@triton.heuristics( + { + "EVEN_M_N": lambda args: args["M"] % args["BLOCK_SIZE_M"] == 0 + and args["N1"] % (args["BLOCK_SIZE_N"]) == 0, + } +) +@gluon.jit +def _gluon_fused_rms_mxfp4_quant( + x1_ptr, + w1_ptr, + x2_ptr, + w2_ptr, + res1_ptr, + out1_fp4_ptr, + out1_bs_ptr, + out2_ptr, + out_res1_ptr, + out1_ptr, + eps1, + eps2, + M, + N1, + N2, + x1_stride_m, + x2_stride_m, + res1_stride_m, + out1_fp4_stride_m, + out1_bs_stride_m, + out1_bs_stride_n, + out2_stride_m, + out_res1_stride_m, + out1_stride_m, + BLOCK_SIZE_M: gl.constexpr, + BLOCK_SIZE_N: gl.constexpr, + BLOCK_SIZE_N2: gl.constexpr, + MXFP4_QUANT_BLOCK_SIZE: gl.constexpr, + HAS_SECOND_INPUT: gl.constexpr, + FIRST_INPUT_RES: gl.constexpr, + FIRST_INPUT_OUT: gl.constexpr, + SCALE_N: gl.constexpr, + SCALE_M_PAD: gl.constexpr, + SCALE_N_PAD: gl.constexpr, + SHUFFLE: gl.constexpr, + SHUFFLE_PAD: gl.constexpr, + EVEN_M_N: gl.constexpr, +): + start_pid = gl.program_id(0) + # get number of programs to determine is 1 or 2 passes + num_pid_m = gl.cdiv(M, BLOCK_SIZE_M) + + # create block layouts + gLayout2D: gl.constexpr = gl.BlockedLayout( + [1, 2], # sizePerThread + [1, 32], # threadsPerWarp + [1, 4], # warpsPerCTA + [1, 0], # order + ) + + gLayoutM: gl.constexpr = gl.SliceLayout(1, gLayout2D) + gLayoutN: gl.constexpr = gl.SliceLayout(0, gLayout2D) + + # 2D shared layout for matrix rows; 1D shared layout for weight vectors + sharedLayout2D: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + sharedLayoutN: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, order=[0]) + + # Tensor descriptors for first input and its weights + x1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + x1_ptr, + [M, N1], + [x1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + + w1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + w1_ptr, + [N1], + [1], + [BLOCK_SIZE_N], + sharedLayoutN, + ) + + # Shared memory for first input and its weights + smemX1 = gl.allocate_shared_memory( + x1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemW1 = gl.allocate_shared_memory( + w1_ptr.dtype.element_ty, [BLOCK_SIZE_N], sharedLayoutN + ) + + # Tensor descriptor and shared memory for optional residual input + if FIRST_INPUT_RES: + res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + res1_ptr, + [M, N1], + [res1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemRes1 = gl.allocate_shared_memory( + res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + + # Second input path — programs with id >= num_pid_m handle x2 + if start_pid >= num_pid_m: + if HAS_SECOND_INPUT: + x2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + x2_ptr, + [M, N2], + [x2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + w2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + w2_ptr, + [N2], + [1], + [BLOCK_SIZE_N2], + sharedLayoutN, + ) + smemX2 = gl.allocate_shared_memory( + x2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + smemW2 = gl.allocate_shared_memory( + w2_ptr.dtype.element_ty, [BLOCK_SIZE_N2], sharedLayoutN + ) + + start_pid -= num_pid_m + + # Load x2 and w2 in parallel then wait for both + gl.amd.gfx1250.tdm.async_load( + x2_desec, [start_pid * BLOCK_SIZE_M, 0], smemX2 + ) + gl.amd.gfx1250.tdm.async_load(w2_desec, [0], smemW2) + gl.amd.gfx1250.tdm.async_wait(0) + + x2 = smemX2.load(gLayout2D).to(gl.float32) + w2 = smemW2.load(gLayoutN).to(gl.float32) + w2 = w2.reshape(1, BLOCK_SIZE_N2) + w2 = gl.convert_layout(w2, gLayout2D) + norm2 = _rmsnorm_op(x2, w2, N2, eps2) + + # Store norm2 output via TDM + out2_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out2_ptr, + [M, N2], + [out2_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N2], + sharedLayout2D, + ) + smemOut2 = gl.allocate_shared_memory( + out2_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N2], sharedLayout2D + ) + smemOut2.store(norm2.to(out2_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out2_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut2 + ) + gl.amd.gfx1250.tdm.async_wait(0) + return + + # First input path + NUM_QUANT_BLOCKS: gl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE + x_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M, gLayoutM) + + # Load x1 and optionally res1 in parallel, then wait + gl.amd.gfx1250.tdm.async_load(x1_desec, [start_pid * BLOCK_SIZE_M, 0], smemX1) + if FIRST_INPUT_RES: + gl.amd.gfx1250.tdm.async_load( + res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemRes1 + ) + gl.amd.gfx1250.tdm.async_wait(0) + + x1 = smemX1.load(gLayout2D).to(gl.float32) + + if FIRST_INPUT_RES: + res1_loaded = smemRes1.load(gLayout2D).to(gl.float32) + x1 = x1 + res1_loaded + + # Load w1 and wait + gl.amd.gfx1250.tdm.async_load(w1_desec, [0], smemW1) + gl.amd.gfx1250.tdm.async_wait(0) + + w1 = smemW1.load(gLayoutN).to(gl.float32) + w1 = w1.reshape(1, BLOCK_SIZE_N) + w1 = gl.convert_layout(w1, gLayout2D) + norm1 = _rmsnorm_op(x1, w1, N1, eps1) + + # Store unquantized output via TDM (optional) + if FIRST_INPUT_OUT: + out1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out1_ptr, + [M, N1], + [out1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemOut1 = gl.allocate_shared_memory( + out1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemOut1.store(norm1.to(out1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOut1 + ) + gl.amd.gfx1250.tdm.async_wait(0) + + out1_fp4, bs_e8m0 = _mxfp4_quant_op( + norm1, BLOCK_SIZE_N, BLOCK_SIZE_M, MXFP4_QUANT_BLOCK_SIZE + ) + out1_fp4 = gl.convert_layout(out1_fp4, gLayout2D) + + # out1_fp4 uses half-width (packed) offsets — keep as regular store + half_x_offs_n = gl.arange(0, BLOCK_SIZE_N // 2) + out_mask1 = (half_x_offs_n < (N1 // 2))[None, :] + if not EVEN_M_N: + out_mask1 = out_mask1 & (x_offs_m < M)[:, None] + gl.store( + out1_fp4_ptr + x_offs_m[:, None] * out1_fp4_stride_m + half_x_offs_n[None, :], + out1_fp4, + mask=out_mask1, + ) + + # shuffle + bs_offs_m = start_pid * BLOCK_SIZE_M + gl.arange(0, BLOCK_SIZE_M) + bs_offs_n = gl.arange(0, NUM_QUANT_BLOCKS) + num_bs_cols = (N1 + MXFP4_QUANT_BLOCK_SIZE - 1) // MXFP4_QUANT_BLOCK_SIZE + if SHUFFLE: + bs_offs_0 = bs_offs_m[:, None] // 32 + bs_offs_1 = bs_offs_m[:, None] % 32 + bs_offs_2 = bs_offs_1 % 16 + bs_offs_1 = bs_offs_1 // 16 + bs_offs_3 = bs_offs_n[None, :] // 8 + bs_offs_4 = bs_offs_n[None, :] % 8 + bs_offs_5 = bs_offs_4 % 4 + bs_offs_4 = bs_offs_4 // 4 + bs_offs = ( + bs_offs_1 + + bs_offs_4 * 2 + + bs_offs_2 * 2 * 2 + + bs_offs_5 * 2 * 2 * 16 + + bs_offs_3 * 2 * 2 * 16 * 4 + + bs_offs_0 * 2 * 16 * SCALE_N_PAD + ) + bs_mask_127 = (bs_offs_m < M)[:, None] & (bs_offs_n < num_bs_cols)[None, :] + bs_e8m0 = gl.where(bs_mask_127, bs_e8m0, 127) + else: + bs_offs = ( + bs_offs_m[:, None] * out1_bs_stride_m + + bs_offs_n[None, :] * out1_bs_stride_n + ) + + bs_mask = None + if not EVEN_M_N: + if not SHUFFLE_PAD: + bs_mask = (bs_offs_m < M)[:, None] & (bs_offs_n < SCALE_N)[None, :] + else: + bs_mask = (bs_offs_m < SCALE_M_PAD)[:, None] & (bs_offs_n < SCALE_N_PAD)[ + None, : + ] + + gl.store( + out1_bs_ptr + bs_offs, bs_e8m0.to(out1_bs_ptr.type.element_ty), mask=bs_mask + ) + + # Store residual output via TDM + if FIRST_INPUT_RES: + out_res1_desec = gl.amd.gfx1250.tdm.make_tensor_descriptor( + out_res1_ptr, + [M, N1], + [out_res1_stride_m, 1], + [BLOCK_SIZE_M, BLOCK_SIZE_N], + sharedLayout2D, + ) + smemOutRes1 = gl.allocate_shared_memory( + out_res1_ptr.dtype.element_ty, [BLOCK_SIZE_M, BLOCK_SIZE_N], sharedLayout2D + ) + smemOutRes1.store(x1.to(out_res1_ptr.dtype.element_ty)) + gl.amd.gfx1250.tdm.async_store( + out_res1_desec, [start_pid * BLOCK_SIZE_M, 0], smemOutRes1 + ) + gl.amd.gfx1250.tdm.async_wait(0) diff --git a/aiter/ops/triton/_triton_kernels/attention/unified_attention.py b/aiter/ops/triton/_triton_kernels/attention/unified_attention.py index 0caac96745..563156b69a 100644 --- a/aiter/ops/triton/_triton_kernels/attention/unified_attention.py +++ b/aiter/ops/triton/_triton_kernels/attention/unified_attention.py @@ -4,6 +4,7 @@ import triton.language as tl import torch from aiter.ops.triton.utils.types import e4m3_dtype +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr float8_info = torch.finfo(e4m3_dtype) @@ -377,7 +378,22 @@ def kernel_unified_attention_2d( ) -@triton.jit +kernel_unified_attention_3d_repr = make_kernel_repr( + "kernel_unified_attention_3d", + [ + "num_query_heads", + "num_queries_per_kv", + "BLOCK_SIZE", + "TILE_SIZE", + "HEAD_SIZE", + "num_warps", + "num_stages", + "SHUFFLED_KV_CACHE", + ], +) + + +@triton.jit(repr=kernel_unified_attention_3d_repr) def kernel_unified_attention_3d( segm_output_ptr, # [num_tokens, num_query_heads, num_segments, head_size] @@ -423,12 +439,21 @@ def kernel_unified_attention_3d( num_seqs: tl.int32, BLOCK_M: tl.constexpr, # int NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + num_warps: tl.constexpr, # int + num_stages: tl.constexpr, # int ALL_DECODE: tl.constexpr = False, # bool + SHUFFLED_KV_CACHE: tl.constexpr = False, # bool ): q_block_global_idx = tl.program_id(0) kv_head_idx = tl.program_id(1) segm_idx = tl.program_id(2) + if SHUFFLED_KV_CACHE: + tl.static_assert( + TILE_SIZE == BLOCK_SIZE, + "TILE_SIZE must be equal to BLOCK_SIZE if SHUFFLED_KV_CACHE is True", + ) + # needed to use exp2 (exp2 -> exp conversion) RCP_LN2 = 1.4426950408889634 qk_scale = scale * RCP_LN2 @@ -462,6 +487,17 @@ def kernel_unified_attention_3d( offs_m = tl.arange(0, BLOCK_M) offs_d = tl.arange(0, HEAD_SIZE_PADDED) offs_t = tl.arange(0, TILE_SIZE) + + offs_k_t_shfl = None + offs_k_d_shfl = None + offs_v_t_shfl = None + offs_v_d_shfl = None + if SHUFFLED_KV_CACHE: + offs_k_t_shfl = tl.arange(0, TILE_SIZE // 16) + offs_k_d_shfl = tl.arange(0, HEAD_SIZE_PADDED * 16) + offs_v_t_shfl = tl.arange(0, TILE_SIZE * 16) + offs_v_d_shfl = tl.arange(0, HEAD_SIZE_PADDED // 16) + query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv query_offset_0 = cur_batch_in_all_start_index + query_pos @@ -547,34 +583,64 @@ def kernel_unified_attention_3d( min((segm_idx + 1) * tiles_per_segment, num_tiles), ): seq_offset = j * TILE_SIZE + offs_t + if TILE_SIZE == BLOCK_SIZE: tile_mask = tl.full((1,), 1, dtype=tl.int1) else: tile_mask = seq_offset < max_seq_prefix_len - physical_block_idx = tl.load( - block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE - ).to(tl.int64) + k_mask = None + v_mask = None + other = None + if SHUFFLED_KV_CACHE: + seq_offset_k_shfl = j * TILE_SIZE + offs_k_t_shfl * 16 + physical_block_idx_shfl = tl.load( + block_tables_ptr + block_table_offset + seq_offset_k_shfl // BLOCK_SIZE + ).to(tl.int64) + k_offset = ( + physical_block_idx_shfl[:, None] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + offs_k_t_shfl[:, None] * stride_k_cache_2 + + offs_k_d_shfl[None, :] * stride_k_cache_3 + ) - v_offset = ( - physical_block_idx[:, None] * stride_v_cache_0 - + kv_head_idx * stride_v_cache_2 - + offs_d[None, :] * stride_v_cache_3 - + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 - ) + seq_offset_v_shfl = j * TILE_SIZE + offs_v_t_shfl // 16 + physical_block_idx_shfl = tl.load( + block_tables_ptr + block_table_offset + seq_offset_v_shfl // BLOCK_SIZE + ).to(tl.int64) + v_offset = ( + physical_block_idx_shfl[None, :] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_v_t_shfl[None, :] * stride_v_cache_3 + + offs_v_d_shfl[:, None] * stride_v_cache_2 + ) + else: + physical_block_idx = tl.load( + block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE + ).to(tl.int64) + + k_offset = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + k_mask = dim_mask[:, None] & tile_mask[None, :] - k_offset = ( - physical_block_idx[None, :] * stride_k_cache_0 - + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 - + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 - ) + v_offset = ( + physical_block_idx[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_d[None, :] * stride_v_cache_3 + + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + v_mask = dim_mask[None, :] & tile_mask[:, None] + other = 0.0 # K : (HEAD_SIZE, TILE_SIZE) K_load = tl.load( key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], - other=0.0, + mask=k_mask, + other=other, cache_modifier=KV_cache_modifier, ) @@ -585,12 +651,26 @@ def kernel_unified_attention_3d( K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) else: K = K_load + if SHUFFLED_KV_CACHE: + K = ( + K.reshape( + 1, + TILE_SIZE // 16, + HEAD_SIZE_PADDED // 16, + 2, + 16, + 8, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(TILE_SIZE, HEAD_SIZE_PADDED) + .trans(1, 0) + ) # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( value_cache_ptr + v_offset, - mask=dim_mask[None, :] & tile_mask[:, None], - other=0.0, + mask=v_mask, + other=other, cache_modifier=KV_cache_modifier, ) @@ -601,6 +681,20 @@ def kernel_unified_attention_3d( V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) else: V = V_load + if SHUFFLED_KV_CACHE: + V = ( + V.reshape( + 1, + HEAD_SIZE_PADDED // 16, + TILE_SIZE // 16, + 2, + 16, + 8, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(HEAD_SIZE_PADDED, TILE_SIZE) + .trans(1, 0) + ) seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 diff --git a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py index cb2b5524b0..27df5e0895 100644 --- a/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py +++ b/aiter/ops/triton/_triton_kernels/flash_attn_triton_amd/utils.py @@ -50,7 +50,7 @@ # ------------------------------- ArchFamily = Literal["cdna", "rdna"] -CDNA_ARCHS = frozenset({"gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx950"}) +CDNA_ARCHS = frozenset({"gfx908", "gfx90a", "gfx940", "gfx941", "gfx942", "gfx950", "gfx1250"}) RDNA_ARCHS = frozenset( { "gfx1030", @@ -63,7 +63,7 @@ "gfx1201", } ) -FP8_ARCHS = frozenset({"gfx942", "gfx950"}) +FP8_ARCHS = frozenset({"gfx942", "gfx950", "gfx1250"}) _RECOMMENDED_FP8_REPLACEMENTS: dict[str, dict[torch.dtype, torch.dtype]] = { "gfx942": { diff --git a/aiter/ops/triton/attention/unified_attention.py b/aiter/ops/triton/attention/unified_attention.py index a16720a963..f5a18c83d0 100644 --- a/aiter/ops/triton/attention/unified_attention.py +++ b/aiter/ops/triton/attention/unified_attention.py @@ -82,7 +82,7 @@ def select_3d_config( "TILE_SIZE": TILE_SIZE, "NUM_SEGMENTS_PER_SEQ": num_segments, "num_warps": attn_warps, - "num_stages": 1, + "num_stages": 2, "waves_per_eu": 2, } reduce_config = { @@ -133,6 +133,7 @@ def unified_attention( qq_bias=None, # Optional tensor for sinks sinks=None, + shuffled_kv_cache: bool = False, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -144,12 +145,18 @@ def unified_attention( use_qq_bias = qq_bias is not None SLIDING_WINDOW = 1 + window_size[0] - block_size = v.shape[1] + _, num_query_heads, head_size = q.shape + if shuffled_kv_cache: + # key_cache: num_blocks, num_kv_heads, block_size // 16, head_size * 16 + # value_cache: num_blocks, num_kv_heads, head_size // 16, block_size * 16 + _, num_kv_heads, block_size, _ = k.shape + block_size = block_size * 16 + else: + # key_cache and value_cache: num_blocks, block_size, num_kv_heads, head_size + _, block_size, num_kv_heads, _ = k.shape + num_seqs = len(seqused_k) - num_query_heads = q.shape[1] - num_kv_heads = k.shape[2] num_queries_per_kv = num_query_heads // num_kv_heads - head_size = q.shape[2] BLOCK_M = ( 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) @@ -320,6 +327,7 @@ def unified_attention( num_seqs=num_seqs, BLOCK_M=BLOCK_M, ALL_DECODE=ALL_DECODE, + SHUFFLED_KV_CACHE=shuffled_kv_cache, **attn_config, ) reduce_segments[(q.shape[0], num_query_heads)]( diff --git a/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A16W16.json new file mode 100644 index 0000000000..b96bf629e9 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A16W16.json @@ -0,0 +1,22 @@ +{ + "M_GEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json new file mode 100644 index 0000000000..43f7876e18 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A8W8-A_PER_TOKEN_GROUP_PREQUANT_W_PER_BATCHED_TENSOR_QUANT.json @@ -0,0 +1,68 @@ +{ + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg" + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg" + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg" + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg" + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg" + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg" + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A8W8.json b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A8W8.json new file mode 100644 index 0000000000..b96bf629e9 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-A8W8.json @@ -0,0 +1,22 @@ +{ + "M_GEQ_4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-AFP4WFP4.json b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-AFP4WFP4.json new file mode 100644 index 0000000000..fb69cf643b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM-AFP4WFP4.json @@ -0,0 +1,80 @@ +{ + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM_PREQUANT-AFP4WFP4.json b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM_PREQUANT-AFP4WFP4.json new file mode 100644 index 0000000000..a0a550dcab --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-BATCHED_GEMM_PREQUANT-AFP4WFP4.json @@ -0,0 +1,80 @@ +{ + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-FF-A16W16-fused.json b/aiter/ops/triton/configs/gemm/gfx1250-FF-A16W16-fused.json new file mode 100644 index 0000000000..07fb73a6be --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-FF-A16W16-fused.json @@ -0,0 +1,50 @@ +{ + "M_LEQ_4": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "kpack": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json new file mode 100644 index 0000000000..ddd4570995 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16-N8=512-N16=256-K=7168.json @@ -0,0 +1,110 @@ +{ + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 8, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json new file mode 100644 index 0000000000..9fb4bf4ec6 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-A8W8_BLOCKSCALE-A16W16.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4-A16W16-N4=512-N16=256-K=7168.json b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4-A16W16-N4=512-N16=256-K=7168.json new file mode 100644 index 0000000000..2473a3884e --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4-A16W16-N4=512-N16=256-K=7168.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 14 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 8 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 7 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 7 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 4 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 4 + }, + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4-A16W16.json b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4-A16W16.json new file mode 100644 index 0000000000..3811ccc061 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4-A16W16.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4_PRESHUFFLED-A16W16.json b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4_PRESHUFFLED-A16W16.json new file mode 100644 index 0000000000..0efdc2a24e --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-FUSED-GEMM-AFP4WFP4_PRESHUFFLED-A16W16.json @@ -0,0 +1,38 @@ +{ + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16-ATOMIC.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16-ATOMIC.json new file mode 100644 index 0000000000..44271f5634 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16-ATOMIC.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "NUM_KSPLIT": 1, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "kpack": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16-gated.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16-gated.json new file mode 100644 index 0000000000..67448543a2 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16-gated.json @@ -0,0 +1,74 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "kpack": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "kpack": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16.json new file mode 100644 index 0000000000..d9b4fe159a --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W16.json @@ -0,0 +1,80 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W8_BLOCKSCALE.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W8_BLOCKSCALE.json new file mode 100644 index 0000000000..97c355f55b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W8_BLOCKSCALE.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W8_BLOCKSCALE_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W8_BLOCKSCALE_PRESHUFFLED.json new file mode 100644 index 0000000000..97c355f55b --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16W8_BLOCKSCALE_PRESHUFFLED.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16WFP4.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16WFP4.json new file mode 100644 index 0000000000..6929191d87 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16WFP4.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_16": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16WFP4_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16WFP4_PRESHUFFLED.json new file mode 100644 index 0000000000..01951d60ce --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A16WFP4_PRESHUFFLED.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8.json new file mode 100644 index 0000000000..4f38cf580f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_BLOCKSCALE.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_BLOCKSCALE.json new file mode 100644 index 0000000000..21cff21414 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_BLOCKSCALE.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".ca", + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_BLOCKSCALE_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_BLOCKSCALE_PRESHUFFLED.json new file mode 100644 index 0000000000..720f66dad4 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_BLOCKSCALE_PRESHUFFLED.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_PER_TOKEN_SCALE.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_PER_TOKEN_SCALE.json new file mode 100644 index 0000000000..694727842a --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8W8_PER_TOKEN_SCALE.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8WFP4.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8WFP4.json new file mode 100644 index 0000000000..cd4cbd2ff8 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-A8WFP4.json @@ -0,0 +1,80 @@ +{ + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 6, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "kpack": 1, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4.json new file mode 100644 index 0000000000..07a2036905 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4.json @@ -0,0 +1,74 @@ +{ + "M_LEQ_16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 16 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 2, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 32, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4_PRESHUFFLED.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4_PRESHUFFLED.json new file mode 100644 index 0000000000..a4eafd14af --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4_PRESHUFFLED.json @@ -0,0 +1,86 @@ +{ + "M_LEQ_8": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_31": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + }, + "any": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1250-GEMM_PREQUANT-AFP4WFP4.json b/aiter/ops/triton/configs/gemm/gfx1250-GEMM_PREQUANT-AFP4WFP4.json new file mode 100644 index 0000000000..70042715e3 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1250-GEMM_PREQUANT-AFP4WFP4.json @@ -0,0 +1,74 @@ +{ + "M_LEQ_8": { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_32": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_64": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 4 + }, + "any": { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 512, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 4 + } +} diff --git a/aiter/ops/triton/configs/gfx1250-MHA-DEFAULT.json b/aiter/ops/triton/configs/gfx1250-MHA-DEFAULT.json new file mode 100644 index 0000000000..09f3bd7036 --- /dev/null +++ b/aiter/ops/triton/configs/gfx1250-MHA-DEFAULT.json @@ -0,0 +1,80 @@ +{ + "fwd": { + "dropout_or_fp32": { + "BLOCK_M": 32, + "BLOCK_N": 32, + "PRELOAD_V": true, + "waves_per_eu": 1, + "num_warps": 2, + "num_ctas": 1, + "num_stages": 1 + }, + "default": { + "BLOCK_M": 128, + "BLOCK_N": 64, + "PRELOAD_V": true, + "waves_per_eu": 2, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 1 + }, + "pe": { + "BLOCK_M": 256, + "BLOCK_N": 64, + "PRELOAD_V": true, + "waves_per_eu": 2, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 3 + }, + "pe_dropout_or_fp32": { + "BLOCK_M": 256, + "BLOCK_N": 64, + "PRELOAD_V": true, + "waves_per_eu": 2, + "num_warps": 8, + "num_ctas": 1, + "num_stages": 1 + } + }, + "bkwd_fused" : { + "preprocess_kernel": { + "PRE_BLOCK": 128 + }, + "dkdvdq_kernel_N64" : { + "BLOCK_M": 16, + "BLOCK_N": 64, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "BLK_SLICE_FACTOR": 1, + "matrix_instr_nonkdim": 16 + }, + "dkdvdq_kernel_N128" : { + "BLOCK_M": 16, + "BLOCK_N": 128, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 2, + "BLK_SLICE_FACTOR": 1, + "matrix_instr_nonkdim": 16 + } + }, + "bkwd_onekernel" : { + "preprocess_kernel": { + "PRE_BLOCK": 128 + }, + "onekernel" : { + "BLOCK_M1": 32, + "BLOCK_N1": 128, + "BLOCK_M2": 128, + "BLOCK_N2": 32, + "BLK_SLICE_FACTOR": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "num_warps": 4, + "num_ctas": 1, + "num_stages": 1 + } + } +} diff --git a/aiter/ops/triton/gemm/basic/gemm_a8w8.py b/aiter/ops/triton/gemm/basic/gemm_a8w8.py index 39b5092ecf..28c666e4cb 100644 --- a/aiter/ops/triton/gemm/basic/gemm_a8w8.py +++ b/aiter/ops/triton/gemm/basic/gemm_a8w8.py @@ -6,9 +6,10 @@ import triton from aiter.ops.triton._triton_kernels.gemm.basic.gemm_a8w8 import ( _gemm_a8w8_kernel, - _gemm_a8w8_reduce_kernel, _get_config, ) +from aiter.ops.triton.utils.device_info import get_num_xcds + from aiter.ops.triton.utils.logger import AiterTritonLogger _LOGGER = AiterTritonLogger() @@ -23,62 +24,47 @@ def gemm_a8w8( dtype: Optional[float] = torch.bfloat16, y: Optional[torch.Tensor] = None, config: Optional[dict] = None, - skip_reduce: Optional[bool] = False, ): """ Computes 8 bit matrix multiplication Y = (X @ W^T) * (x_scale * w_scale) with optional bias. INT8 inputs are scaled back to higher precision using per-tensor scale factors. Args: - x (torch.Tensor): Input matrix with shape (M, K). - w (torch.Tensor): Weight matrix with shape (N, K), internally transposed. + x (torch.Tensor): INT8 input matrix with shape (M, K). + w (torch.Tensor): INT8 weight matrix with shape (N, K), internally transposed. x_scale (torch.Tensor): Scale factor for x with shape (M, 1) or (M,). w_scale (torch.Tensor): Scale factor for w with shape (1, N) or (N,). bias (Optional[torch.Tensor]): Bias vector with shape (N,). dtype (Optional[torch.dtype]): Output datatype (BF16 or FP16). y (Optional[torch.Tensor]): Pre-allocated output tensor with shape (M, N). config (Optional[dict]): Kernel tuning parameters (BLOCK_SIZE_M, BLOCK_SIZE_N, - BLOCK_SIZE_K, GROUP_SIZE_M, NUM_KSPLIT, SPLITK_BLOCK_SIZE). - skip_reduce (Optional[bool]): Skip reduction of split-K partial results. - Enables kernel fusion with downstream operations (FP8/FP4 quantization, - RMSNorm). Returns shape (NUM_KSPLIT, M, N) instead of (M, N). + BLOCK_SIZE_K, GROUP_SIZE_M). Returns: - torch.Tensor: Output with shape (M, N) or (NUM_KSPLIT, M, N) if skip_reduce=True. + torch.Tensor: Output with shape (M, N) in higher precision format. """ _LOGGER.info( f"GEMM_A8W8: x={tuple(x.shape)} w={tuple(w.shape)} x_scale={tuple(x_scale.shape)} w_scale={tuple(w_scale.shape)}" ) + # Check constraints. assert x.shape[1] == w.shape[1], "Incompatible dimensions!!!" M, K = x.shape N, K = w.shape + # Transpose w (kernel expects (K, N)) w = w.T - if config is None: - config, _ = _get_config(M, N, K) - - if y is None and (config["NUM_KSPLIT"] == 1 or not skip_reduce): + if y is None: y = torch.empty((M, N), dtype=dtype, device=x.device) - if config["NUM_KSPLIT"] > 1: - y_pp = torch.empty( - (config["NUM_KSPLIT"], M, N), - dtype=torch.float32, - device=y.device if y is not None else x.device, - ) - else: - y_pp = None + if config is None: + config, _ = _get_config(M, N, K) - grid = lambda META: ( # noqa: E731 - ( - META["NUM_KSPLIT"] - * triton.cdiv(M, META["BLOCK_SIZE_M"]) - * triton.cdiv(N, META["BLOCK_SIZE_N"]) - ), + grid = ( + triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(N, config["BLOCK_SIZE_N"]), ) _gemm_a8w8_kernel[grid]( x, @@ -86,7 +72,7 @@ def gemm_a8w8( x_scale, w_scale, bias, - y if config["NUM_KSPLIT"] == 1 else y_pp, + y, M, N, K, @@ -94,41 +80,11 @@ def gemm_a8w8( x.stride(1), w.stride(0), w.stride(1), - 0 if config["NUM_KSPLIT"] == 1 else y_pp.stride(0), - y.stride(0) if config["NUM_KSPLIT"] == 1 else y_pp.stride(1), - y.stride(1) if config["NUM_KSPLIT"] == 1 else y_pp.stride(2), - (bias is not None) and (config["NUM_KSPLIT"] == 1), + y.stride(0), + y.stride(1), + bias is not None, + NUM_XCDS=get_num_xcds(), **config, ) - if config["NUM_KSPLIT"] > 1: - if skip_reduce: - return y_pp - - REDUCE_BLOCK_SIZE_M = 32 - REDUCE_BLOCK_SIZE_N = 32 - ACTUAL_KSPLIT = triton.cdiv(K, config["SPLITK_BLOCK_SIZE"]) - - grid_reduce = ( - triton.cdiv(M, REDUCE_BLOCK_SIZE_M), - triton.cdiv(N, REDUCE_BLOCK_SIZE_N), - ) - _gemm_a8w8_reduce_kernel[grid_reduce]( - y_pp, - y, - bias, - M, - N, - y_pp.stride(0), - y_pp.stride(1), - y_pp.stride(2), - y.stride(0), - y.stride(1), - bias is not None, - BLOCK_SIZE_M=REDUCE_BLOCK_SIZE_M, - BLOCK_SIZE_N=REDUCE_BLOCK_SIZE_N, - ACTUAL_KSPLIT=ACTUAL_KSPLIT, - MAX_KSPLIT=triton.next_power_of_2(config["NUM_KSPLIT"]), - ) - return y diff --git a/aiter/ops/triton/gluon/unified_attention_2d.py b/aiter/ops/triton/gluon/unified_attention_2d.py new file mode 100644 index 0000000000..b5f9ca6b52 --- /dev/null +++ b/aiter/ops/triton/gluon/unified_attention_2d.py @@ -0,0 +1,1350 @@ + +import torch +from triton.experimental import gluon +import triton.experimental.gluon.language as gl +from triton.language.core import _aggregate as aggregate +import pytest +from aiter.ops.triton.utils._triton import arch_info +import os + +PRINT_IRS = os.environ.get("PRINT_IRS", "0") == "1" + + +@aggregate +class AsyncKVLoaderConfig: + """Configuration for asynchronous KV loader.""" + + blocked_k: gl.constexpr + blocked_v: gl.constexpr + shared_k_layout: gl.constexpr + shared_v_layout: gl.constexpr + USE_LOAD_BUFFER_OP: gl.constexpr + KV_CACHE_MODIFIER: gl.constexpr + + k_reg_layout: gl.constexpr + v_reg_layout: gl.constexpr + + @gluon.constexpr_function + def __init__(self, cfg): + # Blocked layouts for global-to-shared memory loads + HEAD_SIZE_DIV = cfg.HEAD_SIZE // 8 + # gl.static_assert(WARP_SIZE % HEAD_SIZE_DIV == 0, "WARP_SIZE must be divisible by HEAD_SIZE_DIV") + self.blocked_v = gl.constexpr( + gl.BlockedLayout( + size_per_thread=[1, 8], + threads_per_warp=[cfg.WARP_SIZE // HEAD_SIZE_DIV, HEAD_SIZE_DIV], + warps_per_cta=[cfg.NUM_WARPS, 1], + order=[1, 0], + ) + ) + self.blocked_k = gl.constexpr( + gl.BlockedLayout( + size_per_thread=[8, 1], + threads_per_warp=[HEAD_SIZE_DIV, cfg.WARP_SIZE // HEAD_SIZE_DIV], + warps_per_cta=[1, cfg.NUM_WARPS], + order=[0, 1], + ) + ) + + # Swizzled shared memory layouts for K and V + self.shared_k_layout = gl.constexpr( + gl.SwizzledSharedLayout(vec=8, per_phase=2, max_phase=8, order=[0, 1]) + ) + self.shared_v_layout = gl.constexpr( + gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0]) + ) + + self.KV_CACHE_MODIFIER = cfg.KV_CACHE_MODIFIER + self.USE_LOAD_BUFFER_OP = cfg.USE_LOAD_BUFFER_OP + + self.k_reg_layout = gl.constexpr(cfg.k_layout) + self.v_reg_layout = gl.constexpr(cfg.v_layout) + + +@aggregate +class AsyncKVLoader: + kv_cfg: AsyncKVLoaderConfig + key_cache_ptr: gl.tensor + value_cache_ptr: gl.tensor + block_tables_ptr_shifted: gl.tensor + k_shared: gl.shared_memory_descriptor + v_shared: gl.shared_memory_descriptor + k_base_offset: gl.tensor + v_base_offset: gl.tensor + stride_k_cache_0: gl.tensor + stride_v_cache_0: gl.tensor + @gluon.constexpr_function + def __init__( + self, + kv_cfg, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr_shifted, + k_shared, + v_shared, + k_base_offset, + v_base_offset, + stride_k_cache_0, + stride_v_cache_0, + ): + self.kv_cfg = kv_cfg + self.key_cache_ptr = key_cache_ptr + self.value_cache_ptr = value_cache_ptr + self.k_shared = k_shared + self.v_shared = v_shared + self.k_base_offset = k_base_offset + self.v_base_offset = v_base_offset + self.block_tables_ptr_shifted = block_tables_ptr_shifted + self.stride_k_cache_0 = stride_k_cache_0 + self.stride_v_cache_0 = stride_v_cache_0 + @gluon.jit + def initialize( + cfg, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr_shifted, + kv_head_idx, + num_blocks, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + ): + kv_cfg = AsyncKVLoaderConfig(cfg) + k_shared = gl.allocate_shared_memory( + key_cache_ptr.type.element_ty, + [2, cfg.HEAD_SIZE, cfg.TILE_SIZE], + layout=kv_cfg.shared_k_layout, + ) + v_shared = gl.allocate_shared_memory( + value_cache_ptr.type.element_ty, + [2, cfg.TILE_SIZE, cfg.HEAD_SIZE], + layout=kv_cfg.shared_v_layout, + ) + + # Precompute KV load offsets (constant across tiles) + offs_d_k = gl.arange( + 0, cfg.HEAD_SIZE, layout=gl.SliceLayout(1, kv_cfg.blocked_k) + )[:, None] + offs_n_k = gl.arange( + 0, cfg.TILE_SIZE, layout=gl.SliceLayout(0, kv_cfg.blocked_k) + )[None, :] + k_base_offset = ( + kv_head_idx * stride_k_cache_2 + + offs_d_k * stride_k_cache_3 + + offs_n_k * stride_k_cache_1 + ) + + offs_d_v = gl.arange( + 0, cfg.HEAD_SIZE, layout=gl.SliceLayout(0, kv_cfg.blocked_v) + )[None, :] + offs_n_v = gl.arange( + 0, cfg.TILE_SIZE, layout=gl.SliceLayout(1, kv_cfg.blocked_v) + )[:, None] + v_base_offset = ( + kv_head_idx * stride_v_cache_2 + + offs_d_v * stride_v_cache_3 + + offs_n_v * stride_v_cache_1 + ) + + return AsyncKVLoader( + kv_cfg, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr_shifted, + k_shared, + v_shared, + k_base_offset, + v_base_offset, + stride_k_cache_0, + stride_v_cache_0, + ) + + @gluon.jit + def load_k_to_shared(self, k_offset, buffer_id): + # Async copy K tile from global to shared memory + if self.kv_cfg.USE_LOAD_BUFFER_OP: + gl.amd.cdna4.async_copy.buffer_load_to_shared( + self.k_shared.index(buffer_id), + self.key_cache_ptr, + self.k_base_offset + k_offset, + cache_modifier=self.kv_cfg.KV_CACHE_MODIFIER, + ) + else: + gl.amd.cdna4.async_copy.global_load_to_shared( + self.k_shared.index(buffer_id), + self.key_cache_ptr + self.k_base_offset + k_offset, + cache_modifier=self.kv_cfg.KV_CACHE_MODIFIER, + ) + gl.amd.cdna4.async_copy.commit_group() + + @gluon.jit + def load_v_to_shared(self, v_offset, buffer_id): + # Async copy V tile from global to shared memory + if self.kv_cfg.USE_LOAD_BUFFER_OP: + gl.amd.cdna4.async_copy.buffer_load_to_shared( + self.v_shared.index(buffer_id), + self.value_cache_ptr, + self.v_base_offset + v_offset, + cache_modifier=self.kv_cfg.KV_CACHE_MODIFIER, + ) + else: + gl.amd.cdna4.async_copy.global_load_to_shared( + self.v_shared.index(buffer_id), + self.value_cache_ptr + self.v_base_offset + v_offset, + cache_modifier=self.kv_cfg.KV_CACHE_MODIFIER, + ) + gl.amd.cdna4.async_copy.commit_group() + + @gluon.jit + def load_k_from_shared(self, wait_count, buffer_id): + # Wait for async K copy and load from shared memory + gl.amd.cdna4.async_copy.wait_group(wait_count) + return gl.amd.cdna4.async_copy.load_shared_relaxed( + self.k_shared.index(buffer_id), self.kv_cfg.k_reg_layout + ) + + @gluon.jit + def load_v_from_shared(self, wait_count, buffer_id): + # Wait for async V copy and load from shared memory + gl.amd.cdna4.async_copy.wait_group(wait_count) + return gl.amd.cdna4.async_copy.load_shared_relaxed( + self.v_shared.index(buffer_id), self.kv_cfg.v_reg_layout + ) + + @gluon.jit + def load_block_ids(self, i): + return gl.load(self.block_tables_ptr_shifted + i) * self.stride_k_cache_0 + + +@aggregate +class TDMKVLoaderConfig: + """Configuration for TDM KV loader.""" + + shared_k_layout: gl.constexpr + shared_v_layout: gl.constexpr + USE_LOAD_BUFFER_OP: gl.constexpr + KV_CACHE_MODIFIER: gl.constexpr + + k_reg_layout: gl.constexpr + v_reg_layout: gl.constexpr + BLOCK_SIZE: gl.constexpr + + @gluon.constexpr_function + def __init__(self, cfg): + # Swizzled shared memory layouts for K and V + # self.shared_k_layout = gl.constexpr( + # gl.SwizzledSharedLayout(vec=8, per_phase=2, max_phase=8, order=[0, 1])) + # self.shared_v_layout = gl.constexpr( + # gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0])) + self.shared_k_layout = gl.constexpr( + gl.PaddedSharedLayout.with_identity_for( + [[cfg.HEAD_SIZE, 8]], [cfg.BLOCK_SIZE, cfg.HEAD_SIZE], [1, 0] + ) + ) + self.shared_v_layout = gl.constexpr( + gl.PaddedSharedLayout.with_identity_for( + [[cfg.HEAD_SIZE, 16]], [cfg.BLOCK_SIZE, cfg.HEAD_SIZE], [1, 0] + ) + ) + self.KV_CACHE_MODIFIER = cfg.KV_CACHE_MODIFIER + self.USE_LOAD_BUFFER_OP = cfg.USE_LOAD_BUFFER_OP + + self.k_reg_layout = gl.constexpr(cfg.k_layout) + self.v_reg_layout = gl.constexpr(cfg.v_layout) + self.BLOCK_SIZE = gl.constexpr(cfg.BLOCK_SIZE) + + +@aggregate +class TDMKVLoader: + kv_cfg: TDMKVLoaderConfig + block_tables_ptr_shifted: gl.tensor + k_shared: gl.shared_memory_descriptor + v_shared: gl.shared_memory_descriptor + k_desc: gl.amd.gfx1250.tdm.tensor_descriptor + v_desc: gl.amd.gfx1250.tdm.tensor_descriptor + kv_head_idx: gl.tensor + stride_k_cache_2: gl.tensor + stride_v_cache_2: gl.tensor + + @gluon.constexpr_function + def __init__( + self, + kv_cfg, + block_tables_ptr_shifted, + k_shared, + v_shared, + k_desc, + v_desc, + kv_head_idx, + stride_k_cache_2, + stride_v_cache_2, + ): + self.kv_cfg = kv_cfg + self.k_shared = k_shared + self.v_shared = v_shared + self.k_desc = k_desc + self.v_desc = v_desc + self.block_tables_ptr_shifted = block_tables_ptr_shifted + self.kv_head_idx = kv_head_idx + self.stride_k_cache_2 = stride_k_cache_2 + self.stride_v_cache_2 = stride_v_cache_2 + + @gluon.jit + def initialize( + cfg, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr_shifted, + kv_head_idx, + num_blocks, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + ): + kv_cfg = TDMKVLoaderConfig(cfg) + k_shared = gl.allocate_shared_memory( + key_cache_ptr.type.element_ty, + [2, cfg.BLOCK_SIZE, cfg.HEAD_SIZE], + layout=kv_cfg.shared_k_layout, + ) + v_shared = gl.allocate_shared_memory( + value_cache_ptr.type.element_ty, + [2, cfg.BLOCK_SIZE, cfg.HEAD_SIZE], + layout=kv_cfg.shared_v_layout, + ) + + k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=key_cache_ptr, + shape=(num_blocks * cfg.BLOCK_SIZE, cfg.NUM_KV_HEADS * cfg.HEAD_SIZE), + strides=(stride_k_cache_1, stride_k_cache_3), + block_shape=(cfg.BLOCK_SIZE, cfg.HEAD_SIZE), + layout=kv_cfg.shared_k_layout, + ) + v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=value_cache_ptr, + shape=(num_blocks * cfg.BLOCK_SIZE, cfg.NUM_KV_HEADS * cfg.HEAD_SIZE), + strides=(stride_v_cache_1, stride_v_cache_3), + block_shape=(cfg.BLOCK_SIZE, cfg.HEAD_SIZE), + layout=kv_cfg.shared_v_layout, + ) + + return TDMKVLoader( + kv_cfg, + block_tables_ptr_shifted, + k_shared, + v_shared, + k_desc, + v_desc, + kv_head_idx, + stride_k_cache_2, + stride_v_cache_2, + ) + + @gluon.jit + def load_k_to_shared(self, k_offset, buffer_id): + offsets = [ + (k_offset * (self.kv_cfg.BLOCK_SIZE)).to(gl.int32), + (self.kv_head_idx * self.stride_k_cache_2).to(gl.int32), + ] + gl.amd.gfx1250.tdm.async_load( + self.k_desc, offsets, self.k_shared.index(buffer_id) + ) + + @gluon.jit + def load_v_to_shared(self, v_offset, buffer_id): + offsets = [ + (v_offset * (self.kv_cfg.BLOCK_SIZE)).to(gl.int32), + (self.kv_head_idx * self.stride_v_cache_2).to(gl.int32), + ] + gl.amd.gfx1250.tdm.async_load( + self.v_desc, offsets, self.v_shared.index(buffer_id) + ) + + @gluon.jit + def load_k_from_shared(self, wait_count, buffer_id): + gl.amd.gfx1250.tdm.async_wait(wait_count) + return ( + self.k_shared.index(buffer_id) + .permute([1, 0]) + .load(layout=self.kv_cfg.k_reg_layout) + ) + + @gluon.jit + def load_v_from_shared(self, wait_count, buffer_id): + gl.amd.gfx1250.tdm.async_wait(wait_count) + return self.v_shared.index(buffer_id).load(layout=self.kv_cfg.v_reg_layout) + + @gluon.jit + def load_block_ids(self, i): + return gl.load(self.block_tables_ptr_shifted + i) + +@aggregate +class TDMGatherKVLoaderConfig: + """Configuration for TDM KV loader.""" + + shared_k_layout: gl.constexpr + shared_v_layout: gl.constexpr + USE_LOAD_BUFFER_OP: gl.constexpr + KV_CACHE_MODIFIER: gl.constexpr + + k_reg_layout: gl.constexpr + v_reg_layout: gl.constexpr + BLOCK_SIZE: gl.constexpr + HEAD_SIZE: gl.constexpr + NUM_KV_HEADS: gl.constexpr + NUM_KV_BLOCKS: gl.constexpr + TILE_SIZE: gl.constexpr + gather_ids_layout: gl.constexpr + @gluon.constexpr_function + def __init__(self, cfg): + # Swizzled shared memory layouts for K and V + self.shared_k_layout = gl.constexpr( + gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0])) + self.shared_v_layout = gl.constexpr( + gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0])) + + self.KV_CACHE_MODIFIER = cfg.KV_CACHE_MODIFIER + self.USE_LOAD_BUFFER_OP = cfg.USE_LOAD_BUFFER_OP + + self.k_reg_layout = gl.constexpr(cfg.k_layout) + self.v_reg_layout = gl.constexpr(cfg.v_layout) + self.BLOCK_SIZE = gl.constexpr(cfg.BLOCK_SIZE) + self.HEAD_SIZE = gl.constexpr(cfg.HEAD_SIZE) + self.NUM_KV_BLOCKS = gl.constexpr(cfg.NUM_KV_BLOCKS) + self.TILE_SIZE = gl.constexpr(cfg.TILE_SIZE) + self.NUM_KV_HEADS = gl.constexpr(cfg.NUM_KV_HEADS) + + self.gather_ids_layout = gl.constexpr( + gl.BlockedLayout( + size_per_thread=[cfg.NUM_KV_BLOCKS], + threads_per_warp=[cfg.WARP_SIZE], + warps_per_cta=[cfg.NUM_WARPS], + order=[0], + ) + ) +@aggregate +class TDMGatherKVLoader: + kv_cfg: TDMGatherKVLoaderConfig + block_tables_ptr_shifted: gl.tensor + k_shared: gl.shared_memory_descriptor + v_shared: gl.shared_memory_descriptor + k_desc: gl.amd.gfx1250.tdm.tensor_descriptor + v_desc: gl.amd.gfx1250.tdm.tensor_descriptor + kv_head_idx: gl.tensor + stride_k_cache_2: gl.tensor + stride_v_cache_2: gl.tensor + @gluon.constexpr_function + def __init__( + self, + kv_cfg, + block_tables_ptr_shifted, + k_shared, + v_shared, + k_desc, + v_desc, + kv_head_idx, + stride_k_cache_2, + stride_v_cache_2, + ): + self.kv_cfg = kv_cfg + self.k_shared = k_shared + self.v_shared = v_shared + self.k_desc = k_desc + self.v_desc = v_desc + self.block_tables_ptr_shifted = block_tables_ptr_shifted + self.kv_head_idx = kv_head_idx + self.stride_k_cache_2 = stride_k_cache_2 + self.stride_v_cache_2 = stride_v_cache_2 + + @gluon.jit + def initialize( + cfg, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr_shifted, + kv_head_idx, + num_blocks, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + ): + kv_cfg = TDMGatherKVLoaderConfig(cfg) + k_shared = gl.allocate_shared_memory( + key_cache_ptr.type.element_ty, + [2, cfg.NUM_KV_BLOCKS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE], + layout=kv_cfg.shared_k_layout, + ) + v_shared = gl.allocate_shared_memory( + value_cache_ptr.type.element_ty, + [2, cfg.NUM_KV_BLOCKS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE], + layout=kv_cfg.shared_v_layout, + ) + + k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=key_cache_ptr, + shape=(num_blocks * cfg.NUM_KV_HEADS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE), + strides=(stride_k_cache_1, stride_k_cache_3), + block_shape=(cfg.NUM_KV_BLOCKS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE), + layout=kv_cfg.shared_k_layout, + ) + v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=value_cache_ptr, + shape=(num_blocks * cfg.NUM_KV_HEADS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE), + strides=(stride_v_cache_1, stride_v_cache_3), + block_shape=(cfg.NUM_KV_BLOCKS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE), + layout=kv_cfg.shared_v_layout, + ) + + return TDMGatherKVLoader( + kv_cfg, + block_tables_ptr_shifted, + k_shared, + v_shared, + k_desc, + v_desc, + kv_head_idx, + stride_k_cache_2, + stride_v_cache_2, + ) + + @gluon.jit + def load_k_to_shared(self, k_offset, buffer_id): + src_row_indices = (k_offset * self.kv_cfg.NUM_KV_HEADS + self.kv_head_idx).to( + gl.int32 + ) + + gl.amd.gfx1250.tdm.async_gather( + self.k_desc, src_row_indices, 0, self.k_shared.index(buffer_id) + ) + + @gluon.jit + def load_v_to_shared(self, v_offset, buffer_id): + src_row_indices = (v_offset * self.kv_cfg.NUM_KV_HEADS + self.kv_head_idx).to( + gl.int32 + ) + gl.amd.gfx1250.tdm.async_gather( + self.v_desc, src_row_indices, 0, self.v_shared.index(buffer_id) + ) + + @gluon.jit + def load_k_from_shared(self, wait_count, buffer_id): + gl.amd.gfx1250.tdm.async_wait(wait_count) + return ( + self.k_shared.index(buffer_id) + .reshape([self.kv_cfg.TILE_SIZE, self.kv_cfg.HEAD_SIZE]) + .permute([1, 0]) + .load(layout=self.kv_cfg.k_reg_layout) + ) + + @gluon.jit + def load_v_from_shared(self, wait_count, buffer_id): + gl.amd.gfx1250.tdm.async_wait(wait_count) + return (self.v_shared.index(buffer_id) + .reshape([self.kv_cfg.TILE_SIZE, self.kv_cfg.HEAD_SIZE]) + .load(layout=self.kv_cfg.v_reg_layout) + ) + + @gluon.jit + def load_block_ids(self, i): + offs = gl.arange(0, self.kv_cfg.NUM_KV_BLOCKS, layout=self.kv_cfg.gather_ids_layout) + return gl.load(self.block_tables_ptr_shifted + i * self.kv_cfg.NUM_KV_BLOCKS + offs) + + +@aggregate +class AttentionConfig: + """Configuration for unified attention layouts and derived constants (CDNA4).""" + + # Constants + ARCH_NAME: gl.constexpr + HEAD_SIZE: gl.constexpr + BLOCK_SIZE: gl.constexpr + BLOCK_M: gl.constexpr + TILE_SIZE: gl.constexpr + NUM_KV_BLOCKS: gl.constexpr + NUM_QUERY_HEADS: gl.constexpr + NUM_KV_HEADS: gl.constexpr + SLIDING_WINDOW: gl.constexpr + NUM_QUERIES_PER_KV: gl.constexpr + BLOCK_Q: gl.constexpr + RCP_LN2: gl.constexpr + QK_SCALE: gl.constexpr + WARP_SIZE: gl.constexpr + NUM_WARPS: gl.constexpr + # Operator layouts + qk_layout: gl.constexpr + pv_layout: gl.constexpr + + # Dot operand layouts + q_layout: gl.constexpr + k_layout: gl.constexpr + v_layout: gl.constexpr + p_layout: gl.constexpr + + # Blocked layouts for global-to-shared loads + blocked_q: gl.constexpr + + Q_CACHE_MODIFIER: gl.constexpr + KV_CACHE_MODIFIER: gl.constexpr + + USE_LOAD_BUFFER_OP: gl.constexpr + USE_STORE_BUFFER_OP: gl.constexpr + ALL_DECODE: gl.constexpr + @gluon.constexpr_function + def __init__( + self, + ARCH_NAME, + NUM_WARPS, + HEAD_SIZE, + BLOCK_SIZE, + TILE_SIZE, + BLOCK_M, + BLOCK_Q, + NUM_QUERY_HEADS, + NUM_KV_HEADS, + SLIDING_WINDOW, + SCALE, + USE_LOAD_BUFFER_OP, + USE_STORE_BUFFER_OP, + ALL_DECODE, + ): + + # Constants + self.HEAD_SIZE = gl.constexpr(HEAD_SIZE) + self.BLOCK_SIZE = gl.constexpr(BLOCK_SIZE) + self.BLOCK_M = gl.constexpr(BLOCK_M) + self.NUM_QUERY_HEADS = gl.constexpr(NUM_QUERY_HEADS) + self.NUM_KV_HEADS = gl.constexpr(NUM_KV_HEADS) + self.SLIDING_WINDOW = gl.constexpr(SLIDING_WINDOW) + # Derived constants + self.NUM_QUERIES_PER_KV = gl.constexpr(NUM_QUERY_HEADS // NUM_KV_HEADS) + self.BLOCK_Q = gl.constexpr(BLOCK_Q) + self.NUM_KV_BLOCKS = gl.constexpr(TILE_SIZE // BLOCK_SIZE) + self.TILE_SIZE = gl.constexpr(TILE_SIZE) + self.RCP_LN2 = gl.constexpr(1.4426950408889634) + self.QK_SCALE = gl.constexpr(SCALE * self.RCP_LN2) + self.USE_LOAD_BUFFER_OP = gl.constexpr(USE_LOAD_BUFFER_OP) + self.USE_STORE_BUFFER_OP = gl.constexpr(USE_STORE_BUFFER_OP) + self.ALL_DECODE = gl.constexpr(ALL_DECODE) + self.ARCH_NAME = gl.constexpr(ARCH_NAME) + self.WARP_SIZE = gl.constexpr(32 if ARCH_NAME == "gfx1250" else 64) + self.NUM_WARPS = gl.constexpr(NUM_WARPS) + # Operator layouts (gfx1250 WMMA) + if ARCH_NAME == "gfx1250": + assert NUM_WARPS == 4 or NUM_WARPS == 8 + + if NUM_WARPS == 4: + warp_bases = [[1, 0], [2, 0]] + else: + warp_bases = [[1, 0], [2, 0], [4, 0]] + self.qk_layout = gl.constexpr( + gl.amd.AMDWMMALayout( + version=3, + transposed=True, + instr_shape=[16, 16, 32], + warp_bases=warp_bases, + ) + ) + else: + self.qk_layout = gl.constexpr( + gl.amd.AMDMFMALayout( + version=4, + transposed=True, + instr_shape=[32, 32, 16], + warps_per_cta=[NUM_WARPS, 1], + ) + ) + self.pv_layout = self.qk_layout + + # Dot operand layouts + self.q_layout = gl.constexpr(gl.DotOperandLayout(0, self.qk_layout, 8)) + self.k_layout = gl.constexpr(gl.DotOperandLayout(1, self.qk_layout, 8)) + self.v_layout = gl.constexpr(gl.DotOperandLayout(1, self.pv_layout, 8)) + self.p_layout = gl.constexpr(gl.DotOperandLayout(0, self.pv_layout, 8)) + + # Blocked layouts for global-to-shared memory loads + HEAD_SIZE_DIV = HEAD_SIZE // 8 + self.blocked_q = gl.constexpr( + gl.BlockedLayout( + size_per_thread=[1, 8], + threads_per_warp=[self.WARP_SIZE // 8, 8], + warps_per_cta=[NUM_WARPS, 1], + order=[1, 0], + ) + ) + self.Q_CACHE_MODIFIER = gl.constexpr(".cg") + self.KV_CACHE_MODIFIER = gl.constexpr(".cg") if ALL_DECODE else gl.constexpr("") + + +@aggregate +class AttentionProgram: + """Program state and core operations for the unified attention kernel.""" + + cfg: AttentionConfig + + q: gl.tensor + + key_cache_ptr: gl.tensor + value_cache_ptr: gl.tensor + output_ptr: gl.tensor + + tile_start: gl.tensor + tile_end: gl.tensor + safe_tile_end: gl.tensor + # query_pos_qk: gl.tensor + query_mask_qk: gl.tensor + # context_len: gl.tensor + context_len_q_pos_qk: gl.tensor + + @gluon.constexpr_function + def __init__( + self, + cfg, + q, + key_cache_ptr, + value_cache_ptr, + output_ptr, + tile_start, + tile_end, + safe_tile_end, + query_mask_qk, + context_len_q_pos_qk, + ): + self.cfg = cfg + self.q = q + self.key_cache_ptr = key_cache_ptr + self.value_cache_ptr = value_cache_ptr + self.output_ptr = output_ptr + self.tile_start = tile_start + self.tile_end = tile_end + self.safe_tile_end = safe_tile_end + self.query_mask_qk = query_mask_qk + self.context_len_q_pos_qk = context_len_q_pos_qk + + @gluon.jit + def initialize( + cfg, + q, + key_cache_ptr, + value_cache_ptr, + output_ptr, + max_seq_prefix_len, + q_block_local_idx, + cur_batch_query_len, + context_len, + query_pos, + query_mask, + ): + # Calculate tile range + num_tiles = (max_seq_prefix_len + cfg.TILE_SIZE - 1) // cfg.TILE_SIZE + tile_start = 0 + tile_end = num_tiles + if cfg.SLIDING_WINDOW > 0: + qpos_lo = q_block_local_idx * cfg.BLOCK_Q + qpos_hi = gl.minimum( + qpos_lo + (cfg.BLOCK_M - 1) // cfg.NUM_QUERIES_PER_KV, + cur_batch_query_len - 1, + ) + first_allowed_key = context_len + qpos_lo - cfg.SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + tile_start = gl.maximum(0, first_allowed_key // cfg.TILE_SIZE) + tile_end = gl.minimum((last_allowed_key // cfg.TILE_SIZE) + 1, num_tiles) + + query_pos_qk = gl.convert_layout(query_pos, gl.SliceLayout(1, cfg.qk_layout))[ + :, None + ] + query_mask_qk = gl.convert_layout(query_mask, cfg.qk_layout) + + context_len_q_pos_qk = context_len + query_pos_qk + + # Compute the tile index beyond which causal masking is needed. + # min causal pos = context_len + first query pos in block + # Tiles j < safe_tile_end have all KV positions within causal range + # for every query row, so apply_mask_qk can be skipped. + min_causal_pos = context_len + q_block_local_idx * cfg.BLOCK_Q + safe_tile_end = (min_causal_pos + 1) // cfg.TILE_SIZE + safe_tile_end = gl.minimum(safe_tile_end, tile_end) + safe_tile_end = gl.maximum(safe_tile_end, tile_start) + + return AttentionProgram( + cfg, + q, + key_cache_ptr, + value_cache_ptr, + output_ptr, + tile_start, + tile_end, + safe_tile_end, + query_mask_qk, + context_len_q_pos_qk, + ) + + + @gluon.jit + def load_q_from_global( + self, + query_ptr, + q_block_local_idx, + cur_batch_in_all_start_index, + kv_head_idx, + cur_batch_query_len, + query_stride_0, + query_stride_1, + ): + """Load Q from global memory.""" + offs_m = gl.arange( + 0, self.cfg.BLOCK_M, layout=gl.SliceLayout(1, self.cfg.q_layout) + ) + offs_d = gl.arange( + 0, self.cfg.HEAD_SIZE, layout=gl.SliceLayout(0, self.cfg.q_layout) + ) + query_pos = ( + q_block_local_idx * self.cfg.BLOCK_Q + offs_m // self.cfg.NUM_QUERIES_PER_KV + ) + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = ( + kv_head_idx * self.cfg.NUM_QUERIES_PER_KV + + offs_m % self.cfg.NUM_QUERIES_PER_KV + ) + + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < self.cfg.NUM_QUERY_HEADS + query_mask = query_mask_0[:, None] & query_mask_1[:, None] + + q_offs = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + if self.cfg.USE_STORE_BUFFER_OP: + q = gl.amd.cdna4.buffer_load( + query_ptr + q_offs, + mask=query_mask, + other=0.0, + cache_modifier=self.cfg.Q_CACHE_MODIFIER, + ) + else: + q = gl.load( + query_ptr + q_offs, + mask=query_mask, + other=0.0, + cache_modifier=self.cfg.Q_CACHE_MODIFIER, + ) + return q, query_pos, query_mask + + @gluon.jit + def compute_qk(self, k): + S = gl.zeros( + [self.cfg.BLOCK_M, self.cfg.TILE_SIZE], + dtype=gl.float32, + layout=self.cfg.qk_layout, + ) + if self.cfg.ARCH_NAME == "gfx1250": + return gl.amd.gfx1250.wmma(self.q, k, S) * self.cfg.QK_SCALE + else: + return gl.amd.cdna4.mfma(self.q, k, S) * self.cfg.QK_SCALE + + @gluon.jit + def apply_mask_qk(self, S, j): + seq_offset = ( + j * self.cfg.TILE_SIZE + + gl.arange( + 0, self.cfg.TILE_SIZE, layout=gl.SliceLayout(0, self.cfg.qk_layout) + )[None, :] + ) + + seq_mask = seq_offset <= self.context_len_q_pos_qk + if self.cfg.SLIDING_WINDOW > 0: + seq_mask = seq_mask & ( + (self.context_len_q_pos_qk - seq_offset) < self.cfg.SLIDING_WINDOW + ) + full_mask = seq_mask + S = gl.where(full_mask, S, float("-inf")) + return S + + @gluon.jit + def softmax_part0(self, S, M): + m_ij = gl.maximum(M, gl.max(S, axis=1)) + m_ij = gl.where(m_ij > float("-inf"), m_ij, 0.0) + p = gl.exp2(S - m_ij[:, None]) + alpha = gl.exp2(M - m_ij) + return p, alpha, m_ij + + @gluon.jit + def softmax_part1(self, p, L, acc, alpha): + l_ij = gl.sum(p, 1) + acc = acc * alpha[:, None] + p = p.to(gl.bfloat16, fp_downcast_rounding="rtz") + L = L * alpha + l_ij + return p, L, acc + + @gluon.jit + def compute_pv(self, p, v, acc): + p = gl.convert_layout(p, self.cfg.p_layout) + if self.cfg.ARCH_NAME == "gfx1250": + return gl.amd.gfx1250.wmma(p, v, acc) + else: + return gl.amd.cdna4.mfma(p, v, acc) + + @gluon.jit + def store_output( + self, + out, + q_block_local_idx, + cur_batch_in_all_start_index, + kv_head_idx, + cur_batch_query_len, + output_stride_0, + output_stride_1, + ): + offs_m_out = gl.arange( + 0, self.cfg.BLOCK_M, layout=gl.SliceLayout(1, self.cfg.blocked_q) + ) + offs_d_out = gl.arange( + 0, self.cfg.HEAD_SIZE, layout=gl.SliceLayout(0, self.cfg.blocked_q) + ) + + query_pos_out = ( + q_block_local_idx * self.cfg.BLOCK_Q + + offs_m_out // self.cfg.NUM_QUERIES_PER_KV + ) + query_offset_0_out = cur_batch_in_all_start_index + query_pos_out + query_offset_1_out = ( + kv_head_idx * self.cfg.NUM_QUERIES_PER_KV + + offs_m_out % self.cfg.NUM_QUERIES_PER_KV + ) + + o_offs = ( + query_offset_0_out[:, None] * output_stride_0 + + query_offset_1_out[:, None] * output_stride_1 + + offs_d_out[None, :] + ) + + query_mask_0_out = query_pos_out < cur_batch_query_len + query_mask_1_out = query_offset_1_out < self.cfg.NUM_QUERY_HEADS + o_mask = query_mask_0_out[:, None] & query_mask_1_out[:, None] + casted_out = out.to(self.output_ptr.dtype.element_ty) + casted_out = gl.convert_layout(casted_out, self.cfg.blocked_q) + if self.cfg.USE_STORE_BUFFER_OP: + gl.amd.cdna4.buffer_store(casted_out, self.output_ptr, o_offs, mask=o_mask) + else: + gl.store(self.output_ptr + o_offs, casted_out, mask=o_mask) + + +@gluon.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: gl.constexpr, +): + """Binary search to find the sequence index for a given query block index.""" + left = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = gl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + return left - 1 + + +@gluon.jit +def kernel_unified_attention_2d( + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] + sink_ptr, # [num_query_heads] + output_ptr, # [num_tokens, num_query_heads, head_size] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + query_start_len_ptr, # [num_seqs+1] + query_stride_0, + query_stride_1, + output_stride_0, + output_stride_1, + USE_SINKS: gl.constexpr, # bool + SLIDING_WINDOW: gl.constexpr, # int + num_blocks, + stride_k_cache_0: gl.int32, + stride_k_cache_1: gl.int32, + stride_k_cache_2: gl.int32, + stride_k_cache_3: gl.constexpr, + stride_v_cache_0: gl.int32, + stride_v_cache_1: gl.int32, + stride_v_cache_2: gl.int32, + stride_v_cache_3: gl.constexpr, + block_table_stride, + num_seqs: gl.constexpr, + SCALE: gl.constexpr, + NUM_QUERY_HEADS: gl.constexpr, + NUM_KV_HEADS: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + TILE_SIZE: gl.constexpr, + HEAD_SIZE: gl.constexpr, + BLOCK_Q: gl.constexpr, + BLOCK_M: gl.constexpr, + ARCH_NAME: gl.constexpr, + USE_LOAD_BUFFER_OP: gl.constexpr = False, + USE_STORE_BUFFER_OP: gl.constexpr = False, + ALL_DECODE: gl.constexpr = False, + USE_TDM: gl.constexpr = False, +): + NUM_WARPS: gl.constexpr = gl.num_warps() + # Workgroup offsets + kv_head_idx = gl.program_id(0) + q_block_global_idx = gl.num_programs(1) - 1 - gl.program_id(1) + # Build config with all layouts and derived constants + cfg = AttentionConfig( + ARCH_NAME, + NUM_WARPS, + HEAD_SIZE, + BLOCK_SIZE, + TILE_SIZE, + BLOCK_M, + BLOCK_Q, + NUM_QUERY_HEADS, + NUM_KV_HEADS, + SLIDING_WINDOW, + SCALE, + USE_LOAD_BUFFER_OP, + USE_STORE_BUFFER_OP, + ALL_DECODE, + ) + + # Cast strides to int64 when not using buffer ops + if not USE_LOAD_BUFFER_OP and not USE_TDM: + stride_k_cache_0 = stride_k_cache_0.to(gl.int64) + stride_k_cache_1 = stride_k_cache_1.to(gl.int64) + stride_k_cache_2 = stride_k_cache_2.to(gl.int64) + stride_v_cache_0 = stride_v_cache_0.to(gl.int64) + stride_v_cache_1 = stride_v_cache_1.to(gl.int64) + stride_v_cache_2 = stride_v_cache_2.to(gl.int64) + + if not USE_STORE_BUFFER_OP: + output_stride_0 = output_stride_0.to(gl.int64) + output_stride_1 = output_stride_1.to(gl.int64) + + # Find sequence index using binary search + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, cfg.BLOCK_Q + ) + + # Get query block start and local index + cur_batch_in_all_start_index = gl.load(query_start_len_ptr + seq_idx) + q_block_start_idx = cur_batch_in_all_start_index // cfg.BLOCK_Q + seq_idx + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_stop_index = gl.load(query_start_len_ptr + seq_idx + 1) + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + if q_block_local_idx * cfg.BLOCK_Q >= cur_batch_query_len: + return + + offs_m = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, cfg.q_layout)) + offs_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, cfg.q_layout)) + query_pos = q_block_local_idx * cfg.BLOCK_Q + offs_m // cfg.NUM_QUERIES_PER_KV + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = ( + kv_head_idx * cfg.NUM_QUERIES_PER_KV + offs_m % cfg.NUM_QUERIES_PER_KV + ) + + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < NUM_QUERY_HEADS + query_mask = query_mask_0[:, None] & query_mask_1[:, None] + + q_offs = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + + q = gl.load( + query_ptr + q_offs, + mask=query_mask, + other=0.0, + cache_modifier=cfg.Q_CACHE_MODIFIER, + ) + + seq_len = gl.load(seq_lens_ptr + seq_idx) + context_len = seq_len - cur_batch_query_len + block_tables_ptr_shifted = block_tables_ptr + seq_idx * block_table_stride + + # Max KV position that any query in this block attends to + max_seq_prefix_len = ( + context_len + + q_block_local_idx * cfg.BLOCK_Q + + (BLOCK_M - 1) // cfg.NUM_QUERIES_PER_KV + + 1 + ) + max_seq_prefix_len = gl.minimum(max_seq_prefix_len, seq_len) + + # build program + pgm = AttentionProgram.initialize( + cfg, + q, + key_cache_ptr, + value_cache_ptr, + output_ptr, + max_seq_prefix_len, + q_block_local_idx, + cur_batch_query_len, + context_len, + query_pos, + query_mask, + ) + if USE_TDM: + if TILE_SIZE == BLOCK_SIZE: + KVLoader: gl.constexpr = TDMKVLoader + else: + KVLoader: gl.constexpr = TDMGatherKVLoader + else: + gl.static_assert(TILE_SIZE == BLOCK_SIZE, "With async kv loader, TILE_SIZE must be equal to BLOCK_SIZE") + KVLoader: gl.constexpr = AsyncKVLoader + + kv_loader = KVLoader.initialize( + cfg, + key_cache_ptr, + value_cache_ptr, + block_tables_ptr_shifted, + kv_head_idx, + num_blocks, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + ) + + # Initialize accumulators + if not USE_SINKS: + M = gl.full( + [BLOCK_M], + float("-inf"), + dtype=gl.float32, + layout=gl.SliceLayout(1, cfg.pv_layout), + ) + else: + offs_m_pv = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, cfg.pv_layout)) + query_offset_1_pv = ( + kv_head_idx * cfg.NUM_QUERIES_PER_KV + offs_m_pv % cfg.NUM_QUERIES_PER_KV + ) + query_mask_1_pv = query_offset_1_pv < NUM_QUERY_HEADS + # Prescale with RCP_LN2, needed for exp2 + M = ( + gl.load( + sink_ptr + query_offset_1_pv, + mask=query_mask_1_pv, + other=float("-inf"), + ).to(dtype=gl.float32) + * cfg.RCP_LN2 + ) + + L = gl.full( + [BLOCK_M], 1.0, dtype=gl.float32, layout=gl.SliceLayout(1, cfg.pv_layout) + ) + acc = gl.zeros([BLOCK_M, HEAD_SIZE], dtype=gl.float32, layout=cfg.pv_layout) + # TODO (cagri): Assuming stride_k_cache_0 == stride_v_cache_0 + # Prologue: load first tile's block index and issue async K, V loads + physical_block_idx = kv_loader.load_block_ids(pgm.tile_start) + + # rotating buffer index logic + # TODO (cagri): Loop unrolling can get rid of this + buffer_id: gl.int32 = 0 + kv_loader.load_k_to_shared(physical_block_idx, buffer_id=buffer_id) + kv_loader.load_v_to_shared(physical_block_idx, buffer_id=buffer_id) + # Main attention loop over KV tiles (staged, num_stages=2) + for j in range(pgm.tile_start, pgm.tile_end - 1): + next_physical_block_idx = kv_loader.load_block_ids(j + 1) + + k = kv_loader.load_k_from_shared(wait_count=1, buffer_id=buffer_id) + + # Prefetch next tile (shared is free since k, v are in registers) + kv_loader.load_k_to_shared(next_physical_block_idx, buffer_id=1 - buffer_id) + kv_loader.load_v_to_shared(next_physical_block_idx, buffer_id=1 - buffer_id) + + # Compute attention for current tile + S = pgm.compute_qk(k) + if j >= pgm.safe_tile_end or SLIDING_WINDOW > 0: + S = pgm.apply_mask_qk(S, j) + p, alpha, M = pgm.softmax_part0(S, M) + p, L, acc = pgm.softmax_part1(p, L, acc, alpha) + v = kv_loader.load_v_from_shared(wait_count=2, buffer_id=buffer_id) + acc = pgm.compute_pv(p, v, acc) + buffer_id = 1 - buffer_id + + # Load k_i, v_i from shared into registers + k = kv_loader.load_k_from_shared(wait_count=1, buffer_id=buffer_id) + # Compute attention for current tile + S = pgm.compute_qk(k) + S = pgm.apply_mask_qk(S, pgm.tile_end - 1) + p, alpha, M = pgm.softmax_part0(S, M) + p, L, acc = pgm.softmax_part1(p, L, acc, alpha) + v = kv_loader.load_v_from_shared(wait_count=0, buffer_id=buffer_id) + acc = pgm.compute_pv(p, v, acc) + # Normalize and store output + l_recip = 1 / L[:, None] + acc = acc * l_recip + + pgm.store_output( + acc, + q_block_local_idx, + cur_batch_in_all_start_index, + kv_head_idx, + cur_batch_query_len, + output_stride_0, + output_stride_1, + ) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + sinks, + new_kv_layout=False, + num_kv_blocks=1, + use_tdm=False, +): + """ + Run the unified attention kernel with paged KV cache. + + Args: + q: Query tensor [num_tokens, num_query_heads, head_size] + k: Key cache [num_blks, blk_size, num_kv_heads, head_size] + v: Value cache [num_blks, blk_size, num_kv_heads, head_size] + out: Output tensor [num_tokens, num_query_heads, head_size] + cu_seqlens_q: Cumulative query lengths [num_seqs + 1] + seqused_k: Sequence lengths [num_seqs] + max_seqlen_q: Maximum query length + max_seqlen_k: Maximum key/value length + softmax_scale: Attention scale factor + causal: Whether to use causal masking + window_size: Sliding window size + block_table: Block tables [num_seqs, max_num_blocks_per_seq] + softcap: Softcap value + q_descale: Query scale + k_descale: Key scale + v_descale: Value scale + sinks: Sinks tensor [num_query_heads,] + """ + NUM_SEQS = len(seqused_k) + NUM_Q_HEADS = q.shape[1] + HEAD_SIZE = q.shape[2] + num_blocks = k.shape[0] + if new_kv_layout: + assert use_tdm, "With new kv layout, USE_TDM must be True" + BLOCK_SIZE = k.shape[2] + NUM_KV_HEADS = k.shape[1] + else: + assert num_kv_blocks == 1, "With original kv layout, num_kv_blocks must be 1" + BLOCK_SIZE = k.shape[1] + NUM_KV_HEADS = k.shape[2] + # if use_tdm: + # assert ARCH_NAME == "gfx1250", "With TDM, ARCH must be gfx1250" + BLOCK_M = 128 + SLIDING_WINDOW = 1 + window_size[0] + ALL_DECODE = max_seqlen_q == 1 + NUM_QUERIES_PER_KV = NUM_Q_HEADS // NUM_KV_HEADS + BLOCK_Q = BLOCK_M // NUM_QUERIES_PER_KV + total_query_blocks = q.shape[0] // BLOCK_Q + NUM_SEQS + assert num_kv_blocks & (num_kv_blocks - 1) == 0, "num_kv_blocks must be a power of 2" + TILE_SIZE = num_kv_blocks * BLOCK_SIZE + ARCH_NAME = arch_info.get_arch() + NUM_WARPS = 4 + kv_size = k.nelement() * k.element_size() + MAX_INT32 = 2**31 - 1 + USE_LOAD_BUFFER_OP = ARCH_NAME != "gfx1250" and kv_size <= MAX_INT32 + USE_STORE_BUFFER_OP = out.nelement() * out.element_size() <= MAX_INT32 + #waves_per_eu = 2 if HEAD_SIZE < 128 else 2 + waves_per_eu = 2 + grid = (NUM_KV_HEADS, total_query_blocks) + attn_kernel = kernel_unified_attention_2d[grid]( + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + output_ptr=out, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + query_start_len_ptr=cu_seqlens_q, + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=SLIDING_WINDOW, + num_blocks=num_blocks, + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + block_table_stride=block_table.stride(0), + num_seqs=NUM_SEQS, + SCALE=softmax_scale, + NUM_QUERY_HEADS=NUM_Q_HEADS, + NUM_KV_HEADS=NUM_KV_HEADS, + BLOCK_SIZE=BLOCK_SIZE, + TILE_SIZE=TILE_SIZE, + HEAD_SIZE=HEAD_SIZE, + BLOCK_Q=BLOCK_Q, + BLOCK_M=BLOCK_M, + ARCH_NAME=ARCH_NAME, + waves_per_eu=waves_per_eu, + USE_LOAD_BUFFER_OP=USE_LOAD_BUFFER_OP, + USE_STORE_BUFFER_OP=USE_STORE_BUFFER_OP, + num_warps=NUM_WARPS, + ALL_DECODE=ALL_DECODE, + USE_TDM=use_tdm, + + ) + + if PRINT_IRS and getattr(unified_attention, "print", False) == False: + setattr(unified_attention, "print", True) + print_irs_to_files(attn_kernel, f"unified_attention_2d_gluon_block_m_{BLOCK_M}_tile_size_{TILE_SIZE}_block_size_{BLOCK_SIZE}_head_size_{HEAD_SIZE}") + return attn_kernel + + +def print_irs_to_files(compiled_kernel, prefix): + for key in compiled_kernel.asm.keys(): + with open(f"{prefix}_{key}.txt", "w") as fptr: + print(compiled_kernel.asm[key], file=fptr) diff --git a/aiter/ops/triton/gluon/unified_attention_3d.py b/aiter/ops/triton/gluon/unified_attention_3d.py new file mode 100644 index 0000000000..17b1de0afe --- /dev/null +++ b/aiter/ops/triton/gluon/unified_attention_3d.py @@ -0,0 +1,649 @@ +# The kernels in this file are adapted from vLLM: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py +import triton +import torch +from aiter.ops.triton.utils.device_info import get_num_sms +import math +from aiter.ops.triton.gluon.unified_attention_3d_kernel import ( + gluon_kernel_unified_attention_3d, + gluon_kernel_unified_attention_3d_async, + gluon_reduce_segments, +) +from aiter.ops.triton.gluon.unified_attention_3d_kernel_tdm import ( + gluon_kernel_unified_attention_3d_tdm, +) +from triton.experimental import gluon +import triton.experimental.gluon.language as gl +import aiter.ops.triton.utils._triton.arch_info as arch_info + +DEVICE_ARCH = arch_info.get_arch() +IS_DEVICE_ARCH_GFX12 = DEVICE_ARCH in ("gfx1250",) +WARP_SIZE = 32 if IS_DEVICE_ARCH_GFX12 else 64 +WAPR_SIZE_LOG2 = int(math.log2(WARP_SIZE)) +from aiter.ops.triton.utils.types import e4m3_dtype + + +def make_kv_cache_shuffled_layout( + BLOCK_SIZE_N_SHFL, + BLOCK_SIZE_INNER_DIM_SHFL, + fastest_dim_num_warps, + dtype=torch.bfloat16, +): + num_warps_log2 = int(math.log2(fastest_dim_num_warps)) + BLOCK_SIZE_N_SHFL_log2 = int(math.log2(BLOCK_SIZE_N_SHFL)) + BLOCK_SIZE_INNER_DIM_SHFL_log2 = int(math.log2(BLOCK_SIZE_INNER_DIM_SHFL)) + # TODO: support e4m3_dtype and mxfp4x2 + # assert dtype in [torch.bfloat16, e4m3_dtype, torch.uint8], f"Unsupported dtype: {dtype} for making linear layout for shuffled weights" + assert dtype in [ + torch.bfloat16 + ], f"Unsupported dtype: {dtype} for making linear layout for shuffled weights" + if dtype == torch.bfloat16: + # (8 elements per thread for BF16) + coalesced_size_log2 = 3 + elif dtype == e4m3_dtype: + # (16 elements per thread for e4m3_dtype) + coalesced_size_log2 = 4 + else: + # (16*2 elements per thread for mxfp4x2) + coalesced_size_log2 = 4 + assert ( + BLOCK_SIZE_INNER_DIM_SHFL_log2 > coalesced_size_log2 + WAPR_SIZE_LOG2 + ), "BLOCK_SIZE_INNER_DIM_SHFL_log2 must be greater than coalesced_size_log2 + WAPR_SIZE_LOG2, please increase block_size to at least 64" + reg_bases = ( + [[0, 1 << v] for v in range(coalesced_size_log2)] + + [ + [0, 1 << v] + for v in range( + coalesced_size_log2 + WAPR_SIZE_LOG2, BLOCK_SIZE_INNER_DIM_SHFL_log2 + ) + ] + + [[1 << v, 0] for v in range(num_warps_log2, BLOCK_SIZE_N_SHFL_log2)] + ) + lane_bases = [ + [0, 1 << v] + for v in range(coalesced_size_log2, coalesced_size_log2 + WAPR_SIZE_LOG2) + ] + if num_warps_log2 > 0: + warp_bases = [[1 << v, 0] for v in range(0, num_warps_log2)] + else: + warp_bases = [[0, 0]] + + layout = gl.constexpr( + gl.DistributedLinearLayout( + reg_bases=reg_bases, + lane_bases=lane_bases, + warp_bases=warp_bases, + block_bases=[], + shape=[BLOCK_SIZE_N_SHFL, BLOCK_SIZE_INNER_DIM_SHFL], + ) + ) + return layout + + +def make_layout_3d( + num_warps: int, + BLOCK_M: int, + TILE_SIZE: int, + BLOCK_SIZE: int, + NUM_BLOCKS_GATHER_PER_TILE: int, + HEAD_SIZE_PADDED: int, + shuffled_kv_cache: bool, + kv_cache_dtype: torch.dtype, + use_tdm: bool, + use_async: bool, + use_swizzle: bool = False, + use_gather: bool = False, +): + + if IS_DEVICE_ARCH_GFX12: + QK_WMMA_LAYOUT: gl.constexpr = gl.amd.AMDWMMALayout( + version=3, + transposed=True, + warp_bases=[(1 << i, 0) for i in range(int(math.log2(num_warps)))], + reg_bases=[], + instr_shape=[16, 16, 32], + ) + + PV_WMMA_LAYOUT: gl.constexpr = gl.amd.AMDWMMALayout( + version=3, + transposed=True, + warp_bases=[(0, 1 << i) for i in range(int(math.log2(num_warps)))], + reg_bases=[], + instr_shape=[16, 16, 32], + ) + Q_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=0, parent=QK_WMMA_LAYOUT, k_width=8 + ) + K_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=1, parent=QK_WMMA_LAYOUT, k_width=8 + ) + P_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=0, parent=PV_WMMA_LAYOUT, k_width=8 + ) + V_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=1, parent=PV_WMMA_LAYOUT, k_width=8 + ) + elif shuffled_kv_cache: + QK_WMMA_LAYOUT: gl.constexpr = gl.amd.AMDMFMALayout( + version=4, + instr_shape=[16, 16, 32], + transposed=True, + warps_per_cta=[num_warps, 1] if use_async else [1, num_warps], + # warps_per_cta=[1, num_warps], + # warps_per_cta=[num_warps, 1], + ) + PV_WMMA_LAYOUT: gl.constexpr = gl.amd.AMDMFMALayout( + version=4, + instr_shape=[16, 16, 16 if TILE_SIZE <= 16 else 32], + transposed=True, + warps_per_cta=[1, num_warps], + ) + Q_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=0, parent=QK_WMMA_LAYOUT, k_width=8 + ) + K_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=1, parent=QK_WMMA_LAYOUT, k_width=8 + ) + P_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=0, parent=PV_WMMA_LAYOUT, k_width=4 if TILE_SIZE <= 16 else 8 + ) + V_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=1, parent=PV_WMMA_LAYOUT, k_width=4 if TILE_SIZE <= 16 else 8 + ) + else: + QK_WMMA_LAYOUT: gl.constexpr = gl.amd.AMDMFMALayout( + version=4, + instr_shape=[16, 16, 32], + transposed=True, + warps_per_cta=[num_warps, 1], + ) + PV_WMMA_LAYOUT: gl.constexpr = gl.amd.AMDMFMALayout( + version=4, + instr_shape=[16, 16, 16 if TILE_SIZE <= 16 else 32], + transposed=True, + warps_per_cta=[1, num_warps], + ) + Q_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=0, parent=QK_WMMA_LAYOUT, k_width=8 + ) + K_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=1, parent=QK_WMMA_LAYOUT, k_width=8 + ) + P_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=0, parent=PV_WMMA_LAYOUT, k_width=4 + ) + V_DOT_LAYOUT: gl.constexpr = gl.DotOperandLayout( + operand_index=1, parent=PV_WMMA_LAYOUT, k_width=4 + ) + + if use_tdm or not use_swizzle: + Q_SHARED_LAYOUT: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[HEAD_SIZE_PADDED, 8]], + shape=[BLOCK_M, HEAD_SIZE_PADDED], + order=[1, 0], + ) + if use_gather: + K_SHARED_LAYOUT: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[BLOCK_SIZE * HEAD_SIZE_PADDED, 8]], + shape=( + [NUM_BLOCKS_GATHER_PER_TILE, BLOCK_SIZE * HEAD_SIZE_PADDED] + if use_tdm + else [HEAD_SIZE_PADDED, TILE_SIZE] + ), + order=[1, 0], + ) + V_SHARED_LAYOUT: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[BLOCK_SIZE * HEAD_SIZE_PADDED, 8]], + shape=[NUM_BLOCKS_GATHER_PER_TILE, BLOCK_SIZE * HEAD_SIZE_PADDED], + order=[1, 0], + ) + else: + K_SHARED_LAYOUT: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[HEAD_SIZE_PADDED, 8]], + shape=( + [TILE_SIZE, HEAD_SIZE_PADDED] + if use_tdm + else [HEAD_SIZE_PADDED, TILE_SIZE] + ), + order=[1, 0], + ) + V_SHARED_LAYOUT: gl.constexpr = gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[HEAD_SIZE_PADDED, 8]], + shape=[TILE_SIZE, HEAD_SIZE_PADDED], + order=[1, 0], + ) + else: + Q_SHARED_LAYOUT: gl.constexpr = gl.SwizzledSharedLayout( + vec=8, per_phase=2, max_phase=8, order=[1, 0] + ) + if shuffled_kv_cache: + K_SHARED_LAYOUT: gl.constexpr = gl.SwizzledSharedLayout( + vec=1, per_phase=1, max_phase=1, order=[1, 0] + ) + V_SHARED_LAYOUT: gl.constexpr = gl.SwizzledSharedLayout( + vec=1, per_phase=1, max_phase=1, order=[1, 0] + ) + else: + K_SHARED_LAYOUT: gl.constexpr = gl.SwizzledSharedLayout( + vec=8, per_phase=2, max_phase=8, order=[0, 1] + ) + V_SHARED_LAYOUT: gl.constexpr = gl.SwizzledSharedLayout( + vec=1, per_phase=1, max_phase=1, order=[1, 0] + ) + + # size_per_thread along the fastest moving dimension is set to 8 (BF16) + size_per_thread_fastest_dim = 8 + + # size_per_thread * threads_per_warp along the fastest moving dimension is set to HEAD_SIZE_PADDED with only 1 warp_per_cta, + # therefore, threads_per_warp along the fastest moving dimension should be HEAD_SIZE_PADDED // size_per_thread_fastest_dim + # clamp the threads_per_warp along the fastest moving dimension to 1 ~ WARP_SIZE + threads_per_warp_fastest_dim = max( + min((HEAD_SIZE_PADDED // size_per_thread_fastest_dim), WARP_SIZE), 1 + ) + + # in gfx950, ttg.async_copy_global_to_local will fail if threads_per_warp=[WARP_SIZE//4, 4] is used + Q_LOAD_LAYOUT: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, size_per_thread_fastest_dim], + threads_per_warp=[ + WARP_SIZE // threads_per_warp_fastest_dim, + threads_per_warp_fastest_dim, + ], + warps_per_cta=[num_warps, 1], + order=[1, 0], + ) + if shuffled_kv_cache: + K_LOAD_LAYOUT = make_kv_cache_shuffled_layout( + TILE_SIZE // 16, + HEAD_SIZE_PADDED * 16, + 1 if (use_async or IS_DEVICE_ARCH_GFX12) else num_warps, + ) + V_LOAD_LAYOUT = make_kv_cache_shuffled_layout( + HEAD_SIZE_PADDED // 16, TILE_SIZE * 16, num_warps + ) + else: + K_LOAD_LAYOUT: gl.constexpr = gl.BlockedLayout( + size_per_thread=[size_per_thread_fastest_dim, 1], + threads_per_warp=[ + threads_per_warp_fastest_dim, + WARP_SIZE // threads_per_warp_fastest_dim, + ], + warps_per_cta=[1, num_warps], + order=[0, 1], + ) + V_LOAD_LAYOUT: gl.constexpr = gl.BlockedLayout( + size_per_thread=[1, size_per_thread_fastest_dim], + threads_per_warp=[ + WARP_SIZE // threads_per_warp_fastest_dim, + threads_per_warp_fastest_dim, + ], + warps_per_cta=[num_warps, 1], + order=[1, 0], + ) + + # TODO: for future impl + # ctas_per_cga = [1, 1] + # cga_layout_Q = make_cga_layout( + # ctasPerCga=ctas_per_cga, + # ctaSplitNum=[ctas_per_cga[0], 1], + # ctaOrder=[0, 1] + # ) + # cga_layout_K = make_cga_layout( + # ctasPerCga=ctas_per_cga, + # ctaSplitNum=[1, ctas_per_cga[1]], + # ctaOrder=[0, 1] + # ) + # cga_layout_S = make_cga_layout( + # ctasPerCga=ctas_per_cga, + # ctaSplitNum=[ctas_per_cga[0], ctas_per_cga[1]], + # ctaOrder=[0, 1] + # ) + + return { + "QK_WMMA_LAYOUT": QK_WMMA_LAYOUT, + "PV_WMMA_LAYOUT": PV_WMMA_LAYOUT, + "Q_DOT_LAYOUT": Q_DOT_LAYOUT, + "K_DOT_LAYOUT": K_DOT_LAYOUT, + "P_DOT_LAYOUT": P_DOT_LAYOUT, + "V_DOT_LAYOUT": V_DOT_LAYOUT, + "Q_SHARED_LAYOUT": Q_SHARED_LAYOUT, + "K_SHARED_LAYOUT": K_SHARED_LAYOUT, + "V_SHARED_LAYOUT": V_SHARED_LAYOUT, + "Q_LOAD_LAYOUT": Q_LOAD_LAYOUT, + "K_LOAD_LAYOUT": K_LOAD_LAYOUT, + "V_LOAD_LAYOUT": V_LOAD_LAYOUT, + } + + +def select_3d_config( + head_size, + block_size, + element_size, + max_seqlen_k, + target_num_prgms, + num_2d_prgms, + BLOCK_M: int, + HEAD_SIZE_PADDED: int, + kv_cache_dtype: torch.dtype, + use_tdm: bool = False, + num_tdm_gather: int = 1, + use_async: bool = True, + use_swizzle: bool = True, + shuffled_kv_cache: bool = False, +): + """ + if use_tdm is True, use_async and use_swizzle will be ignored + if use_tdm is False, num_tdm_gather will be ignored + if use_async is True, use_swizzle will be forced to True + if use_tdm and use_async are False, num_stages will be ignored, use_swizzle determines whether to use PaddedSharedLayout or SwizzledSharedLayout + """ + reduce_num_warps = 2 + attn_warps = 2 + + if shuffled_kv_cache: + assert ( + block_size >= 64 + ), "Only block_size >= 64 is supported for shuffled KV cache" + + if use_tdm and num_tdm_gather > 1: + assert num_tdm_gather in [4, 8], "num_tdm_gather must be 4 or 8" + + TILE_SIZE = block_size * num_tdm_gather + + MAX_SEGMENTS = min(128, math.ceil(max_seqlen_k / TILE_SIZE)) + num_segments = math.ceil(target_num_prgms / num_2d_prgms) + num_segments = min(num_segments, MAX_SEGMENTS) + num_segments = triton.next_power_of_2(num_segments) + num_segments = min(num_segments, 128) + MIN_SEGMENTS = 16 if TILE_SIZE <= 16 else 8 + num_segments = max(num_segments, MIN_SEGMENTS) + + # TODO: needs a better way to determine num_segments for TDM gather pipelined + if use_tdm and num_tdm_gather > 1: + num_segments = 4 + + if num_segments == MIN_SEGMENTS: + reduce_num_warps = 1 + + config_parms = ( + attn_warps, + BLOCK_M, + TILE_SIZE, + block_size, + num_tdm_gather, + HEAD_SIZE_PADDED, + shuffled_kv_cache, + kv_cache_dtype, + use_tdm, + use_async, + ) + + # num_tiles_per_seq = (max_seqlen_k // num_segments + TILE_SIZE - 1) // TILE_SIZE + + attn_stages = 1 + if use_tdm: + # With TDM async_copy pipelined, use_swizzle will be ignored (padded smem layout is used always) + attn_impl = gluon_kernel_unified_attention_3d_tdm + layout_configs = {"NUM_BLOCKS_GATHER_PER_TILE": num_tdm_gather} + attn_stages = 2 + else: + if use_async: + # With async_copy pipelined, use_swizzle should always be True + attn_impl = gluon_kernel_unified_attention_3d_async + layout_configs = make_layout_3d( + *config_parms, + use_swizzle=True, + ) + # gfx12 does not have async_copy.buffer_load_to_shared + # TODO: check KV cache size to determine if use_buffer_load is needed in gfx950 + layout_configs["USE_LOAD_BUFFER_OP"] = not IS_DEVICE_ARCH_GFX12 + attn_stages = 2 + else: + # Baseline kernel, num_stages does not matter, use_swizzle can be either True or False + attn_impl = gluon_kernel_unified_attention_3d + layout_configs = make_layout_3d( + *config_parms, + use_swizzle=use_swizzle, + ) + layout_configs["TILE_SIZE"] = TILE_SIZE + + attn_config = { + "NUM_SEGMENTS_PER_SEQ": num_segments, + "WARP_SIZE": WARP_SIZE, + "num_warps": attn_warps, + "num_stages": attn_stages, + "waves_per_eu": 2, + **layout_configs, + } + + reduce_config = { + "TILE_SIZE": TILE_SIZE, + "NUM_SEGMENTS_PER_SEQ": num_segments, + "num_warps": reduce_num_warps, + "num_stages": 1, + "waves_per_eu": 2, + } + + return attn_config, reduce_config, attn_impl + + +def use_2d_kernel( + head_size, + sliding_window, + all_decode, + max_seqlen_q, + max_seqlen_k, + target_num_prgms, + num_2d_prgms, +): + return ( + (sliding_window > 0) + or (max_seqlen_k <= 512) + or (num_2d_prgms > target_num_prgms) + ) + + +def unified_attention( + q, + k, + v, + out, + cu_seqlens_q, + max_seqlen_q, + seqused_k, + max_seqlen_k, + softmax_scale, + causal, + window_size, + block_table, + softcap, + q_descale, + k_descale, + v_descale, + alibi_slopes=None, + output_scale=None, + qq_bias=None, + # Optional tensor for sinks + sinks=None, + use_tdm: bool = False, + num_tdm_gather: int = 1, + use_async: bool = True, + shuffled_kv_cache: bool = False, +): + assert causal, "Only causal attention is supported" + assert q_descale is None, "Q scales not supported" + + if sinks is not None: + assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size" + + use_alibi_slopes = alibi_slopes is not None + use_qq_bias = qq_bias is not None + SLIDING_WINDOW = 1 + window_size[0] + + num_tokens, num_query_heads, head_size = q.shape + kv_cache_dtype = k.dtype + if shuffled_kv_cache: + # key_cache: num_blocks, num_kv_heads, block_size // 16, head_size * 16 + # value_cache: num_blocks, num_kv_heads, head_size // 16, block_size * 16 + num_blocks, num_kv_heads, block_size, _ = k.shape + block_size = block_size * 16 + else: + if use_tdm and num_tdm_gather > 1: + # key_cache and value_cache: num_blocks, num_kv_heads, block_size, head_size + num_blocks, num_kv_heads, block_size, _ = k.shape + else: + # key_cache and value_cache: num_blocks, block_size, num_kv_heads, head_size + num_blocks, block_size, num_kv_heads, _ = k.shape + + num_seqs = len(seqused_k) + num_queries_per_kv = num_query_heads // num_kv_heads + + BLOCK_M = ( + 16 if num_queries_per_kv <= 16 else triton.next_power_of_2(num_queries_per_kv) + ) + BLOCK_Q = BLOCK_M // num_queries_per_kv + assert BLOCK_Q >= 1 + # Ideally we would launch with kernel with: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] blocks. + # However, it is slow to realize the query_lens on cpu. + # Instead we use upper-bound: + # \sum_i[ceil(query_len[i] / BLOCK_Q)] + # <= \sum_i[floor(query_len[i] / BLOCK_Q) + 1] + # = \sum_i[floor(query_len[i] / BLOCK_Q)] + num_seqs + # <= floor(\sum_i(query_len[i]) / BLOCK_Q) + num_seqs + # = floor(num_tokens / BLOCK_Q) + num_seqs + cu_count = get_num_sms() + total_num_q_blocks = num_tokens // BLOCK_Q + num_seqs + target_num_prgms = cu_count * 4 + num_2d_prgms = total_num_q_blocks * num_kv_heads + ALL_DECODE = max_seqlen_q == 1 + # if batch contains a prefill + if use_2d_kernel( + head_size, + SLIDING_WINDOW, + ALL_DECODE, + max_seqlen_q, + max_seqlen_k, + target_num_prgms, + num_2d_prgms, + ): + raise NotImplementedError("2D Gluon Unified Attention is not yet implemented.") + else: + head_size_padded = triton.next_power_of_2(head_size) + assert head_size_padded == head_size, "head_size must be a power of 2" + + if not IS_DEVICE_ARCH_GFX12: + assert use_tdm == False, "TDM is not supported on non-GFX12 devices" + + use_swizzle = None + if use_tdm == True: # TDM + use_async = None + elif use_async == True: # ASYNC + pass + else: # Baseline (use_swizzle can be either True or False, fix to True for now) + use_swizzle = True + + attn_config, reduce_config, attn_impl = select_3d_config( + head_size, + block_size, + q.element_size(), + max_seqlen_k, + target_num_prgms, + num_2d_prgms, + BLOCK_M, + head_size_padded, + kv_cache_dtype, + use_tdm, + num_tdm_gather, + use_async, + use_swizzle, + shuffled_kv_cache, + ) + NUM_SEGMENTS = attn_config["NUM_SEGMENTS_PER_SEQ"] + segm_output = torch.empty( + num_tokens, + num_query_heads, + NUM_SEGMENTS, + triton.next_power_of_2(head_size), + dtype=torch.float32, + device=q.device, + ) + segm_max = torch.empty( + num_tokens, + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + segm_expsum = torch.empty( + num_tokens, + num_query_heads, + NUM_SEGMENTS, + dtype=torch.float32, + device=q.device, + ) + + attn_impl[(total_num_q_blocks, num_kv_heads, NUM_SEGMENTS)]( + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + query_ptr=q, + key_cache_ptr=k, + value_cache_ptr=v, + sink_ptr=sinks, + block_tables_ptr=block_table, + seq_lens_ptr=seqused_k, + alibi_slopes_ptr=alibi_slopes, + qq_bias_ptr=qq_bias, + k_scale=k_descale, + v_scale=v_descale, + softcap=softcap, + num_seqs=num_seqs, + num_blocks=num_blocks, + block_table_stride=block_table.stride(0), + query_stride_0=q.stride(0), + query_stride_1=q.stride(1), + qq_bias_stride_0=qq_bias.stride(0) if use_qq_bias else 0, + BLOCK_SIZE=block_size, + HEAD_SIZE=head_size, + USE_ALIBI_SLOPES=use_alibi_slopes, + USE_QQ_BIAS=use_qq_bias, + USE_SOFTCAP=(softcap > 0), + USE_SINKS=(sinks is not None), + SLIDING_WINDOW=SLIDING_WINDOW, + stride_k_cache_0=k.stride(0), + stride_k_cache_1=k.stride(1), + stride_k_cache_2=k.stride(2), + stride_k_cache_3=k.stride(3), + stride_v_cache_0=v.stride(0), + stride_v_cache_1=v.stride(1), + stride_v_cache_2=v.stride(2), + stride_v_cache_3=v.stride(3), + query_start_len_ptr=cu_seqlens_q, + SCALE=softmax_scale, + NUM_QUERY_HEADS=num_query_heads, + NUM_KV_HEADS=num_kv_heads, + BLOCK_Q=BLOCK_Q, + BLOCK_M=BLOCK_M, + ALL_DECODE=ALL_DECODE, + SHUFFLED_KV_CACHE=shuffled_kv_cache, + **attn_config, + ) + + gluon_reduce_segments[(q.shape[0], num_query_heads)]( + output_ptr=out, + segm_output_ptr=segm_output, + segm_max_ptr=segm_max, + segm_expsum_ptr=segm_expsum, + seq_lens_ptr=seqused_k, + num_seqs=num_seqs, + num_query_heads=num_query_heads, + out_scale_inv=1 / output_scale if output_scale is not None else 1.0, + output_stride_0=out.stride(0), + output_stride_1=out.stride(1), + block_table_stride=block_table.stride(0), + HEAD_SIZE=head_size, + query_start_len_ptr=cu_seqlens_q, + BLOCK_Q=BLOCK_Q, + USE_FP8=output_scale is not None, + **reduce_config, + ) diff --git a/aiter/ops/triton/gluon/unified_attention_3d_kernel.py b/aiter/ops/triton/gluon/unified_attention_3d_kernel.py new file mode 100644 index 0000000000..834cc9209b --- /dev/null +++ b/aiter/ops/triton/gluon/unified_attention_3d_kernel.py @@ -0,0 +1,2885 @@ +# The kernels in this file are adapted from vLLM: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py +from re import T +from shlex import join +import triton +import triton.language as tl +import torch +from aiter.ops.triton.utils.types import e4m3_dtype +from triton.experimental import gluon +import triton.experimental.gluon.language as gl +import aiter.ops.triton.utils._triton.arch_info as arch_info +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr + +# from triton._C.libtriton.gluon_ir import make_cga_layout + +DEVICE_ARCH = arch_info.get_arch() +MMA_operation: gl.constexpr = ( + gl.amd.gfx1250.wmma + if gl.constexpr(DEVICE_ARCH in ("gfx1250",)) + else gl.amd.cdna4.mfma +) + +float8_info = torch.finfo(e4m3_dtype) + + +@gluon.jit +def fast_exp(x): + RCP_LN2: tl.constexpr = 1.4426950408889634 + return tl.math.exp2(x * RCP_LN2) + + +@gluon.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@gluon.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.math.exp2(Sdiv) + p2 = tl.math.exp2(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@gluon.jit +def _find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: gl.constexpr, + use_q_block_mode: gl.constexpr, +): + left: gl.int32 = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = gl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + + return left - 1 + + +@gluon.jit +def _get_q_metadata( + query_start_len_ptr, + seq_idx, + q_block_global_idx, + BLOCK_Q: gl.constexpr, +): + q_block_start_idx = gl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = gl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = gl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + return q_block_local_idx, cur_batch_query_len, cur_batch_in_all_start_index + + +@gluon.jit +def _get_seq_metadata( + seq_lens_ptr, + seq_idx, + TILE_SIZE: gl.constexpr, + NUM_SEGMENTS_PER_SEQ: gl.constexpr, +): + # sequence len for this particular sequence + seq_len = gl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + return seq_len, tiles_per_segment + + +@gluon.jit +def _allocate_L_M_acc( + sink_ptr, + segm_idx, + query_offset_1, + query_mask_1, + RCP_LN2, + BLOCK_M: gl.constexpr, + HEAD_SIZE: gl.constexpr, + QK_WMMA_LAYOUT: gl.constexpr, + PV_WMMA_LAYOUT: gl.constexpr, + USE_SINKS: gl.constexpr, +): + + # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + if USE_SINKS: + if segm_idx == 0: + # Prescale with RCP_LN2, needed for exp2 + M = ( + gl.amd.cdna4.buffer_load( + ptr=sink_ptr, + offsets=query_offset_1.to(gl.int32), + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=gl.float32) + * RCP_LN2 + ) + else: + M = gl.full( + [BLOCK_M], + float("-inf"), + dtype=tl.float32, + layout=gl.SliceLayout(1, QK_WMMA_LAYOUT), + ) + else: + M = gl.full( + [BLOCK_M], + float("-inf"), + dtype=tl.float32, + layout=gl.SliceLayout(1, QK_WMMA_LAYOUT), + ) + + # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + L = gl.full( + [BLOCK_M], 1.0, dtype=tl.float32, layout=gl.SliceLayout(1, QK_WMMA_LAYOUT) + ) + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + acc = gl.zeros([BLOCK_M, HEAD_SIZE], dtype=tl.float32, layout=PV_WMMA_LAYOUT) + + return L, M, acc + + +@gluon.jit +def _perform_QK_wmma_and_update_L_M( + Q, + K, + L, + M, + acc, + qq_bias_row_ptrs, + seq_offset, + query_mask, + query_pos, + context_len, + alibi_slope, + qq_bias_stride_0, + qk_scale, + softcap, + RCP_LN2, + BLOCK_M: gl.constexpr, + TILE_SIZE: gl.constexpr, + USE_SOFTCAP: gl.constexpr, + SLIDING_WINDOW: gl.constexpr, + USE_ALIBI_SLOPES: gl.constexpr, + USE_QQ_BIAS: gl.constexpr, + Q_LOAD_LAYOUT: gl.constexpr, + QK_WMMA_LAYOUT: gl.constexpr, + PV_WMMA_LAYOUT: gl.constexpr, +): + # S : shape = (BLOCK_M, TILE_SIZE), layout = QK_WMMA_LAYOUT + S = gl.zeros([BLOCK_M, TILE_SIZE], dtype=tl.float32, layout=QK_WMMA_LAYOUT) + # qk_scale = scale * RCP_LN2 (log_2 e) so that we can use exp2 later + S = qk_scale * MMA_operation(Q, K, S) + # S : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT + # S = gl.convert_layout(S, layout=Q_LOAD_LAYOUT) + + if USE_SOFTCAP: + # softcap here uses exp2 and consumes RCP_LN2 conversion. + # multiply by RCP_LN2 again to be used in later exp2 + S = apply_softcap(S, softcap) * RCP_LN2 + + seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1 + + S = gl.where(query_mask & seq_mask, S, float("-inf")) + + if SLIDING_WINDOW > 0: + S = gl.where( + (context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW, + S, + float("-inf"), + ) + + if USE_ALIBI_SLOPES: + # prescale w. RCP_LN2 for later exp2 + S += alibi_slope[:, None] * (seq_offset - context_len) * RCP_LN2 + + if USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < qq_bias_stride_0 + qq_bias = gl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + # prescale w. RCP_LN2 for later exp2 + S += qq_bias * RCP_LN2 + + # compute running maximum + # m_j : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + m_j = gl.maximum(M, gl.max(S, axis=1)) + + # For sliding window there's a chance the max is -inf due to masking of + # the entire row. In this case we need to set m_j 0 to avoid NaN + m_j = gl.where(m_j > float("-inf"), m_j, 0.0) + + # P : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT + P = gl.exp2(S - m_j[:, None]) + + # l_j : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + l_j = gl.sum(P, axis=1) + + # alpha : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + alpha = gl.exp2(M - m_j) + + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + acc = acc * gl.convert_layout(alpha[:, None], layout=PV_WMMA_LAYOUT) + + # update constants + # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + L = L * alpha + l_j + M = m_j + + return P, L, M, acc + + +@gluon.jit +def _perform_PV_wmma( + P, + V, + acc, + P_DOT_LAYOUT: gl.constexpr, +): + P = P.to(V.dtype) + P = gl.convert_layout(P, layout=P_DOT_LAYOUT) + # P : shape = (BLOCK_M, TILE_SIZE), layout = P_DOT_LAYOUT + # V : shape = (TILE_SIZE, HEAD_SIZE), layout = V_DOT_LAYOUT + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + acc = MMA_operation(P, V, acc) + return acc + + +# @gluon.jit +# def _tdm_async_gather_load_to_lds( +# j, +# desc, +# src_row_indices, +# src_col_offset, +# dst, +# num_stages: gl.constexpr, +# ): +# gl.amd.gfx1250.tdm.async_gather( +# desc=desc, +# src_row_indices=src_row_indices, +# src_col_offset=0, +# dst=dst.index(j % num_stages), +# ) + +# return j + 1 + + +# @gluon.jit +# def _tdm_gather_request_from_lds( +# j, +# kv_scale, +# Q_dtype, +# smem, +# asycn_wait: gl.constexpr, +# layout: gl.constexpr, +# transpose: gl.constexpr, +# num_ctas: gl.constexpr, +# num_stages: gl.constexpr, +# TILE_SIZE: gl.constexpr, +# HEAD_SIZE: gl.constexpr, +# ): +# if num_ctas > 1: +# gl.amd.gfx1250.cluster.arrive() +# gl.amd.gfx1250.tdm.async_wait(asycn_wait) +# if num_ctas > 1: +# gl.amd.gfx1250.cluster.wait() +# if transpose: +# X = ( +# smem.index(j % num_stages) +# .reshape([TILE_SIZE, HEAD_SIZE]) +# .permute([1, 0]) +# .load(layout=layout) +# ) +# else: +# X = ( +# smem.index(j % num_stages) +# .reshape([TILE_SIZE, HEAD_SIZE]) +# .load(layout=layout) +# ) + +# if X.dtype.is_fp8() and not Q_dtype.is_fp8(): +# X = (X.to(gl.float32) * gl.load(kv_scale)).to(Q_dtype) + +# return j + 1, X + + +# @gluon.jit +# def _tdm_gather_get_kv_offsets( +# j, +# offs_j, +# kv_head_idx, +# block_tables_sorted_ptr, +# block_table_offset, +# stride_k_cache_h: gl.int64, +# stride_v_cache_h: gl.int64, +# NUM_BLOCKS_GATHER_PER_TILE: gl.constexpr, +# ): +# physical_block_idx = gl.load( +# block_tables_sorted_ptr +# + block_table_offset +# + j * NUM_BLOCKS_GATHER_PER_TILE +# + offs_j +# ) + +# offs_k_gather_idx = (physical_block_idx * stride_k_cache_h + kv_head_idx).to( +# tl.int32 +# ) +# offs_v_gather_idx = (physical_block_idx * stride_v_cache_h + kv_head_idx).to( +# tl.int32 +# ) + +# return j + 1, offs_k_gather_idx, offs_v_gather_idx + + +# @gluon.jit +# def _tdm_gather_create_tensor_descriptors_and_allocate_lds( +# q_ptr, +# k_ptr, +# v_ptr, +# NUM_BLOCKS, +# stride_q_m: gl.int64, # int +# stride_q_d: gl.constexpr, # int +# stride_k_t: gl.int64, # int +# stride_k_d: gl.constexpr, # int +# stride_v_t: gl.int64, # int +# stride_v_d: gl.constexpr, # int +# q_shared_layout: gl.constexpr, +# k_shared_layout: gl.constexpr, +# v_shared_layout: gl.constexpr, +# NUM_KV_HEADS: gl.constexpr, +# BLOCK_M: gl.constexpr, +# HEAD_SIZE: gl.constexpr, +# BLOCK_SIZE: gl.constexpr, +# TILE_SIZE: gl.constexpr, +# HEAD_SIZE: gl.constexpr, +# NUM_BLOCKS_GATHER_PER_TILE: gl.constexpr, +# num_stages: gl.constexpr, +# ): +# gl.static_assert(stride_k_d == 1, "stride_k_d must be 1") +# gl.static_assert(stride_v_d == 1, "stride_v_d must be 1") +# # gl.static_assert(stride_k_t == BLOCK_SIZE * HEAD_SIZE, "stride_k_t must be BLOCK_SIZE * HEAD_SIZE") +# # gl.static_assert(stride_v_t == BLOCK_SIZE * HEAD_SIZE, "stride_v_t must be BLOCK_SIZE * HEAD_SIZE") + +# k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( +# base=k_ptr, +# shape=(NUM_BLOCKS * NUM_KV_HEADS, BLOCK_SIZE * HEAD_SIZE), +# strides=(stride_k_t, stride_k_d), +# block_shape=(NUM_BLOCKS_GATHER_PER_TILE, BLOCK_SIZE * HEAD_SIZE), +# layout=k_shared_layout, +# ) + +# v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( +# base=v_ptr, +# shape=(NUM_BLOCKS * NUM_KV_HEADS, BLOCK_SIZE * HEAD_SIZE), +# strides=(stride_v_t, stride_v_d), +# block_shape=(NUM_BLOCKS_GATHER_PER_TILE, BLOCK_SIZE * HEAD_SIZE), +# layout=v_shared_layout, +# ) + +# smem_Q = gl.allocate_shared_memory( +# q_ptr.type.element_ty, +# shape=[BLOCK_M, HEAD_SIZE], +# layout=q_shared_layout, +# ) +# smem_K = gl.allocate_shared_memory( +# k_desc.dtype, +# shape=[num_stages] + k_desc.block_shape, +# layout=k_desc.layout, +# ) +# smem_V = gl.allocate_shared_memory( +# v_desc.dtype, +# shape=[num_stages] + v_desc.block_shape, +# layout=v_desc.layout, +# ) + +# return k_desc, v_desc, smem_Q, smem_K, smem_V + + +# gluon_kernel_unified_attention_3d_tdm_gather_repr = make_kernel_repr( +# "gluon_kernel_unified_attention_3d_tdm_gather", +# [ +# "num_query_heads", +# "num_queries_per_kv", +# "BLOCK_SIZE", +# "TILE_SIZE", +# "HEAD_SIZE", +# "num_warps", +# "num_stages", +# "cache_modifier", +# ], +# ) + + +# @gluon.jit(repr=gluon_kernel_unified_attention_3d_tdm_gather_repr) +# def gluon_kernel_unified_attention_3d_tdm_gather( +# segm_output_ptr, +# # [num_tokens, num_query_heads, num_segments, head_size] +# segm_max_ptr, # [num_tokens, num_query_heads, num_segments] +# segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] +# query_ptr, # [num_tokens, num_query_heads, head_size] +# key_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] +# value_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] +# sink_ptr, # [num_query_heads] +# block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] +# # block_table_sorted_indices_ptr, # [num_seqs, max_num_blocks_per_seq] +# seq_lens_ptr, # [num_seqs] +# alibi_slopes_ptr, # [num_query_heads] +# qq_bias_ptr, # [num_query_tokens, num_query_tokens] +# scale, # float32 +# k_scale, # float32 +# v_scale, # float32 +# softcap, # float32 +# num_tokens, # int +# NUM_BLOCKS, # int +# num_query_heads: gl.constexpr, # int +# num_queries_per_kv: gl.constexpr, # int +# block_table_stride: gl.int64, # int +# # block_table_sorted_indices_stride: gl.int64, # int +# query_stride_0: gl.int64, # int +# query_stride_1: gl.int64, # int, should be equal to head_size +# qq_bias_stride_0: gl.int64, # int +# BLOCK_SIZE: gl.constexpr, # int +# TILE_SIZE: gl.constexpr, # int, must be power of 2 +# HEAD_SIZE: gl.constexpr, # int +# HEAD_SIZE: gl.constexpr, # int, must be power of 2 +# USE_ALIBI_SLOPES: gl.constexpr, # bool +# USE_QQ_BIAS: gl.constexpr, # bool +# USE_SOFTCAP: gl.constexpr, # bool +# USE_SINKS: gl.constexpr, # bool +# SLIDING_WINDOW: gl.constexpr, # int +# stride_k_cache_0: gl.int64, # int +# stride_k_cache_1: gl.int64, # int +# stride_k_cache_2: gl.int64, # int +# stride_k_cache_3: gl.constexpr, # int +# stride_v_cache_0: gl.int64, # int +# stride_v_cache_1: gl.int64, # int +# stride_v_cache_2: gl.int64, # int +# stride_v_cache_3: gl.constexpr, # int +# query_start_len_ptr, # [num_seqs+1] +# BLOCK_Q: gl.constexpr, # int +# num_seqs: gl.int32, +# BLOCK_M: gl.constexpr, # int +# NUM_SEGMENTS_PER_SEQ: gl.constexpr, # int +# num_warps: gl.constexpr, # int +# num_stages: gl.constexpr, # int +# QK_WMMA_LAYOUT: gl.constexpr, +# PV_WMMA_LAYOUT: gl.constexpr, +# Q_DOT_LAYOUT: gl.constexpr, +# K_DOT_LAYOUT: gl.constexpr, +# P_DOT_LAYOUT: gl.constexpr, +# V_DOT_LAYOUT: gl.constexpr, +# Q_SHARED_LAYOUT: gl.constexpr, +# K_SHARED_LAYOUT: gl.constexpr, +# V_SHARED_LAYOUT: gl.constexpr, +# Q_LOAD_LAYOUT: gl.constexpr, +# K_LOAD_LAYOUT: gl.constexpr, +# V_LOAD_LAYOUT: gl.constexpr, +# num_ctas: gl.constexpr = 1, # int +# ALL_DECODE: gl.constexpr = False, # bool +# SHUFFLED_KV_CACHE: gl.constexpr = False, # bool +# ): +# q_block_global_idx = gl.program_id(0) +# kv_head_idx = gl.program_id(1) +# segm_idx = gl.program_id(2) +# # num_ctas: gl.constexpr = gl.num_ctas() +# pred = 1 +# pred_i32 = pred.to(gl.int32) if hasattr(pred, "to") else pred + +# gl.static_assert( +# TILE_SIZE % BLOCK_SIZE == 0, "TILE_SIZE must be multiple of BLOCK_SIZE" +# ) +# NUM_BLOCKS_GATHER_PER_TILE: gl.constexpr = TILE_SIZE // BLOCK_SIZE +# gl.static_assert( +# NUM_BLOCKS_GATHER_PER_TILE == 4 or NUM_BLOCKS_GATHER_PER_TILE == 8, +# "NUM_BLOCKS_GATHER_PER_TILE must be 4 or 8", +# ) + +# # needed to use exp2 (exp2 -> exp conversion) +# RCP_LN2 = 1.4426950408889634 +# qk_scale = scale * RCP_LN2 + +# seq_idx = _find_seq_idx( +# query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True +# ) + +# q_block_local_idx, cur_batch_query_len, cur_batch_in_all_start_index = ( +# _get_q_metadata( +# query_start_len_ptr, +# seq_idx, +# q_block_global_idx, +# BLOCK_Q, +# ) +# ) + +# if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: +# return + +# seq_len, tiles_per_segment = _get_seq_metadata( +# seq_lens_ptr, +# seq_idx, +# TILE_SIZE, +# NUM_SEGMENTS_PER_SEQ, +# ) + +# if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: +# return + +# # block table offset for this particular sequence +# block_table_offset = seq_idx * block_table_stride + +# # context length for this particular sequence +# context_len = seq_len - cur_batch_query_len + +# offs_q_m = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) +# offs_q_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, Q_LOAD_LAYOUT)) + +# query_pos = q_block_local_idx * BLOCK_Q + offs_q_m // num_queries_per_kv + +# query_offset_0 = cur_batch_in_all_start_index + query_pos +# query_offset_1 = kv_head_idx * num_queries_per_kv + offs_q_m % num_queries_per_kv +# query_offset = ( +# query_offset_0[:, None] * query_stride_0 +# + query_offset_1[:, None] * query_stride_1 +# + offs_q_d[None, :] +# ) + +# if HEAD_SIZE != HEAD_SIZE: +# dim_mask = offs_q_d < HEAD_SIZE +# else: +# dim_mask = gl.full((1,), 1, dtype=tl.int1) + +# query_mask_0 = query_pos < cur_batch_query_len +# query_mask_1 = query_offset_1 < num_query_heads + +# NUM_KV_HEADS: gl.constexpr = num_query_heads // num_queries_per_kv + +# k_desc, v_desc, smem_Q, smem_K, smem_V = ( +# _tdm_gather_create_tensor_descriptors_and_allocate_lds( +# query_ptr, +# key_cache_ptr, +# value_cache_ptr, +# NUM_BLOCKS, +# query_stride_1, +# 1, +# stride_k_cache_1, # stride_k_cache_1 = BLOCK_SIZE * HEAD_SIZE +# stride_k_cache_3, +# stride_v_cache_1, # stride_v_cache_1 = BLOCK_SIZE * HEAD_SIZE +# stride_v_cache_3, +# Q_SHARED_LAYOUT, +# K_SHARED_LAYOUT, +# V_SHARED_LAYOUT, +# NUM_KV_HEADS, +# BLOCK_M, +# HEAD_SIZE, +# BLOCK_SIZE, +# TILE_SIZE, +# HEAD_SIZE, +# NUM_BLOCKS_GATHER_PER_TILE, +# num_stages=num_stages, +# ) +# ) + +# # Q_load : shape = (BLOCK_M, HEAD_SIZE), layout = Q_LOAD_LAYOUT +# Q_load = gl.amd.cdna4.buffer_load( +# ptr=query_ptr, +# offsets=query_offset.to(gl.int32), +# mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], +# other=0.0, +# ) +# smem_Q.store(Q_load) +# # Q : shape = (BLOCK_M, HEAD_SIZE), layout = Q_DOT_LAYOUT +# Q = smem_Q.load(layout=Q_DOT_LAYOUT) + +# offs_q_m_qk = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, QK_WMMA_LAYOUT)) +# query_pos_qk = q_block_local_idx * BLOCK_Q + offs_q_m_qk // num_queries_per_kv +# query_offset_1_qk = ( +# kv_head_idx * num_queries_per_kv + offs_q_m_qk % num_queries_per_kv +# ) +# query_mask_0_qk = query_pos_qk < cur_batch_query_len +# query_mask_1_qk = query_offset_1_qk < num_query_heads +# query_mask_qk = query_mask_1_qk[:, None] & query_mask_0_qk[:, None] + +# L, M, acc = _allocate_L_M_acc( +# sink_ptr, +# segm_idx, +# query_offset_1_qk, +# query_mask_1_qk, +# RCP_LN2, +# BLOCK_M, +# HEAD_SIZE, +# QK_WMMA_LAYOUT, +# PV_WMMA_LAYOUT, +# USE_SINKS, +# ) + +# # alibi slope for this head +# alibi_slope = None +# if USE_ALIBI_SLOPES: +# alibi_slope = tl.load( +# alibi_slopes_ptr + query_offset_1_qk, mask=query_mask_1, other=0.0 +# ) + +# # query-query attention bias +# qq_bias_row_ptrs = None +# if USE_QQ_BIAS: +# qq_bias_row_ptrs = ( +# qq_bias_ptr + query_pos_qk[:, None] * qq_bias_stride_0 +# ) # shape: [BLOCK_M] + +# # compute the length of the longest sequence prefix spanned by any +# # query token in the current q_block (q_block_local_idx) +# max_seq_prefix_len = ( +# context_len +# + q_block_local_idx * BLOCK_Q +# + (BLOCK_M - 1) // num_queries_per_kv +# + 1 +# ) + +# # adjust for potential padding in the last q_block by considering the +# # actual sequence length +# max_seq_prefix_len = gl.minimum(max_seq_prefix_len, seq_len) + +# # calculate the number of tiles that need to be processed to +# # cover the longest sequence prefix (due to causal masking, tiles beyond +# # this prefix can be skipped) +# num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + +# # KV_cache_modifier: gl.constexpr = ".cg" if ALL_DECODE else "" + +# k_from_hbm = 0 +# k_from_lds = 0 +# v_from_hbm = 0 +# v_from_lds = 0 +# GATHER_LOAD_LAYOUT: gl.constexpr = gl.BlockedLayout( +# size_per_thread=[NUM_BLOCKS_GATHER_PER_TILE], +# threads_per_warp=[32], +# warps_per_cta=[num_warps], +# order=[0], +# ) +# offs_j = gl.arange(0, NUM_BLOCKS_GATHER_PER_TILE, layout=GATHER_LOAD_LAYOUT) +# j_from_hbm = segm_idx * tiles_per_segment +# j_from_lds = segm_idx * tiles_per_segment +# seq_offset = j_from_lds * TILE_SIZE + gl.arange( +# 0, TILE_SIZE, layout=gl.SliceLayout(0, QK_WMMA_LAYOUT) +# ) + +# for _ in range(num_stages - 1): +# j_from_hbm, offs_k_gather_idx, offs_v_gather_idx = _tdm_gather_get_kv_offsets( +# j_from_hbm, +# offs_j, +# kv_head_idx, +# block_tables_ptr, +# block_table_offset, +# stride_k_cache_0 // stride_k_cache_1, # = NUM_KV_HEADS +# stride_v_cache_0 // stride_v_cache_1, # = NUM_KV_HEADS +# NUM_BLOCKS_GATHER_PER_TILE, +# ) +# k_from_hbm = _tdm_async_gather_load_to_lds( +# k_from_hbm, +# desc=k_desc, +# src_row_indices=offs_k_gather_idx, +# src_col_offset=0, +# dst=smem_K, +# num_stages=num_stages, +# ) +# v_from_hbm = _tdm_async_gather_load_to_lds( +# v_from_hbm, +# desc=v_desc, +# src_row_indices=offs_v_gather_idx, +# src_col_offset=0, +# dst=smem_V, +# num_stages=num_stages, +# ) + +# # iterate through tiles within current segment +# # for _ in range(tiles_per_segment - (num_stages - 1)): +# for _ in range( +# segm_idx * tiles_per_segment, +# min((segm_idx + 1) * tiles_per_segment, num_tiles) - (num_stages - 1), +# ): +# j_from_hbm, offs_k_gather_idx, offs_v_gather_idx = _tdm_gather_get_kv_offsets( +# j_from_hbm, +# offs_j, +# kv_head_idx, +# block_tables_ptr, +# block_table_offset, +# stride_k_cache_0 // stride_k_cache_1, # = NUM_KV_HEADS +# stride_v_cache_0 // stride_v_cache_1, # = NUM_KV_HEADS +# NUM_BLOCKS_GATHER_PER_TILE, +# ) +# k_from_hbm = _tdm_async_gather_load_to_lds( +# k_from_hbm, +# desc=k_desc, +# src_row_indices=offs_k_gather_idx, +# src_col_offset=0, +# dst=smem_K, +# num_stages=num_stages, +# ) +# v_from_hbm = _tdm_async_gather_load_to_lds( +# v_from_hbm, +# desc=v_desc, +# src_row_indices=offs_v_gather_idx, +# src_col_offset=0, +# dst=smem_V, +# num_stages=num_stages, +# ) + +# # K : shape = (HEAD_SIZE, TILE_SIZE), layout = K_DOT_LAYOUT +# k_from_lds, K = _tdm_gather_request_from_lds( +# k_from_lds, +# k_scale, +# Q.dtype, +# smem_K, +# asycn_wait=(num_stages - 1) * 2 + 1, +# layout=K_DOT_LAYOUT, +# transpose=True, +# num_ctas=num_ctas, +# num_stages=num_stages, +# TILE_SIZE=TILE_SIZE, +# HEAD_SIZE=HEAD_SIZE, +# ) + +# # P : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT +# # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# P, L, M, acc = _perform_QK_wmma_and_update_L_M( +# Q, +# K, +# L, +# M, +# acc, +# qq_bias_row_ptrs, +# seq_offset, +# query_mask_qk, +# query_pos_qk, +# context_len, +# alibi_slope, +# qq_bias_stride_0, +# qk_scale, +# softcap, +# RCP_LN2, +# BLOCK_M, +# TILE_SIZE, +# USE_SOFTCAP, +# SLIDING_WINDOW, +# USE_ALIBI_SLOPES, +# USE_QQ_BIAS, +# Q_LOAD_LAYOUT, +# QK_WMMA_LAYOUT, +# PV_WMMA_LAYOUT, +# ) + +# # V : shape = (TILE_SIZE, HEAD_SIZE), layout = V_DOT_LAYOUT +# v_from_lds, V = _tdm_gather_request_from_lds( +# v_from_lds, +# v_scale, +# Q.dtype, +# smem_V, +# asycn_wait=(num_stages - 1) * 2, +# layout=V_DOT_LAYOUT, +# transpose=False, +# num_ctas=num_ctas, +# num_stages=num_stages, +# TILE_SIZE=TILE_SIZE, +# HEAD_SIZE=HEAD_SIZE, +# ) + +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# acc = _perform_PV_wmma(P, V, acc, P_DOT_LAYOUT) + +# j_from_lds = j_from_lds + 1 +# seq_offset += TILE_SIZE + +# for _ in range(num_stages - 1): +# # K : shape = (HEAD_SIZE, TILE_SIZE), layout = K_DOT_LAYOUT +# k_from_lds, K = _tdm_gather_request_from_lds( +# k_from_lds, +# k_scale, +# Q.dtype, +# smem_K, +# asycn_wait=(num_stages - 2) * 2 + 1, +# layout=K_DOT_LAYOUT, +# transpose=True, +# num_ctas=num_ctas, +# num_stages=num_stages, +# TILE_SIZE=TILE_SIZE, +# HEAD_SIZE=HEAD_SIZE, +# ) + +# # P : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT +# # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# P, L, M, acc = _perform_QK_wmma_and_update_L_M( +# Q, +# K, +# L, +# M, +# acc, +# qq_bias_row_ptrs, +# seq_offset, +# query_mask_qk, +# query_pos_qk, +# context_len, +# alibi_slope, +# qq_bias_stride_0, +# qk_scale, +# softcap, +# RCP_LN2, +# BLOCK_M, +# TILE_SIZE, +# USE_SOFTCAP, +# SLIDING_WINDOW, +# USE_ALIBI_SLOPES, +# USE_QQ_BIAS, +# Q_LOAD_LAYOUT, +# QK_WMMA_LAYOUT, +# PV_WMMA_LAYOUT, +# ) + +# # V : shape = (TILE_SIZE, HEAD_SIZE), layout = V_DOT_LAYOUT +# v_from_lds, V = _tdm_gather_request_from_lds( +# v_from_lds, +# v_scale, +# Q.dtype, +# smem_V, +# asycn_wait=(num_stages - 2) * 2, +# layout=V_DOT_LAYOUT, +# transpose=False, +# num_ctas=num_ctas, +# num_stages=num_stages, +# TILE_SIZE=TILE_SIZE, +# HEAD_SIZE=HEAD_SIZE, +# ) + +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# acc = _perform_PV_wmma(P, V, acc, P_DOT_LAYOUT) + +# j_from_lds = j_from_lds + 1 +# seq_offset += TILE_SIZE + +# # store segm_output +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = Q_LOAD_LAYOUT +# acc = gl.convert_layout(acc, layout=Q_LOAD_LAYOUT) +# segm_output_offset = ( +# query_offset_0[:, None] +# * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) +# + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) +# + segm_idx * HEAD_SIZE +# + offs_q_d[None, :] +# ) +# gl.amd.cdna4.buffer_store( +# stored_value=acc, +# ptr=segm_output_ptr, +# offsets=segm_output_offset, +# mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], +# ) + +# # store segm_max and segm_expsum +# # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) +# # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) +# segm_offset = ( +# query_offset_0 * (num_query_heads * NUM_SEGMENTS_PER_SEQ) +# + query_offset_1 * NUM_SEGMENTS_PER_SEQ +# + segm_idx +# ) +# L = gl.convert_layout(L, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) +# M = gl.convert_layout(M, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) +# gl.amd.cdna4.buffer_store( +# stored_value=M, +# ptr=segm_max_ptr, +# offsets=segm_offset, +# mask=query_mask_0 & query_mask_1, +# ) +# gl.amd.cdna4.buffer_store( +# stored_value=L, +# ptr=segm_expsum_ptr, +# offsets=segm_offset, +# mask=query_mask_0 & query_mask_1, +# ) + + +# @gluon.jit +# def _tdm_get_kv_offsets( +# j, +# kv_head_idx, +# block_tables_ptr, +# block_table_offset, +# stride_k_cache_t: gl.int64, +# stride_k_cache_d: gl.int64, +# stride_v_cache_t: gl.int64, +# stride_v_cache_d: gl.int64, +# ): +# physical_block_idx = gl.load(block_tables_ptr + block_table_offset + j) + +# offs_k_t = (physical_block_idx * stride_k_cache_t).to(tl.int32) +# offs_k_d = (kv_head_idx * stride_k_cache_d).to(tl.int32) + +# offs_v_t = (physical_block_idx * stride_v_cache_t).to(tl.int32) +# offs_v_d = (kv_head_idx * stride_v_cache_d).to(tl.int32) + +# return j + 1, offs_k_t, offs_k_d, offs_v_t, offs_v_d + + +# @gluon.jit +# def _tdm_async_load_to_lds( +# j, +# src, +# offsets, +# dest, +# pred_i32, +# num_stages: gl.constexpr, +# ): +# # gl.amd.gfx1250.tdm.prefetch( +# # src=k_desc, +# # offsets=[ +# # 0, +# # offs_kv_t_starts, +# # ], +# # pred=pred.to(gl.int1) +# # ) +# gl.amd.gfx1250.tdm.async_load( +# src=src, +# offsets=offsets, +# dest=dest.index(j % num_stages), +# pred=pred_i32, +# ) + +# return j + 1 + + +# @gluon.jit +# def _tdm_request_from_lds( +# j, +# kv_scale, +# Q_dtype, +# smem, +# asycn_wait: gl.constexpr, +# layout: gl.constexpr, +# transpose: gl.constexpr, +# num_ctas: gl.constexpr, +# num_stages: gl.constexpr, +# ): +# if num_ctas > 1: +# gl.amd.gfx1250.cluster.arrive() +# gl.amd.gfx1250.tdm.async_wait(asycn_wait) +# if num_ctas > 1: +# gl.amd.gfx1250.cluster.wait() +# if transpose: +# X = smem.index(j % num_stages).permute([1, 0]).load(layout=layout) +# else: +# X = smem.index(j % num_stages).load(layout=layout) + +# if X.dtype.is_fp8() and not Q_dtype.is_fp8(): +# X = (X.to(gl.float32) * gl.load(kv_scale)).to(Q_dtype) + +# return j + 1, X + + +# @gluon.jit +# def _tdm_create_tensor_descriptors_and_allocate_lds( +# q_ptr, +# k_ptr, +# v_ptr, +# NUM_BLOCKS, +# stride_q_m: gl.int64, # int +# stride_q_d: gl.constexpr, # int +# stride_k_t: gl.int64, # int +# stride_k_d: gl.constexpr, # int +# stride_v_t: gl.int64, # int +# stride_v_d: gl.constexpr, # int +# q_shared_layout: gl.constexpr, +# k_shared_layout: gl.constexpr, +# v_shared_layout: gl.constexpr, +# NUM_KV_HEADS: gl.constexpr, +# BLOCK_M: gl.constexpr, +# HEAD_SIZE: gl.constexpr, +# TILE_SIZE: gl.constexpr, +# HEAD_SIZE: gl.constexpr, +# num_stages: gl.constexpr, +# ): +# gl.static_assert(stride_q_d == 1, "stride_q_d must be 1") +# gl.static_assert(stride_k_d == 1, "stride_k_d must be 1") +# gl.static_assert(stride_v_d == 1, "stride_v_d must be 1") +# # q_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( +# # base=q_ptr, +# # shape=(M, HEAD_SIZE), +# # strides=(stride_q_m, stride_q_d), +# # block_shape=(BLOCK_M, HEAD_SIZE), +# # layout=q_shared_layout, +# # ) + +# k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( +# base=k_ptr, +# shape=(NUM_BLOCKS * TILE_SIZE, NUM_KV_HEADS * HEAD_SIZE), +# strides=(stride_k_t, stride_k_d), +# block_shape=(TILE_SIZE, HEAD_SIZE), +# layout=k_shared_layout, +# ) + +# v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( +# base=v_ptr, +# shape=(NUM_BLOCKS * TILE_SIZE, NUM_KV_HEADS * HEAD_SIZE), +# strides=(stride_v_t, stride_v_d), +# block_shape=(TILE_SIZE, HEAD_SIZE), +# layout=v_shared_layout, +# ) + +# # smem_Q = gl.allocate_shared_memory( +# # q_desc.dtype, +# # shape=q_desc.block_shape, +# # layout=q_shared_layout, +# # ) +# smem_Q = gl.allocate_shared_memory( +# q_ptr.type.element_ty, +# shape=[BLOCK_M, HEAD_SIZE], +# layout=q_shared_layout, +# ) +# smem_K = gl.allocate_shared_memory( +# k_desc.dtype, +# shape=[num_stages] + k_desc.block_shape, +# layout=k_desc.layout, +# ) +# smem_V = gl.allocate_shared_memory( +# v_desc.dtype, +# shape=[num_stages] + v_desc.block_shape, +# layout=v_desc.layout, +# ) + +# return k_desc, v_desc, smem_Q, smem_K, smem_V + + +# gluon_kernel_unified_attention_3d_tdm_repr = make_kernel_repr( +# "gluon_kernel_unified_attention_3d_tdm", +# [ +# "num_query_heads", +# "num_queries_per_kv", +# "BLOCK_SIZE", +# "TILE_SIZE", +# "HEAD_SIZE", +# "num_warps", +# "num_stages", +# "cache_modifier", +# ], +# ) + + +# @gluon.jit(repr=gluon_kernel_unified_attention_3d_tdm_repr) +# def gluon_kernel_unified_attention_3d_tdm( +# segm_output_ptr, +# # [num_tokens, num_query_heads, num_segments, head_size] +# segm_max_ptr, # [num_tokens, num_query_heads, num_segments] +# segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] +# query_ptr, # [num_tokens, num_query_heads, head_size] +# key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] +# value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] +# sink_ptr, # [num_query_heads] +# block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] +# seq_lens_ptr, # [num_seqs] +# alibi_slopes_ptr, # [num_query_heads] +# qq_bias_ptr, # [num_query_tokens, num_query_tokens] +# scale, # float32 +# k_scale, # float32 +# v_scale, # float32 +# softcap, # float32 +# num_tokens, # int +# NUM_BLOCKS, # int +# num_query_heads: gl.constexpr, # int +# num_queries_per_kv: gl.constexpr, # int +# block_table_stride: gl.int64, # int +# query_stride_0: gl.int64, # int +# query_stride_1: gl.int64, # int, should be equal to head_size +# qq_bias_stride_0: gl.int64, # int +# BLOCK_SIZE: gl.constexpr, # int +# TILE_SIZE: gl.constexpr, # int, must be power of 2 +# HEAD_SIZE: gl.constexpr, # int +# HEAD_SIZE: gl.constexpr, # int, must be power of 2 +# USE_ALIBI_SLOPES: gl.constexpr, # bool +# USE_QQ_BIAS: gl.constexpr, # bool +# USE_SOFTCAP: gl.constexpr, # bool +# USE_SINKS: gl.constexpr, # bool +# SLIDING_WINDOW: gl.constexpr, # int +# stride_k_cache_0: gl.int64, # int +# stride_k_cache_1: gl.int64, # int +# stride_k_cache_2: gl.int64, # int +# stride_k_cache_3: gl.constexpr, # int +# stride_v_cache_0: gl.int64, # int +# stride_v_cache_1: gl.int64, # int +# stride_v_cache_2: gl.int64, # int +# stride_v_cache_3: gl.constexpr, # int +# query_start_len_ptr, # [num_seqs+1] +# BLOCK_Q: gl.constexpr, # int +# num_seqs: gl.int32, +# BLOCK_M: gl.constexpr, # int +# NUM_SEGMENTS_PER_SEQ: gl.constexpr, # int +# num_warps: gl.constexpr, # int +# num_stages: gl.constexpr, # int +# QK_WMMA_LAYOUT: gl.constexpr, +# PV_WMMA_LAYOUT: gl.constexpr, +# Q_DOT_LAYOUT: gl.constexpr, +# K_DOT_LAYOUT: gl.constexpr, +# P_DOT_LAYOUT: gl.constexpr, +# V_DOT_LAYOUT: gl.constexpr, +# Q_SHARED_LAYOUT: gl.constexpr, +# K_SHARED_LAYOUT: gl.constexpr, +# V_SHARED_LAYOUT: gl.constexpr, +# Q_LOAD_LAYOUT: gl.constexpr, +# K_LOAD_LAYOUT: gl.constexpr, +# V_LOAD_LAYOUT: gl.constexpr, +# num_ctas: gl.constexpr = 1, # int +# ALL_DECODE: gl.constexpr = False, # bool +# SHUFFLED_KV_CACHE: gl.constexpr = False, # bool +# ): +# q_block_global_idx = gl.program_id(0) +# kv_head_idx = gl.program_id(1) +# segm_idx = gl.program_id(2) +# # num_ctas: gl.constexpr = gl.num_ctas() +# pred = 1 +# pred_i32 = pred.to(gl.int32) if hasattr(pred, "to") else pred + +# gl.static_assert( +# TILE_SIZE == BLOCK_SIZE, "TILE_SIZE must be the same as BLOCK_SIZE" +# ) + +# # needed to use exp2 (exp2 -> exp conversion) +# RCP_LN2 = 1.4426950408889634 +# qk_scale = scale * RCP_LN2 + +# seq_idx = _find_seq_idx( +# query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True +# ) + +# q_block_local_idx, cur_batch_query_len, cur_batch_in_all_start_index = ( +# _get_q_metadata( +# query_start_len_ptr, +# seq_idx, +# q_block_global_idx, +# BLOCK_Q, +# ) +# ) + +# if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: +# return + +# seq_len, tiles_per_segment = _get_seq_metadata( +# seq_lens_ptr, +# seq_idx, +# TILE_SIZE, +# NUM_SEGMENTS_PER_SEQ, +# ) + +# if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: +# return + +# # block table offset for this particular sequence +# block_table_offset = seq_idx * block_table_stride + +# # context length for this particular sequence +# context_len = seq_len - cur_batch_query_len + +# offs_q_m = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) +# offs_q_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, Q_LOAD_LAYOUT)) + +# query_pos = q_block_local_idx * BLOCK_Q + offs_q_m // num_queries_per_kv + +# query_offset_0 = cur_batch_in_all_start_index + query_pos +# query_offset_1 = kv_head_idx * num_queries_per_kv + offs_q_m % num_queries_per_kv +# query_offset = ( +# query_offset_0[:, None] * query_stride_0 +# + query_offset_1[:, None] * query_stride_1 +# + offs_q_d[None, :] +# ) + +# if HEAD_SIZE != HEAD_SIZE: +# dim_mask = offs_q_d < HEAD_SIZE +# else: +# dim_mask = gl.full((1,), 1, dtype=tl.int1) + +# query_mask_0 = query_pos < cur_batch_query_len +# query_mask_1 = query_offset_1 < num_query_heads + +# NUM_KV_HEADS: gl.constexpr = num_query_heads // num_queries_per_kv + +# k_desc, v_desc, smem_Q, smem_K, smem_V = ( +# _tdm_create_tensor_descriptors_and_allocate_lds( +# query_ptr, +# key_cache_ptr, +# value_cache_ptr, +# NUM_BLOCKS, +# query_stride_1, +# 1, +# stride_k_cache_1, # stride_k_cache_1 = HEAD_SIZE * num_kv_heads +# stride_k_cache_3, +# stride_v_cache_1, +# stride_v_cache_3, +# Q_SHARED_LAYOUT, +# K_SHARED_LAYOUT, +# V_SHARED_LAYOUT, +# NUM_KV_HEADS, +# BLOCK_M, +# HEAD_SIZE, +# TILE_SIZE, +# HEAD_SIZE, +# num_stages=num_stages, +# ) +# ) + +# # Q_load : shape = (BLOCK_M, HEAD_SIZE), layout = Q_LOAD_LAYOUT +# Q_load = gl.amd.cdna4.buffer_load( +# ptr=query_ptr, +# offsets=query_offset.to(gl.int32), +# mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], +# other=0.0, +# ) +# smem_Q.store(Q_load) +# # Q : shape = (BLOCK_M, HEAD_SIZE), layout = Q_DOT_LAYOUT +# Q = smem_Q.load(layout=Q_DOT_LAYOUT) + +# offs_q_m_qk = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, QK_WMMA_LAYOUT)) +# query_pos_qk = q_block_local_idx * BLOCK_Q + offs_q_m_qk // num_queries_per_kv +# query_offset_1_qk = ( +# kv_head_idx * num_queries_per_kv + offs_q_m_qk % num_queries_per_kv +# ) +# query_mask_0_qk = query_pos_qk < cur_batch_query_len +# query_mask_1_qk = query_offset_1_qk < num_query_heads +# query_mask_qk = query_mask_1_qk[:, None] & query_mask_0_qk[:, None] + +# L, M, acc = _allocate_L_M_acc( +# sink_ptr, +# segm_idx, +# query_offset_1_qk, +# query_mask_1_qk, +# RCP_LN2, +# BLOCK_M, +# HEAD_SIZE, +# QK_WMMA_LAYOUT, +# PV_WMMA_LAYOUT, +# USE_SINKS, +# ) + +# # alibi slope for this head +# alibi_slope = None +# if USE_ALIBI_SLOPES: +# alibi_slope = tl.load( +# alibi_slopes_ptr + query_offset_1_qk, mask=query_mask_1, other=0.0 +# ) + +# # query-query attention bias +# qq_bias_row_ptrs = None +# if USE_QQ_BIAS: +# qq_bias_row_ptrs = ( +# qq_bias_ptr + query_pos_qk[:, None] * qq_bias_stride_0 +# ) # shape: [BLOCK_M] + +# # compute the length of the longest sequence prefix spanned by any +# # query token in the current q_block (q_block_local_idx) +# max_seq_prefix_len = ( +# context_len +# + q_block_local_idx * BLOCK_Q +# + (BLOCK_M - 1) // num_queries_per_kv +# + 1 +# ) + +# # adjust for potential padding in the last q_block by considering the +# # actual sequence length +# max_seq_prefix_len = gl.minimum(max_seq_prefix_len, seq_len) + +# # calculate the number of tiles that need to be processed to +# # cover the longest sequence prefix (due to causal masking, tiles beyond +# # this prefix can be skipped) +# num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + +# # KV_cache_modifier: gl.constexpr = ".cg" if ALL_DECODE else "" + +# k_from_hbm = 0 +# k_from_lds = 0 +# v_from_hbm = 0 +# v_from_lds = 0 +# j_from_hbm = segm_idx * tiles_per_segment +# j_from_lds = segm_idx * tiles_per_segment +# seq_offset = j_from_lds * TILE_SIZE + gl.arange( +# 0, TILE_SIZE, layout=gl.SliceLayout(0, QK_WMMA_LAYOUT) +# ) + +# for _ in range(num_stages - 1): +# j_from_hbm, offs_k_t, offs_k_d, offs_v_t, offs_v_d = _tdm_get_kv_offsets( +# j_from_hbm, +# kv_head_idx, +# block_tables_ptr, +# block_table_offset, +# stride_k_cache_0 // stride_k_cache_1, # = BLOCK_SIZE +# stride_k_cache_2, +# stride_v_cache_0 // stride_v_cache_1, # = BLOCK_SIZE +# stride_v_cache_2, +# ) +# k_from_hbm = _tdm_async_load_to_lds( +# k_from_hbm, +# src=k_desc, +# offsets=[offs_k_t, offs_k_d], +# dest=smem_K, +# pred_i32=pred_i32, +# num_stages=num_stages, +# ) +# v_from_hbm = _tdm_async_load_to_lds( +# v_from_hbm, +# src=v_desc, +# offsets=[offs_v_t, offs_v_d], +# dest=smem_V, +# pred_i32=pred_i32, +# num_stages=num_stages, +# ) + +# # iterate through tiles within current segment +# # for _ in range(tiles_per_segment - (num_stages - 1)): +# for _ in range( +# segm_idx * tiles_per_segment, +# min((segm_idx + 1) * tiles_per_segment, num_tiles) - (num_stages - 1), +# ): +# j_from_hbm, offs_k_t, offs_k_d, offs_v_t, offs_v_d = _tdm_get_kv_offsets( +# j_from_hbm, +# kv_head_idx, +# block_tables_ptr, +# block_table_offset, +# stride_k_cache_0 // stride_k_cache_1, # = BLOCK_SIZE +# stride_k_cache_2, +# stride_v_cache_0 // stride_v_cache_1, # = BLOCK_SIZE +# stride_v_cache_2, +# ) +# k_from_hbm = _tdm_async_load_to_lds( +# k_from_hbm, +# src=k_desc, +# offsets=[offs_k_t, offs_k_d], +# dest=smem_K, +# pred_i32=pred_i32, +# num_stages=num_stages, +# ) +# v_from_hbm = _tdm_async_load_to_lds( +# v_from_hbm, +# src=v_desc, +# offsets=[offs_v_t, offs_v_d], +# dest=smem_V, +# pred_i32=pred_i32, +# num_stages=num_stages, +# ) + +# # K : shape = (HEAD_SIZE, TILE_SIZE), layout = K_DOT_LAYOUT +# k_from_lds, K = _tdm_request_from_lds( +# k_from_lds, +# k_scale, +# Q.dtype, +# smem_K, +# asycn_wait=(num_stages - 1) * 2 + 1, +# layout=K_DOT_LAYOUT, +# transpose=True, +# num_ctas=num_ctas, +# num_stages=num_stages, +# ) + +# # P : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT +# # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# P, L, M, acc = _perform_QK_wmma_and_update_L_M( +# Q, +# K, +# L, +# M, +# acc, +# qq_bias_row_ptrs, +# seq_offset, +# query_mask_qk, +# query_pos_qk, +# context_len, +# alibi_slope, +# qq_bias_stride_0, +# qk_scale, +# softcap, +# RCP_LN2, +# BLOCK_M, +# TILE_SIZE, +# USE_SOFTCAP, +# SLIDING_WINDOW, +# USE_ALIBI_SLOPES, +# USE_QQ_BIAS, +# Q_LOAD_LAYOUT, +# QK_WMMA_LAYOUT, +# PV_WMMA_LAYOUT, +# ) + +# # V : shape = (TILE_SIZE, HEAD_SIZE), layout = V_DOT_LAYOUT +# v_from_lds, V = _tdm_request_from_lds( +# v_from_lds, +# v_scale, +# Q.dtype, +# smem_V, +# asycn_wait=(num_stages - 1) * 2, +# layout=V_DOT_LAYOUT, +# transpose=False, +# num_ctas=num_ctas, +# num_stages=num_stages, +# ) + +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# acc = _perform_PV_wmma(P, V, acc, P_DOT_LAYOUT) + +# j_from_lds = j_from_lds + 1 +# seq_offset += TILE_SIZE + +# for _ in range(num_stages - 1): +# # K : shape = (HEAD_SIZE, TILE_SIZE), layout = K_DOT_LAYOUT +# k_from_lds, K = _tdm_request_from_lds( +# k_from_lds, +# k_scale, +# Q.dtype, +# smem_K, +# asycn_wait=(num_stages - 2) * 2 + 1, +# layout=K_DOT_LAYOUT, +# transpose=True, +# num_ctas=num_ctas, +# num_stages=num_stages, +# ) + +# # P : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT +# # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# P, L, M, acc = _perform_QK_wmma_and_update_L_M( +# Q, +# K, +# L, +# M, +# acc, +# qq_bias_row_ptrs, +# seq_offset, +# query_mask_qk, +# query_pos_qk, +# context_len, +# alibi_slope, +# qq_bias_stride_0, +# qk_scale, +# softcap, +# RCP_LN2, +# BLOCK_M, +# TILE_SIZE, +# USE_SOFTCAP, +# SLIDING_WINDOW, +# USE_ALIBI_SLOPES, +# USE_QQ_BIAS, +# Q_LOAD_LAYOUT, +# QK_WMMA_LAYOUT, +# PV_WMMA_LAYOUT, +# ) + +# # V : shape = (TILE_SIZE, HEAD_SIZE), layout = V_DOT_LAYOUT +# v_from_lds, V = _tdm_request_from_lds( +# v_from_lds, +# v_scale, +# Q.dtype, +# smem_V, +# asycn_wait=(num_stages - 2) * 2, +# layout=V_DOT_LAYOUT, +# transpose=False, +# num_ctas=num_ctas, +# num_stages=num_stages, +# ) + +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT +# acc = _perform_PV_wmma(P, V, acc, P_DOT_LAYOUT) + +# j_from_lds = j_from_lds + 1 +# seq_offset += TILE_SIZE + +# # store segm_output +# # acc : shape = (BLOCK_M, HEAD_SIZE), layout = Q_LOAD_LAYOUT +# acc = gl.convert_layout(acc, layout=Q_LOAD_LAYOUT) +# segm_output_offset = ( +# query_offset_0[:, None] +# * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) +# + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) +# + segm_idx * HEAD_SIZE +# + offs_q_d[None, :] +# ) +# gl.amd.cdna4.buffer_store( +# stored_value=acc, +# ptr=segm_output_ptr, +# offsets=segm_output_offset, +# mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], +# ) + +# # store segm_max and segm_expsum +# # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) +# # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) +# segm_offset = ( +# query_offset_0 * (num_query_heads * NUM_SEGMENTS_PER_SEQ) +# + query_offset_1 * NUM_SEGMENTS_PER_SEQ +# + segm_idx +# ) +# L = gl.convert_layout(L, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) +# M = gl.convert_layout(M, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) +# gl.amd.cdna4.buffer_store( +# stored_value=M, +# ptr=segm_max_ptr, +# offsets=segm_offset, +# mask=query_mask_0 & query_mask_1, +# ) +# gl.amd.cdna4.buffer_store( +# stored_value=L, +# ptr=segm_expsum_ptr, +# offsets=segm_offset, +# mask=query_mask_0 & query_mask_1, +# ) + + +@gluon.jit +def _async_load_to_lds( + from_hbm, + dest, + ptr, + offsets, + mask, + num_stages: gl.constexpr, + cache_modifier: gl.constexpr, + use_buffer_load: gl.constexpr = False, +): + if use_buffer_load: + gl.amd.cdna4.async_copy.buffer_load_to_shared( + dest=dest.index(from_hbm % num_stages), + ptr=ptr, + offsets=offsets.to(gl.int32), + mask=mask, + cache_modifier=cache_modifier, + ) + else: + gl.amd.cdna4.async_copy.global_load_to_shared( + dest=dest.index(from_hbm % num_stages), + ptr=ptr + offsets, + mask=mask, + cache_modifier=cache_modifier, + ) + gl.amd.cdna4.async_copy.commit_group() + return from_hbm + 1 + + +@gluon.jit +def _request_from_lds( + from_lds, + kv_scale, + Q_dtype, + smem, + wait_group, + LOAD_LAYOUT: gl.constexpr, + DOT_LAYOUT: gl.constexpr, + SHUFFLED_KV_CACHE: gl.constexpr, + num_stages: gl.constexpr, +): + gl.amd.cdna4.async_copy.wait_group(wait_group) + if SHUFFLED_KV_CACHE: + KV = gl.amd.cdna4.async_copy.load_shared_relaxed( + smem.index(from_lds % num_stages), layout=LOAD_LAYOUT + ) + else: + KV = gl.amd.cdna4.async_copy.load_shared_relaxed( + smem.index(from_lds % num_stages), layout=DOT_LAYOUT + ) + if KV.dtype.is_fp8() and not Q_dtype.is_fp8(): + KV = (KV.to(gl.float32) * gl.load(kv_scale)).to(Q_dtype) + return KV, from_lds + 1 + + +@gluon.jit +def _get_kv_offsets( + j, + kv_head_idx, + block_tables_ptr, + block_table_offset, + offs_k_t, + offs_k_d, + offs_v_t, + offs_v_d, + max_seq_prefix_len, + stride_k_cache_0: gl.int64, + stride_k_cache_1: gl.int64, + stride_k_cache_2: gl.int64, + stride_k_cache_3: gl.constexpr, + stride_v_cache_0: gl.int64, + stride_v_cache_1: gl.int64, + stride_v_cache_2: gl.int64, + stride_v_cache_3: gl.constexpr, + K_LOAD_LAYOUT: gl.constexpr, + V_LOAD_LAYOUT: gl.constexpr, + TILE_SIZE: gl.constexpr, + BLOCK_SIZE: gl.constexpr, + SHUFFLED_KV_CACHE: gl.constexpr, +): + # seq_k_offset : shape = (TILE_SIZE, ), layout = gl.SliceLayout(0, K_LOAD_LAYOUT) + # seq_v_offset : shape = (TILE_SIZE, ), layout = gl.SliceLayout(1, V_LOAD_LAYOUT) + + if SHUFFLED_KV_CACHE: + physical_block_idx = gl.load(block_tables_ptr + block_table_offset + j).to( + tl.int64 + ) + + k_offset = ( + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + offs_k_t[:, None] * stride_k_cache_2 + + offs_k_d[None, :] * stride_k_cache_3 + ) + v_offset = ( + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_v_t[None, :] * stride_v_cache_3 + + offs_v_d[:, None] * stride_v_cache_2 + ) + return j + 1, k_offset, v_offset, None, None + else: + seq_k_offset = j * TILE_SIZE + offs_k_t + seq_v_offset = j * TILE_SIZE + offs_v_t + + if TILE_SIZE == BLOCK_SIZE: + tile_k_mask = gl.full( + (1,), 1, dtype=tl.int1, layout=gl.SliceLayout(0, K_LOAD_LAYOUT) + ) + tile_v_mask = gl.full( + (1,), 1, dtype=tl.int1, layout=gl.SliceLayout(1, V_LOAD_LAYOUT) + ) + else: + tile_k_mask = seq_k_offset < max_seq_prefix_len + tile_v_mask = seq_v_offset < max_seq_prefix_len + + physical_block_idx_k = gl.amd.cdna4.buffer_load( + ptr=block_tables_ptr, + offsets=(block_table_offset + seq_k_offset // BLOCK_SIZE).to(gl.int32), + ).to(tl.int64) + + physical_block_idx_v = gl.amd.cdna4.buffer_load( + ptr=block_tables_ptr, + offsets=(block_table_offset + seq_v_offset // BLOCK_SIZE).to(gl.int32), + ).to(tl.int64) + + k_offset = ( + physical_block_idx_k[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_k_d[:, None] * stride_k_cache_3 + + (seq_k_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + v_offset = ( + physical_block_idx_v[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_v_d[None, :] * stride_v_cache_3 + + (seq_v_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + return j + 1, k_offset, v_offset, tile_k_mask, tile_v_mask + + +gluon_kernel_unified_attention_3d_async_repr = make_kernel_repr( + "gluon_kernel_unified_attention_3d_async", + [ + "num_query_heads", + "num_queries_per_kv", + "BLOCK_SIZE", + "TILE_SIZE", + "HEAD_SIZE", + "num_warps", + "num_stages", + "use_buffer_load", + "ALL_DECODE", + "SHUFFLED_KV_CACHE", + ], +) + + +@gluon.jit(repr=gluon_kernel_unified_attention_3d_async_repr) +def gluon_kernel_unified_attention_3d_async( + segm_output_ptr, # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] + value_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_seqs: gl.int32, # int + num_blocks: gl.int32, # int + query_stride_0: gl.int32, # int + query_stride_1: gl.int32, # int, should be equal to head_size + qq_bias_stride_0: gl.int32, # int + USE_ALIBI_SLOPES: gl.constexpr, # bool + USE_QQ_BIAS: gl.constexpr, # bool + USE_SOFTCAP: gl.constexpr, # bool + USE_SINKS: gl.constexpr, # bool + SLIDING_WINDOW: gl.constexpr, # int + stride_k_cache_0: gl.int64, # int + stride_k_cache_1: gl.int64, # int + stride_k_cache_2: gl.int64, # int + stride_k_cache_3: gl.constexpr, # int + stride_v_cache_0: gl.int64, # int + stride_v_cache_1: gl.int64, # int + stride_v_cache_2: gl.int64, # int + stride_v_cache_3: gl.constexpr, # int + block_table_stride: gl.int64, # int + query_start_len_ptr, # [num_seqs+1] + SCALE: gl.constexpr, # float32 + NUM_QUERY_HEADS: gl.constexpr, # int + NUM_KV_HEADS: gl.constexpr, # int + BLOCK_SIZE: gl.constexpr, # int + TILE_SIZE: gl.constexpr, # int + HEAD_SIZE: gl.constexpr, # int + BLOCK_Q: gl.constexpr, # int + BLOCK_M: gl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: gl.constexpr, # int + WARP_SIZE: gl.constexpr, # int + num_warps: gl.constexpr, # int + waves_per_eu: gl.constexpr, # int + num_stages: gl.constexpr, # int + QK_WMMA_LAYOUT: gl.constexpr, + PV_WMMA_LAYOUT: gl.constexpr, + Q_DOT_LAYOUT: gl.constexpr, + K_DOT_LAYOUT: gl.constexpr, + P_DOT_LAYOUT: gl.constexpr, + V_DOT_LAYOUT: gl.constexpr, + Q_SHARED_LAYOUT: gl.constexpr, + K_SHARED_LAYOUT: gl.constexpr, + V_SHARED_LAYOUT: gl.constexpr, + Q_LOAD_LAYOUT: gl.constexpr, + K_LOAD_LAYOUT: gl.constexpr, + V_LOAD_LAYOUT: gl.constexpr, + ALL_DECODE: gl.constexpr = False, # bool + SHUFFLED_KV_CACHE: gl.constexpr = False, # bool + USE_LOAD_BUFFER_OP: gl.constexpr = False, # bool +): + q_block_global_idx = gl.program_id(0) + kv_head_idx = gl.program_id(1) + segm_idx = gl.program_id(2) + num_ctas: gl.constexpr = gl.num_ctas() + + # needed to use exp2 (exp2 -> exp conversion) + RCP_LN2 = 1.4426950408889634 + qk_scale = SCALE * RCP_LN2 + + seq_idx = _find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_local_idx, cur_batch_query_len, cur_batch_in_all_start_index = ( + _get_q_metadata( + query_start_len_ptr, + seq_idx, + q_block_global_idx, + BLOCK_Q, + ) + ) + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + seq_len, tiles_per_segment = _get_seq_metadata( + seq_lens_ptr, + seq_idx, + TILE_SIZE, + NUM_SEGMENTS_PER_SEQ, + ) + + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + return + + # block table offset for this particular sequence + block_table_offset = seq_idx * block_table_stride + + # context length for this particular sequence + context_len = seq_len - cur_batch_query_len + + smem_Q = gl.allocate_shared_memory( + query_ptr.type.element_ty, [BLOCK_M, HEAD_SIZE], layout=Q_SHARED_LAYOUT + ) + smem_K = None + smem_V = None + if SHUFFLED_KV_CACHE: + smem_K = gl.allocate_shared_memory( + key_cache_ptr.type.element_ty, + [num_stages, TILE_SIZE // 16, HEAD_SIZE * 16], + layout=K_SHARED_LAYOUT, + ) + smem_V = gl.allocate_shared_memory( + value_cache_ptr.type.element_ty, + [num_stages, HEAD_SIZE // 16, TILE_SIZE * 16], + layout=V_SHARED_LAYOUT, + ) + else: + smem_K = gl.allocate_shared_memory( + key_cache_ptr.type.element_ty, + [num_stages, HEAD_SIZE, TILE_SIZE], + layout=K_SHARED_LAYOUT, + ) + smem_V = gl.allocate_shared_memory( + value_cache_ptr.type.element_ty, + [num_stages, TILE_SIZE, HEAD_SIZE], + layout=V_SHARED_LAYOUT, + ) + + offs_q_m = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) + offs_q_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, Q_LOAD_LAYOUT)) + + if SHUFFLED_KV_CACHE: + offs_k_t = gl.arange( + 0, TILE_SIZE // 16, layout=gl.SliceLayout(1, K_LOAD_LAYOUT) + ) + offs_k_d = gl.arange(0, HEAD_SIZE * 16, layout=gl.SliceLayout(0, K_LOAD_LAYOUT)) + offs_v_t = gl.arange(0, TILE_SIZE * 16, layout=gl.SliceLayout(0, V_LOAD_LAYOUT)) + offs_v_d = gl.arange( + 0, HEAD_SIZE // 16, layout=gl.SliceLayout(1, V_LOAD_LAYOUT) + ) + else: + offs_k_t = gl.arange(0, TILE_SIZE, layout=gl.SliceLayout(0, K_LOAD_LAYOUT)) + offs_k_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(1, K_LOAD_LAYOUT)) + offs_v_t = gl.arange(0, TILE_SIZE, layout=gl.SliceLayout(1, V_LOAD_LAYOUT)) + offs_v_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, V_LOAD_LAYOUT)) + + num_queries_per_kv: gl.constexpr = NUM_QUERY_HEADS // NUM_KV_HEADS + query_pos = q_block_local_idx * BLOCK_Q + offs_q_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_q_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_q_d[None, :] + ) + + dim_mask = gl.full((1,), 1, dtype=tl.int1) + + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < NUM_QUERY_HEADS + + # Q_load : shape = (BLOCK_M, HEAD_SIZE), layout = Q_LOAD_LAYOUT + Q_load = gl.amd.cdna4.buffer_load( + ptr=query_ptr, + offsets=query_offset.to(gl.int32), + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + smem_Q.store(Q_load) + # Q : shape = (BLOCK_M, HEAD_SIZE), layout = Q_DOT_LAYOUT + Q = smem_Q.load(layout=Q_DOT_LAYOUT) + + offs_q_m_qk = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, QK_WMMA_LAYOUT)) + query_pos_qk = q_block_local_idx * BLOCK_Q + offs_q_m_qk // num_queries_per_kv + query_offset_1_qk = ( + kv_head_idx * num_queries_per_kv + offs_q_m_qk % num_queries_per_kv + ) + query_mask_0_qk = query_pos_qk < cur_batch_query_len + query_mask_1_qk = query_offset_1_qk < NUM_QUERY_HEADS + query_mask_qk = query_mask_1_qk[:, None] & query_mask_0_qk[:, None] + + L, M, acc = _allocate_L_M_acc( + sink_ptr, + segm_idx, + query_offset_1_qk, + query_mask_1_qk, + RCP_LN2, + BLOCK_M, + HEAD_SIZE, + QK_WMMA_LAYOUT, + PV_WMMA_LAYOUT, + USE_SINKS, + ) + + # alibi slope for this head + alibi_slope = None + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1_qk, mask=query_mask_1, other=0.0 + ) + + # query-query attention bias + qq_bias_row_ptrs = None + if USE_QQ_BIAS: + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos_qk[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = gl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + KV_cache_modifier: gl.constexpr = ".cg" if ALL_DECODE else "" + + k_from_hbm = 0 + k_from_lds = 0 + v_from_hbm = 0 + v_from_lds = 0 + j_from_hbm = segm_idx * tiles_per_segment + seq_offset = j_from_hbm * TILE_SIZE + gl.arange( + 0, TILE_SIZE, layout=gl.SliceLayout(0, QK_WMMA_LAYOUT) + ) + + for _ in range(num_stages - 1): + j_from_hbm, k_offset, v_offset, tile_k_mask, tile_v_mask = _get_kv_offsets( + j_from_hbm, + kv_head_idx, + block_tables_ptr, + block_table_offset, + offs_k_t, + offs_k_d, + offs_v_t, + offs_v_d, + max_seq_prefix_len, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + K_LOAD_LAYOUT, + V_LOAD_LAYOUT, + TILE_SIZE, + BLOCK_SIZE, + SHUFFLED_KV_CACHE, + ) + k_mask = None + v_mask = None + if not SHUFFLED_KV_CACHE: + k_mask = dim_mask[:, None] & tile_k_mask[None, :] + v_mask = dim_mask[None, :] & tile_v_mask[:, None] + + k_from_hbm = _async_load_to_lds( + k_from_hbm, + dest=smem_K, + ptr=key_cache_ptr, + offsets=k_offset, + mask=k_mask, + num_stages=num_stages, + cache_modifier=KV_cache_modifier, + use_buffer_load=USE_LOAD_BUFFER_OP, + ) + v_from_hbm = _async_load_to_lds( + v_from_hbm, + dest=smem_V, + ptr=value_cache_ptr, + offsets=v_offset, + mask=v_mask, + num_stages=num_stages, + cache_modifier=KV_cache_modifier, + use_buffer_load=USE_LOAD_BUFFER_OP, + ) + + # iterate through tiles within current segment + for _ in range( + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles) - (num_stages - 1), + ): + # K, k_from_lds = _request_from_lds( + # k_from_lds, + # k_scale, + # Q.dtype, + # smem_K, + # wait_group=(num_stages - 2) * 2 + 1, + # LOAD_LAYOUT=K_LOAD_LAYOUT, + # DOT_LAYOUT=K_DOT_LAYOUT, + # SHUFFLED_KV_CACHE=SHUFFLED_KV_CACHE, + # num_stages=num_stages, + # ) + # if SHUFFLED_KV_CACHE: + # K = _unshuffle_kv_cache(K, TILE_SIZE, HEAD_SIZE) + # K = gl.convert_layout(value=K, layout=K_DOT_LAYOUT, assert_trivial=True) + + j_from_hbm, k_offset, v_offset, tile_k_mask, tile_v_mask = _get_kv_offsets( + j_from_hbm, + kv_head_idx, + block_tables_ptr, + block_table_offset, + offs_k_t, + offs_k_d, + offs_v_t, + offs_v_d, + max_seq_prefix_len, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + K_LOAD_LAYOUT, + V_LOAD_LAYOUT, + TILE_SIZE, + BLOCK_SIZE, + SHUFFLED_KV_CACHE, + ) + k_mask = None + v_mask = None + if not SHUFFLED_KV_CACHE: + k_mask = dim_mask[:, None] & tile_k_mask[None, :] + v_mask = dim_mask[None, :] & tile_v_mask[:, None] + + K, k_from_lds = _request_from_lds( + k_from_lds, + k_scale, + Q.dtype, + smem_K, + wait_group=(num_stages - 2) * 2 + 1, + LOAD_LAYOUT=K_LOAD_LAYOUT, + DOT_LAYOUT=K_DOT_LAYOUT, + SHUFFLED_KV_CACHE=SHUFFLED_KV_CACHE, + num_stages=num_stages, + ) + if SHUFFLED_KV_CACHE: + K = _unshuffle_kv_cache(K, TILE_SIZE, HEAD_SIZE) + K = gl.convert_layout(value=K, layout=K_DOT_LAYOUT, assert_trivial=True) + + # K_load : shape = (HEAD_SIZE, TILE_SIZE), layout = K_LOAD_LAYOUT + k_from_hbm = _async_load_to_lds( + k_from_hbm, + dest=smem_K, + ptr=key_cache_ptr, + offsets=k_offset, + mask=k_mask, + num_stages=num_stages, + cache_modifier=KV_cache_modifier, + use_buffer_load=USE_LOAD_BUFFER_OP, + ) + + # K, k_from_lds = _request_from_lds( + # k_from_lds, + # k_scale, + # Q.dtype, + # smem_K, + # wait_group=(num_stages - 1) * 2, + # LOAD_LAYOUT=K_LOAD_LAYOUT, + # DOT_LAYOUT=K_DOT_LAYOUT, + # SHUFFLED_KV_CACHE=SHUFFLED_KV_CACHE, + # num_stages=num_stages, + # ) + # if SHUFFLED_KV_CACHE: + # K = _unshuffle_kv_cache(K, TILE_SIZE, HEAD_SIZE) + # K = gl.convert_layout(value=K, layout=K_DOT_LAYOUT, assert_trivial=True) + + # V_load : shape = (TILE_SIZE, HEAD_SIZE), layout = Q_LOAD_LAYOUT + v_from_hbm = _async_load_to_lds( + v_from_hbm, + dest=smem_V, + ptr=value_cache_ptr, + offsets=v_offset, + mask=v_mask, + num_stages=num_stages, + cache_modifier=KV_cache_modifier, + use_buffer_load=USE_LOAD_BUFFER_OP, + ) + + # K, k_from_lds = _request_from_lds( + # k_from_lds, + # k_scale, + # Q.dtype, + # smem_K, + # wait_group=(num_stages - 1) * 2 + 1, + # LOAD_LAYOUT=K_LOAD_LAYOUT, + # DOT_LAYOUT=K_DOT_LAYOUT, + # SHUFFLED_KV_CACHE=SHUFFLED_KV_CACHE, + # num_stages=num_stages, + # ) + # if SHUFFLED_KV_CACHE: + # K = _unshuffle_kv_cache(K, TILE_SIZE, HEAD_SIZE) + # K = gl.convert_layout(value=K, layout=K_DOT_LAYOUT, assert_trivial=True) + + # P : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT + # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + P, L, M, acc = _perform_QK_wmma_and_update_L_M( + Q, + K, + L, + M, + acc, + qq_bias_row_ptrs, + seq_offset, + query_mask_qk, + query_pos_qk, + context_len, + alibi_slope, + qq_bias_stride_0, + qk_scale, + softcap, + RCP_LN2, + BLOCK_M, + TILE_SIZE, + USE_SOFTCAP, + SLIDING_WINDOW, + USE_ALIBI_SLOPES, + USE_QQ_BIAS, + Q_LOAD_LAYOUT, + QK_WMMA_LAYOUT, + PV_WMMA_LAYOUT, + ) + + # V : shape = (TILE_SIZE, HEAD_SIZE), layout = V_DOT_LAYOUT + V, v_from_lds = _request_from_lds( + v_from_lds, + v_scale, + Q.dtype, + smem_V, + wait_group=(num_stages - 1) * 2, + LOAD_LAYOUT=V_LOAD_LAYOUT, + DOT_LAYOUT=V_DOT_LAYOUT, + SHUFFLED_KV_CACHE=SHUFFLED_KV_CACHE, + num_stages=num_stages, + ) + if SHUFFLED_KV_CACHE: + V = _unshuffle_kv_cache(V, HEAD_SIZE, TILE_SIZE) + V = gl.convert_layout(value=V, layout=V_DOT_LAYOUT, assert_trivial=True) + + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + acc = _perform_PV_wmma(P, V, acc, P_DOT_LAYOUT) + + seq_offset += TILE_SIZE + + for _ in range(num_stages - 1): + # K : shape = (HEAD_SIZE, TILE_SIZE), layout = K_DOT_LAYOUT + K, k_from_lds = _request_from_lds( + k_from_lds, + k_scale, + Q.dtype, + smem_K, + wait_group=(num_stages - 2) * 2 + + 1, # there is no async_copy in the epilogue, hence num_stages - 2 + # wait_group=0, + LOAD_LAYOUT=K_LOAD_LAYOUT, + DOT_LAYOUT=K_DOT_LAYOUT, + SHUFFLED_KV_CACHE=SHUFFLED_KV_CACHE, + num_stages=num_stages, + ) + if SHUFFLED_KV_CACHE: + K = _unshuffle_kv_cache(K, TILE_SIZE, HEAD_SIZE) + K = gl.convert_layout(value=K, layout=K_DOT_LAYOUT, assert_trivial=True) + + # P : shape = (BLOCK_M, TILE_SIZE), layout = Q_LOAD_LAYOUT + # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + P, L, M, acc = _perform_QK_wmma_and_update_L_M( + Q, + K, + L, + M, + acc, + qq_bias_row_ptrs, + seq_offset, + query_mask_qk, + query_pos_qk, + context_len, + alibi_slope, + qq_bias_stride_0, + qk_scale, + softcap, + RCP_LN2, + BLOCK_M, + TILE_SIZE, + USE_SOFTCAP, + SLIDING_WINDOW, + USE_ALIBI_SLOPES, + USE_QQ_BIAS, + Q_LOAD_LAYOUT, + QK_WMMA_LAYOUT, + PV_WMMA_LAYOUT, + ) + + # V : shape = (TILE_SIZE, HEAD_SIZE), layout = V_DOT_LAYOUT + V, v_from_lds = _request_from_lds( + v_from_lds, + v_scale, + Q.dtype, + smem_V, + wait_group=(num_stages - 2) + * 2, # there is no async_copy in the epilogue, hence num_stages - 2 + # wait_group=0, + LOAD_LAYOUT=V_LOAD_LAYOUT, + DOT_LAYOUT=V_DOT_LAYOUT, + SHUFFLED_KV_CACHE=SHUFFLED_KV_CACHE, + num_stages=num_stages, + ) + if SHUFFLED_KV_CACHE: + V = _unshuffle_kv_cache(V, HEAD_SIZE, TILE_SIZE) + V = gl.convert_layout(value=V, layout=V_DOT_LAYOUT, assert_trivial=True) + + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + acc = _perform_PV_wmma(P, V, acc, P_DOT_LAYOUT) + + seq_offset += TILE_SIZE + + # store segm_output + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = Q_LOAD_LAYOUT + acc = gl.convert_layout(acc, layout=Q_LOAD_LAYOUT) + segm_output_offset = ( + query_offset_0[:, None] * (NUM_QUERY_HEADS * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) + + query_offset_1[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) + + segm_idx * HEAD_SIZE + + offs_q_d[None, :] + ) + gl.amd.cdna4.buffer_store( + stored_value=acc, + ptr=segm_output_ptr, + offsets=segm_output_offset, + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + ) + + # store segm_max and segm_expsum + # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) + # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) + segm_offset = ( + query_offset_0 * (NUM_QUERY_HEADS * NUM_SEGMENTS_PER_SEQ) + + query_offset_1 * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + L = gl.convert_layout(L, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) + M = gl.convert_layout(M, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) + gl.amd.cdna4.buffer_store( + stored_value=M, + ptr=segm_max_ptr, + offsets=segm_offset, + mask=query_mask_0 & query_mask_1, + ) + gl.amd.cdna4.buffer_store( + stored_value=L, + ptr=segm_expsum_ptr, + offsets=segm_offset, + mask=query_mask_0 & query_mask_1, + ) + + +@gluon.jit +def _unshuffle_kv_cache( + X, + BLOCK_SIZE_N: gl.constexpr, + BLOCK_SIZE_INNER_DIM: gl.constexpr, +): + return ( + X.reshape( + 1, + BLOCK_SIZE_N // 16, + BLOCK_SIZE_INNER_DIM // 16, + 2, + 16, + 8, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_INNER_DIM) + .trans(1, 0) + ) + + +@gluon.jit +def _buffer_load_to_reg( + x_scale, + Q_dtype, + ptr, + offsets, + mask, + other, + cache_modifier: gl.constexpr, + SHUFFLED_KV_CACHE: gl.constexpr = False, +): + X = gl.amd.cdna4.buffer_load( + ptr=ptr, + offsets=offsets.to(gl.int32), + mask=mask, + other=other, + cache=cache_modifier, + ) + if X.dtype.is_fp8() and not Q_dtype.is_fp8(): + X = (X.to(gl.float32) * gl.load(x_scale)).to(Q_dtype) + return X + + +gluon_kernel_unified_attention_3d_repr = make_kernel_repr( + "gluon_kernel_unified_attention_3d", + [ + "num_query_heads", + "num_queries_per_kv", + "BLOCK_SIZE", + "TILE_SIZE", + "HEAD_SIZE", + "num_warps", + "num_stages", + "ALL_DECODE", + "SHUFFLED_KV_CACHE", + ], +) + + +@gluon.jit(repr=gluon_kernel_unified_attention_3d_repr) +def gluon_kernel_unified_attention_3d( + segm_output_ptr, # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] + value_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_seqs: gl.int32, # int + num_blocks: gl.int32, # int + query_stride_0: gl.int32, # int + query_stride_1: gl.int32, # int, should be equal to head_size + qq_bias_stride_0: gl.int32, # int + USE_ALIBI_SLOPES: gl.constexpr, # bool + USE_QQ_BIAS: gl.constexpr, # bool + USE_SOFTCAP: gl.constexpr, # bool + USE_SINKS: gl.constexpr, # bool + SLIDING_WINDOW: gl.constexpr, # int + stride_k_cache_0: gl.int64, # int + stride_k_cache_1: gl.int64, # int + stride_k_cache_2: gl.int64, # int + stride_k_cache_3: gl.constexpr, # int + stride_v_cache_0: gl.int64, # int + stride_v_cache_1: gl.int64, # int + stride_v_cache_2: gl.int64, # int + stride_v_cache_3: gl.constexpr, # int + block_table_stride: gl.int64, # int + query_start_len_ptr, # [num_seqs+1] + SCALE: gl.constexpr, # float32 + NUM_QUERY_HEADS: gl.constexpr, # int + NUM_KV_HEADS: gl.constexpr, # int + BLOCK_SIZE: gl.constexpr, # int + TILE_SIZE: gl.constexpr, # int + HEAD_SIZE: gl.constexpr, # int + BLOCK_Q: gl.constexpr, # int + BLOCK_M: gl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: gl.constexpr, # int + WARP_SIZE: gl.constexpr, # int + num_warps: gl.constexpr, # int + waves_per_eu: gl.constexpr, # int + num_stages: gl.constexpr, # int + QK_WMMA_LAYOUT: gl.constexpr, + PV_WMMA_LAYOUT: gl.constexpr, + Q_DOT_LAYOUT: gl.constexpr, + K_DOT_LAYOUT: gl.constexpr, + P_DOT_LAYOUT: gl.constexpr, + V_DOT_LAYOUT: gl.constexpr, + Q_SHARED_LAYOUT: gl.constexpr, + K_SHARED_LAYOUT: gl.constexpr, + V_SHARED_LAYOUT: gl.constexpr, + Q_LOAD_LAYOUT: gl.constexpr, + K_LOAD_LAYOUT: gl.constexpr, + V_LOAD_LAYOUT: gl.constexpr, + ALL_DECODE: gl.constexpr = False, # bool + SHUFFLED_KV_CACHE: gl.constexpr = False, # bool + USE_LOAD_BUFFER_OP: gl.constexpr = False, # bool +): + q_block_global_idx = gl.program_id(0) + kv_head_idx = gl.program_id(1) + segm_idx = gl.program_id(2) + + # needed to use exp2 (exp2 -> exp conversion) + RCP_LN2 = 1.4426950408889634 + qk_scale = SCALE * RCP_LN2 + + seq_idx = _find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True + ) + + q_block_local_idx, cur_batch_query_len, cur_batch_in_all_start_index = ( + _get_q_metadata( + query_start_len_ptr, + seq_idx, + q_block_global_idx, + BLOCK_Q, + ) + ) + + if q_block_local_idx * BLOCK_Q >= cur_batch_query_len: + return + + seq_len, tiles_per_segment = _get_seq_metadata( + seq_lens_ptr, + seq_idx, + TILE_SIZE, + NUM_SEGMENTS_PER_SEQ, + ) + + if segm_idx * tiles_per_segment * TILE_SIZE >= seq_len: + return + + # block table offset for this particular sequence + block_table_offset = seq_idx * block_table_stride + + # context length for this particular sequence + context_len = seq_len - cur_batch_query_len + + smem_Q = gl.allocate_shared_memory( + query_ptr.type.element_ty, [BLOCK_M, HEAD_SIZE], layout=Q_SHARED_LAYOUT + ) + smem_K = None + smem_V = None + if SHUFFLED_KV_CACHE: + # pass + smem_K = gl.allocate_shared_memory( + key_cache_ptr.type.element_ty, + [TILE_SIZE // 16, HEAD_SIZE * 16], + layout=K_SHARED_LAYOUT, + ) + smem_V = gl.allocate_shared_memory( + value_cache_ptr.type.element_ty, + [HEAD_SIZE // 16, TILE_SIZE * 16], + layout=V_SHARED_LAYOUT, + ) + else: + smem_K = gl.allocate_shared_memory( + key_cache_ptr.type.element_ty, + [HEAD_SIZE, TILE_SIZE], + layout=K_SHARED_LAYOUT, + ) + smem_V = gl.allocate_shared_memory( + value_cache_ptr.type.element_ty, + [TILE_SIZE, HEAD_SIZE], + layout=V_SHARED_LAYOUT, + ) + + offs_q_m = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT)) + offs_q_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, Q_LOAD_LAYOUT)) + + if SHUFFLED_KV_CACHE: + offs_k_t = gl.arange( + 0, TILE_SIZE // 16, layout=gl.SliceLayout(1, K_LOAD_LAYOUT) + ) + offs_k_d = gl.arange(0, HEAD_SIZE * 16, layout=gl.SliceLayout(0, K_LOAD_LAYOUT)) + offs_v_t = gl.arange(0, TILE_SIZE * 16, layout=gl.SliceLayout(0, V_LOAD_LAYOUT)) + offs_v_d = gl.arange( + 0, HEAD_SIZE // 16, layout=gl.SliceLayout(1, V_LOAD_LAYOUT) + ) + else: + offs_k_t = gl.arange(0, TILE_SIZE, layout=gl.SliceLayout(0, K_LOAD_LAYOUT)) + offs_k_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(1, K_LOAD_LAYOUT)) + offs_v_t = gl.arange(0, TILE_SIZE, layout=gl.SliceLayout(1, V_LOAD_LAYOUT)) + offs_v_d = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, V_LOAD_LAYOUT)) + + num_queries_per_kv: gl.constexpr = NUM_QUERY_HEADS // NUM_KV_HEADS + query_pos = q_block_local_idx * BLOCK_Q + offs_q_m // num_queries_per_kv + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = kv_head_idx * num_queries_per_kv + offs_q_m % num_queries_per_kv + query_offset = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_q_d[None, :] + ) + + dim_mask = gl.full((1,), 1, dtype=tl.int1) + + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < NUM_QUERY_HEADS + + # Q_load : shape = (BLOCK_M, HEAD_SIZE), layout = Q_LOAD_LAYOUT + Q_load = gl.amd.cdna4.buffer_load( + ptr=query_ptr, + offsets=query_offset.to(gl.int32), + mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + smem_Q.store(Q_load) + # Q : shape = (BLOCK_M, HEAD_SIZE), layout = Q_DOT_LAYOUT + Q = smem_Q.load(layout=Q_DOT_LAYOUT) + + offs_q_m_qk = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, QK_WMMA_LAYOUT)) + offs_q_d_qk = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, QK_WMMA_LAYOUT)) + query_pos_qk = q_block_local_idx * BLOCK_Q + offs_q_m_qk // num_queries_per_kv + query_offset_0_qk = cur_batch_in_all_start_index + query_pos_qk + query_offset_1_qk = ( + kv_head_idx * num_queries_per_kv + offs_q_m_qk % num_queries_per_kv + ) + query_mask_0_qk = query_pos_qk < cur_batch_query_len + query_mask_1_qk = query_offset_1_qk < NUM_QUERY_HEADS + query_mask_qk = query_mask_1_qk[:, None] & query_mask_0_qk[:, None] + offs_seq_t = gl.arange(0, TILE_SIZE, layout=gl.SliceLayout(0, QK_WMMA_LAYOUT)) + + L, M, acc = _allocate_L_M_acc( + sink_ptr, + segm_idx, + query_offset_1_qk, + query_mask_1_qk, + RCP_LN2, + BLOCK_M, + HEAD_SIZE, + QK_WMMA_LAYOUT, + PV_WMMA_LAYOUT, + USE_SINKS, + ) + + # alibi slope for this head + alibi_slope = None + if USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1_qk, mask=query_mask_1, other=0.0 + ) + + # query-query attention bias + qq_bias_row_ptrs = None + if USE_QQ_BIAS: + qq_bias_row_ptrs = ( + qq_bias_ptr + query_pos_qk[:, None] * qq_bias_stride_0 + ) # shape: [BLOCK_M] + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * BLOCK_Q + + (BLOCK_M - 1) // num_queries_per_kv + + 1 + ) + + # adjust for potential padding in the last q_block by considering the + # actual sequence length + max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len) + + # calculate the number of tiles that need to be processed to + # cover the longest sequence prefix (due to causal masking, tiles beyond + # this prefix can be skipped) + num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) + + KV_cache_modifier: tl.constexpr = ".cg" if ALL_DECODE else "" + # iterate through tiles within current segment + for j in range( + segm_idx * tiles_per_segment, + min((segm_idx + 1) * tiles_per_segment, num_tiles), + ): + # seq_k_offset : shape = (TILE_SIZE if not SHUFFLED_KV_CACHE else TILE_SIZE // 16, ), layout = gl.SliceLayout(0 if not SHUFFLED_KV_CACHE else 1, K_LOAD_LAYOUT) + # seq_v_offset : shape = (TILE_SIZE, ), layout = gl.SliceLayout(1, Q_LOAD_LAYOUT) + # seq_offset : shape = (TILE_SIZE, ), layout = gl.SliceLayout(0, QK_WMMA_LAYOUT) + seq_offset = j * TILE_SIZE + offs_seq_t + + k_mask = None + v_mask = None + other = None + if SHUFFLED_KV_CACHE: + # seq_k_offset = j * TILE_SIZE + offs_k_t * 16 + # seq_v_offset = j * TILE_SIZE + offs_v_t // 16 + # physical_block_idx_k = gl.amd.cdna4.buffer_load( + # ptr=block_tables_ptr, + # offsets=(block_table_offset + seq_k_offset // BLOCK_SIZE).to(gl.int32), + # ).to(tl.int64) + + # physical_block_idx_v = gl.amd.cdna4.buffer_load( + # ptr=block_tables_ptr, + # offsets=(block_table_offset + seq_v_offset // BLOCK_SIZE).to(gl.int32), + # ).to(tl.int64) + physical_block_idx = gl.load(block_tables_ptr + block_table_offset + j).to( + tl.int64 + ) + + k_offset = ( + # physical_block_idx_k[:, None] * stride_k_cache_0 + physical_block_idx * stride_k_cache_0 + + kv_head_idx * stride_k_cache_1 + + offs_k_t[:, None] * stride_k_cache_2 + + offs_k_d[None, :] * stride_k_cache_3 + ) + v_offset = ( + # physical_block_idx_v[None, :] * stride_v_cache_0 + physical_block_idx * stride_v_cache_0 + + kv_head_idx * stride_v_cache_1 + + offs_v_t[None, :] * stride_v_cache_3 + + offs_v_d[:, None] * stride_v_cache_2 + ) + else: + seq_k_offset = j * TILE_SIZE + offs_k_t + seq_v_offset = j * TILE_SIZE + offs_v_t + + if TILE_SIZE == BLOCK_SIZE: + tile_k_mask = gl.full( + (1,), 1, dtype=tl.int1, layout=gl.SliceLayout(0, K_LOAD_LAYOUT) + ) + tile_v_mask = gl.full( + (1,), 1, dtype=tl.int1, layout=gl.SliceLayout(1, Q_LOAD_LAYOUT) + ) + else: + tile_k_mask = seq_k_offset < max_seq_prefix_len + tile_v_mask = seq_v_offset < max_seq_prefix_len + + physical_block_idx_k = gl.amd.cdna4.buffer_load( + ptr=block_tables_ptr, + offsets=(block_table_offset + seq_k_offset // BLOCK_SIZE).to(gl.int32), + ).to(tl.int64) + + physical_block_idx_v = gl.amd.cdna4.buffer_load( + ptr=block_tables_ptr, + offsets=(block_table_offset + seq_v_offset // BLOCK_SIZE).to(gl.int32), + ).to(tl.int64) + + k_offset = ( + physical_block_idx_k[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_k_d[:, None] * stride_k_cache_3 + + (seq_k_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + k_mask = dim_mask[:, None] & tile_k_mask[None, :] + v_offset = ( + physical_block_idx_v[:, None] * stride_v_cache_0 + + kv_head_idx * stride_v_cache_2 + + offs_v_d[None, :] * stride_v_cache_3 + + (seq_v_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 + ) + v_mask = dim_mask[None, :] & tile_v_mask[:, None] + other = 0.0 + + # K_load : shape = (HEAD_SIZE, TILE_SIZE), layout = K_LOAD_LAYOUT + K = _buffer_load_to_reg( + k_scale, + Q.dtype, + key_cache_ptr, + k_offset.to(gl.int32), + k_mask, + other, + KV_cache_modifier, + SHUFFLED_KV_CACHE, + ) + if SHUFFLED_KV_CACHE: + # smem_K.store(K) + pass + else: + smem_K.store(K) + + # V_load : shape = (TILE_SIZE, HEAD_SIZE), layout = Q_LOAD_LAYOUT + V = _buffer_load_to_reg( + v_scale, + Q.dtype, + value_cache_ptr, + v_offset.to(gl.int32), + v_mask, + other, + KV_cache_modifier, + ) + if SHUFFLED_KV_CACHE: + # smem_V.store(V) + pass + else: + smem_V.store(V) + + if SHUFFLED_KV_CACHE: + # K = smem_K.load(layout=K_LOAD_LAYOUT) + K = _unshuffle_kv_cache(K, TILE_SIZE, HEAD_SIZE) + K = gl.convert_layout(value=K, layout=K_DOT_LAYOUT, assert_trivial=True) + else: + K = smem_K.load(layout=K_DOT_LAYOUT) + P, L, M, acc = _perform_QK_wmma_and_update_L_M( + Q, + K, + L, + M, + acc, + qq_bias_row_ptrs, + seq_offset, + query_mask_qk, + query_pos_qk, + context_len, + alibi_slope, + qq_bias_stride_0, + qk_scale, + softcap, + RCP_LN2, + BLOCK_M, + TILE_SIZE, + USE_SOFTCAP, + SLIDING_WINDOW, + USE_ALIBI_SLOPES, + USE_QQ_BIAS, + Q_LOAD_LAYOUT, + QK_WMMA_LAYOUT, + PV_WMMA_LAYOUT, + ) + + if SHUFFLED_KV_CACHE: + # V = smem_V.load(layout=V_LOAD_LAYOUT) + V = _unshuffle_kv_cache(V, HEAD_SIZE, TILE_SIZE) + V = gl.convert_layout(value=V, layout=V_DOT_LAYOUT, assert_trivial=True) + else: + V = smem_V.load(layout=V_DOT_LAYOUT) + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + acc = _perform_PV_wmma(P, V, acc, P_DOT_LAYOUT) + + # store segm_output + # acc : shape = (BLOCK_M, HEAD_SIZE), layout = PV_WMMA_LAYOUT + offs_q_m_pv = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, PV_WMMA_LAYOUT)) + offs_q_d_pv = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, PV_WMMA_LAYOUT)) + query_pos_pv = q_block_local_idx * BLOCK_Q + offs_q_m_pv // num_queries_per_kv + query_offset_0_pv = cur_batch_in_all_start_index + query_pos_pv + query_offset_1_pv = ( + kv_head_idx * num_queries_per_kv + offs_q_m_pv % num_queries_per_kv + ) + query_mask_0_pv = query_pos_pv < cur_batch_query_len + query_mask_1_pv = query_offset_1_pv < NUM_QUERY_HEADS + query_mask_pv = query_mask_1_pv[:, None] & query_mask_0_pv[:, None] + segm_output_offset = ( + query_offset_0_pv[:, None] + * (NUM_QUERY_HEADS * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) + + query_offset_1_pv[:, None] * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) + + segm_idx * HEAD_SIZE + + offs_q_d_pv[None, :] + ) + gl.amd.cdna4.buffer_store( + stored_value=acc, + ptr=segm_output_ptr, + offsets=segm_output_offset, + mask=dim_mask[None, :] & query_mask_pv, + ) + + # store segm_max and segm_expsum + # L : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) + # M : shape = (BLOCK_M, ), layout = gl.SliceLayout(1, QK_WMMA_LAYOUT) + segm_offset = ( + query_offset_0_qk * (NUM_QUERY_HEADS * NUM_SEGMENTS_PER_SEQ) + + query_offset_1_qk * NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + gl.amd.cdna4.buffer_store( + stored_value=M, + ptr=segm_max_ptr, + offsets=segm_offset, + mask=query_mask_0_qk & query_mask_1_qk, + ) + gl.amd.cdna4.buffer_store( + stored_value=L, + ptr=segm_expsum_ptr, + offsets=segm_offset, + mask=query_mask_0_qk & query_mask_1_qk, + ) + + +@triton.jit +def gluon_reduce_segments( + output_ptr, # [num_tokens, num_query_heads, head_size] + segm_output_ptr, + # [num_tokens, num_query_heads, max_num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] + seq_lens_ptr, # [num_seqs] + num_seqs, # int + num_query_heads: tl.constexpr, # int + out_scale_inv, # float32 + output_stride_0: tl.int64, # int + output_stride_1: tl.int64, # int, should be equal to head_size + block_table_stride: tl.int64, # int + TILE_SIZE: tl.constexpr, # int + HEAD_SIZE: tl.constexpr, # int, must be power of 2 + query_start_len_ptr, # [num_seqs+1] + BLOCK_Q: tl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int + USE_FP8: tl.constexpr, # bool + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, +): + query_token_idx = tl.program_id(0) + query_head_idx = tl.program_id(1) + + seq_idx = _find_seq_idx( + query_start_len_ptr, query_token_idx, num_seqs, BLOCK_Q, False + ) + + # sequence len for this particular sequence + seq_len = tl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + # create masks for subsequent loads + act_num_segments = cdiv_fn(seq_len, tiles_per_segment * TILE_SIZE) + segm_mask = tl.arange(0, NUM_SEGMENTS_PER_SEQ) < tl.full( + [NUM_SEGMENTS_PER_SEQ], act_num_segments, dtype=tl.int32 + ) + + dim_mask = tl.full((1,), 1, dtype=tl.int1) + + # load segment maxima + segm_offset = ( + query_token_idx.to(tl.int64) * (num_query_heads * NUM_SEGMENTS_PER_SEQ) + + query_head_idx * NUM_SEGMENTS_PER_SEQ + + tl.arange(0, NUM_SEGMENTS_PER_SEQ) + ) + segm_max = tl.load(segm_max_ptr + segm_offset, mask=segm_mask, other=float("-inf")) + overall_max = tl.max(segm_max) + + # load and rescale segment exp sums + segm_expsum = tl.load(segm_expsum_ptr + segm_offset, mask=segm_mask, other=0.0) + segm_expsum = segm_expsum * tl.math.exp2(segm_max - overall_max) + overall_expsum = tl.sum(segm_expsum) + + # load, rescale, and add segment attention outputs + segm_output_offset = ( + query_token_idx.to(tl.int64) + * (num_query_heads * NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) + + query_head_idx * (NUM_SEGMENTS_PER_SEQ * HEAD_SIZE) + + tl.arange(0, NUM_SEGMENTS_PER_SEQ)[:, None] * HEAD_SIZE + + tl.arange(0, HEAD_SIZE)[None, :] + ) + segm_output = tl.load( + segm_output_ptr + segm_output_offset, + mask=segm_mask[:, None] & dim_mask[None, :], + other=0.0, + ) + segm_output *= tl.math.exp2(segm_max - overall_max)[:, None] + acc_sum = tl.sum(segm_output, axis=0) + # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 + acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) + + if USE_FP8: + acc = acc * tl.load(out_scale_inv) + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) + + # write result + output_offset = ( + query_token_idx * output_stride_0 + + query_head_idx * output_stride_1 + + tl.arange(0, HEAD_SIZE) + ) + acc = acc.to(output_ptr.type.element_ty) + tl.store(output_ptr + output_offset, acc, mask=dim_mask) diff --git a/aiter/ops/triton/gluon/unified_attention_3d_kernel_tdm.py b/aiter/ops/triton/gluon/unified_attention_3d_kernel_tdm.py new file mode 100644 index 0000000000..e10192b777 --- /dev/null +++ b/aiter/ops/triton/gluon/unified_attention_3d_kernel_tdm.py @@ -0,0 +1,1649 @@ +# The kernels in this file are adapted from vLLM: +# https://github.com/vllm-project/vllm/blob/main/vllm/attention/ops/triton_unified_attention.py +from re import T +import triton +import triton.language as tl +import torch +from aiter.ops.triton.utils.types import e4m3_dtype +from triton.experimental import gluon +import triton.experimental.gluon.language as gl +import aiter.ops.triton.utils._triton.arch_info as arch_info +from triton.language.core import _aggregate as aggregate +from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr + +import math + +# from triton._C.libtriton.gluon_ir import make_cga_layout + +DEVICE_ARCH = arch_info.get_arch() +IS_DEVICE_ARCH_GFX12 = DEVICE_ARCH in ("gfx1250",) +MMA_operation: gl.constexpr = ( + gl.amd.gfx1250.wmma if gl.constexpr(IS_DEVICE_ARCH_GFX12) else gl.amd.cdna4.mfma +) +WARP_SIZE = 32 if IS_DEVICE_ARCH_GFX12 else 64 +WAPR_SIZE_LOG2 = int(math.log2(WARP_SIZE)) + +float8_info = torch.finfo(e4m3_dtype) + + +@gluon.jit +def apply_softcap(S, x): + Sdiv = S / x + p1 = tl.math.exp2(Sdiv) + p2 = tl.math.exp2(-Sdiv) + return x * (p1 - p2) / (p1 + p2) + + +@aggregate +class AttentionConfig: + """Configuration for unified attention layouts and derived constants.""" + + # Core dimensions + HEAD_SIZE: gl.constexpr + BLOCK_SIZE: gl.constexpr + NUM_BLOCKS_GATHER_PER_TILE: gl.constexpr + NUM_SEGMENTS_PER_SEQ: gl.constexpr + BLOCK_M: gl.constexpr + NUM_QUERY_HEADS: gl.constexpr + NUM_KV_HEADS: gl.constexpr + SLIDING_WINDOW: gl.constexpr + + # Derived constants + TILE_SIZE: gl.constexpr + NUM_QUERIES_PER_KV: gl.constexpr + BLOCK_Q: gl.constexpr + RCP_LN2: gl.constexpr + QK_SCALE: gl.constexpr + + # Operator layouts (CDNA4 MFMA) + QK_WMMA_LAYOUT: gl.constexpr + PV_WMMA_LAYOUT: gl.constexpr + + # Dot operand layouts + Q_DOT_LAYOUT: gl.constexpr + K_DOT_LAYOUT: gl.constexpr + V_DOT_LAYOUT: gl.constexpr + P_DOT_LAYOUT: gl.constexpr + + # Layout for loading Q + Q_LOAD_LAYOUT: gl.constexpr + + # Shared memory layouts + Q_SHARED_LAYOUT: gl.constexpr + K_SHARED_LAYOUT: gl.constexpr + V_SHARED_LAYOUT: gl.constexpr + GATHER_BLOCKED_LAYOUT: gl.constexpr + K_LOAD_LAYOUT: gl.constexpr + V_LOAD_LAYOUT: gl.constexpr + + q_cache_modifier: gl.constexpr + kv_cache_modifier: gl.constexpr + + USE_ALIBI_SLOPES: gl.constexpr + USE_QQ_BIAS: gl.constexpr + USE_SOFTCAP: gl.constexpr + USE_SINKS: gl.constexpr + USE_LOAD_BUFFER_OP: gl.constexpr + USE_STORE_BUFFER_OP: gl.constexpr + + NUM_STAGES: gl.constexpr + SHUFFLED_KV_CACHE: gl.constexpr + + @gluon.constexpr_function + def __init__( + self, + HEAD_SIZE, + BLOCK_SIZE, + NUM_BLOCKS_GATHER_PER_TILE, + NUM_SEGMENTS_PER_SEQ, + BLOCK_M, + BLOCK_Q, + NUM_QUERY_HEADS, + NUM_KV_HEADS, + SLIDING_WINDOW, + NUM_WARPS, + WARP_SIZE, + NUM_STAGES, + SCALE, + USE_ALIBI_SLOPES, + USE_QQ_BIAS, + USE_SOFTCAP, + USE_SINKS, + USE_LOAD_BUFFER_OP, + USE_STORE_BUFFER_OP, + SHUFFLED_KV_CACHE, + ): + # Constants + self.HEAD_SIZE = gl.constexpr(HEAD_SIZE) + self.BLOCK_SIZE = gl.constexpr(BLOCK_SIZE) + self.NUM_BLOCKS_GATHER_PER_TILE = gl.constexpr(NUM_BLOCKS_GATHER_PER_TILE) + self.NUM_SEGMENTS_PER_SEQ = gl.constexpr(NUM_SEGMENTS_PER_SEQ) + self.BLOCK_M = gl.constexpr(BLOCK_M) + self.NUM_QUERY_HEADS = gl.constexpr(NUM_QUERY_HEADS) + self.NUM_KV_HEADS = gl.constexpr(NUM_KV_HEADS) + self.SLIDING_WINDOW = gl.constexpr(SLIDING_WINDOW) + self.NUM_STAGES = gl.constexpr(NUM_STAGES) + self.SHUFFLED_KV_CACHE = gl.constexpr(SHUFFLED_KV_CACHE) + # Derived constants + self.TILE_SIZE = gl.constexpr(BLOCK_SIZE * NUM_BLOCKS_GATHER_PER_TILE) + self.NUM_QUERIES_PER_KV = gl.constexpr(NUM_QUERY_HEADS // NUM_KV_HEADS) + self.BLOCK_Q = gl.constexpr(BLOCK_Q) + self.RCP_LN2 = gl.constexpr(1.4426950408889634) + self.QK_SCALE = gl.constexpr(SCALE * self.RCP_LN2) + self.USE_ALIBI_SLOPES = gl.constexpr(USE_ALIBI_SLOPES) + self.USE_QQ_BIAS = gl.constexpr(USE_QQ_BIAS) + self.USE_SOFTCAP = gl.constexpr(USE_SOFTCAP) + self.USE_SINKS = gl.constexpr(USE_SINKS) + self.USE_LOAD_BUFFER_OP = gl.constexpr(USE_LOAD_BUFFER_OP) + self.USE_STORE_BUFFER_OP = gl.constexpr(USE_STORE_BUFFER_OP) + + # gl.static_assert(NUM_WARPS == 2 or NUM_WARPS == 4, "NUM_WARPS must be 2 or 4") + assert NUM_WARPS == 2 or NUM_WARPS == 4 + + if NUM_WARPS == 2: + warp_bases_qk = [(1, 0)] + warp_bases_pv = [(0, 1)] + elif NUM_WARPS == 4: + warp_bases_qk = [(1, 0), (2, 0)] + warp_bases_pv = [(0, 1), (0, 2)] + + # gl.static_assert( + # WARP_SIZE == 32 or WARP_SIZE == 64, "WARP_SIZE must be 32 or 64" + # ) + assert WARP_SIZE == 32 + + self.QK_WMMA_LAYOUT = gl.constexpr( + gl.amd.AMDWMMALayout( + version=3, + transposed=True, + warp_bases=warp_bases_qk, + # warp_bases=[(1 << i, 0) for i in range(int(math.log2(NUM_WARPS)))], + reg_bases=[], + instr_shape=[16, 16, 32], + ) + ) + + self.PV_WMMA_LAYOUT = gl.constexpr( + gl.amd.AMDWMMALayout( + version=3, + transposed=True, + warp_bases=warp_bases_pv, + # warp_bases=[(0, 1 << i) for i in range(int(math.log2(NUM_WARPS)))], + reg_bases=[], + instr_shape=[16, 16, 32], + ) + ) + self.Q_DOT_LAYOUT = gl.constexpr( + gl.DotOperandLayout(operand_index=0, parent=self.QK_WMMA_LAYOUT, k_width=8) + ) + self.K_DOT_LAYOUT = gl.constexpr( + gl.DotOperandLayout(operand_index=1, parent=self.QK_WMMA_LAYOUT, k_width=8) + ) + self.P_DOT_LAYOUT = gl.constexpr( + gl.DotOperandLayout(operand_index=0, parent=self.PV_WMMA_LAYOUT, k_width=8) + ) + self.V_DOT_LAYOUT = gl.constexpr( + gl.DotOperandLayout(operand_index=1, parent=self.PV_WMMA_LAYOUT, k_width=8) + ) + + # gl.static_assert( + # NUM_BLOCKS_GATHER_PER_TILE == 1 + # or NUM_BLOCKS_GATHER_PER_TILE == 4 + # or NUM_BLOCKS_GATHER_PER_TILE == 8, + # "NUM_BLOCKS_GATHER_PER_TILE must be 1, 4, or 8", + # ) + assert ( + NUM_BLOCKS_GATHER_PER_TILE == 1 + or NUM_BLOCKS_GATHER_PER_TILE == 4 + or NUM_BLOCKS_GATHER_PER_TILE == 8 + ) + + self.Q_SHARED_LAYOUT = gl.constexpr( + gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[HEAD_SIZE, 8]], + shape=[BLOCK_M, HEAD_SIZE], + order=[1, 0], + ) + ) + + if self.SHUFFLED_KV_CACHE: + self.K_SHARED_LAYOUT = gl.constexpr( + gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0]) + ) + self.V_SHARED_LAYOUT = gl.constexpr( + gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0]) + ) + if NUM_BLOCKS_GATHER_PER_TILE == 1: + self.GATHER_BLOCKED_LAYOUT = gl.constexpr(None) + else: + self.GATHER_BLOCKED_LAYOUT = gl.constexpr( + gl.BlockedLayout( + size_per_thread=[NUM_BLOCKS_GATHER_PER_TILE], + threads_per_warp=[WARP_SIZE], + warps_per_cta=[NUM_WARPS], + order=[0], + ) + ) + elif NUM_BLOCKS_GATHER_PER_TILE == 1: + self.K_SHARED_LAYOUT = gl.constexpr( + gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[HEAD_SIZE, 8]], + shape=([BLOCK_SIZE, HEAD_SIZE]), + order=[1, 0], + ) + ) + self.V_SHARED_LAYOUT = gl.constexpr( + gl.PaddedSharedLayout.with_identity_for( + interval_padding_pairs=[[HEAD_SIZE, 8]], + shape=[BLOCK_SIZE, HEAD_SIZE], + order=[1, 0], + ) + ) + self.GATHER_BLOCKED_LAYOUT = gl.constexpr(None) + else: + self.K_SHARED_LAYOUT = gl.constexpr( + gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0]) + ) + self.V_SHARED_LAYOUT = gl.constexpr( + gl.SwizzledSharedLayout(vec=1, per_phase=1, max_phase=1, order=[1, 0]) + ) + # TODO Disabled PaddedSharedLayout for now + # self.K_SHARED_LAYOUT = gl.constexpr( + # gl.PaddedSharedLayout.with_identity_for( + # interval_padding_pairs=[[BLOCK_SIZE * HEAD_SIZE, 8]], + # shape=([NUM_BLOCKS_GATHER_PER_TILE, BLOCK_SIZE * HEAD_SIZE]), + # order=[1, 0], + # ) + # ) + # self.V_SHARED_LAYOUT = gl.constexpr( + # gl.PaddedSharedLayout.with_identity_for( + # interval_padding_pairs=[[BLOCK_SIZE * HEAD_SIZE, 8]], + # shape=[NUM_BLOCKS_GATHER_PER_TILE, BLOCK_SIZE * HEAD_SIZE], + # order=[1, 0], + # ) + # ) + self.GATHER_BLOCKED_LAYOUT = gl.constexpr( + gl.BlockedLayout( + size_per_thread=[NUM_BLOCKS_GATHER_PER_TILE], + threads_per_warp=[WARP_SIZE], + warps_per_cta=[NUM_WARPS], + order=[0], + ) + ) + + # size_per_thread along the fastest moving dimension is set to 8 (BF16) + size_per_thread_fastest_dim = gl.constexpr(8) + # size_per_thread * threads_per_warp along the fastest moving dimension is set to HEAD_SIZE with only 1 warp_per_cta, + # therefore, threads_per_warp along the fastest moving dimension should be HEAD_SIZE // size_per_thread_fastest_dim + # clamp the threads_per_warp along the fastest moving dimension to 1 ~ WARP_SIZE + threads_per_warp_fastest_dim = max( + min((HEAD_SIZE // size_per_thread_fastest_dim), WARP_SIZE), 1 + ) + + self.Q_LOAD_LAYOUT = gl.constexpr( + gl.BlockedLayout( + size_per_thread=[1, size_per_thread_fastest_dim], + threads_per_warp=[ + WARP_SIZE // threads_per_warp_fastest_dim, + threads_per_warp_fastest_dim, + ], + warps_per_cta=[NUM_WARPS, 1], + order=[1, 0], + ) + ) + if self.SHUFFLED_KV_CACHE: + if self.NUM_BLOCKS_GATHER_PER_TILE == 1: + # self.K_LOAD_LAYOUT = self.make_kv_cache_shuffled_layout( + # self.TILE_SIZE // 16, + # self.HEAD_SIZE * 16, + # 1, + # ) + # self.V_LOAD_LAYOUT = self.make_kv_cache_shuffled_layout( + # self.HEAD_SIZE // 16, + # self.TILE_SIZE * 16, + # NUM_WARPS, + # ) + self.K_LOAD_LAYOUT = gl.constexpr(None) + self.V_LOAD_LAYOUT = gl.constexpr(None) + else: + self.K_LOAD_LAYOUT = gl.constexpr(None) + self.V_LOAD_LAYOUT = gl.constexpr(None) + else: + self.K_LOAD_LAYOUT = gl.constexpr(None) + self.V_LOAD_LAYOUT = gl.constexpr(None) + + self.q_cache_modifier = gl.constexpr(".cg") + self.kv_cache_modifier = gl.constexpr("") + + @gluon.constexpr_function + def make_kv_cache_shuffled_layout( + self, + BLOCK_SIZE_N_SHFL, + BLOCK_SIZE_INNER_DIM_SHFL, + fastest_dim_num_warps, + dtype=torch.bfloat16, + ): + num_warps_log2 = int(math.log2(fastest_dim_num_warps)) + BLOCK_SIZE_N_SHFL_log2 = int(math.log2(BLOCK_SIZE_N_SHFL)) + BLOCK_SIZE_INNER_DIM_SHFL_log2 = int(math.log2(BLOCK_SIZE_INNER_DIM_SHFL)) + # TODO: support e4m3_dtype and mxfp4x2 + # assert dtype in [torch.bfloat16, e4m3_dtype, torch.uint8], f"Unsupported dtype: {dtype} for making linear layout for shuffled weights" + assert dtype in [ + torch.bfloat16 + ], f"Unsupported dtype: {dtype} for making linear layout for shuffled weights" + if dtype == torch.bfloat16: + # (8 elements per thread for BF16) + coalesced_size_log2 = 3 + elif dtype == e4m3_dtype: + # (16 elements per thread for e4m3_dtype) + coalesced_size_log2 = 4 + else: + # (16*2 elements per thread for mxfp4x2) + coalesced_size_log2 = 4 + assert ( + BLOCK_SIZE_INNER_DIM_SHFL_log2 > coalesced_size_log2 + WAPR_SIZE_LOG2 + ), "BLOCK_SIZE_INNER_DIM_SHFL_log2 must be greater than coalesced_size_log2 + WAPR_SIZE_LOG2, please increase block_size to at least 64" + reg_bases = ( + [[0, 1 << v] for v in range(coalesced_size_log2)] + + [ + [0, 1 << v] + for v in range( + coalesced_size_log2 + WAPR_SIZE_LOG2, BLOCK_SIZE_INNER_DIM_SHFL_log2 + ) + ] + + [ + [0, 1 << v] + for v in range( + num_warps_log2 + BLOCK_SIZE_INNER_DIM_SHFL_log2, + BLOCK_SIZE_INNER_DIM_SHFL_log2 + BLOCK_SIZE_N_SHFL_log2, + ) + ] + ) + lane_bases = [ + [0, 1 << v] + for v in range(coalesced_size_log2, coalesced_size_log2 + WAPR_SIZE_LOG2) + ] + if num_warps_log2 > 0: + warp_bases = [ + [0, 1 << v] + for v in range( + BLOCK_SIZE_INNER_DIM_SHFL_log2, + num_warps_log2 + BLOCK_SIZE_INNER_DIM_SHFL_log2, + ) + ] + else: + warp_bases = [[0, 0]] + + layout = gl.constexpr( + gl.DistributedLinearLayout( + reg_bases=reg_bases, + lane_bases=lane_bases, + warp_bases=warp_bases, + block_bases=[], + shape=[1, BLOCK_SIZE_N_SHFL * BLOCK_SIZE_INNER_DIM_SHFL], + ) + ) + return layout + + +@aggregate +class AttentionProgram: + """Program state and core operations for the unified attention kernel.""" + + cfg: AttentionConfig + + q: gl.tensor + k_shared: gl.shared_memory_descriptor + v_shared: gl.shared_memory_descriptor + + key_cache_ptr: gl.tensor + value_cache_ptr: gl.tensor + output_ptr: gl.tensor + # segm_output_ptr: gl.tensor + segm_max_ptr: gl.tensor + segm_expsum_ptr: gl.tensor + + tile_start: gl.tensor + tile_end: gl.tensor + safe_tile_end: gl.tensor + kv_head_idx: gl.tensor + query_mask_qk: gl.tensor + context_len: gl.tensor + context_len_q_pos_qk: gl.tensor + query_pos_qk: gl.tensor + query_mask_qk: gl.tensor + query_offset_0_qk: gl.tensor + query_offset_1_qk: gl.tensor + query_mask_0_qk: gl.tensor + query_mask_1_qk: gl.tensor + query_offset_0_pv: gl.tensor + query_offset_1_pv: gl.tensor + query_mask_0_pv: gl.tensor + query_mask_1_pv: gl.tensor + + k_desc: gl.amd.gfx1250.tdm.tensor_descriptor + v_desc: gl.amd.gfx1250.tdm.tensor_descriptor + stride_k_cache_0: gl.tensor + stride_k_cache_1: gl.tensor + stride_k_cache_2: gl.tensor + stride_k_cache_3: gl.tensor + stride_v_cache_0: gl.tensor + stride_v_cache_1: gl.tensor + stride_v_cache_2: gl.tensor + stride_v_cache_3: gl.tensor + + qq_bias_stride_0: gl.tensor + softcap: gl.tensor + + @gluon.constexpr_function + def __init__( + self, + cfg, + q, + k_shared, + v_shared, + key_cache_ptr, + value_cache_ptr, + output_ptr, + segm_max_ptr, + segm_expsum_ptr, + tile_start, + tile_end, + safe_tile_end, + kv_head_idx, + context_len, + context_len_q_pos_qk, + query_pos_qk, + query_mask_qk, + query_offset_0_qk, + query_offset_1_qk, + query_mask_0_qk, + query_mask_1_qk, + query_offset_0_pv, + query_offset_1_pv, + query_mask_0_pv, + query_mask_1_pv, + k_desc, + v_desc, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + qq_bias_stride_0, + softcap, + ): + self.cfg = cfg + self.q = q + self.key_cache_ptr = key_cache_ptr + self.value_cache_ptr = value_cache_ptr + self.output_ptr = output_ptr + self.segm_max_ptr = segm_max_ptr + self.segm_expsum_ptr = segm_expsum_ptr + self.k_shared = k_shared + self.v_shared = v_shared + self.k_desc = k_desc + self.v_desc = v_desc + self.tile_start = tile_start + self.tile_end = tile_end + self.safe_tile_end = safe_tile_end + self.context_len = context_len + self.context_len_q_pos_qk = context_len_q_pos_qk + self.query_pos_qk = query_pos_qk + self.query_mask_qk = query_mask_qk + self.query_offset_0_qk = query_offset_0_qk + self.query_offset_1_qk = query_offset_1_qk + self.query_mask_0_qk = query_mask_0_qk + self.query_mask_1_qk = query_mask_1_qk + self.query_offset_0_pv = query_offset_0_pv + self.query_offset_1_pv = query_offset_1_pv + self.query_mask_0_pv = query_mask_0_pv + self.query_mask_1_pv = query_mask_1_pv + self.kv_head_idx = kv_head_idx + self.stride_k_cache_0 = stride_k_cache_0 + self.stride_k_cache_1 = stride_k_cache_1 + self.stride_k_cache_2 = stride_k_cache_2 + self.stride_k_cache_3 = stride_k_cache_3 + self.stride_v_cache_0 = stride_v_cache_0 + self.stride_v_cache_1 = stride_v_cache_1 + self.stride_v_cache_2 = stride_v_cache_2 + self.stride_v_cache_3 = stride_v_cache_3 + self.qq_bias_stride_0 = qq_bias_stride_0 + self.softcap = softcap + + @gluon.jit + def initialize( + cfg: AttentionConfig, + q, + key_cache_ptr, + value_cache_ptr, + output_ptr, + segm_max_ptr, + segm_expsum_ptr, + max_seq_prefix_len, + q_block_local_idx, + cur_batch_query_len, + context_len, + kv_head_idx, + num_blocks, + query_pos_qk, + query_mask_qk, + query_offset_0_qk, + query_offset_1_qk, + query_mask_0_qk, + query_mask_1_qk, + query_offset_0_pv, + query_offset_1_pv, + query_mask_0_pv, + query_mask_1_pv, + segm_idx, + tiles_per_segment, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + qq_bias_stride_0, + softcap, + ): + # the last dimension of the stride should always be 1 + # gl.static_assert(stride_k_cache_3 == 1) + # gl.static_assert(stride_v_cache_3 == 1) + # if cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + # # in TDM mode, KV cache shape should be [num_blocks, BLOCK_SIZE, NUM_KV_HEADS, HEAD_SIZE] + # gl.static_assert(stride_k_cache_0 // stride_k_cache_1 == cfg.BLOCK_SIZE) + # gl.static_assert(stride_v_cache_0 // stride_v_cache_1 == cfg.BLOCK_SIZE) + # gl.static_assert(stride_k_cache_2 == cfg.HEAD_SIZE) + # gl.static_assert(stride_v_cache_2 == cfg.HEAD_SIZE) + # else: + # # in TDM gather mode, KV cache shape should be [num_blocks, NUM_KV_HEADS, BLOCK_SIZE, HEAD_SIZE] + # gl.static_assert(stride_k_cache_0 // stride_k_cache_1 == cfg.NUM_KV_HEADS) + # gl.static_assert(stride_v_cache_0 // stride_v_cache_1 == cfg.NUM_KV_HEADS) + + if cfg.SHUFFLED_KV_CACHE: + if cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=key_cache_ptr, + shape=( + num_blocks * cfg.NUM_KV_HEADS, + cfg.BLOCK_SIZE * cfg.HEAD_SIZE, + ), + strides=(stride_k_cache_1, 1), + block_shape=(gl.constexpr(1), cfg.BLOCK_SIZE * cfg.HEAD_SIZE), + layout=cfg.K_SHARED_LAYOUT, + ) + v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=value_cache_ptr, + shape=( + num_blocks * cfg.NUM_KV_HEADS, + cfg.HEAD_SIZE * cfg.BLOCK_SIZE, + ), + strides=(stride_v_cache_1, 1), + block_shape=(gl.constexpr(1), cfg.HEAD_SIZE * cfg.BLOCK_SIZE), + layout=cfg.V_SHARED_LAYOUT, + ) + else: + k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=key_cache_ptr, + shape=( + num_blocks * cfg.NUM_KV_HEADS, + cfg.BLOCK_SIZE * cfg.HEAD_SIZE, + ), + strides=(stride_k_cache_1, 1), + block_shape=( + cfg.NUM_BLOCKS_GATHER_PER_TILE, + cfg.BLOCK_SIZE * cfg.HEAD_SIZE, + ), + layout=cfg.K_SHARED_LAYOUT, + ) + v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=value_cache_ptr, + shape=( + num_blocks * cfg.NUM_KV_HEADS, + cfg.HEAD_SIZE * cfg.BLOCK_SIZE, + ), + strides=(stride_v_cache_1, 1), + block_shape=( + cfg.NUM_BLOCKS_GATHER_PER_TILE, + cfg.HEAD_SIZE * cfg.BLOCK_SIZE, + ), + layout=cfg.V_SHARED_LAYOUT, + ) + elif cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=key_cache_ptr, + shape=(num_blocks * cfg.BLOCK_SIZE, cfg.NUM_KV_HEADS * cfg.HEAD_SIZE), + strides=(stride_k_cache_1, 1), + block_shape=(cfg.BLOCK_SIZE, cfg.HEAD_SIZE), + layout=cfg.K_SHARED_LAYOUT, + ) + v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=value_cache_ptr, + shape=(num_blocks * cfg.BLOCK_SIZE, cfg.NUM_KV_HEADS * cfg.HEAD_SIZE), + strides=(stride_v_cache_1, 1), + block_shape=(cfg.BLOCK_SIZE, cfg.HEAD_SIZE), + layout=cfg.V_SHARED_LAYOUT, + ) + else: + k_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=key_cache_ptr, + shape=(num_blocks * cfg.NUM_KV_HEADS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE), + strides=(stride_k_cache_1, 1), + block_shape=( + cfg.NUM_BLOCKS_GATHER_PER_TILE, + cfg.BLOCK_SIZE * cfg.HEAD_SIZE, + ), + layout=cfg.K_SHARED_LAYOUT, + ) + v_desc = gl.amd.gfx1250.tdm.make_tensor_descriptor( + base=value_cache_ptr, + shape=(num_blocks * cfg.NUM_KV_HEADS, cfg.BLOCK_SIZE * cfg.HEAD_SIZE), + strides=(stride_v_cache_1, 1), + block_shape=( + cfg.NUM_BLOCKS_GATHER_PER_TILE, + cfg.BLOCK_SIZE * cfg.HEAD_SIZE, + ), + layout=cfg.V_SHARED_LAYOUT, + ) + + k_shared = gl.allocate_shared_memory( + k_desc.dtype, + [cfg.NUM_STAGES] + k_desc.block_shape, + layout=cfg.K_SHARED_LAYOUT, + ) + v_shared = gl.allocate_shared_memory( + v_desc.dtype, + [cfg.NUM_STAGES] + v_desc.block_shape, + layout=cfg.V_SHARED_LAYOUT, + ) + + # Calculate tile range + num_tiles = (max_seq_prefix_len + cfg.BLOCK_SIZE - 1) // cfg.BLOCK_SIZE + tile_start = segm_idx * tiles_per_segment + tile_end = min((segm_idx + 1) * tiles_per_segment, num_tiles) + if cfg.SLIDING_WINDOW > 0: + qpos_lo = q_block_local_idx * cfg.BLOCK_Q + qpos_hi = gl.minimum( + qpos_lo + (cfg.BLOCK_M - 1) // cfg.NUM_QUERIES_PER_KV, + cur_batch_query_len - 1, + ) + first_allowed_key = context_len + qpos_lo - cfg.SLIDING_WINDOW + 1 + last_allowed_key = context_len + qpos_hi + tile_start = gl.maximum(0, first_allowed_key // cfg.BLOCK_SIZE) + tile_end = gl.minimum((last_allowed_key // cfg.BLOCK_SIZE) + 1, num_tiles) + + query_pos_qk = gl.convert_layout( + query_pos_qk, gl.SliceLayout(1, cfg.QK_WMMA_LAYOUT) + )[:, None] + query_mask_qk = gl.convert_layout(query_mask_qk, cfg.QK_WMMA_LAYOUT) + + context_len_q_pos_qk = context_len + query_pos_qk + + # Compute the tile index beyond which causal masking is needed. + # min causal pos = context_len + first query pos in block + # Tiles j < safe_tile_end have all KV positions within causal range + # for every query row, so apply_mask_qk can be skipped. + min_causal_pos = context_len + q_block_local_idx * cfg.BLOCK_Q + safe_tile_end = (min_causal_pos + 1) // cfg.BLOCK_SIZE + safe_tile_end = gl.minimum(safe_tile_end, tile_end) + safe_tile_end = gl.maximum(safe_tile_end, tile_start) + + return AttentionProgram( + cfg, + q, + k_shared, + v_shared, + key_cache_ptr, + value_cache_ptr, + output_ptr, + # segm_output_ptr, + segm_max_ptr, + segm_expsum_ptr, + tile_start, + tile_end, + safe_tile_end, + kv_head_idx, + context_len, + context_len_q_pos_qk, + query_pos_qk, + query_mask_qk, + query_offset_0_qk, + query_offset_1_qk, + query_mask_0_qk, + query_mask_1_qk, + query_offset_0_pv, + query_offset_1_pv, + query_mask_0_pv, + query_mask_1_pv, + k_desc, + v_desc, + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + qq_bias_stride_0, + softcap, + ) + + @gluon.jit + def get_next_buffer_id(self, buffer_id): + if self.cfg.NUM_STAGES == 2: + return 1 - buffer_id + else: + return (buffer_id + 1) % self.cfg.NUM_STAGES + + @gluon.jit + def allocate_accumulator( + self, + sink_ptr, + segm_idx, + query_offset_1, + query_mask_1, + ): + if self.cfg.USE_SINKS: + if segm_idx == 0: + # Prescale with RCP_LN2, needed for exp2 + M = ( + gl.amd.cdna4.buffer_load( + ptr=sink_ptr, + offsets=query_offset_1.to(gl.int32), + mask=query_mask_1, + other=float("-inf"), + ).to(dtype=gl.float32) + * self.cfg.RCP_LN2 + ) + else: + M = gl.full( + [self.cfg.BLOCK_M], + float("-inf"), + dtype=tl.float32, + layout=gl.SliceLayout(1, self.cfg.QK_WMMA_LAYOUT), + ) + else: + M = gl.full( + [self.cfg.BLOCK_M], + float("-inf"), + dtype=tl.float32, + layout=gl.SliceLayout(1, self.cfg.QK_WMMA_LAYOUT), + ) + + L = gl.full( + [self.cfg.BLOCK_M], + 1.0, + dtype=tl.float32, + layout=gl.SliceLayout(1, self.cfg.QK_WMMA_LAYOUT), + ) + acc = gl.zeros( + [self.cfg.BLOCK_M, self.cfg.HEAD_SIZE], + dtype=tl.float32, + layout=self.cfg.PV_WMMA_LAYOUT, + ) + + return L, M, acc + + @gluon.jit + def load_physical_block_idx(self, j, offs_j, block_tables_ptr_shifted): + if self.cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + # TDM load + physical_block_idx = gl.load(block_tables_ptr_shifted + j) + else: + # TDM gather + offs_j = gl.arange( + 0, + self.cfg.NUM_BLOCKS_GATHER_PER_TILE, + layout=self.cfg.GATHER_BLOCKED_LAYOUT, + ) + physical_block_idx = gl.load( + block_tables_ptr_shifted + + j * self.cfg.NUM_BLOCKS_GATHER_PER_TILE + + offs_j + ) + + return j + 1, physical_block_idx + + @gluon.jit + def load_q_from_global( + self, + query_ptr, + q_block_local_idx, + cur_batch_in_all_start_index, + kv_head_idx, + cur_batch_query_len, + query_stride_0, + query_stride_1, + ): + """Load Q from global memory.""" + offs_m = gl.arange( + 0, self.cfg.BLOCK_M, layout=gl.SliceLayout(1, self.cfg.Q_DOT_LAYOUT) + ) + offs_d = gl.arange( + 0, self.cfg.HEAD_SIZE, layout=gl.SliceLayout(0, self.cfg.Q_DOT_LAYOUT) + ) + query_pos = ( + q_block_local_idx * self.cfg.BLOCK_Q + offs_m // self.cfg.NUM_QUERIES_PER_KV + ) + + query_offset_0 = cur_batch_in_all_start_index + query_pos + query_offset_1 = ( + kv_head_idx * self.cfg.NUM_QUERIES_PER_KV + + offs_m % self.cfg.NUM_QUERIES_PER_KV + ) + + query_mask_0 = query_pos < cur_batch_query_len + query_mask_1 = query_offset_1 < self.cfg.NUM_QUERY_HEADS + query_mask = query_mask_0[:, None] & query_mask_1[:, None] + + q_offs = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d[None, :] + ) + if self.cfg.USE_STORE_BUFFER_OP: + q = gl.amd.cdna4.buffer_load( + query_ptr + q_offs, + mask=query_mask, + other=0.0, + cache_modifier=self.cfg.q_cache_modifier, + ) + else: + q = gl.load( + query_ptr + q_offs, + mask=query_mask, + other=0.0, + cache_modifier=self.cfg.q_cache_modifier, + ) + return q, query_pos, query_mask + + @gluon.jit + def unshuffle_k(self, K): + K = ( + K.reshape( + 1, + self.cfg.TILE_SIZE // 16, + self.cfg.HEAD_SIZE // 16, + 2, + 16, + 8, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(self.cfg.TILE_SIZE, self.cfg.HEAD_SIZE) + .trans(1, 0) + ) + return gl.convert_layout( + value=K, layout=self.cfg.K_DOT_LAYOUT, assert_trivial=True + ) + + @gluon.jit + def unshuffle_v(self, V): + V = ( + V.reshape( + 1, + self.cfg.HEAD_SIZE // 16, + self.cfg.TILE_SIZE // 16, + 2, + 16, + 8, + ) + .permute(0, 1, 4, 2, 3, 5) + .reshape(self.cfg.HEAD_SIZE, self.cfg.TILE_SIZE) + .trans(1, 0) + ) + return gl.convert_layout( + value=V, layout=self.cfg.V_DOT_LAYOUT, assert_trivial=True + ) + + @gluon.jit + def lds_unshuffle_k(self, buffer_id): + return ( + self.k_shared.index(buffer_id) + .reshape( + ( + self.cfg.NUM_BLOCKS_GATHER_PER_TILE, + self.cfg.BLOCK_SIZE // 16, + self.cfg.HEAD_SIZE // 16, + 2, + 16, + 8, + ) + ) + .permute((0, 1, 4, 2, 3, 5)) + .reshape((self.cfg.TILE_SIZE, self.cfg.HEAD_SIZE)) + .permute((1, 0)) + ) + + @gluon.jit + def lds_unshuffle_v(self, buffer_id): + return ( + self.v_shared.index(buffer_id) + .reshape( + ( + self.cfg.NUM_BLOCKS_GATHER_PER_TILE, + self.cfg.HEAD_SIZE // 16, + self.cfg.BLOCK_SIZE // 16, + 2, + 16, + 8, + ) + ) + .permute((0, 1, 4, 2, 3, 5)) + .reshape( + ( + self.cfg.NUM_BLOCKS_GATHER_PER_TILE, + self.cfg.HEAD_SIZE, + self.cfg.BLOCK_SIZE, + ) + ) + .permute((1, 0, 2)) + .reshape((self.cfg.HEAD_SIZE, self.cfg.TILE_SIZE)) + .permute((1, 0)) + ) + + @gluon.jit + def tdm_shared_load_k(self, wait_count, buffer_id): + gl.amd.gfx1250.tdm.async_wait(wait_count) + if self.cfg.SHUFFLED_KV_CACHE: + return self.lds_unshuffle_k(buffer_id).load(layout=self.cfg.K_DOT_LAYOUT) + # K = self.k_shared.index(buffer_id).load(layout=self.cfg.K_LOAD_LAYOUT) + # return self.unshuffle_k(K) + + elif self.cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + return ( + self.k_shared.index(buffer_id) + .permute([1, 0]) + .load(layout=self.cfg.K_DOT_LAYOUT) + ) + else: + return ( + self.k_shared.index(buffer_id) + .reshape([self.cfg.TILE_SIZE, self.cfg.HEAD_SIZE]) + .permute([1, 0]) + .load(layout=self.cfg.K_DOT_LAYOUT) + ) + + @gluon.jit + def tdm_shared_load_v(self, wait_count, buffer_id): + gl.amd.gfx1250.tdm.async_wait(wait_count) + if self.cfg.SHUFFLED_KV_CACHE: + return self.lds_unshuffle_v(buffer_id).load(layout=self.cfg.V_DOT_LAYOUT) + # V = self.v_shared.index(buffer_id).load(layout=self.cfg.V_LOAD_LAYOUT) + # return self.unshuffle_v(V) + else: + if self.cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + return self.v_shared.index(buffer_id).load(layout=self.cfg.V_DOT_LAYOUT) + else: + return ( + self.v_shared.index(buffer_id) + .reshape([self.cfg.TILE_SIZE, self.cfg.HEAD_SIZE]) + .load(layout=self.cfg.V_DOT_LAYOUT) + ) + + @gluon.jit + def tdm_load_global_to_shared_k(self, block_idx, buffer_id): + if self.cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + if self.cfg.SHUFFLED_KV_CACHE: + offsets = [ + (block_idx * self.cfg.NUM_KV_HEADS + self.kv_head_idx).to(gl.int32), + 0, + ] + gl.amd.gfx1250.tdm.async_load( + self.k_desc, offsets, self.k_shared.index(buffer_id) + ) + else: + offsets = [ + (block_idx * self.cfg.BLOCK_SIZE).to(gl.int32), + (self.kv_head_idx * self.stride_k_cache_2).to(gl.int32), + ] + gl.amd.gfx1250.tdm.async_load( + self.k_desc, offsets, self.k_shared.index(buffer_id) + ) + else: + # TDM gather handles both shuffled and unshuffled cases in the same way + src_row_indices = (block_idx * self.cfg.NUM_KV_HEADS + self.kv_head_idx).to( + gl.int32 + ) + gl.amd.gfx1250.tdm.async_gather( + self.k_desc, + src_row_indices, + 0, + self.k_shared.index(buffer_id), + ) + + @gluon.jit + def tdm_load_global_to_shared_v(self, block_idx, buffer_id): + if self.cfg.NUM_BLOCKS_GATHER_PER_TILE == 1: + if self.cfg.SHUFFLED_KV_CACHE: + offsets = [ + (block_idx * self.cfg.NUM_KV_HEADS + self.kv_head_idx).to(gl.int32), + 0, + ] + gl.amd.gfx1250.tdm.async_load( + self.v_desc, offsets, self.v_shared.index(buffer_id) + ) + else: + offsets = [ + (block_idx * self.cfg.BLOCK_SIZE).to(gl.int32), + (self.kv_head_idx * self.stride_v_cache_2).to(gl.int32), + ] + gl.amd.gfx1250.tdm.async_load( + self.v_desc, offsets, self.v_shared.index(buffer_id) + ) + else: + # TDM gather handles both shuffled and unshuffled cases in the same way + src_row_indices = (block_idx * self.cfg.NUM_KV_HEADS + self.kv_head_idx).to( + gl.int32 + ) + gl.amd.gfx1250.tdm.async_gather( + self.v_desc, + src_row_indices, + 0, + self.v_shared.index(buffer_id), + ) + + @gluon.jit + def compute_qk(self, k): + S = gl.zeros( + [self.cfg.BLOCK_M, self.cfg.TILE_SIZE], + dtype=gl.float32, + layout=self.cfg.QK_WMMA_LAYOUT, + ) + return gl.amd.gfx1250.wmma(self.q, k, S) * self.cfg.QK_SCALE + + @gluon.jit + def apply_softcap(self, S): + if self.cfg.USE_SOFTCAP: + S = apply_softcap(S, self.softcap) * self.cfg.RCP_LN2 + return S + + @gluon.jit + def apply_mask_qk_3D(self, S, seq_offset, alibi_slope, qq_bias_row_ptrs): + seq_mask = seq_offset[None, :] < self.context_len + self.query_pos_qk + 1 + S = gl.where(self.query_mask_qk & seq_mask, S, float("-inf")) + if self.cfg.SLIDING_WINDOW > 0: + S = gl.where( + (self.context_len + self.query_pos_qk - seq_offset) + < self.cfg.SLIDING_WINDOW, + S, + float("-inf"), + ) + + if self.cfg.USE_ALIBI_SLOPES: + # prescale w. RCP_LN2 for later exp2 + S += ( + alibi_slope[:, None] + * (seq_offset - self.context_len) + * self.cfg.RCP_LN2 + ) + + if self.cfg.USE_QQ_BIAS: + # compute key positions relative to query section + key_rel_pos = seq_offset - self.context_len # shape: [BLOCK_SIZE] + # load bias only for keys that correspond to queries + is_query_key = key_rel_pos >= 0 and key_rel_pos < self.qq_bias_stride_0 + qq_bias = gl.load( + qq_bias_row_ptrs + key_rel_pos[None, :], + mask=is_query_key[None, :], # avoid OOB for context keys + other=0.0, + ) + # prescale w. RCP_LN2 for later exp2 + S += qq_bias * self.cfg.RCP_LN2 + + return S + + @gluon.jit + def apply_mask_qk(self, S, j): + seq_offset = ( + j * self.cfg.TILE_SIZE + + gl.arange( + 0, + self.cfg.TILE_SIZE, + layout=gl.SliceLayout(0, self.cfg.QK_WMMA_LAYOUT), + )[None, :] + ) + + seq_mask = seq_offset <= self.context_len_q_pos_qk + if self.cfg.SLIDING_WINDOW > 0: + seq_mask = seq_mask & ( + (self.context_len_q_pos_qk - seq_offset) < self.cfg.SLIDING_WINDOW + ) + full_mask = seq_mask + S = gl.where(full_mask, S, float("-inf")) + return S + + @gluon.jit + def softmax_part0(self, S, M): + m_ij = gl.maximum(M, gl.max(S, axis=1)) + m_ij = gl.where(m_ij > float("-inf"), m_ij, 0.0) + p = gl.exp2(S - m_ij[:, None]) + alpha = gl.exp2(M - m_ij) + return p, alpha, m_ij + + @gluon.jit + def softmax_part1(self, p, L, acc, alpha): + l_ij = gl.sum(p, 1) + acc = acc * gl.convert_layout(alpha[:, None], layout=self.cfg.PV_WMMA_LAYOUT) + p = p.to(gl.bfloat16, fp_downcast_rounding="rtz") + L = L * alpha + l_ij + return p, L, acc + + @gluon.jit + def compute_pv(self, p, v, acc): + p = gl.convert_layout(p, self.cfg.P_DOT_LAYOUT) + return gl.amd.gfx1250.wmma(p, v, acc) + + @gluon.jit + def store_output_3D(self, acc, M, L, segm_idx): + # acc = gl.convert_layout(acc, layout=self.cfg.PV_WMMA_LAYOUT) + offs_q_d = gl.arange( + 0, self.cfg.HEAD_SIZE, layout=gl.SliceLayout(0, self.cfg.PV_WMMA_LAYOUT) + ) + dim_mask = gl.full((1,), 1, dtype=tl.int1) + + segm_output_offset = ( + self.query_offset_0_pv[:, None] + * ( + self.cfg.NUM_QUERY_HEADS + * self.cfg.NUM_SEGMENTS_PER_SEQ + * self.cfg.HEAD_SIZE + ) + + self.query_offset_1_pv[:, None] + * (self.cfg.NUM_SEGMENTS_PER_SEQ * self.cfg.HEAD_SIZE) + + segm_idx * self.cfg.HEAD_SIZE + + offs_q_d[None, :] + ) + if self.cfg.USE_STORE_BUFFER_OP: + gl.amd.cdna4.buffer_store( + stored_value=acc, + ptr=self.output_ptr, + offsets=segm_output_offset, + mask=dim_mask[None, :] + & self.query_mask_0_pv[:, None] + & self.query_mask_1_pv[:, None], + ) + else: + gl.store( + self.output_ptr + segm_output_offset.to(gl.int64), + acc, + mask=dim_mask[None, :] + & self.query_mask_0_pv[:, None] + & self.query_mask_1_pv[:, None], + ) + + segm_offset = ( + self.query_offset_0_qk + * (self.cfg.NUM_QUERY_HEADS * self.cfg.NUM_SEGMENTS_PER_SEQ) + + self.query_offset_1_qk * self.cfg.NUM_SEGMENTS_PER_SEQ + + segm_idx + ) + L = gl.convert_layout(L, layout=gl.SliceLayout(1, self.cfg.QK_WMMA_LAYOUT)) + M = gl.convert_layout(M, layout=gl.SliceLayout(1, self.cfg.QK_WMMA_LAYOUT)) + + if self.cfg.USE_STORE_BUFFER_OP: + gl.amd.cdna4.buffer_store( + stored_value=M, + ptr=self.segm_max_ptr, + offsets=segm_offset.to(gl.int32), + mask=self.query_mask_0_qk & self.query_mask_1_qk, + ) + gl.amd.cdna4.buffer_store( + stored_value=L, + ptr=self.segm_expsum_ptr, + offsets=segm_offset.to(gl.int32), + mask=self.query_mask_0_qk & self.query_mask_1_qk, + ) + else: + gl.store( + self.segm_max_ptr + segm_offset.to(gl.int64), + M, + mask=self.query_mask_0_qk & self.query_mask_1_qk, + ) + gl.store( + self.segm_expsum_ptr + segm_offset.to(gl.int64), + L, + mask=self.query_mask_0_qk & self.query_mask_1_qk, + ) + + @gluon.jit + def store_output( + self, + out, + q_block_local_idx, + cur_batch_in_all_start_index, + kv_head_idx, + cur_batch_query_len, + output_stride_0, + output_stride_1, + ): + offs_m_out = gl.arange( + 0, self.cfg.BLOCK_M, layout=gl.SliceLayout(1, self.cfg.PV_WMMA_LAYOUT) + ) + offs_d_out = gl.arange( + 0, self.cfg.HEAD_SIZE, layout=gl.SliceLayout(0, self.cfg.PV_WMMA_LAYOUT) + ) + + query_pos_out = ( + q_block_local_idx * self.cfg.BLOCK_Q + + offs_m_out // self.cfg.NUM_QUERIES_PER_KV + ) + query_offset_0_out = cur_batch_in_all_start_index + query_pos_out + query_offset_1_out = ( + kv_head_idx * self.cfg.NUM_QUERIES_PER_KV + + offs_m_out % self.cfg.NUM_QUERIES_PER_KV + ) + + o_offs = ( + query_offset_0_out[:, None] * output_stride_0 + + query_offset_1_out[:, None] * output_stride_1 + + offs_d_out[None, :] + ) + + query_mask_0_out = query_pos_out < cur_batch_query_len + query_mask_1_out = query_offset_1_out < self.cfg.NUM_QUERY_HEADS + o_mask = query_mask_0_out[:, None] & query_mask_1_out[:, None] + casted_out = out.to(self.output_ptr.dtype.element_ty) + if self.cfg.USE_STORE_BUFFER_OP: + gl.amd.cdna4.buffer_store(casted_out, self.output_ptr, o_offs, mask=o_mask) + else: + gl.store(self.output_ptr + o_offs, casted_out, mask=o_mask) + + +@gluon.jit +def find_seq_idx( + query_start_len_ptr, + target_idx, + num_seqs, + BLOCK_Q: gl.constexpr, + use_q_block_mode: gl.constexpr = True, +): + """Binary search to find the sequence index for a given query block index.""" + left = 0 + right = num_seqs + while left < right: + mid = (left + right) // 2 + val = gl.load(query_start_len_ptr + mid) + mid_val = val // BLOCK_Q + mid if use_q_block_mode else val + if mid_val <= target_idx: + left = mid + 1 + else: + right = mid + return left - 1 + + +@gluon.jit +def get_q_metadata( + query_start_len_ptr, + seq_idx, + q_block_global_idx, + BLOCK_Q: gl.constexpr, +): + q_block_start_idx = gl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx + + q_block_local_idx = q_block_global_idx - q_block_start_idx + + cur_batch_in_all_start_index = gl.load(query_start_len_ptr + seq_idx) + cur_batch_in_all_stop_index = gl.load(query_start_len_ptr + seq_idx + 1) + + cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index + + return q_block_local_idx, cur_batch_query_len, cur_batch_in_all_start_index + + +@gluon.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@gluon.jit +def get_seq_metadata( + seq_lens_ptr, + seq_idx, + TILE_SIZE: gl.constexpr, + NUM_SEGMENTS_PER_SEQ: gl.constexpr, +): + # sequence len for this particular sequence + seq_len = gl.load(seq_lens_ptr + seq_idx) + + # number of segments for this particular sequence + num_segments = NUM_SEGMENTS_PER_SEQ + tiles_per_segment = cdiv_fn(seq_len, num_segments * TILE_SIZE) + + return seq_len, tiles_per_segment + + +@gluon.jit +def gluon_kernel_unified_attention_3d_tdm( + segm_output_ptr, # [num_tokens, num_query_heads, num_segments, head_size] + segm_max_ptr, # [num_tokens, num_query_heads, num_segments] + segm_expsum_ptr, # [num_tokens, num_query_heads, num_segments] + query_ptr, # [num_tokens, num_query_heads, head_size] + key_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] + value_cache_ptr, # [num_blks, num_kv_heads, blk_size, head_size] + sink_ptr, # [num_query_heads] + block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] + seq_lens_ptr, # [num_seqs] + alibi_slopes_ptr, # [num_query_heads] + qq_bias_ptr, # [num_query_tokens, num_query_tokens] + k_scale, # float32 + v_scale, # float32 + softcap, # float32 + num_seqs: gl.int32, # int + num_blocks: gl.int32, # int + query_stride_0: gl.int32, # int + query_stride_1: gl.int32, # int, should be equal to head_size + qq_bias_stride_0: gl.int32, # int + USE_ALIBI_SLOPES: gl.constexpr, # bool + USE_QQ_BIAS: gl.constexpr, # bool + USE_SOFTCAP: gl.constexpr, # bool + USE_SINKS: gl.constexpr, # bool + SLIDING_WINDOW: gl.constexpr, # int + stride_k_cache_0: gl.int32, # int + stride_k_cache_1: gl.int32, # int + stride_k_cache_2: gl.int32, # int + stride_k_cache_3: gl.int32, # int + stride_v_cache_0: gl.int32, # int + stride_v_cache_1: gl.int32, # int + stride_v_cache_2: gl.int32, # int + stride_v_cache_3: gl.int32, # int + block_table_stride: gl.int64, # int + query_start_len_ptr, # [num_seqs+1] + SCALE: gl.constexpr, # float32 + NUM_QUERY_HEADS: gl.constexpr, # int + NUM_KV_HEADS: gl.constexpr, # int + BLOCK_SIZE: gl.constexpr, # int + HEAD_SIZE: gl.constexpr, # int + BLOCK_Q: gl.constexpr, # int + BLOCK_M: gl.constexpr, # int + NUM_SEGMENTS_PER_SEQ: gl.constexpr, # int + WARP_SIZE: gl.constexpr, # int + num_warps: gl.constexpr, # int + waves_per_eu: gl.constexpr, # int + num_stages: gl.constexpr, # int + num_ctas: gl.constexpr = 1, # int + NUM_BLOCKS_GATHER_PER_TILE: gl.constexpr = 1, # int NUM_BLOCKS_GATHER_PER_TILE > 1 for TDM gather mode + ALL_DECODE: gl.constexpr = False, # bool + SHUFFLED_KV_CACHE: gl.constexpr = False, # bool + USE_LOAD_BUFFER_OP: gl.constexpr = False, # bool + USE_STORE_BUFFER_OP: gl.constexpr = False, # bool +): + # Build config with all layouts and derived constants + cfg = AttentionConfig( + HEAD_SIZE, + BLOCK_SIZE, + NUM_BLOCKS_GATHER_PER_TILE, + NUM_SEGMENTS_PER_SEQ, + BLOCK_M, + BLOCK_Q, + NUM_QUERY_HEADS, + NUM_KV_HEADS, + SLIDING_WINDOW, + num_warps, + WARP_SIZE, + num_stages, + SCALE, + USE_ALIBI_SLOPES, + USE_QQ_BIAS, + USE_SOFTCAP, + USE_SINKS, + USE_LOAD_BUFFER_OP, + USE_STORE_BUFFER_OP, + SHUFFLED_KV_CACHE, + ) + + # Workgroup offsets + q_block_global_idx = gl.program_id(0) + kv_head_idx = gl.program_id(1) + segm_idx = gl.program_id(2) + + # Find sequence index using binary search + seq_idx = find_seq_idx( + query_start_len_ptr, q_block_global_idx, num_seqs, cfg.BLOCK_Q, True + ) + + # Get query block start and local index + q_block_local_idx, cur_batch_query_len, cur_batch_in_all_start_index = ( + get_q_metadata( + query_start_len_ptr, + seq_idx, + q_block_global_idx, + cfg.BLOCK_Q, + ) + ) + + if q_block_local_idx * cfg.BLOCK_Q >= cur_batch_query_len: + return + + seq_len, tiles_per_segment = get_seq_metadata( + seq_lens_ptr, + seq_idx, + cfg.TILE_SIZE, + cfg.NUM_SEGMENTS_PER_SEQ, + ) + + if segm_idx * tiles_per_segment * cfg.TILE_SIZE >= seq_len: + return + + context_len = seq_len - cur_batch_query_len + block_tables_ptr_shifted = block_tables_ptr + seq_idx * block_table_stride + + # load Q + offs_q_m_load = gl.arange(0, BLOCK_M, layout=gl.SliceLayout(1, cfg.Q_LOAD_LAYOUT)) + offs_q_d_load = gl.arange(0, HEAD_SIZE, layout=gl.SliceLayout(0, cfg.Q_LOAD_LAYOUT)) + query_pos_load = ( + q_block_local_idx * BLOCK_Q + offs_q_m_load // cfg.NUM_QUERIES_PER_KV + ) + query_offset_0_load = cur_batch_in_all_start_index + query_pos_load + query_offset_1_load = ( + kv_head_idx * cfg.NUM_QUERIES_PER_KV + offs_q_m_load % cfg.NUM_QUERIES_PER_KV + ) + query_offset_load = ( + query_offset_0_load[:, None] * query_stride_0 + + query_offset_1_load[:, None] * query_stride_1 + + offs_q_d_load[None, :] + ) + dim_mask_load = gl.full((1,), 1, dtype=tl.int1) + query_mask_0_load = query_pos_load < cur_batch_query_len + query_mask_1_load = query_offset_1_load < cfg.NUM_QUERY_HEADS + q_shared = gl.allocate_shared_memory( + query_ptr.type.element_ty, + shape=[BLOCK_M, HEAD_SIZE], + layout=cfg.Q_SHARED_LAYOUT, + ) + Q_load = gl.amd.cdna4.buffer_load( + ptr=query_ptr, + offsets=query_offset_load.to(gl.int32), + mask=dim_mask_load[None, :] + & query_mask_0_load[:, None] + & query_mask_1_load[:, None], + other=0.0, + ) + q_shared.store(Q_load) + Q = q_shared.load(layout=cfg.Q_DOT_LAYOUT) + + # define offsets and masks in QK WMMA_LAYOUT + offs_q_m_qk = gl.arange( + 0, cfg.BLOCK_M, layout=gl.SliceLayout(1, cfg.QK_WMMA_LAYOUT) + ) + query_pos_qk = ( + q_block_local_idx * cfg.BLOCK_Q + offs_q_m_qk // cfg.NUM_QUERIES_PER_KV + ) + query_offset_0_qk = cur_batch_in_all_start_index + query_pos_qk + query_offset_1_qk = ( + kv_head_idx * cfg.NUM_QUERIES_PER_KV + offs_q_m_qk % cfg.NUM_QUERIES_PER_KV + ) + query_mask_0_qk = query_pos_qk < cur_batch_query_len + query_mask_1_qk = query_offset_1_qk < cfg.NUM_QUERY_HEADS + query_mask_qk = query_mask_1_qk[:, None] & query_mask_0_qk[:, None] + + query_offset_0_pv = gl.convert_layout( + query_offset_0_qk, layout=gl.SliceLayout(1, cfg.PV_WMMA_LAYOUT) + ) + query_offset_1_pv = gl.convert_layout( + query_offset_1_qk, layout=gl.SliceLayout(1, cfg.PV_WMMA_LAYOUT) + ) + query_mask_0_pv = gl.convert_layout( + query_mask_0_qk, layout=gl.SliceLayout(1, cfg.PV_WMMA_LAYOUT) + ) + query_mask_1_pv = gl.convert_layout( + query_mask_1_qk, layout=gl.SliceLayout(1, cfg.PV_WMMA_LAYOUT) + ) + + # compute the length of the longest sequence prefix spanned by any + # query token in the current q_block (q_block_local_idx) + max_seq_prefix_len = ( + context_len + + q_block_local_idx * cfg.BLOCK_Q + + (cfg.BLOCK_M - 1) // cfg.NUM_QUERIES_PER_KV + + 1 + ) + max_seq_prefix_len = gl.minimum(max_seq_prefix_len, seq_len) + + # TODO: resume from here + # build program + pgm: AttentionProgram = AttentionProgram.initialize( + cfg, + Q, + key_cache_ptr, + value_cache_ptr, + segm_output_ptr, + segm_max_ptr, + segm_expsum_ptr, + max_seq_prefix_len, + q_block_local_idx, + cur_batch_query_len, + context_len, + kv_head_idx, + num_blocks, + query_pos_qk, + query_mask_qk, + query_offset_0_qk, + query_offset_1_qk, + query_mask_0_qk, + query_mask_1_qk, + query_offset_0_pv, + query_offset_1_pv, + query_mask_0_pv, + query_mask_1_pv, + segm_idx, # for 2D, segm_idx = 0 + tiles_per_segment, # for 2D, tiles_per_segment = num_tiles = (max_seq_prefix_len + cfg.BLOCK_SIZE - 1) // cfg.BLOCK_SIZE + stride_k_cache_0, + stride_k_cache_1, + stride_k_cache_2, + stride_k_cache_3, + stride_v_cache_0, + stride_v_cache_1, + stride_v_cache_2, + stride_v_cache_3, + qq_bias_stride_0, + softcap, + ) + + # alibi slope for this head + alibi_slope = None + if cfg.USE_ALIBI_SLOPES: + alibi_slope = tl.load( + alibi_slopes_ptr + query_offset_1_qk, mask=query_mask_1_qk, other=0.0 + ) + + # query-query attention bias + qq_bias_row_ptrs = None + if cfg.USE_QQ_BIAS: + qq_bias_row_ptrs = qq_bias_ptr + query_pos_qk[:, None] * qq_bias_stride_0 + + L, M, acc = pgm.allocate_accumulator( + sink_ptr, + segm_idx, + query_offset_1_qk, + query_mask_1_qk, + ) + + j_from_hbm: gl.int32 = segm_idx * tiles_per_segment + buffer_id: gl.int32 = 0 + seq_offset = j_from_hbm * cfg.TILE_SIZE + gl.arange( + 0, cfg.TILE_SIZE, layout=gl.SliceLayout(0, cfg.QK_WMMA_LAYOUT) + ) + + for _ in range(cfg.NUM_STAGES - 1): + j_from_hbm, physical_block_idx = pgm.load_physical_block_idx( + j_from_hbm, kv_head_idx, block_tables_ptr_shifted + ) + pgm.tdm_load_global_to_shared_k(physical_block_idx, buffer_id=buffer_id) + pgm.tdm_load_global_to_shared_v(physical_block_idx, buffer_id=buffer_id) + + # Main attention loop over KV tiles (staged, num_stages=2) + for j in range(pgm.tile_start, pgm.tile_end - (cfg.NUM_STAGES - 1)): + j_from_hbm, next_physical_block_idx = pgm.load_physical_block_idx( + j_from_hbm, kv_head_idx, block_tables_ptr_shifted + ) + k = pgm.tdm_shared_load_k( + wait_count=(cfg.NUM_STAGES - 2) * 2 + 1, buffer_id=buffer_id + ) + + next_buffer_id = pgm.get_next_buffer_id(buffer_id) + # Prefetch next tile (shared is free since k, v are in registers) + pgm.tdm_load_global_to_shared_k( + next_physical_block_idx, buffer_id=next_buffer_id + ) + pgm.tdm_load_global_to_shared_v( + next_physical_block_idx, buffer_id=next_buffer_id + ) + + # Compute attention for current tile + S = pgm.compute_qk(k) + + S = pgm.apply_softcap(S) + S = pgm.apply_mask_qk_3D(S, seq_offset, alibi_slope, qq_bias_row_ptrs) + # if j >= pgm.safe_tile_end or SLIDING_WINDOW > 0: + # S = pgm.apply_mask_qk(S, j) + + p, alpha, M = pgm.softmax_part0(S, M) + p, L, acc = pgm.softmax_part1(p, L, acc, alpha) + v = pgm.tdm_shared_load_v( + wait_count=(cfg.NUM_STAGES - 1) * 2, buffer_id=buffer_id + ) + p = p.to(v.dtype) + acc = pgm.compute_pv(p, v, acc) + + buffer_id = next_buffer_id + seq_offset += cfg.TILE_SIZE + + for _ in range(cfg.NUM_STAGES - 1): + # Load k_i, v_i from shared into registers + k = pgm.tdm_shared_load_k( + wait_count=(cfg.NUM_STAGES - 2) * 2 + 1, buffer_id=buffer_id + ) + # Compute attention for current tile + S = pgm.compute_qk(k) + + S = pgm.apply_softcap(S) + S = pgm.apply_mask_qk_3D(S, seq_offset, alibi_slope, qq_bias_row_ptrs) + # S = pgm.apply_mask_qk(S, pgm.tile_end - 1) + + p, alpha, M = pgm.softmax_part0(S, M) + p, L, acc = pgm.softmax_part1(p, L, acc, alpha) + v = pgm.tdm_shared_load_v( + wait_count=(cfg.NUM_STAGES - 2) * 2, buffer_id=buffer_id + ) + p = p.to(v.dtype) + acc = pgm.compute_pv(p, v, acc) + + seq_offset += cfg.TILE_SIZE + + # Normalize and store output, this is done in reduce kernel for 3D + # l_recip = 1 / L[:, None] + # acc = acc * l_recip + + pgm.store_output_3D( + acc, + M, + L, + segm_idx, + ) diff --git a/aiter/ops/triton/quant/fused_mxfp4_quant.py b/aiter/ops/triton/quant/fused_mxfp4_quant.py index c774b8944b..36994b7a87 100644 --- a/aiter/ops/triton/quant/fused_mxfp4_quant.py +++ b/aiter/ops/triton/quant/fused_mxfp4_quant.py @@ -4,6 +4,7 @@ import triton.language as tl from typing import Optional from aiter.utility import dtypes +import aiter.ops.triton.utils._triton.arch_info as arch_info from aiter.ops.triton._triton_kernels.quant.fused_mxfp4_quant import ( _fused_rms_mxfp4_quant_kernel, _fused_flatten_mxfp4_quant, @@ -11,6 +12,9 @@ _fused_reduce_rms_mxfp4_quant_kernel, _fused_dynamic_mxfp4_quant_moe_sort_kernel, ) +from aiter.ops.triton._gluon_kernels.quant.fuse_mxfp4_quant import ( + _gluon_fused_rms_mxfp4_quant, +) from aiter.ops.triton._triton_kernels.activation import ( _get_activation_from_str, ) @@ -30,6 +34,7 @@ def fused_rms_mxfp4_quant( shuffle: Optional[bool] = False, scale_shuffle_padding: Optional[bool] = False, output_unquantized_inp1=False, + impl: str = "auto", ): """ This op contains several steps: @@ -105,43 +110,97 @@ def fused_rms_mxfp4_quant( out2_stride_m = out2.stride(0) grid = (triton.cdiv(M, BLOCK_SIZE_M) * (2 if (x2 is not None) else 1),) - _fused_rms_mxfp4_quant_kernel[grid]( - x1, - x1_weight, - x2, - x2_weight, - res1, - out1_fp4, - out1_bs, - out2, - out_res1, - out1, - x1_epsilon, - x2_epsilon, - M, - N1, - N2, - x1.stride(0), - x2_stride_m, - res1_stride_m, - out1_fp4.stride(0), - *out1_bs.stride(), - out2_stride_m, - out_res1_stride_m, - out1_stride_m, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_N2=BLOCK_SIZE_N2, - MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, - HAS_SECOND_INPUT=(x2 is not None), - FIRST_INPUT_RES=(res1 is not None), - FIRST_INPUT_OUT=output_unquantized_inp1, - SCALE_N=SCALE_N_valid, - SCALE_M_PAD=(SCALE_M if use_scale_shuffle_padding else 1), - SCALE_N_PAD=SCALE_N, - SHUFFLE=shuffle, - SHUFFLE_PAD=use_scale_shuffle_padding, - ) + + _arch = arch_info.get_arch() + if impl == "auto": + _use_gluon = _arch == "gfx1250" + elif impl == "gluon": + if _arch != "gfx1250": + raise RuntimeError( + f"Gluon kernel requires gfx1250, current arch is {_arch!r}" + ) + _use_gluon = True + elif impl == "triton": + _use_gluon = False + else: + raise ValueError(f"Unknown impl {impl!r}. Choose 'auto', 'triton', or 'gluon'") + + if _use_gluon: + _gluon_fused_rms_mxfp4_quant[grid]( + x1, + x1_weight, + x2, + x2_weight, + res1, + out1_fp4, + out1_bs, + out2, + out_res1, + out1, + x1_epsilon, + x2_epsilon, + M, + N1, + N2, + x1.stride(0), + x2_stride_m, + res1_stride_m, + out1_fp4.stride(0), + *out1_bs.stride(), + out2_stride_m, + out_res1_stride_m, + out1_stride_m, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + HAS_SECOND_INPUT=(x2 is not None), + FIRST_INPUT_RES=(res1 is not None), + FIRST_INPUT_OUT=output_unquantized_inp1, + SCALE_N=SCALE_N_valid, + SCALE_M_PAD=(SCALE_M if use_scale_shuffle_padding else 1), + SCALE_N_PAD=SCALE_N, + SHUFFLE=shuffle, + SHUFFLE_PAD=use_scale_shuffle_padding, + ) + else: + _fused_rms_mxfp4_quant_kernel[grid]( + x1, + x1_weight, + x2, + x2_weight, + res1, + out1_fp4, + out1_bs, + out2, + out_res1, + out1, + x1_epsilon, + x2_epsilon, + M, + N1, + N2, + x1.stride(0), + x2_stride_m, + res1_stride_m, + out1_fp4.stride(0), + *out1_bs.stride(), + out2_stride_m, + out_res1_stride_m, + out1_stride_m, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_N2=BLOCK_SIZE_N2, + MXFP4_QUANT_BLOCK_SIZE=MXFP4_QUANT_BLOCK_SIZE, + HAS_SECOND_INPUT=(x2 is not None), + FIRST_INPUT_RES=(res1 is not None), + FIRST_INPUT_OUT=output_unquantized_inp1, + SCALE_N=SCALE_N_valid, + SCALE_M_PAD=(SCALE_M if use_scale_shuffle_padding else 1), + SCALE_N_PAD=SCALE_N, + SHUFFLE=shuffle, + SHUFFLE_PAD=use_scale_shuffle_padding, + ) return (out1_fp4, out1_bs), out1, out2, out_res1 diff --git a/op_tests/triton_tests/attention/test_chunked_pa_prefill.py b/op_tests/triton_tests/attention/test_chunked_pa_prefill.py index 424f5a1147..5379ca8dbe 100644 --- a/op_tests/triton_tests/attention/test_chunked_pa_prefill.py +++ b/op_tests/triton_tests/attention/test_chunked_pa_prefill.py @@ -172,7 +172,8 @@ def _get_alibi_slopes(total_num_heads: int, device: torch.device) -> torch.Tenso def seed_everything(seed): random.seed(seed) - torch.manual_seed(seed) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(seed) def input_helper( diff --git a/op_tests/triton_tests/attention/test_extend_attention.py b/op_tests/triton_tests/attention/test_extend_attention.py index 7fbe8f0135..c7e86f8a0b 100644 --- a/op_tests/triton_tests/attention/test_extend_attention.py +++ b/op_tests/triton_tests/attention/test_extend_attention.py @@ -20,7 +20,8 @@ def input_helper( equal_seqlens=False, requires_grad=False, ): - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) if not equal_seqlens: max_extend_length = extend_length @@ -164,7 +165,8 @@ def test_op_fwd( logit_cap=0.0, device="cuda", ): - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) torch.set_default_device(device) torch.set_default_dtype(dtype) diff --git a/op_tests/triton_tests/attention/test_fp8_mqa_logits.py b/op_tests/triton_tests/attention/test_fp8_mqa_logits.py index 76561cfca3..9df6b3e1e5 100644 --- a/op_tests/triton_tests/attention/test_fp8_mqa_logits.py +++ b/op_tests/triton_tests/attention/test_fp8_mqa_logits.py @@ -96,7 +96,8 @@ def test_fp8_mqa_logits( head_dim: int, disable_cp: bool, ) -> None: - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) if s_q > s_k: pytest.skip() q = torch.randn(s_q, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) diff --git a/op_tests/triton_tests/attention/test_hstu_attn.py b/op_tests/triton_tests/attention/test_hstu_attn.py index 5990b7dc22..aff6272ba0 100644 --- a/op_tests/triton_tests/attention/test_hstu_attn.py +++ b/op_tests/triton_tests/attention/test_hstu_attn.py @@ -27,7 +27,8 @@ def generate_sparse_seq_len( sparsity: float, device: torch.device, ) -> torch.Tensor: - torch.manual_seed(1) # for reproducibility + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(1) # for reproducibility if sparsity == 0.0: return torch.zeros(size=(size,), device=device, dtype=torch.int) @@ -166,7 +167,8 @@ def test_hstu_attention( alpha = 1.0 / attn_dim * 10000 # generate inputs - torch.manual_seed(1001) # for reproducibility + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(1001) # for reproducibility lengths = generate_sparse_seq_len( size=batch_size, max_seq_len=max_seq_len, diff --git a/op_tests/triton_tests/attention/test_la.py b/op_tests/triton_tests/attention/test_la.py index ad38626027..3d2cf3b488 100644 --- a/op_tests/triton_tests/attention/test_la.py +++ b/op_tests/triton_tests/attention/test_la.py @@ -347,7 +347,8 @@ def test_persistent_lean_attention( ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) # Long seqlen (>512K) can hit memory access fault. Suspect compiler issue # WA with shorter d and longer BLOCK_N if any(item > 524288 for item in n_ctx): @@ -453,7 +454,8 @@ def test_persistent_lean_attention_outer( causal, RAGGED_BATCH, ): - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) config = _get_config( batch_size=batch, diff --git a/op_tests/triton_tests/attention/test_la_paged.py b/op_tests/triton_tests/attention/test_la_paged.py index 521161b9bb..5bf3aae458 100644 --- a/op_tests/triton_tests/attention/test_la_paged.py +++ b/op_tests/triton_tests/attention/test_la_paged.py @@ -56,7 +56,8 @@ def test_persistent_lean_attention( ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) # Long seqlen (>512K) can hit memory access fault. Suspect compiler issue # WA with shorter d and longer BLOCK_N if any(item > 524288 for item in n_ctx): diff --git a/op_tests/triton_tests/attention/test_mha.py b/op_tests/triton_tests/attention/test_mha.py index a41e4084e9..fe2746fc9e 100644 --- a/op_tests/triton_tests/attention/test_mha.py +++ b/op_tests/triton_tests/attention/test_mha.py @@ -4,6 +4,7 @@ import torch import pytest import logging +import numpy as np from aiter.ops.triton.attention.mha import ( flash_attn_func, flash_attn_varlen_func, @@ -27,60 +28,8 @@ logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) DEBUG_MODE = False - - -def _attention_ref_with_tol(q, k, v, do, is_fp8=False, **kwargs): - """Run attention reference and compute adaptive tolerances. - - Follows the upstream flash attention tolerance pattern - (see tests/test_flash_attn.py in Dao-AILab/flash-attention). Runs two - PyTorch references (upcast and non-upcast) and uses the gap between them as - a baseline for tolerance. - - Returns (out, (dq, dk, dv), fwd_tol, [dq_tol, dk_tol, dv_tol]) - where each tol is (atol, rtol). - """ - has_dropout = kwargs.get("dropout_p", 0.0) > 0.0 - - def _run_ref(upcast, reorder_ops=False): - q_ = q.detach().clone().requires_grad_(True) - k_ = k.detach().clone().requires_grad_(True) - v_ = v.detach().clone().requires_grad_(True) - with torch.enable_grad(): - out, _, _ = attention_ref( - q_, k_, v_, upcast=upcast, reorder_ops=reorder_ops, **kwargs - ) - dq, dk, dv = torch.autograd.grad(out, (q_, k_, v_), do) - return out, dq, dk, dv - - def _tol(ref_val, pt_val, is_forward=False): - baseline = (pt_val - ref_val).abs().max().item() - if is_fp8: - mult = 4 - atol_floor = 3e-1 if is_forward else 1.0 - rtol_floor = 1e-1 - elif has_dropout: - # Dropout scaling (1/(1-p)) amplifies precision errors in the - # fused kernel differently than in the reference. The baseline - # between two references uses the same mask so it underestimates - # the kernel-vs-reference gap. - mult = 2 - atol_floor = 1e-1 if is_forward else 2.0 - rtol_floor = 1e-1 - else: - mult = 2 - atol_floor = 1e-2 if is_forward else 1.5e-2 - rtol_floor = 1e-5 - atol = max(mult * baseline, atol_floor) - return atol, rtol_floor - - out, dq, dk, dv = _run_ref(upcast=True) - out_pt, dq_pt, dk_pt, dv_pt = _run_ref(upcast=False, reorder_ops=True) - - fwd_tol = _tol(out, out_pt, is_forward=True) - bwd_tols = [_tol(dq, dq_pt), _tol(dk, dk_pt), _tol(dv, dv_pt)] - - return out, (dq, dk, dv), fwd_tol, bwd_tols +ATOL_fp8 = 3.0e-1 +RTOL_fp8 = 2.5e-1 def pad_rearrange_dropout_mask( @@ -114,27 +63,45 @@ def pad_rearrange_dropout_mask( return padded_dropout_mask -def assert_cosine_similarity(actual, expected, threshold=0.96, norm_floor=1e-3): - """Assert that two tensors have high cosine similarity.""" - a = actual.float().flatten() - b = expected.float().flatten() - # NOTE: cosine similarity is unstable for near-zero tensors - if b.norm().item() > norm_floor: - cos_sim = torch.nn.functional.cosine_similarity( - a.unsqueeze(0), b.unsqueeze(0) - ).item() - assert cos_sim >= threshold, f"Cosine similarity {cos_sim:.6f} < {threshold}" +def fp8_assert_close( + tensor_a, tensor_b, atol=ATOL_fp8, rtol=RTOL_fp8, max_diff_percentage=0.5 +): + """Assert tensors are close with tolerance for small percentage of elements""" + # standard comparison + abs_diff = torch.abs(tensor_a - tensor_b) + rel_diff = abs_diff / torch.abs(tensor_b.clamp(min=1e-6)) + # calculate elements that exceed tolerance + abs_check = abs_diff > atol + rel_check = rel_diff > rtol + failed_check = torch.logical_and(abs_check, rel_check) -def fp8_assert_close(tensor_a, tensor_b, atol=1.0, cos_sim_threshold=0.96): - """FP8 quality check: max absolute error + cosine similarity.""" - a = tensor_a.float().flatten() - b = tensor_b.float().flatten() + # calculate percentage of failed elements + failed_percentage = failed_check.sum().item() / failed_check.numel() * 100 - max_abs = (a - b).abs().max().item() - assert max_abs <= atol, f"Max absolute error {max_abs:.4f} > {atol}" + # if percentage is small enough, test passes + if failed_percentage <= max_diff_percentage: + return True + + # Otherwise, provide diagnostic information + max_abs_idx = torch.argmax(abs_diff).item() + max_rel_idx = torch.argmax(rel_diff).item() + + flat_to_idx = lambda flat_idx, shape: np.unravel_index( # noqa: E731 + flat_idx, shape + ) - assert_cosine_similarity(tensor_a, tensor_b, cos_sim_threshold) + max_abs_pos = flat_to_idx(max_abs_idx, tensor_a.shape) + max_rel_pos = flat_to_idx(max_rel_idx, tensor_a.shape) + + max_abs_diff = abs_diff.flatten()[max_abs_idx].item() + max_rel_diff = rel_diff.flatten()[max_rel_idx].item() + + raise AssertionError( + f"Tensors not close enough! {failed_percentage:.6f}% elements exceed tolerance.\n" + f"Greatest absolute difference: {max_abs_diff} at index {max_abs_pos} (up to {atol} allowed)\n" + f"Greatest relative difference: {max_rel_diff} at index {max_rel_pos} (up to {rtol} allowed)" + ) @pytest.mark.parametrize("BATCH", [1, 4, 57, 128]) @@ -230,7 +197,9 @@ def test_mha( ) if FP8: - fp8_assert_close(triton_out, torch_out.to(triton_out.dtype)) + fp8_assert_close( + triton_out, torch_out.to(triton_out.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 + ) else: torch.testing.assert_close(triton_out, torch_out, atol=1e-2, rtol=1e-2) @@ -262,7 +231,8 @@ def test_mha_int64_strides( In the absence of strides being int64, parts of the offset computation is done in 32 bit and overflows resulting in segfaults. """ torch.cuda.empty_cache() - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) # use int64 strides. mha_set_use_int64_strides( True @@ -365,7 +335,8 @@ def test_mha_varlen( ): torch.set_printoptions(threshold=10000) torch.cuda.empty_cache() - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) q = torch.randn((BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ), device="cuda", dtype=dtype) k = torch.randn((BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ), device="cuda", dtype=dtype) v = torch.randn((BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ), device="cuda", dtype=dtype) @@ -497,27 +468,29 @@ def test_mha_varlen( ) if FP8: - fp8_assert_close(triton_out, torch_out.to(triton_out.dtype)) + torch.testing.assert_close( + triton_out, torch_out.to(triton_out.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 + ) else: torch.testing.assert_close( triton_out, torch_out.to(triton_out.dtype), atol=1e-1, rtol=1e-1 ) -# Production shapes based on real models: -# HQ=32, HK=8: Llama 3 8B (GQA 4:1) -# HQ=64, HK=8: Llama 3 70B (GQA 8:1) -# HQ=32, HK=32: Llama 2 7B (MHA) -@pytest.mark.parametrize("BATCH", [1, 4]) -@pytest.mark.parametrize("SEQLEN_Q", [512, 1024, 2048]) -@pytest.mark.parametrize("SEQLEN_K", [512, 1024, 2048]) -@pytest.mark.parametrize("NUM_Q_HEADS", [32, 64]) -@pytest.mark.parametrize("NUM_K_HEADS", [8]) -@pytest.mark.parametrize("HEAD_SZ", [128]) -@pytest.mark.parametrize("CAUSAL", [True, False]) -@pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) +@pytest.mark.parametrize("BATCH", [1, 4, 57, 128]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(1, 1), (4, 4), (128, 128), (2, 1), (1, 2), (32, 16), (64, 128)], +) +@pytest.mark.parametrize("DROPOUT, CAUSAL", [(0.0, False), (0.0, True), (0.2, False)]) +# @pytest.mark.parametrize('DROPOUT, CAUSAL',[(0.0, False),(0.0, True),(0.2, False),(0.2, True)]) #Debug Causal + Dropout. fails for seq >= 64 +@pytest.mark.parametrize( + "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (16, 16), (2, 1), (48, 8)] +) +@pytest.mark.parametrize("HEAD_SZ", [8, 32, 128]) +@pytest.mark.parametrize("FP8", [False]) @pytest.mark.parametrize("FUSED", [False, True]) -@pytest.mark.parametrize("FP8", [True, False]) +# @pytest.mark.parametrize('FP8',[(False), (True)]) #TODO Debug FP8 def test_mha_backward( BATCH: int, SEQLEN_Q: int, @@ -525,37 +498,63 @@ def test_mha_backward( NUM_Q_HEADS: int, NUM_K_HEADS: int, HEAD_SZ: int, - CAUSAL: bool, DROPOUT: float, - FUSED: bool, + CAUSAL: bool, FP8: bool, + FUSED: bool, dtype=torch.float16, ): - HAS_DROPOUT = DROPOUT > 0.0 torch.cuda.empty_cache() - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) + + # TODO: Enable these tests once this is fixed + # As of torch 2.9.1+rocm7.1.1, these test cases aren't working + # on gfx942 machines. They are confirmed to work on torch 2.7.1+rocm 7.0. + # This was tested with the same Triton compiler version: + # https://github.com/triton-lang/triton/commit/ecbb77c + if ( + arch == "gfx942" + and not FUSED + and HEAD_SZ == 128 + and (DROPOUT, CAUSAL) == (0.2, False) + and (SEQLEN_Q, SEQLEN_K) in [(4, 4), (2, 1)] + ): + pytest.skip( + "triton_dv and torch_dv are not matching for these test cases on gfx942 architecture" + ) if FUSED and CAUSAL: pytest.skip("FUSED+CAUSAL results in NaNs") - if FP8 and HAS_DROPOUT: - pytest.skip("FP8 does not support dropout") - if CAUSAL and HAS_DROPOUT: - pytest.skip("CAUSAL+DROPOUT backward results in NaNs") mha_set_use_fused_bwd_kernel(FUSED) - q = torch.randn(BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype) - k = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype) - v = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype) + q = torch.randn((BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ), device="cuda", dtype=dtype) + k = torch.randn((BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ), device="cuda", dtype=dtype) + v = torch.randn((BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ), device="cuda", dtype=dtype) q.requires_grad = True k.requires_grad = True v.requires_grad = True + do = torch.randn_like(q) - # Triton forward + backward + if DEBUG_MODE: + print("--------------Triton----------------") + print(f"q.shape={q.shape} q={q}") + print(f"k.shape={k.shape} k={k}") + print(f"v.shape={v.shape} v={v}") + print(f"do.shape={do.shape} do={do}") + with torch.enable_grad(): if FP8: - triton_out = flash_attn_fp8_func(q, k, v, causal=CAUSAL) - dropout_mask = None + if DROPOUT > 0.0: + pytest.skip("FP8 does not support dropout_p") + triton_out = flash_attn_fp8_func( + q, + k, + v, + causal=CAUSAL, + ) + lse, sd_mask = None, None else: triton_out = flash_attn_func( q, @@ -563,51 +562,92 @@ def test_mha_backward( v, dropout_p=DROPOUT, causal=CAUSAL, - return_lse=HAS_DROPOUT, - return_attn_probs=HAS_DROPOUT, + return_lse=True, + return_attn_probs=True, ) - if HAS_DROPOUT: - dropout_mask = triton_out[2] >= 0 - triton_out = triton_out[0] - else: - dropout_mask = None + + assert len(triton_out) == 3 + triton_out, lse, sd_mask = triton_out[0], triton_out[1], triton_out[2] + + if DROPOUT > 0.0: + dropout_mask = sd_mask >= 0 + else: + dropout_mask = None + triton_dq, triton_dk, triton_dv = torch.autograd.grad( triton_out, (q, k, v), do.clone() ) - # Reference forward + backward with adaptive tolerances - torch_out, torch_grads, fwd_tol, bwd_tols = _attention_ref_with_tol( - q, - k, - v, - do, - is_fp8=FP8, - dropout_p=DROPOUT, - dropout_mask=dropout_mask, - causal=CAUSAL, + if DEBUG_MODE: + print(f"triton_out={triton_out}") + print(f"triton_lse={lse}") + print(f"sd_mask={sd_mask}") + print(f"triton_dq.shape={triton_dq.shape} triton_dq={triton_dq}") + print(f"triton_dk.shape={triton_dk.shape} triton_dk={triton_dk}") + print(f"triton_dv.shape={triton_dv.shape} triton_dv={triton_dv}") + print(f"dropout_mask={dropout_mask}") + + if DEBUG_MODE: + print("--------------Torch----------------") + print(f"q.shape={q.shape} q={q}") + print(f"k.shape={k.shape} k={k}") + print(f"v.shape={v.shape} v={v}") + print(f"do.shape={do.shape} do={do}") + with torch.enable_grad(): + torch_out = attention_ref( + q, k, v, dropout_p=DROPOUT, dropout_mask=dropout_mask, causal=CAUSAL + ) + torch_out, attention_scores, _ = torch_out + + torch.testing.assert_close( + triton_out, torch_out.to(triton_out.dtype), atol=1e-2, rtol=1e-2 ) - torch_dq, torch_dk, torch_dv = torch_grads - # Check quality - triton_vals = [triton_out, triton_dq, triton_dk, triton_dv] - ref_vals = [torch_out, torch_dq, torch_dk, torch_dv] - tols = [fwd_tol] + bwd_tols - for tri, ref, (atol, rtol) in zip(triton_vals, ref_vals, tols): - torch.testing.assert_close(tri, ref.to(tri.dtype), atol=atol, rtol=rtol) - if FP8: - assert_cosine_similarity(tri, ref) + torch_dq, torch_dk, torch_dv = torch.autograd.grad(torch_out, (q, k, v), do) + if DEBUG_MODE: + print(f"torch_out={torch_out}") + print(f"torch_attn_scores={attention_scores}") + print(f"torch_dq.shape={torch_dq.shape} torch_dq={torch_dq}") + print(f"torch_dk.shape={torch_dk.shape} torch_dk={torch_dk}") + print(f"torch_dv.shape={torch_dv.shape} torch_dv={torch_dv}") -@pytest.mark.parametrize("BATCH", [1, 4]) -@pytest.mark.parametrize("SEQLEN_Q", [512, 1024, 2048]) -@pytest.mark.parametrize("SEQLEN_K", [512, 1024, 2048]) -@pytest.mark.parametrize("NUM_Q_HEADS", [32, 64]) -@pytest.mark.parametrize("NUM_K_HEADS", [8]) -@pytest.mark.parametrize("HEAD_SZ", [128]) -@pytest.mark.parametrize("CAUSAL", [True, False]) -@pytest.mark.parametrize("DROPOUT", [0.0, 0.2]) + if FP8: + fp8_assert_close( + triton_dq, torch_dq.to(triton_dq.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 + ) + fp8_assert_close( + triton_dk, torch_dk.to(triton_dk.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 + ) + fp8_assert_close( + triton_dv, torch_dv.to(triton_dv.dtype), atol=ATOL_fp8, rtol=RTOL_fp8 + ) + else: + torch.testing.assert_close( + triton_dq, torch_dq.to(triton_out.dtype), atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close( + triton_dk, torch_dk.to(triton_out.dtype), atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close( + triton_dv, torch_dv.to(triton_out.dtype), atol=1e-2, rtol=1e-2 + ) + + +@pytest.mark.parametrize("BATCH", [1, 4, 57, 128]) +@pytest.mark.parametrize( + "SEQLEN_Q, SEQLEN_K", + [(1, 1), (4, 4), (128, 128), (2, 1), (1, 2), (32, 16), (64, 128)], +) +@pytest.mark.parametrize("DROPOUT, CAUSAL", [(0.0, False), (0.0, True)]) +# @pytest.mark.parametrize('DROPOUT, CAUSAL',[(0.0, False),(0.0, True),(0.2, False),(0.2, True)]) #Debug Causal + Dropout. Fails for seq >=64 +@pytest.mark.parametrize( + "NUM_Q_HEADS, NUM_K_HEADS", [(1, 1), (16, 16), (2, 1), (48, 8)] +) +@pytest.mark.parametrize("HEAD_SZ", [8, 32, 128]) +@pytest.mark.parametrize("FP8", [False]) @pytest.mark.parametrize("FUSED", [False, True]) -@pytest.mark.parametrize("FP8", [True, False]) +# @pytest.mark.parametrize('FP8',[(False), (True)]) #TODO Debug FP8 def test_mha_backward_varlen( BATCH: int, SEQLEN_Q: int, @@ -615,27 +655,23 @@ def test_mha_backward_varlen( NUM_Q_HEADS: int, NUM_K_HEADS: int, HEAD_SZ: int, - CAUSAL: bool, DROPOUT: float, - FUSED: bool, + CAUSAL: bool, FP8: bool, + FUSED: bool, dtype=torch.float16, ): - HAS_DROPOUT = DROPOUT > 0.0 torch.cuda.empty_cache() - torch.manual_seed(20) - + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) + # pytest.skip("Backward accuracy issues due to Triton compiler") if FUSED and CAUSAL: pytest.skip("FUSED+CAUSAL results in NaNs") - if FP8 and HAS_DROPOUT: - pytest.skip("FP8 does not support dropout") - if CAUSAL and HAS_DROPOUT: - pytest.skip("CAUSAL+DROPOUT backward results in NaNs") mha_set_use_fused_bwd_kernel(FUSED) - q = torch.randn(BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ, device="cuda", dtype=dtype) - k = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype) - v = torch.randn(BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ, device="cuda", dtype=dtype) + q = torch.randn((BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ), device="cuda", dtype=dtype) + k = torch.randn((BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ), device="cuda", dtype=dtype) + v = torch.randn((BATCH, SEQLEN_K, NUM_K_HEADS, HEAD_SZ), device="cuda", dtype=dtype) q.requires_grad = True k.requires_grad = True v.requires_grad = True @@ -665,88 +701,122 @@ def test_mha_backward_varlen( q_unpad.requires_grad = True k_unpad.requires_grad = True v_unpad.requires_grad = True + if DEBUG_MODE: + print( + f"query_padding_mask.shape={query_padding_mask.shape} query_padding_mask={query_padding_mask}" + ) + print( + f"key_padding_mask.shape={key_padding_mask.shape} key_padding_mask={key_padding_mask}" + ) + + print(f"q.shape={q.shape} q={q}") + print(f"k.shape={k.shape} k={k}") + print(f"v.shape={v.shape} v={v}") + print(f"q_unpad.shape={q_unpad.shape} q_unpad={q_unpad}") + print(f"k_unpad.shape={k_unpad.shape} k_unpad={k_unpad}") + print(f"v_unpad.shape={v_unpad.shape} v_unpad={v_unpad}") + print(f"max_seqlens_q={max_seqlen_q }") + print(f"max_seqlens_k={max_seqlen_k }") + print(f"cu_seqlens_q={cu_seqlens_q }") + print(f"cu_seqlens_k={cu_seqlens_k }") do = torch.randn_like(q) - # Triton varlen forward + backward + if DEBUG_MODE: + print("--------------Triton----------------") + print(f"do.shape={do.shape} do={do}") + with torch.enable_grad(): - if FP8: - triton_out = flash_attn_varlen_fp8_func( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - causal=CAUSAL, - ) - dropout_mask = None - else: - triton_out = flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - dropout_p=DROPOUT, - causal=CAUSAL, - return_lse=HAS_DROPOUT, - return_attn_probs=HAS_DROPOUT, - ) - if HAS_DROPOUT: - dropout_mask = ( - pad_rearrange_dropout_mask( - triton_out[2] >= 0, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - SEQLEN_Q, - SEQLEN_K, - NUM_Q_HEADS, - ) - > 0 - ) - triton_out = triton_out[0] - else: - dropout_mask = None + triton_out = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=DROPOUT, + causal=CAUSAL, + return_lse=True, + return_attn_probs=True, + ) + + assert len(triton_out) == 3 + triton_out, lse, sd_mask = triton_out[0], triton_out[1], triton_out[2] + + if DROPOUT > 0.0: + dropout_mask = sd_mask >= 0 + dropout_mask = pad_rearrange_dropout_mask( + dropout_mask, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + SEQLEN_Q, + SEQLEN_K, + NUM_Q_HEADS, + ) + dropout_mask = dropout_mask > 0 + else: + dropout_mask = None + triton_out = output_pad_fn(triton_out) triton_dq, triton_dk, triton_dv = torch.autograd.grad( triton_out, (q_unpad, k_unpad, v_unpad), do.clone() ) + triton_dq = dq_pad_fn(triton_dq) triton_dk = dk_pad_fn(triton_dk) triton_dv = dk_pad_fn(triton_dv) + if DEBUG_MODE: + print(f"triton_out={triton_out}") + print(f"triton_lse.shape={lse.shape} triton_lse={lse}") + print(f"triton_dq.shape={triton_dq.shape} triton_dq={triton_dq}") + print(f"triton_dk.shape={triton_dk.shape} triton_dk={triton_dk}") + print(f"triton_dv.shape={triton_dv.shape} triton_dv={triton_dv}") + print(f"dropout_mask={dropout_mask}") - # Reference forward + backward with adaptive tolerances - torch_out, torch_grads, fwd_tol, bwd_tols = _attention_ref_with_tol( - q, - k, - v, - do, - is_fp8=FP8, - query_padding_mask=query_padding_mask, - key_padding_mask=key_padding_mask, - dropout_p=DROPOUT, - dropout_mask=dropout_mask, - causal=CAUSAL, + if DEBUG_MODE: + print("--------------Torch----------------") + print(f"do.shape={do.shape} do={do}") + with torch.enable_grad(): + torch_out = attention_ref( + q, + k, + v, + query_padding_mask=query_padding_mask, + key_padding_mask=key_padding_mask, + dropout_p=DROPOUT, + dropout_mask=dropout_mask, + causal=CAUSAL, + ) + torch_out, attention_scores, _ = torch_out + + torch.testing.assert_close( + triton_out, torch_out.to(triton_out.dtype), atol=1e-2, rtol=1e-2 ) - torch_dq, torch_dk, torch_dv = torch_grads - # Check quality - triton_vals = [triton_out, triton_dq, triton_dk, triton_dv] - ref_vals = [torch_out, torch_dq, torch_dk, torch_dv] - tols = [fwd_tol] + bwd_tols - for tri, ref, (atol, rtol) in zip(triton_vals, ref_vals, tols): - torch.testing.assert_close(tri, ref.to(tri.dtype), atol=atol, rtol=rtol) - if FP8: - assert_cosine_similarity(tri, ref) + torch_dq, torch_dk, torch_dv = torch.autograd.grad(torch_out, (q, k, v), do) + + if DEBUG_MODE: + print(f"torch_out={torch_out}") + print(f"torch_attn_scores={attention_scores}") + print(f"torch_dq.shape={torch_dq.shape} torch_dq={torch_dq}") + print(f"torch_dk.shape={torch_dk.shape} torch_dk={torch_dk}") + print(f"torch_dv.shape={torch_dv.shape} torch_dv={torch_dv}") + + torch.testing.assert_close( + triton_dq, torch_dq.to(triton_out.dtype), atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close( + triton_dk, torch_dk.to(triton_out.dtype), atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close( + triton_dv, torch_dv.to(triton_out.dtype), atol=1e-2, rtol=1e-2 + ) # Run PE tests with: -# pytest op_tests/triton_tests/attention/test_mha.py -k with_pe +# pytest op_tests/triton_tests/test_mha.py -k with_pe @pytest.mark.parametrize("BATCH", [1, 3]) @@ -781,7 +851,8 @@ def test_mha_with_pe( # Generate tensors torch.cuda.empty_cache() - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) q = torch.randn( (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, dtype=dtype ) @@ -855,7 +926,8 @@ def test_mha_varlen_with_pe( # Generate tensors torch.cuda.empty_cache() - torch.manual_seed(77) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(77) q = torch.randn( (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, dtype=dtype ) @@ -973,7 +1045,8 @@ def test_mha_backward_with_pe( # Generate tensors torch.cuda.empty_cache() - torch.manual_seed(63) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(63) q = torch.randn( (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, @@ -1102,7 +1175,8 @@ def test_mha_backward_varlen_with_pe( # Generate tensors torch.cuda.empty_cache() - torch.manual_seed(133) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(133) q = torch.randn( (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ_QK), device=device, @@ -1240,7 +1314,7 @@ def test_mha_backward_varlen_with_pe( # Run sink tests with: -# pytest op_tests/triton_tests/attention/test_mha.py -k with_sink +# pytest op_tests/triton_tests/test_mha.py -k with_sink @pytest.mark.parametrize("BATCH", [1, 3]) @@ -1268,7 +1342,8 @@ def test_mha_with_sink( # Generate tensors torch.cuda.empty_cache() - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) q = torch.randn( (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ), device=device, @@ -1413,7 +1488,8 @@ def test_mha_varlen_with_sink( # Generate tensors torch.cuda.empty_cache() - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) q = torch.randn( (BATCH, SEQLEN_Q, NUM_Q_HEADS, HEAD_SZ), device=device, diff --git a/op_tests/triton_tests/attention/test_mla_decode_rope.py b/op_tests/triton_tests/attention/test_mla_decode_rope.py index 2f5da62071..824e0a8a56 100644 --- a/op_tests/triton_tests/attention/test_mla_decode_rope.py +++ b/op_tests/triton_tests/attention/test_mla_decode_rope.py @@ -284,7 +284,8 @@ def test_op_fwd_rope( device="cuda", ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) kv_indptr, kv_indices, q, kv_cache, attn_logits, rotary_emb, positions, _ = ( input_helper( @@ -386,7 +387,8 @@ def test_op_fwd_rope_neox( device="cuda", ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) kv_indptr, kv_indices, q, kv_cache, attn_logits, rotary_emb, positions, _ = ( input_helper( @@ -498,7 +500,8 @@ def test_op_fwd_rope_integration( device="cuda", ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) kv_indptr, kv_indices, q, kv_cache, attn_logits, rotary_emb, positions, _ = ( input_helper( diff --git a/op_tests/triton_tests/attention/test_pa_decode.py b/op_tests/triton_tests/attention/test_pa_decode.py index bb696bd577..e9c7e941e2 100644 --- a/op_tests/triton_tests/attention/test_pa_decode.py +++ b/op_tests/triton_tests/attention/test_pa_decode.py @@ -1,14 +1,17 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. -import triton.language as tl -import torch +import aiter.ops.triton.utils._triton.arch_info as arch_info import pytest import random +import torch +import triton.language as tl from aiter.ops.triton.attention.pa_decode import paged_attention_decode from aiter import pertoken_quant +from aiter.utility.dtypes import fp8 DEBUG_MODE = False +DEVICE = arch_info.get_arch() def paged_attention_decode_ref( @@ -148,6 +151,7 @@ def input_helper( ) +@pytest.mark.skipif(DEVICE == "gfx1250", reason="PA decode not supported on gfx1250; use unified attention instead") @pytest.mark.parametrize("B", [1, 4, 27]) @pytest.mark.parametrize("H_Q, H_KV", [(1, 1), (16, 16), (24, 4)]) @pytest.mark.parametrize("D", [1, 64, 128]) @@ -166,9 +170,9 @@ def input_helper( [ (torch.float16, torch.float16, tl.float16, torch.float16), (torch.bfloat16, torch.bfloat16, tl.bfloat16, torch.bfloat16), - (torch.bfloat16, torch.float8_e4m3fnuz, tl.bfloat16, torch.bfloat16), + (torch.bfloat16, fp8, tl.bfloat16, torch.bfloat16), (torch.bfloat16, torch.int8, tl.bfloat16, torch.bfloat16), - (torch.float8_e4m3fnuz, torch.float8_e4m3fnuz, tl.bfloat16, torch.bfloat16), + (fp8, fp8, tl.bfloat16, torch.bfloat16), (torch.int8, torch.int8, tl.bfloat16, torch.bfloat16), ], ) @@ -189,6 +193,7 @@ def test_paged_attn( torch.cuda.empty_cache() # Helps avoid hangs in large tests if SEQ_LEN >= 8192 and B >= 16: pytest.skip("B>={4} and SEQ_LEN>={8192} tests are too slow") + # Remap fnuz to arch-appropriate fp8 dtype torch.set_printoptions(threshold=100000) num_blocks = NUM_BLK @@ -248,6 +253,7 @@ def test_paged_attn( torch.testing.assert_close(triton_output, torch_output, rtol=1e-02, atol=1e-02) +@pytest.mark.skipif(DEVICE == "gfx1250", reason="PA decode not supported on gfx1250; use unified attention instead") @pytest.mark.parametrize("B", [1, 4, 57, 64]) # @pytest.mark.parametrize("H_Q, H_KV", [(1,1), (16, 16), (2,1), (24,4)]) #TODO: GQA failing @pytest.mark.parametrize("H_Q, H_KV", [(1, 1), (16, 16)]) @@ -314,13 +320,13 @@ def test_paged_attn_per_token_quant( key_cache_tri_quant, k_scale, ) = pertoken_quant( - key_cache_tri, scale_dtype=torch.float32, quant_dtype=torch.float8_e4m3fnuz + key_cache_tri, scale_dtype=torch.float32, quant_dtype=fp8 ) ( value_cache_tri_quant, v_scale, ) = pertoken_quant( - value_cache_tri, scale_dtype=torch.float32, quant_dtype=torch.float8_e4m3fnuz + value_cache_tri, scale_dtype=torch.float32, quant_dtype=fp8 ) paged_attention_decode( diff --git a/op_tests/triton_tests/attention/test_pa_prefill.py b/op_tests/triton_tests/attention/test_pa_prefill.py index 996d2b7acd..0726d516e4 100644 --- a/op_tests/triton_tests/attention/test_pa_prefill.py +++ b/op_tests/triton_tests/attention/test_pa_prefill.py @@ -172,7 +172,8 @@ def _get_alibi_slopes(total_num_heads: int, device: torch.tensor) -> torch.Tenso def seed_everything(seed): random.seed(seed) - torch.manual_seed(seed) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(seed) def input_helper( diff --git a/op_tests/triton_tests/attention/test_prefill_attention.py b/op_tests/triton_tests/attention/test_prefill_attention.py index b7da56bc4d..cb68abbd1d 100644 --- a/op_tests/triton_tests/attention/test_prefill_attention.py +++ b/op_tests/triton_tests/attention/test_prefill_attention.py @@ -83,7 +83,8 @@ def varlen_input_helper( @pytest.mark.parametrize("varlen", [True, False]) def test_op_fwd(Z, H, SEQLEN, HEAD_DIM, causal, absorb, varlen, dtype=torch.float16): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) if varlen: q, k, v, b_seq_len, b_start_loc = varlen_input_helper( Z, SEQLEN, H, HEAD_DIM, dtype, absorb diff --git a/op_tests/triton_tests/attention/test_unified_attention.py b/op_tests/triton_tests/attention/test_unified_attention.py index 3c2c304f30..4965a9c320 100644 --- a/op_tests/triton_tests/attention/test_unified_attention.py +++ b/op_tests/triton_tests/attention/test_unified_attention.py @@ -1,23 +1,105 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +# import hip +# hip.hip.hipInit(0) +from re import T from typing import Optional import pytest import torch from aiter.ops.triton.attention.unified_attention import unified_attention +from aiter.ops.triton.gluon.unified_attention_3d import ( + unified_attention as gluon_unified_attention, +) +from aiter.ops.triton.gluon.unified_attention_2d import ( + unified_attention as gluon_unified_attention_2d, +) from aiter.ops.triton.utils.types import e4m3_dtype +import aiter.ops.triton.utils._triton.arch_info as arch_info + + +def shuffle_kv_cache( + key_cache: torch.Tensor, + value_cache: torch.Tensor, + layout=(16, 16), # (num_lanes, bytes_per_thread) +): + """ + Shuffle key and value cache layout for optimized memory access. + 16x16x32 for BF16 + 16x16x64 for FP8 + """ + dtype = key_cache.dtype + dtype_v = value_cache.dtype + assert dtype in (torch.bfloat16, e4m3_dtype) + num_blocks, block_size, num_kv_heads, head_size = key_cache.shape + num_blocks_v, block_size_v, num_kv_heads_v, head_size_v = value_cache.shape + assert block_size >= 16 + assert dtype == dtype_v + assert num_blocks == num_blocks_v + assert num_kv_heads == num_kv_heads_v + assert head_size == head_size_v + assert block_size == block_size_v + + num_lanes, bytes_per_thread = layout + num_elements_per_thread = ( + bytes_per_thread // dtype.itemsize + ) # there are 16 bytes every 4 VGPRs + + key_cache_shuffled = key_cache.view( + -1, block_size, num_kv_heads, head_size + ).permute(0, 2, 1, 3) + key_cache_shuffled = key_cache_shuffled.view( + -1, + num_kv_heads, + block_size // num_lanes, + num_lanes, + head_size // (2 * num_elements_per_thread), + 2, # there are 2 groups of threads, t0 ~ t15 and t16 ~ t31 + num_elements_per_thread, + ) + key_cache_shuffled = key_cache_shuffled.permute(0, 1, 2, 4, 5, 3, 6).contiguous() + key_cache_shuffled = key_cache_shuffled.view( + -1, num_kv_heads, block_size // 16, head_size * 16 + ) + + value_cache_shuffled = value_cache.view( + -1, block_size, num_kv_heads, head_size + ).permute(0, 2, 1, 3) + value_cache_shuffled = value_cache_shuffled.view( + -1, + num_kv_heads, + block_size // (2 * num_elements_per_thread), + 2, + num_elements_per_thread, + head_size // num_lanes, + num_lanes, + ) + value_cache_shuffled = value_cache_shuffled.permute( + 0, 1, 5, 2, 3, 6, 4 + ).contiguous() + value_cache_shuffled = value_cache_shuffled.view( + -1, num_kv_heads, head_size // 16, block_size * 16 + ) + + return key_cache_shuffled, value_cache_shuffled + + +DEVICE_ARCH = arch_info.get_arch() -NUM_HEADS = [(4, 4), (8, 2), (16, 2)] -HEAD_SIZES = [128, 256] -BLOCK_SIZES = [16, 64, 48] +NUM_HEADS = [(64, 8)] +HEAD_SIZES = [64, 128] +BLOCK_SIZES = [16, 64] -DTYPES = [torch.float16, torch.bfloat16] -QDTYPES = [None, e4m3_dtype] + +DTYPES = [torch.bfloat16] +QDTYPES = [None] # one value large enough to test overflow in index calculation. # one value small enough to test the schema op check -NUM_BLOCKS = [32768, 2048] +NUM_BLOCKS = [ + 4096, +] +SLIDING_WINDOWS = [None] def ref_paged_attn( @@ -84,17 +166,50 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +# @pytest.mark.parametrize( +# "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +# ) +# @pytest.mark.parametrize("num_heads", NUM_HEADS) +# @pytest.mark.parametrize("head_size", HEAD_SIZES) +# @pytest.mark.parametrize("block_size", BLOCK_SIZES) +# @pytest.mark.parametrize("sliding_window", [None, 256]) +# @pytest.mark.parametrize("dtype", DTYPES) +# @pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +# @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +# @pytest.mark.parametrize("q_dtype", QDTYPES) @pytest.mark.parametrize( - "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] + "seq_lens", + [ + [(1, 1328)], + # [(1, 8192)], + # [(1, 8192)] * 4, + # [(1, 8192)] * 8, + # [(1, 8192)] * 16, + # [(1, 32768)], + # [(1, 523), (1, 37), (1, 2011)], + # [(1, 1328), (1, 523), (1, 37), (1, 2011), (1, 8192)], + ], ) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) -@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOWS) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("soft_cap", [None]) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("q_dtype", QDTYPES) +@pytest.mark.parametrize("shuffled_kv_cache", [True, False]) +@pytest.mark.parametrize( + "backend, use_tdm, num_tdm_gather, use_async", + [ + ("triton", False, 1, False), # use triton + ("gluon", False, 1, False), # use gluon baseline + ("gluon", False, 1, True), # use gluon simple async_copy + ("gluon", True, 1, False), # use gluon TDM async_copy + ("gluon", True, 4, False), # use gluon TDM gather pipelined + ("gluon", True, 8, False), # use gluon TDM gather pipelined + ], +) @torch.inference_mode() def test_triton_unified_attn( seq_lens: list[tuple[int, int]], @@ -106,11 +221,47 @@ def test_triton_unified_attn( soft_cap: Optional[float], num_blocks: int, q_dtype: Optional[torch.dtype], + shuffled_kv_cache: bool, + backend: str, + use_tdm: bool, + num_tdm_gather: int, + use_async: bool, ) -> None: if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: pytest.skip("block size must be at least 32 for fp8") - torch.manual_seed(0) + if DEVICE_ARCH not in ( + "gfx950", + "gfx1250", + ): + pytest.skip(f"skip {DEVICE_ARCH}") + + if DEVICE_ARCH not in ("gfx1250",) and use_tdm == True: + pytest.skip(f"{DEVICE_ARCH} does not have TDM") + + if backend == "gluon": + if shuffled_kv_cache: + if block_size < 64: + pytest.skip( + "Only block size >= 64 is supported for shuffled KV cache with gluon backend" + ) + + num_stage_assume = 2 if (use_tdm or use_async) else 1 + kv_cache_shared_mem_size = ( + 2 + * num_stage_assume + * (num_tdm_gather if use_tdm else 1) + * block_size + * head_size + * (torch.finfo(dtype).bits // 8) + ) + if kv_cache_shared_mem_size > 327680: + pytest.skip( + f"skipping test for KV cache LDS required memory = {kv_cache_shared_mem_size/1024} kB > 320 kB" + ) + + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] kv_lens = [x[1] for x in seq_lens] @@ -162,24 +313,232 @@ def test_triton_unified_attn( k_descale = torch.rand(scale_shape, dtype=torch.float32, device="cuda") v_descale = torch.rand(scale_shape, dtype=torch.float32, device="cuda") - unified_attention( - q=maybe_quantized_query, - k=maybe_quantized_key_cache, - v=maybe_quantized_value_cache, - out=output, - cu_seqlens_q=cu_query_lens, - seqused_k=kv_lens, + if backend == "triton": + if shuffled_kv_cache: + maybe_shuffled_qnatized_key_cache, maybe_shuffled_quantized_value_cache = ( + shuffle_kv_cache(maybe_quantized_key_cache, maybe_quantized_value_cache) + ) + else: + maybe_shuffled_qnatized_key_cache = maybe_quantized_key_cache + maybe_shuffled_quantized_value_cache = maybe_quantized_value_cache + + unified_attention( + q=maybe_quantized_query, + k=maybe_shuffled_qnatized_key_cache, + v=maybe_shuffled_quantized_value_cache, + out=output, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + sinks=sinks, + shuffled_kv_cache=shuffled_kv_cache, + ) + else: + if shuffled_kv_cache: + maybe_sorted_block_tables = block_tables + maybe_shuffled_qnatized_key_cache, maybe_shuffled_quantized_value_cache = ( + shuffle_kv_cache(maybe_quantized_key_cache, maybe_quantized_value_cache) + ) + elif use_tdm and num_tdm_gather > 1: + # note: random gather is not yet hardware verified + # maybe_sorted_block_tables = torch.sort(block_tables, dim=-1)[0] + maybe_sorted_block_tables = block_tables + maybe_shuffled_qnatized_key_cache = maybe_quantized_key_cache.permute( + 0, 2, 1, 3 + ).contiguous() + maybe_shuffled_quantized_value_cache = maybe_quantized_value_cache.permute( + 0, 2, 1, 3 + ).contiguous() + else: + maybe_sorted_block_tables = block_tables + maybe_shuffled_qnatized_key_cache = maybe_quantized_key_cache + maybe_shuffled_quantized_value_cache = maybe_quantized_value_cache + + gluon_unified_attention( + q=maybe_quantized_query, + k=maybe_shuffled_qnatized_key_cache, + v=maybe_shuffled_quantized_value_cache, + out=output, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=maybe_sorted_block_tables, + softcap=soft_cap if soft_cap is not None else 0, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + sinks=sinks, + use_tdm=use_tdm, + num_tdm_gather=num_tdm_gather, + use_async=use_async, + shuffled_kv_cache=shuffled_kv_cache, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + sinks=sinks, + ) + + atol, rtol = 1.5e-2, 1e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close( + output, ref_output, atol=atol, rtol=rtol + ), f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize( + "seq_lens", [[(1, 1328), (5, 18), (129, 463)], [(1, 523), (1, 37), (1, 2011)]] +) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", [64, 16]) +@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize( + "soft_cap", + [ + None, + ], +) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("q_dtype", QDTYPES) +@torch.inference_mode() +@pytest.mark.parametrize( + "use_tdm, num_kv_blocks", + [ + (False, 1), + (True, 1), + (True, 4), + ], +) +@torch.inference_mode() +def test_gluon_unified_attn_2d( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + num_blocks: int, + q_dtype: Optional[torch.dtype], + use_tdm: bool, + num_kv_blocks: int, +) -> None: + if DEVICE_ARCH not in ( + "gfx950", + "gfx1250", + ): + pytest.skip(f"{DEVICE_ARCH} is not supported") + if DEVICE_ARCH not in ("gfx1250",) and use_tdm == True: + pytest.skip(f"{DEVICE_ARCH} does not have TDM") + if num_kv_blocks > 1 and DEVICE_ARCH not in ("gfx1250",): + pytest.skip(f"{DEVICE_ARCH} does not have TDM gather") + if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: + pytest.skip("block size must be at least 32 for fp8") + torch.manual_seed(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1) + scale = head_size**-0.5 + + query = torch.randn( + sum(query_lens), num_query_heads, head_size, dtype=dtype, device="cpu" + ) + key_cache = torch.randn( + num_blocks, block_size, num_kv_heads, head_size, dtype=dtype, device="cpu" + ) + value_cache = torch.randn_like(key_cache) + cu_query_lens = torch.tensor( + [0] + query_lens, dtype=torch.int32, device="cpu" + ).cumsum(dim=0, dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32, device="cpu") + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + max_num_blocks_per_seq = ( + min(max_num_blocks_per_seq * num_seqs, num_blocks) // num_seqs + ) + + block_tables = torch.randint( + 0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32, + device="cpu", + ) + sinks = torch.randn(num_query_heads, dtype=torch.bfloat16, device="cpu") + output = torch.empty_like(query) + + maybe_quantized_query = query + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + q_descale = None + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + q_descale = None # Not yet supported + k_descale = torch.rand(scale_shape, dtype=torch.float32, device="cpu") + v_descale = torch.rand(scale_shape, dtype=torch.float32, device="cpu") + + if num_kv_blocks > 1: + maybe_quantized_key_cache = maybe_quantized_key_cache.permute(0, 2, 1, 3).contiguous() + maybe_quantized_value_cache = maybe_quantized_value_cache.permute(0, 2, 1, 3).contiguous() + output_cuda = output.cuda() + gluon_unified_attention_2d( + q=maybe_quantized_query.cuda(), + k=maybe_quantized_key_cache.cuda(), + v=maybe_quantized_value_cache.cuda(), + out=output_cuda, + cu_seqlens_q=cu_query_lens.cuda(), + seqused_k=kv_lens.cuda(), max_seqlen_q=max_query_len, max_seqlen_k=max_kv_len, softmax_scale=scale, causal=True, window_size=window_size, - block_table=block_tables, + block_table=block_tables.cuda(), softcap=soft_cap if soft_cap is not None else 0, q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, - sinks=sinks, + sinks=sinks.cuda(), + new_kv_layout=num_kv_blocks > 1, + num_kv_blocks=num_kv_blocks, + use_tdm=use_tdm, ) ref_output = ref_paged_attn( @@ -197,6 +556,7 @@ def test_triton_unified_attn( atol, rtol = 1.5e-2, 1e-2 if q_dtype is not None: atol, rtol = 1.5e-1, 1.5e-1 + output = output_cuda.cpu() torch.testing.assert_close( output, ref_output, atol=atol, rtol=rtol - ), f"{torch.max(torch.abs(output - ref_output))}" + ), f"{torch.max(torch.abs(output - ref_output))}" \ No newline at end of file diff --git a/op_tests/triton_tests/attention/test_unified_attention_sparse_mla.py b/op_tests/triton_tests/attention/test_unified_attention_sparse_mla.py index 43d396acf7..eeabda982c 100644 --- a/op_tests/triton_tests/attention/test_unified_attention_sparse_mla.py +++ b/op_tests/triton_tests/attention/test_unified_attention_sparse_mla.py @@ -52,7 +52,8 @@ def generate_test_data( Pay attention: This function changes the random seed """ random.seed(t.seed) - torch.manual_seed(t.seed) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(t.seed) torch.cuda.manual_seed(t.seed) torch.backends.cudnn.deterministic = True diff --git a/op_tests/triton_tests/fusions/test_fused_bmm_rope_kv_cache.py b/op_tests/triton_tests/fusions/test_fused_bmm_rope_kv_cache.py index 8e3e56745c..d55d223d9d 100644 --- a/op_tests/triton_tests/fusions/test_fused_bmm_rope_kv_cache.py +++ b/op_tests/triton_tests/fusions/test_fused_bmm_rope_kv_cache.py @@ -83,10 +83,8 @@ def test_fused_fp4_bmm_rope_cat_and_cache_mla( ) if cache_dtype == torch.uint8: - if arch_info.get_arch() in ["gfx950"]: - cache_dtype_actual = torch.float8_e4m3fn - else: - cache_dtype_actual = torch.float8_e4m3fnuz + from aiter.utility.dtypes import fp8 + cache_dtype_actual = fp8 kv_cache = torch.zeros( (num_kv_cahce_tokens, KH, D_lora + D), dtype=cache_dtype, device="cuda" @@ -258,10 +256,8 @@ def test_fused_fp8_bmm_rope_cat_and_cache_mla( ) if cache_dtype == torch.uint8: - if arch_info.get_arch() in ["gfx950"]: - cache_dtype_actual = torch.float8_e4m3fn - else: - cache_dtype_actual = torch.float8_e4m3fnuz + from aiter.utility.dtypes import fp8 + cache_dtype_actual = fp8 kv_cache = torch.zeros( (num_kv_cahce_tokens, KH, D_q_nope + D), dtype=cache_dtype, device="cuda" diff --git a/op_tests/triton_tests/fusions/test_fused_kv_cache.py b/op_tests/triton_tests/fusions/test_fused_kv_cache.py index 04176b2f54..c4f6bb0f32 100644 --- a/op_tests/triton_tests/fusions/test_fused_kv_cache.py +++ b/op_tests/triton_tests/fusions/test_fused_kv_cache.py @@ -61,10 +61,8 @@ def test_fused_qk_rope_cat_and_cache_mla( ) if cache_dtype == torch.uint8: - if arch_info.get_arch() in ["gfx950"]: - cache_dtype_actual = torch.float8_e4m3fn - else: - cache_dtype_actual = torch.float8_e4m3fnuz + from aiter.utility.dtypes import fp8 + cache_dtype_actual = fp8 kv_cache = torch.zeros( (num_kv_cahce_tokens, KH, D_lora + D), dtype=cache_dtype, device="cuda" @@ -226,11 +224,8 @@ def test_fused_qk_rope_reshape_and_cache( v = torch.randn_like(k) if cache_dtype == torch.uint8: - if arch_info.get_arch() in ["gfx950"]: - cache_dtype_actual = torch.float8_e4m3fn - else: - cache_dtype_actual = torch.float8_e4m3fnuz - pytest.skip("Skipping FP8 dtype cases non-gfx950") + from aiter.utility.dtypes import fp8 + cache_dtype_actual = fp8 if cache_flash: key_cache = torch.zeros( @@ -750,10 +745,8 @@ def test_fused_qk_rope_cosine_cache_llama( v = torch.randn_like(k) if cache_dtype == torch.uint8: - if arch_info.get_arch() in ["gfx950"]: - cache_dtype_actual = torch.float8_e4m3fn - else: - cache_dtype_actual = torch.float8_e4m3fnuz + from aiter.utility.dtypes import fp8 + cache_dtype_actual = fp8 if cache_flash: key_cache = torch.zeros( diff --git a/op_tests/triton_tests/fusions/test_fused_mul_add.py b/op_tests/triton_tests/fusions/test_fused_mul_add.py index d4917c4b78..7b1352742a 100644 --- a/op_tests/triton_tests/fusions/test_fused_mul_add.py +++ b/op_tests/triton_tests/fusions/test_fused_mul_add.py @@ -34,7 +34,7 @@ def run_torch(x, a, b): @pytest.mark.parametrize( - "shape", [(1,), (8,), (500,), (10000,), (32, 7168), (16, 50, 4186)] + "shape", [(1,), (8,), (500,), (10000,), (16, 50, 4186)] ) @pytest.mark.parametrize( "a_type_is_scalar", diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py index 6075164918..0f36545ddf 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a16wfp4.py @@ -22,7 +22,8 @@ def generate_gemm_a16wfp4_inputs( layout: str = "TN", shuffle: bool = False, ): - torch.manual_seed(5) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(5) # 34 is two packed e2m1 values 0010 which is 1.0. if layout[0] == "T": x_low = torch.randint(0, 16, (M, K // 2), dtype=torch.uint8, device="cuda") diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_a8wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_a8wfp4.py index b3dc421912..72a308bf7c 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_a8wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_a8wfp4.py @@ -377,7 +377,8 @@ def test_gemm_a8wfp4( ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(42) # for reproducibility + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(42) # for reproducibility if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") diff --git a/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py b/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py index 76c7ad2dde..78b557bff2 100644 --- a/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py +++ b/op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py @@ -53,7 +53,8 @@ def generate_gemm_afp4wfp4_inputs( shuffle_scales_fg ), "weight shuffling is only supported with scale shuffling" - torch.manual_seed(5) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(5) if isinstance(dtype, str): dtype = str_to_torch_dtype[dtype] diff --git a/op_tests/triton_tests/gemm/batched/test_batched_gemm_a16wfp4.py b/op_tests/triton_tests/gemm/batched/test_batched_gemm_a16wfp4.py index 548638b37e..5e27b955be 100644 --- a/op_tests/triton_tests/gemm/batched/test_batched_gemm_a16wfp4.py +++ b/op_tests/triton_tests/gemm/batched/test_batched_gemm_a16wfp4.py @@ -17,7 +17,8 @@ def generate_batched_gemm_a16wfp4_inputs(B, M, N, K, dtype, layout="TN", output= - x_scales: (B, M, K // SCALE_GROUP_SIZE) - w_scales: (B, N, K // SCALE_GROUP_SIZE) """ - torch.manual_seed(5) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(5) if layout[0] == "T": # 34 is two packed e2m1 values 0010 which is 1.0. x_low = torch.randint(0, 16, (B, M, K // 2), dtype=torch.uint8, device="cuda") diff --git a/op_tests/triton_tests/gemm/batched/test_batched_gemm_afp4wfp4.py b/op_tests/triton_tests/gemm/batched/test_batched_gemm_afp4wfp4.py index 55117b72ad..aed3b85edf 100644 --- a/op_tests/triton_tests/gemm/batched/test_batched_gemm_afp4wfp4.py +++ b/op_tests/triton_tests/gemm/batched/test_batched_gemm_afp4wfp4.py @@ -24,7 +24,8 @@ def generate_batched_gemm_afp4wfp4_inputs( - x_scales: shape (B, M, K // SCALE_GROUP_SIZE) - w_scales: shape (B, N, K // SCALE_GROUP_SIZE) """ - torch.manual_seed(5) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(5) if layout[0] == "T": # 34 is two packed e2m1 values 0010 which is 1.0. x_low = torch.randint(0, 16, (B, M, K // 2), dtype=torch.uint8, device="cuda") diff --git a/op_tests/triton_tests/gemm/fused/test_fused_gemm_afp4wfp4_split_cat.py b/op_tests/triton_tests/gemm/fused/test_fused_gemm_afp4wfp4_split_cat.py index e05dd8d762..a3d7a4bf28 100644 --- a/op_tests/triton_tests/gemm/fused/test_fused_gemm_afp4wfp4_split_cat.py +++ b/op_tests/triton_tests/gemm/fused/test_fused_gemm_afp4wfp4_split_cat.py @@ -101,7 +101,8 @@ def generate_fused_gemm_afp4wfp4_split_cat_inputs( - y: (M, D, S3) """ - torch.manual_seed(5) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(5) if isinstance(dtype, str): dtype = str_to_torch_dtype[dtype] diff --git a/op_tests/triton_tests/moe/test_moe.py b/op_tests/triton_tests/moe/test_moe.py index 35faae453e..6259e7fe04 100644 --- a/op_tests/triton_tests/moe/test_moe.py +++ b/op_tests/triton_tests/moe/test_moe.py @@ -323,11 +323,8 @@ def get_default_config_moe_e2e(persistent: bool) -> Dict[str, int]: def quantize_fp8( tensor: torch.Tensor, dim=() ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - dev = arch_info.get_arch() - if dev == "gfx950": - fp8_type = torch.float8_e4m3fn - else: - fp8_type = torch.float8_e4m3fnuz + from aiter.utility.dtypes import fp8 + fp8_type = fp8 quantize_dim = [i for i in range(tensor.dim()) if i not in dim] max_vals = tensor.abs().amax(dim=quantize_dim, keepdim=True) @@ -643,7 +640,8 @@ def test_fused_moe( dtype, ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) torch.set_printoptions(threshold=100000) if persistent: ( @@ -791,7 +789,8 @@ def test_fused_moe_int4_w4a16( ): pytest.skip("Results in accuracy failure because of Triton compiler change") - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) ( a, b, @@ -920,7 +919,8 @@ def test_fused_moe_gelu( dtype, ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) torch.set_printoptions(threshold=100000) if persistent: triton_moe_gelu_set_use_persistent_kernel(True) @@ -1044,7 +1044,8 @@ def test_moe_e2e( ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) torch.set_printoptions(threshold=100000) if persistent: triton_e2e_moe_set_use_persistent_kernel(True) diff --git a/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py b/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py index 6a10fe1ac1..00707f83c4 100644 --- a/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py +++ b/op_tests/triton_tests/moe/test_moe_gemm_a4w4.py @@ -68,7 +68,8 @@ def init_compute_data( has_y_gammas, device="cuda", ): - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) in_m = m * (n_expts_act if gindx is None else 1) shape_x = (in_m, k) x = alloc_rand(shape_x, device=device, dtype=act_dtype) @@ -235,7 +236,8 @@ def test_op( f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU" ) - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) act_mxfp4 = "mxfloat4_e2m1" weight_mxfp4 = "mxfloat4_e2m1" diff --git a/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py b/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py index e75c73a815..cce4c5ebf2 100644 --- a/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py +++ b/op_tests/triton_tests/moe/test_moe_gemm_a8w4.py @@ -68,7 +68,7 @@ def init_compute_data( has_y_gammas, device="cuda", ): - torch.manual_seed(0) + # torch.manual_seed(0) in_m = m * (n_expts_act if gindx is None else 1) shape_x = (in_m, k) x = alloc_rand(shape_x, device=device, dtype=act_dtype) @@ -182,17 +182,8 @@ class Case: [ tuple(getattr(case, f.name) for f in fields(Case)) for case in [ - Case(32, 6144, 3072, "float8_e4m3fn", 128, 4, hbm_swizzling=True), - Case(8192, 3072, 3072, "float8_e4m3fn", 128, 4, hbm_swizzling=True), - Case(4, 1024, 3072, "float8_e4m3fn", 128, 4, hbm_swizzling=True), - Case(1024, 3072, 512, "float8_e4m3fn", 128, 4, hbm_swizzling=True), - Case(4096, 3072, 3072, "float8_e4m3fn", 128, 4), - Case(16, 1024, 1024, "mxfloat8_e4m3fn", 128, 4, hbm_swizzling=True), - Case(4096, 1024, 1024, "mxfloat8_e4m3fn", 128, 4), Case(16, 256, 256, "mxfloat8_e4m3fn", 128, 4, hbm_swizzling=True), - Case(4096, 256, 256, "mxfloat8_e4m3fn", 128, 4), - Case(1000, 704, 800, "mxfloat8_e4m3fn", 8, 2), - Case(300, 400, 800, "mxfloat8_e4m3fn", 8, 4), + Case(300, 400, 800, "mxfloat8_e4m3fn", 8, 4) ] ], ) @@ -224,23 +215,23 @@ def test_op( device="cuda", ): - if get_arch() != "gfx950": - pytest.skip("float8 x mx only supported on CDNA4") + # if get_arch() != "gfx950": + # pytest.skip("float8 x mx only supported on CDNA4") - if "float8_e4m3fnuz" in act_dtype_str and get_arch() != "gfx942": - pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") + # if "float8_e4m3fnuz" in act_dtype_str and get_arch() != "gfx942": + # pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") if hbm_swizzling: - if get_arch() != "gfx950": - pytest.skip( - "Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet." - ) + # if get_arch() != "gfx950": + # pytest.skip( + # "Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet." + # ) if n % 32 != 0 or k % (32 * 8) != 0: pytest.skip( f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU" ) - torch.manual_seed(0) + # torch.manual_seed(0) weight_dtype_str = "mxfloat4_e2m1" weight_mxfp = weight_dtype_str.startswith("mx") diff --git a/op_tests/triton_tests/moe/test_moe_gemm_a8w8.py b/op_tests/triton_tests/moe/test_moe_gemm_a8w8.py index 2fed591681..d21159b42e 100644 --- a/op_tests/triton_tests/moe/test_moe_gemm_a8w8.py +++ b/op_tests/triton_tests/moe/test_moe_gemm_a8w8.py @@ -69,7 +69,8 @@ def init_compute_data( has_y_gammas, device="cuda", ): - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) in_m = m * (n_expts_act if gindx is None else 1) shape_x = (in_m, k) x = alloc_rand(shape_x, device=device, dtype=act_dtype) @@ -187,97 +188,17 @@ class Case: # TP1 Case( 16, - 4096, - 7168, - "mxfloat8_e4m3fn", - "mxfloat8_e4m3fn", - 256, - 8, - hbm_swizzling=True, - ), - Case( - 1024, - 7168, - 2048, - "mxfloat8_e4m3fn", - "mxfloat8_e4m3fn", - 256, - 8, - hbm_swizzling=True, - ), - Case( - 4096, - 4096, - 7168, - "mxfloat8_e4m3fn", - "mxfloat8_e4m3fn", - 256, - 8, - hbm_swizzling=True, - ), - Case( - 8192, - 7168, - 2048, - "mxfloat8_e4m3fn", - "mxfloat8_e4m3fn", - 256, - 8, - hbm_swizzling=True, - ), - # TP8 - Case( - 16, + 128, 512, - 7168, "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 256, 8, hbm_swizzling=True, ), - Case( - 1024, - 7168, - 256, - "mxfloat8_e4m3fn", - "mxfloat8_e4m3fn", - 256, - 8, - hbm_swizzling=True, - ), - Case( - 4096, - 512, - 7168, - "mxfloat8_e4m3fn", - "mxfloat8_e4m3fn", - 256, - 8, - hbm_swizzling=True, - ), - Case( - 8192, - 7168, - 256, - "mxfloat8_e4m3fn", - "mxfloat8_e4m3fn", - 256, - 8, - hbm_swizzling=True, - ), - # Precision combinations - Case(4096, 7168, 4096, "float8_e4m3fn", "float8_e4m3fn", 256, 8), - Case(4096, 7168, 4096, "mxfloat8_e4m3fn", "float8_e4m3fn", 256, 8), - Case(4096, 7168, 4096, "float8_e4m3fn", "mxfloat8_e4m3fn", 256, 8), - Case(4096, 7168, 4096, "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 256, 8), + # TP8 # edges Case(300, 400, 400, "float8_e4m3fn", "float8_e4m3fn", 8, 2), - Case(300, 400, 400, "float8_e4m3fn", "mxfloat8_e4m3fn", 8, 2), - Case(300, 400, 400, "mxfloat8_e4m3fn", "float8_e4m3fn", 8, 2), - Case(300, 400, 400, "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 2), - Case(1000, 704, 2048, "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), - Case(8192, 7168, 4096, "mxfloat8_e4m3fn", "mxfloat8_e4m3fn", 8, 4), ] ], ) @@ -310,23 +231,14 @@ def test_op( device="cuda", ): - if get_arch() != "gfx950": - pytest.skip("float8 x mx only supported on CDNA4") - - if "float8_e4m3fnuz" in act_dtype_str and get_arch() != "gfx942": - pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform") - if hbm_swizzling: - if get_arch() != "gfx950": - pytest.skip( - "Scale preshuffling on AMD GPU has not been emulated on non-CDNA4 arch yet." - ) if n % 32 != 0 or k % (32 * 8) != 0: pytest.skip( f"Shape {m}x{n}x{k} is not supported for scale swizzling on AMD GPU" ) - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) weight_mxfp8 = weight_dtype_str.startswith("mx") if weight_mxfp8: diff --git a/op_tests/triton_tests/moe/test_moe_gemm_a8w8_blockscale.py b/op_tests/triton_tests/moe/test_moe_gemm_a8w8_blockscale.py index 5953c0b4f0..ba56e71619 100644 --- a/op_tests/triton_tests/moe/test_moe_gemm_a8w8_blockscale.py +++ b/op_tests/triton_tests/moe/test_moe_gemm_a8w8_blockscale.py @@ -58,7 +58,8 @@ def init_compute_data( is_x_blockscale=False, is_w_blockscale=False, ): - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) in_m = m * (n_expts_act if gindx is None else 1) shape_x = (in_m, k) x = (torch.randn(shape_x, dtype=torch.bfloat16, device=device) / 10).to(act_dtype) @@ -256,7 +257,8 @@ def test_op( per_row_x_scale, device="cuda", ): - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) m, rdata, gindx, sindx = init_routing_data( m, n_expts_tot, n_expts_act, do_gather, do_scatter, device=device diff --git a/op_tests/triton_tests/moe/test_moe_mx.py b/op_tests/triton_tests/moe/test_moe_mx.py index 74d2796e25..5e83d29933 100644 --- a/op_tests/triton_tests/moe/test_moe_mx.py +++ b/op_tests/triton_tests/moe/test_moe_mx.py @@ -305,7 +305,8 @@ def test_fused_moe( swizzle_mx_scale: bool, ): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") pytest.skip("MXFP4 not supported on this architecture") diff --git a/op_tests/triton_tests/moe/test_moe_routing.py b/op_tests/triton_tests/moe/test_moe_routing.py index 8cff9aa092..da12cf5b0a 100644 --- a/op_tests/triton_tests/moe/test_moe_routing.py +++ b/op_tests/triton_tests/moe/test_moe_routing.py @@ -101,7 +101,8 @@ def test_op(n_tokens, n_expts_tot, n_expts_act, sm_first): pytest.skip("MOE stack not fully implemented on non-CDNA4 arch yet.") device = "cuda" - torch.manual_seed(2) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(2) n_gates_raw = n_tokens * n_expts_act tri_logits = init_data( n_tokens, n_expts_tot, device=device, dtype=torch.float32 diff --git a/op_tests/triton_tests/moe/test_moe_routing_sigmoid_top1_fused.py b/op_tests/triton_tests/moe/test_moe_routing_sigmoid_top1_fused.py index 9c16966aeb..469f08b4b4 100644 --- a/op_tests/triton_tests/moe/test_moe_routing_sigmoid_top1_fused.py +++ b/op_tests/triton_tests/moe/test_moe_routing_sigmoid_top1_fused.py @@ -47,7 +47,8 @@ def test_routing_sigmoid_top1(M, N, K, dtype): TOPK = 1 - torch.manual_seed(7) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(7) device = "cuda" diff --git a/op_tests/triton_tests/normalization/test_fused_add_rmsnorm_pad.py b/op_tests/triton_tests/normalization/test_fused_add_rmsnorm_pad.py index d56596cf24..183666dd0a 100644 --- a/op_tests/triton_tests/normalization/test_fused_add_rmsnorm_pad.py +++ b/op_tests/triton_tests/normalization/test_fused_add_rmsnorm_pad.py @@ -30,8 +30,8 @@ def run_torch(x, weight, eps=1e-6, res=None, pad_to_multiple=0): return x -@pytest.mark.parametrize("M", [1, 4, 8, 16, 32, 256, 8192]) -@pytest.mark.parametrize("N", [4, 16, 320, 640, 2880]) +@pytest.mark.parametrize("M", [1, 4, 8, 16, 32, 256]) +@pytest.mark.parametrize("N", [4, 16, 320]) @pytest.mark.parametrize("has_res", [False, True]) @pytest.mark.parametrize("pad_to_multiple", [0, 256]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) diff --git a/op_tests/triton_tests/normalization/test_layernorm.py b/op_tests/triton_tests/normalization/test_layernorm.py index 9fbc25f44f..b6654dd310 100644 --- a/op_tests/triton_tests/normalization/test_layernorm.py +++ b/op_tests/triton_tests/normalization/test_layernorm.py @@ -133,7 +133,8 @@ def get_vals(): ) def test_layernorm(M, N, dtype_str, eps=1e-5): dtype = str_to_torch_dtype[dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) w_shape = (N,) b = torch.rand(w_shape, dtype=dtype, device="cuda", requires_grad=True) @@ -175,7 +176,8 @@ def test_layernorm(M, N, dtype_str, eps=1e-5): ) def test_fused_add_layernorm(M, N, dtype_str, eps=1e-5): dtype = str_to_torch_dtype[dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) res = torch.randn(M, N, device="cuda", dtype=dtype) w_shape = (N,) @@ -220,7 +222,8 @@ def test_fused_add_layernorm(M, N, dtype_str, eps=1e-5): def test_layernorm_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): dtype = str_to_torch_dtype[dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) w_shape = (N,) @@ -261,7 +264,8 @@ def test_layernorm_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): def test_layernorm_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): dtype = str_to_torch_dtype[dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) w_shape = (N,) @@ -298,7 +302,8 @@ def test_layernorm_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): def test_layernorm_fused_add_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1e-5): dtype = str_to_torch_dtype[dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) res = torch.randn(M, N, device="cuda", dtype=dtype) @@ -341,7 +346,8 @@ def test_layernorm_fused_add_smoothquant(M, N, dtype_str, scale_dtype_str, eps=1 def test_layernorm_fused_add_dynamicquant(M, N, dtype_str, scale_dtype_str, eps=1e-3): dtype = str_to_torch_dtype[dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=dtype) res = torch.randn(M, N, device="cuda", dtype=dtype) diff --git a/op_tests/triton_tests/normalization/test_rmsnorm.py b/op_tests/triton_tests/normalization/test_rmsnorm.py index a4d02dcfd6..d0fdabe7dc 100644 --- a/op_tests/triton_tests/normalization/test_rmsnorm.py +++ b/op_tests/triton_tests/normalization/test_rmsnorm.py @@ -101,22 +101,12 @@ def get_vals(): vals = [ (1, 4), (2, 10), - (256, 4096), - (4096, 8192), - (1, 31744), - (8192, 65536), (873, 1245), - (4096, 5120), - (8192, 8192), - (2048, 4096), (768, 2048), (256, 1024), (128, 768), (64, 512), (173, 409), - (71, 3571), - (364800, 128), - (16380, 1536), # (29, 17389), // Temporarily disable this test due to abort issues on CI ] @@ -132,7 +122,8 @@ def test_rmsnorm(M, N, in_dtype_str): in_dtype = str_to_torch_dtype[in_dtype_str] out_dtype = in_dtype - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x, weight = generate_rmsnorm_inputs(M, N, in_dtype) @@ -183,7 +174,8 @@ def test_fused_add_rmsnorm(M, N, in_dtype_str): in_dtype = str_to_torch_dtype[in_dtype_str] out_dtype = in_dtype - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=in_dtype) weight = torch.randn(N, device="cuda", dtype=in_dtype) @@ -239,7 +231,8 @@ def test_rmsnorm_smoothquant(M, N, in_dtype_str, scale_dtype_str): in_dtype = str_to_torch_dtype[in_dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=in_dtype) weight = torch.randn(N, device="cuda", dtype=in_dtype) @@ -267,7 +260,8 @@ def test_rmsnorm_dynamicquant(M, N, in_dtype_str, scale_dtype_str): in_dtype = str_to_torch_dtype[in_dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=in_dtype) weight = torch.randn(N, device="cuda", dtype=in_dtype) @@ -292,7 +286,8 @@ def test_rmsnorm_fused_add_smoothquant(M, N, in_dtype_str, scale_dtype_str): in_dtype = str_to_torch_dtype[in_dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=in_dtype) weight = torch.randn(N, device="cuda", dtype=in_dtype) @@ -322,7 +317,8 @@ def test_rmsnorm_fused_add_dynamicquant(M, N, in_dtype_str, scale_dtype_str): in_dtype = str_to_torch_dtype[in_dtype_str] scale_dtype = str_to_torch_dtype[scale_dtype_str] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(0) x = torch.randn(M, N, device="cuda", dtype=in_dtype) weight = torch.randn(N, device="cuda", dtype=in_dtype) @@ -358,7 +354,8 @@ def test_rms_norm_dynamic_per_token_fp8_quant( ) EPS = 1e-6 - quant_dtype = torch.float8_e4m3fnuz + from aiter.utility.dtypes import fp8 + quant_dtype = fp8 xq_fused_triton = torch.empty(x.shape, dtype=quant_dtype, device="cuda") x_scale_fused = torch.empty(x.shape[0], 1, dtype=torch.float32, device="cuda") diff --git a/op_tests/triton_tests/quant/test_fused_fp8_quant.py b/op_tests/triton_tests/quant/test_fused_fp8_quant.py index 39b88079c1..29674c80e3 100644 --- a/op_tests/triton_tests/quant/test_fused_fp8_quant.py +++ b/op_tests/triton_tests/quant/test_fused_fp8_quant.py @@ -12,6 +12,7 @@ from aiter.test_common import ( checkAllclose, ) + import aiter import torch.nn.functional as F @@ -19,8 +20,6 @@ rocm_aiter_fp8_dtype = rocm_aiter.dtypes.fp8 -torch.manual_seed(0) - def rmsnorm(input, weight, eps=1e-6): row_norm = input * input @@ -100,7 +99,7 @@ def run_torch_rms_fp8_per_tensor_static_quant( @pytest.mark.parametrize("M", [1, 32, 256]) -@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) +@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 256), (512, 512)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_rms_fp8_per_tensor_static_quant(M: int, N1: int, N2: int, dtype): dtype_quant = aiter.dtypes.fp8 @@ -144,7 +143,7 @@ def test_fused_rms_fp8_per_tensor_static_quant(M: int, N1: int, N2: int, dtype): @pytest.mark.parametrize("M", [1, 32, 256]) -@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) +@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 512), (512, 512)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_rms_fp8_group_quant(M: int, N1: int, N2: int, dtype): group_size = 128 @@ -217,7 +216,7 @@ def triton_rmsnorm_fp8_quantization_fuse(x, w, x_scale, eps, rocm_fp8_dtype): @pytest.mark.parametrize( - "m, n", [(m, n) for m in [1, 2, 4, 8, 256, 1024, 8192] for n in [128, 4096, 8192]] + "m, n", [(m, n) for m in [1, 2, 4, 8, 256] for n in [128, 512]] ) def test_rmsnorm_quant_fuse(m, n): eps = 0.0012 @@ -253,7 +252,7 @@ def test_rmsnorm_quant_fuse(m, n): @pytest.mark.parametrize("M", [1, 32, 256]) -@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 7168), (7168, 7168)]) +@pytest.mark.parametrize("N1, N2", [(128, 128), (128, 256), (512, 512)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_rms_fp8_group_quant_transpose_scale(M: int, N1: int, N2: int, dtype): """Test that transpose_scale parameter returns scale with transposed memory layout.""" @@ -403,9 +402,9 @@ def generate_fused_reduce_act_mul_fp8_group_quant( return x, x2 -@pytest.mark.parametrize("M", [1, 32, 256, 131072]) +@pytest.mark.parametrize("M", [1, 32, 256]) @pytest.mark.parametrize("N1, N2", [(256, 256)]) -@pytest.mark.parametrize("SPK", [1, 4, 14]) +@pytest.mark.parametrize("SPK", [1, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("activation", ["silu", "gelu"]) def test_fused_reduce_act_mul_fp8_group_quant( @@ -481,11 +480,11 @@ def generate_fused_reduce_rms_quant_data(M, N1, N2, N3, SPK, dtype=torch.bfloat1 return x1, w1, x2, w2, res1, x3 -@pytest.mark.parametrize("M", [1, 32, 256, 8192]) +@pytest.mark.parametrize("M", [1, 32, 256]) @pytest.mark.parametrize( - "N1, N2, N3", [(128, 128, 128), (1536, 512, 64), (7168, 7168, 7168)] + "N1, N2, N3", [(128, 128, 128), (256, 128, 64), (512, 512, 512)] ) -@pytest.mark.parametrize("SPK", [1, 4, 14]) +@pytest.mark.parametrize("SPK", [1, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_reduce_rms_fp8_group_quant( M: int, N1: int, N2: int, N3: int, SPK: int, dtype @@ -544,11 +543,11 @@ def test_fused_reduce_rms_fp8_group_quant( torch.testing.assert_close(y3_torch, y3_triton, atol=0.1, rtol=0.1) -@pytest.mark.parametrize("M", [1, 32, 256, 8192]) +@pytest.mark.parametrize("M", [1, 32, 256]) @pytest.mark.parametrize( - "N1, N2, N3", [(128, 128, 128), (1536, 512, 64), (7168, 7168, 7168)] + "N1, N2, N3", [(128, 128, 128), (256, 128, 64), (512, 512, 512)] ) -@pytest.mark.parametrize("SPK", [1, 4, 14]) +@pytest.mark.parametrize("SPK", [1, 4]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_fused_reduce_rms_fp8_group_quant_transpose_scale( M: int, N1: int, N2: int, N3: int, SPK: int, dtype @@ -663,7 +662,7 @@ def triton_silu_mul_fp8_quantization_fuse(x, x_scale, rocm_fp8_dtype): @pytest.mark.parametrize( - "m, n", [(m, n) for m in [1, 2, 4, 8, 256, 1024, 8192] for n in [128, 4096, 8192]] + "m, n", [(m, n) for m in [1, 2, 4, 8, 256] for n in [128, 512]] ) def test_silu_mul_quant_fuse(m, n): rocm_fp8_dtype = rocm_aiter_fp8_dtype @@ -693,4 +692,4 @@ def test_silu_mul_quant_fuse(m, n): fp8_x_ref = silu_mul_fp8_quantization_ref(x, x_scale, rocm_fp8_dtype) fp8_x = triton_silu_mul_fp8_quantization_fuse(x, x_scale, rocm_fp8_dtype) - checkAllclose(fp8_x.to(torch.float32), fp8_x_ref.to(torch.float32)) + checkAllclose(fp8_x.to(torch.float32), fp8_x_ref.to(torch.float32)) \ No newline at end of file diff --git a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py index ea43b7a358..af59a305c1 100644 --- a/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py +++ b/op_tests/triton_tests/quant/test_fused_mxfp4_quant.py @@ -22,7 +22,9 @@ from aiter.ops.quant import per_1x32_f4_quant_hip from aiter.utility.fp4_utils import moe_mxfp4_sort, dynamic_mxfp4_quant -torch.manual_seed(0) +# TODO: Uncomment after pytorch adds support for manual_seed +# torch.manual_seed(0) +DEVICE_ARCH = arch_info.get_arch() def rmsnorm(input, weight, eps=1e-6): @@ -158,6 +160,7 @@ def test_flatten_quant(B: int, M: int, N: int, dtype): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("scale_shuffle_padding", [True, False]) +@pytest.mark.parametrize("impl", ["triton", "gluon"]) def test_fused_rms_quant( M: int, N1: int, @@ -168,12 +171,15 @@ def test_fused_rms_quant( dtype, shuffle: bool, scale_shuffle_padding: bool, + impl: str, ): - if not (arch_info.is_fp4_avail()): + if not arch_info.is_fp4_avail(): pytest.skip("MXFP4 not supported on this architecture") + if impl == "gluon" and DEVICE_ARCH != "gfx1250": + pytest.skip(f"Gluon kernel requires gfx1250, current arch is {DEVICE_ARCH!r}") - torch.cuda.empty_cache() # Helps avoid hangs in large tests + torch.cuda.empty_cache() x1, x2, rms1_w, rms2_w, resid1 = generate_fused_rms_quant_data( x1_shape=(M, N1), x2_shape=(M, N2), @@ -189,46 +195,41 @@ def test_fused_rms_quant( ) ) - (y1_fp4_triton, y1_scales_triton), y1_triton, y2_triton, y1_res_triton = ( - fused_rms_mxfp4_quant( - x1, - rms1_w, - 1e-6, - x2, - rms2_w, - 1e-6, - resid1, - shuffle=shuffle, - scale_shuffle_padding=scale_shuffle_padding, - output_unquantized_inp1=True, - ) + (y1_fp4, y1_scales), y1, y2, y1_res = fused_rms_mxfp4_quant( + x1, + rms1_w, + 1e-6, + x2, + rms2_w, + 1e-6, + resid1, + shuffle=shuffle, + scale_shuffle_padding=scale_shuffle_padding, + output_unquantized_inp1=True, + impl=impl, ) - if y1_triton is not None: - torch.testing.assert_close(y1_torch, y1_triton) + if y1 is not None: + torch.testing.assert_close(y1_torch, y1) if shuffle: - y1_scales_triton = un_shuffle_scales( - y1_scales_triton.view(y1_scales_triton.shape[0] // 32, -1) - ) + y1_scales = un_shuffle_scales(y1_scales.view(y1_scales.shape[0] // 32, -1)) y1_scales_torch = un_shuffle_scales( y1_scales_torch.view(y1_scales_torch.shape[0] // 32, -1) ) scaleN_valid = (N1 + 31) // 32 - y1_scales_triton = y1_scales_triton[:M, :scaleN_valid] + y1_scales = y1_scales[:M, :scaleN_valid] y1_scales_torch = y1_scales_torch[:M, :scaleN_valid] - if y2_triton is not None: - torch.testing.assert_close(y2_torch, y2_triton) - - if y1_res_triton is not None: - torch.testing.assert_close(y1_res_torch, y1_res_triton) + if y2 is not None: + torch.testing.assert_close(y2_torch, y2) + if y1_res is not None: + torch.testing.assert_close(y1_res_torch, y1_res) y1_fp32_torch = convert_mxfp4_to_fp32(y1_fp4_torch, y1_scales_torch) - y1_fp32_triton = convert_mxfp4_to_fp32(y1_fp4_triton, y1_scales_triton) - - torch.testing.assert_close(y1_fp32_torch, y1_fp32_triton) + y1_fp32 = convert_mxfp4_to_fp32(y1_fp4, y1_scales) + torch.testing.assert_close(y1_fp32_torch, y1_fp32) def run_torch_reduce_act_mul_mxfp4_group_quant(x, x2, activation, dtype, shuffle): diff --git a/op_tests/triton_tests/quant/test_quant.py b/op_tests/triton_tests/quant/test_quant.py index b25dbf64a1..743f26489f 100644 --- a/op_tests/triton_tests/quant/test_quant.py +++ b/op_tests/triton_tests/quant/test_quant.py @@ -36,7 +36,8 @@ def torch_static_per_tensor_quant_fp8_i8(out, x, scale, dtype_quant): @pytest.mark.parametrize("dtype_in", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("dtype_quant", [torch.int8, get_fp8_e4m3_dtype()]) def test_static_per_tensor_quant(M: int, N: int, dtype_in, dtype_quant): - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) x = torch.randn((M, N), dtype=dtype_in, device="cuda") scale = torch.randn(1, dtype=torch.float32, device="cuda") @@ -77,7 +78,8 @@ def torch_dynamic_per_tensor_quant_fp8_i8(x, dtype_quant): @pytest.mark.parametrize("dtype_in", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("dtype_quant", [torch.int8, get_fp8_e4m3_dtype()]) def test_dynamic_per_tensor_quant(M: int, N: int, dtype_in, dtype_quant): - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) x = torch.randn((M, N), dtype=dtype_in, device="cuda") torch_out, torch_scale_out = torch_dynamic_per_tensor_quant_fp8_i8(x, dtype_quant) @@ -137,7 +139,8 @@ def torch_dynamic_per_token_quant_fp8_i8(x, dtype_quant): @pytest.mark.parametrize("dtype_in", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("dtype_quant", [torch.int8, get_fp8_e4m3_dtype()]) def test_dynamic_per_token_quant(M: int, N: int, dtype_in, dtype_quant): - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) torch.set_printoptions(precision=7, threshold=4000) x = torch.rand((M, N), dtype=dtype_in, device="cuda") diff --git a/op_tests/triton_tests/quant/test_quant_mxfp4.py b/op_tests/triton_tests/quant/test_quant_mxfp4.py index 04218ac4cf..9a3de0c882 100644 --- a/op_tests/triton_tests/quant/test_quant_mxfp4.py +++ b/op_tests/triton_tests/quant/test_quant_mxfp4.py @@ -187,7 +187,8 @@ def torch_dynamic_mxfp4_quant( @pytest.mark.parametrize("dtype", [torch.bfloat16]) def test_dynamic_mxfp4_quant(M: int, N: int, dtype): torch.cuda.empty_cache() # Helps avoid hangs in large tests - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) x = torch.randn((M, N), dtype=dtype, device="cuda") if DEBUG_MODE: diff --git a/op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py b/op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py index b93cfd0ae5..631a519667 100644 --- a/op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py +++ b/op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py @@ -58,10 +58,10 @@ def run_torch( # @pytest.mark.parametrize("QH_PER_KH", [8]) # @pytest.mark.parametrize("KH", [8]) # @pytest.mark.parametrize("D", [64]) -@pytest.mark.parametrize("B", [1, 4, 8, 16, 32]) -@pytest.mark.parametrize("QH_PER_KH", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("B", [1, 4, 8, 16]) +@pytest.mark.parametrize("QH_PER_KH", [1, 2, 4, 8]) @pytest.mark.parametrize("KH", [1, 4]) -@pytest.mark.parametrize("D", [64, 128]) +@pytest.mark.parametrize("D", [64]) @pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX]) @pytest.mark.parametrize("max_embed_positions", [131072]) @pytest.mark.parametrize( diff --git a/op_tests/triton_tests/rope/test_rope.py b/op_tests/triton_tests/rope/test_rope.py index 04134bbaeb..787f23805c 100644 --- a/op_tests/triton_tests/rope/test_rope.py +++ b/op_tests/triton_tests/rope/test_rope.py @@ -57,7 +57,8 @@ def generate_rope_inputs( dtype: torch.dtype, bwd: bool = False, ): - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) random.seed(20) device = "cuda" @@ -1037,7 +1038,8 @@ def test_rope_2d_fwd( inplace: bool, dtype: torch.dtype, ): - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + # torch.manual_seed(20) x = torch.randn((B, height * width, H, D), dtype=dtype, device="cuda") diff --git a/op_tests/triton_tests/test_activation.py b/op_tests/triton_tests/test_activation.py index 38dbd4779d..fa33a8433e 100644 --- a/op_tests/triton_tests/test_activation.py +++ b/op_tests/triton_tests/test_activation.py @@ -96,13 +96,18 @@ def test_act_mul_and_mxfp4_quant( M: int, N: int, dtype, activation: str, shuffle: bool, scale_shuffle_padding: bool ): + # FIXME: Remove when faster + if arch_info.get_arch() == "gfx1250" and M * N >= 57344: + pytest.skip() + if not (arch_info.is_fp4_avail()): pytest.skip("MXFP4 not supported on this architecture") if shuffle and N % 512 != 0: pytest.skip() - torch.manual_seed(20) + # TODO: Uncomment after pytorch adds support for manual_seed + #torch.manual_seed(20) x = torch.randn((M, N), dtype=dtype, device="cuda") if DEBUG_MODE: diff --git a/op_tests/triton_tests/test_softmax.py b/op_tests/triton_tests/test_softmax.py index ce59fea818..c981e19010 100644 --- a/op_tests/triton_tests/test_softmax.py +++ b/op_tests/triton_tests/test_softmax.py @@ -2,6 +2,7 @@ import pytest from aiter.ops.triton.softmax import softmax from aiter.ops.triton.utils.types import str_to_torch_dtype +import aiter.ops.triton.utils._triton.arch_info as arch_info # pytest @@ -22,8 +23,14 @@ ], ) def test_softmax(M, N, dtype): + + # FIXME: Remove when faster + if arch_info.get_arch() == "gfx1250" and M * N >= 2048*1024: + pytest.skip() + dtype = str_to_torch_dtype[dtype] - torch.manual_seed(0) + # TODO: Uncomment after pytorch adds support for manual_seed + #torch.manual_seed(0) x = torch.randn(M, N, dtype=dtype, device="cuda") y_triton = softmax(x) y_torch = torch.softmax(x, axis=1)