Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 111 additions & 11 deletions atom/model_ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import aiter
import torch
import triton
import triton.language as tl
from aiter import (
QuantType,
layernorm2d_fwd,
Expand Down Expand Up @@ -288,6 +290,61 @@ def forward(
return x, residual


# decode
@triton.jit
def _rmsnorm_gated_contiguous_128_kernel(
x_ptr,
z_ptr,
weight_ptr,
out_ptr,
num_heads: tl.constexpr,
eps: tl.constexpr,
):
token_id = tl.program_id(0)
head_id = tl.program_id(1)
offsets = tl.arange(0, 128)
row_offset = (token_id * num_heads + head_id) * 128

x = tl.load(x_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32)
z = tl.load(z_ptr + row_offset + offsets, cache_modifier=".ca").to(tl.float32)
weight = tl.load(weight_ptr + offsets, cache_modifier=".ca").to(tl.float32)

variance = tl.sum(x * x, axis=0) * 0.0078125
inv_rms = tl.rsqrt(variance + eps)
gate = z * tl.sigmoid(z)
out = x * inv_rms * weight * gate

tl.store(out_ptr + row_offset + offsets, out)


# prefill
@triton.jit
def _rmsnorm_gated_contiguous_128_tiled_rows_kernel(
x_ptr,
z_ptr,
weight_ptr,
out_ptr,
num_rows: tl.constexpr,
eps: tl.constexpr,
block_rows: tl.constexpr,
):
row_offsets = tl.program_id(0) * block_rows + tl.arange(0, block_rows)
dim_offsets = tl.arange(0, 128)
mask_rows = row_offsets < num_rows
offsets = row_offsets[:, None] * 128 + dim_offsets[None, :]

x = tl.load(x_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32)
z = tl.load(z_ptr + offsets, mask=mask_rows[:, None], other=0.0).to(tl.float32)
weight = tl.load(weight_ptr + dim_offsets, cache_modifier=".ca").to(tl.float32)

variance = tl.sum(x * x, axis=1) * 0.0078125
inv_rms = tl.rsqrt(variance + eps)
gate = z * tl.sigmoid(z)
out = x * inv_rms[:, None] * weight[None, :] * gate

tl.store(out_ptr + offsets, out, mask=mask_rows[:, None])


class RMSNormGated(nn.Module):
"""RMS Normalization with optional gating.

Expand Down Expand Up @@ -360,6 +417,55 @@ def __init__(
def reset_parameters(self):
torch.nn.init.ones_(self.weight)

def forward_triton(self, x: torch.Tensor, z: torch.Tensor):
if (
z is None
or x.ndim != 3
or self.group_size is not None
or not self.norm_before_gate
or x.shape[-1] != 128
or not x.is_contiguous()
or not z.is_contiguous()
Comment on lines 418 to +428
):
Comment on lines +420 to +429
return self.forward_native(x, z)

num_tokens, num_heads, head_dim = x.shape
out = torch.empty(
(num_tokens, num_heads * head_dim),
dtype=x.dtype,
device=x.device,
)

num_rows = num_tokens * num_heads
if num_rows >= 65536:
block_rows = 32
_rmsnorm_gated_contiguous_128_tiled_rows_kernel[
(triton.cdiv(num_rows, block_rows),)
](
x,
z,
self.weight,
out,
num_rows,
self.eps,
block_rows,
num_warps=4,
num_stages=1,
)
else:
_rmsnorm_gated_contiguous_128_kernel[(num_tokens, num_heads)](
x,
z,
self.weight,
out,
num_heads,
self.eps,
num_warps=1,
num_stages=1,
)

return (out, None)

def forward_native(
self, x: torch.Tensor, z: torch.Tensor
) -> tuple[torch.Tensor, None]:
Expand Down Expand Up @@ -479,7 +585,7 @@ def forward(
if self.use_fused_fp8_quant:
return self.forward_fused_fp8(x, z)

return self.forward_native(x, z)
return self.forward_triton(x, z)


class GemmaRMSNorm(nn.Module):
Expand Down Expand Up @@ -547,13 +653,11 @@ def forward_cuda(
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if torch.compiler.is_compiling():
return self.forward_native(x, residual)
from atom.model_ops.triton_gemma_rmsnorm import gemma_rmsnorm_triton

if not getattr(self, "_is_compiled", False):
self.forward_static = torch.compile(self.forward_static) # type: ignore
self._is_compiled = True
return self.forward_native(x, residual)
return gemma_rmsnorm_triton(
x, self.weight.data, self.variance_epsilon, residual
)
Comment on lines +656 to +660
Comment on lines 653 to +660

def _forward_fused_fp8(self, x, residual=None):
from aiter.ops.fused_qk_rmsnorm_group_quant import fused_qk_rmsnorm_group_quant
Expand Down Expand Up @@ -605,10 +709,6 @@ def forward(
# ---------------------------------------------------------------------------
# Fused Q/K RMSNorm Triton kernel
# ---------------------------------------------------------------------------
import triton # noqa: E402
import triton.language as tl # noqa: E402


@triton.jit
def _fused_qk_norm_single_kernel(
q_ptr,
Expand Down
Loading