From be2e93e2c38c344a4b3052ab6b3ad4eef49309e4 Mon Sep 17 00:00:00 2001 From: Umang Singh Date: Thu, 5 Feb 2026 20:55:31 +0530 Subject: [PATCH 1/2] inference-optimize --- src/liger_kernel/ops/rms_norm.py | 111 +++++++++++++------------------ 1 file changed, 46 insertions(+), 65 deletions(-) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index e5cab72ea..3aad4debd 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -57,14 +57,11 @@ def _rms_norm_forward_kernel( casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out elementwise_affine: tl.constexpr, BLOCK_SIZE: tl.constexpr, + # [OPTIMIZATION] Added Switch + STORE_RSTD: tl.constexpr, ): """ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) - - Reference: - 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html - 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 - 3. https://arxiv.org/pdf/1910.07467 """ row_idx = tl.program_id(0).to(tl.int64) @@ -97,10 +94,9 @@ def _rms_norm_forward_kernel( mean_square = tl.sum(X_row * X_row, axis=0) / n_cols rstd = rsqrt(mean_square + eps) - # We can save time by caching rms with minimal memory overhead - # because rms is much smaller compared to X_row, as rms is for each row. - # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). - tl.store(rstd_base, rstd) + #[OPTIMIZATION] Only store RSTD if needed (Training) + if STORE_RSTD: + tl.store(rstd_base, rstd) X_row = X_row * rstd @@ -143,8 +139,7 @@ def _rms_norm_backward_kernel( BLOCK_SIZE: tl.constexpr, ): """ - dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product - dw = sum(dy * (x / RMS)). summation over BxT dimension + Backward kernel remains unchanged. """ row_block_id = tl.program_id(0).to(tl.int64) @@ -230,14 +225,11 @@ def _block_rms_norm_forward_kernel( elementwise_affine: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_ROW: tl.constexpr, + #[OPTIMIZATION] Added Switch + STORE_RSTD: tl.constexpr, ): """ - y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) - - Reference: - 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html - 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 - 3. https://arxiv.org/pdf/1910.07467 + Block implementation optimization. """ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW) @@ -271,10 +263,9 @@ def _block_rms_norm_forward_kernel( mean_square = tl.sum(X_row * X_row, axis=1) / n_cols rstd = rsqrt(mean_square + eps) - # We can save time by caching rms with minimal memory overhead - # because rms is much smaller compared to X_row, as rms is for each row. - # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). - tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask) + #[OPTIMIZATION] Only store RSTD if needed + if STORE_RSTD: + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask) X_row = X_row * rstd[:, None] @@ -320,10 +311,6 @@ def _block_rms_norm_backward_kernel( BLOCK_SIZE: tl.constexpr, BLOCK_ROW: tl.constexpr, ): - """ - dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product - dw = sum(dy * (x / RMS)). summation over BxT dimension - """ pid = tl.program_id(0).cast(tl.int64) NUM_SMS = tl.num_programs(0) @@ -356,7 +343,7 @@ def _block_rms_norm_backward_kernel( X_row = X_row.to(tl.float32) - # Different bacward graphs for different casting modes + # Different backward graphs for different casting modes if casting_mode == _CASTING_MODE_LLAMA: if elementwise_affine: m = (dY_row * W_row[None, :]).to(tl.float32) @@ -406,7 +393,7 @@ def _block_rms_norm_backward_kernel( } -def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): +def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=False): if not isinstance(casting_mode, int): assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" casting_mode = _str_to_casting_mode[casting_mode] @@ -420,10 +407,16 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): BLOCK_SIZE, num_warps = calculate_settings(n_cols) Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - # RSTD is to cache rstd for each row - # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode - rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype - RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + #[OPTIMIZATION] Smart Allocation for RSTD + # Only allocate real memory if we are NOT in inference mode (i.e., we need it for training) + # If inference_mode is True, we pass a dummy empty tensor to avoid allocation overhead. + if not inference_mode: + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + else: + # Dummy tensor (0-size) to satisfy pointer arguments, won't be written to. + RSTD = torch.empty(0, device=X.device) if W is not None: # Check constraints. @@ -438,6 +431,8 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): kernel_args = {} if X.device.type == "xpu": set_large_grf_mode(kernel_args) + + # Decide which kernel to launch if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode: _rms_norm_forward_kernel[(n_rows,)]( Y, @@ -447,15 +442,17 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): W, W.stride(0) if elementwise_affine else 0, RSTD, - RSTD.stride(0), + RSTD.stride(0) if not inference_mode else 0, # Safety stride 0 n_cols, eps, offset, casting_mode, elementwise_affine=elementwise_affine, BLOCK_SIZE=BLOCK_SIZE, + #[OPTIMIZATION] Pass flag + STORE_RSTD=(not inference_mode), num_warps=num_warps, - **kernel_args, # XPU-specific optimization + **kernel_args, ) else: BLOCK_ROW = 16 @@ -468,7 +465,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): W, W.stride(0) if elementwise_affine else 0, RSTD, - RSTD.stride(0), + RSTD.stride(0) if not inference_mode else 0, # Safety stride 0 n_rows, n_cols, eps, @@ -476,13 +473,21 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): casting_mode, elementwise_affine=elementwise_affine, BLOCK_SIZE=BLOCK_SIZE, + #[OPTIMIZATION] Pass flag + STORE_RSTD=(not inference_mode), num_warps=num_warps, - **kernel_args, # XPU-specific optimization + **kernel_args, ) + + #[OPTIMIZATION] Return None for RSTD if in inference mode + if inference_mode: + return Y.view(*shape), X, None, BLOCK_SIZE, num_warps, casting_mode + return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode): + # Backward function remains exactly the same as original shape = dY.shape dim = shape[-1] dY = dY.view(-1, dim) @@ -494,7 +499,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp elif X.device.type == "xpu": sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count elif X.device.type == "npu": - sm_count = get_npu_core_count() + sm_count = get_npu_multi_processor_count() if W is not None: # fp32 for numerical stability especially. @@ -582,24 +587,8 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp class LigerRMSNormFunction(torch.autograd.Function): """ - Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the - weight tensor `W`, with an optional offset and casting mode. - - Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma - uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual - `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. - - In addition, different models cast their inputs at different places during RMSNorm computation. For - example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the - inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently - support the following casting modes (they match HuggingFace Transformers' implementations): - - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. - - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. - - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. - - `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs. - For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place. - Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False` + Autograd FunctionWrapper. + We don't expose inference_mode here because autograd implies training/backward needs. """ @staticmethod @@ -610,13 +599,11 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row W: (H,) """ if isinstance(X, torch.distributed.tensor.DTensor): - # Input tensor is output of a tensor parallel module and - # needs to be gathered to a local tensor to compute - # RMSE layer norm on each TP worker. - # TODO: support CP. X = X.full_tensor() + # Call with default inference_mode=False (Training behavior) Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode) + ctx.offset = offset ctx.casting_mode = casting_mode ctx.in_place = in_place @@ -633,9 +620,6 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row @staticmethod @ensure_contiguous def backward(ctx, dY): - """ - Y: (B, T, H) or (BxT, H) - """ if ctx.elementwise_affine: X, W, RSTD = ctx.saved_tensors else: @@ -643,12 +627,9 @@ def backward(ctx, dY): W = None if isinstance(dY, torch.distributed.tensor.DTensor): - # Gradients are output of a tensor parallel module and - # needs to be gathered to a local tensor for computing RMSE layer. - # TODO: support CP. dY = dY.full_tensor() dX, dW = rms_norm_backward( dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode ) - return dX, dW, None, None, None, None, None + return dX, dW, None, None, None, None, None \ No newline at end of file From 82d05c1c34766d94a07163b951e91020baa3ad17 Mon Sep 17 00:00:00 2001 From: Umang Singh Date: Thu, 5 Feb 2026 21:16:40 +0530 Subject: [PATCH 2/2] fixed comment --- src/liger_kernel/ops/rms_norm.py | 78 ++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index 3aad4debd..69936fe20 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -57,11 +57,14 @@ def _rms_norm_forward_kernel( casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out elementwise_affine: tl.constexpr, BLOCK_SIZE: tl.constexpr, - # [OPTIMIZATION] Added Switch STORE_RSTD: tl.constexpr, ): """ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) + Reference: + 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + 3. https://arxiv.org/pdf/1910.07467 """ row_idx = tl.program_id(0).to(tl.int64) @@ -94,7 +97,8 @@ def _rms_norm_forward_kernel( mean_square = tl.sum(X_row * X_row, axis=0) / n_cols rstd = rsqrt(mean_square + eps) - #[OPTIMIZATION] Only store RSTD if needed (Training) + # We can save time by caching rms with minimal memory overhead + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). if STORE_RSTD: tl.store(rstd_base, rstd) @@ -139,7 +143,8 @@ def _rms_norm_backward_kernel( BLOCK_SIZE: tl.constexpr, ): """ - Backward kernel remains unchanged. + dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product + dw = sum(dy * (x / RMS)). summation over BxT dimension """ row_block_id = tl.program_id(0).to(tl.int64) @@ -225,11 +230,15 @@ def _block_rms_norm_forward_kernel( elementwise_affine: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_ROW: tl.constexpr, - #[OPTIMIZATION] Added Switch STORE_RSTD: tl.constexpr, ): """ - Block implementation optimization. + y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) + + Reference: + 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html + 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 + 3. https://arxiv.org/pdf/1910.07467 """ row_idx = tl.program_id(0) * BLOCK_ROW + tl.arange(0, BLOCK_ROW) @@ -263,7 +272,8 @@ def _block_rms_norm_forward_kernel( mean_square = tl.sum(X_row * X_row, axis=1) / n_cols rstd = rsqrt(mean_square + eps) - #[OPTIMIZATION] Only store RSTD if needed + # We can save time by caching rms with minimal memory overhead + # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). if STORE_RSTD: tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask) @@ -311,6 +321,10 @@ def _block_rms_norm_backward_kernel( BLOCK_SIZE: tl.constexpr, BLOCK_ROW: tl.constexpr, ): + """ + dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product + dw = sum(dy * (x / RMS)). summation over BxT dimension + """ pid = tl.program_id(0).cast(tl.int64) NUM_SMS = tl.num_programs(0) @@ -343,7 +357,7 @@ def _block_rms_norm_backward_kernel( X_row = X_row.to(tl.float32) - # Different backward graphs for different casting modes + # Different bacward graphs for different casting modes if casting_mode == _CASTING_MODE_LLAMA: if elementwise_affine: m = (dY_row * W_row[None, :]).to(tl.float32) @@ -394,6 +408,9 @@ def _block_rms_norm_backward_kernel( def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=False): + """ + Added inference_mode argument. If True, skips RSTD calculation storage. + """ if not isinstance(casting_mode, int): assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" casting_mode = _str_to_casting_mode[casting_mode] @@ -408,9 +425,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=F Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - #[OPTIMIZATION] Smart Allocation for RSTD - # Only allocate real memory if we are NOT in inference mode (i.e., we need it for training) - # If inference_mode is True, we pass a dummy empty tensor to avoid allocation overhead. + # RSTD hum tabhi banayenge agar inference mode NAHI hai if not inference_mode: rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) @@ -431,8 +446,6 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=F kernel_args = {} if X.device.type == "xpu": set_large_grf_mode(kernel_args) - - # Decide which kernel to launch if BLOCK_SIZE > 256 or n_rows < 4096 * 8 or row_mode: _rms_norm_forward_kernel[(n_rows,)]( Y, @@ -442,17 +455,16 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=F W, W.stride(0) if elementwise_affine else 0, RSTD, - RSTD.stride(0) if not inference_mode else 0, # Safety stride 0 + RSTD.stride(0) if not inference_mode else 0, n_cols, eps, offset, casting_mode, elementwise_affine=elementwise_affine, BLOCK_SIZE=BLOCK_SIZE, - #[OPTIMIZATION] Pass flag STORE_RSTD=(not inference_mode), num_warps=num_warps, - **kernel_args, + **kernel_args, # XPU-specific optimization ) else: BLOCK_ROW = 16 @@ -465,7 +477,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=F W, W.stride(0) if elementwise_affine else 0, RSTD, - RSTD.stride(0) if not inference_mode else 0, # Safety stride 0 + RSTD.stride(0) if not inference_mode else 0, n_rows, n_cols, eps, @@ -473,13 +485,12 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=F casting_mode, elementwise_affine=elementwise_affine, BLOCK_SIZE=BLOCK_SIZE, - #[OPTIMIZATION] Pass flag STORE_RSTD=(not inference_mode), num_warps=num_warps, - **kernel_args, + **kernel_args, # XPU-specific optimization ) - - #[OPTIMIZATION] Return None for RSTD if in inference mode + + # Return Logic Update if inference_mode: return Y.view(*shape), X, None, BLOCK_SIZE, num_warps, casting_mode @@ -487,7 +498,6 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=F def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps, in_place, row_mode): - # Backward function remains exactly the same as original shape = dY.shape dim = shape[-1] dY = dY.view(-1, dim) @@ -587,21 +597,33 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp class LigerRMSNormFunction(torch.autograd.Function): """ - Autograd FunctionWrapper. - We don't expose inference_mode here because autograd implies training/backward needs. + Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the + weight tensor `W`, with an optional offset and casting mode. + + Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma + uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual + `(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function. + + In addition, different models cast their inputs at different places during RMSNorm computation. For + example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the + inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently + support the following casting modes (they match HuggingFace Transformers' implementations): + - 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32. + - 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype. + - 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation. + + `in_place` option means whether to in_place modify dY to store dX. This is default to `True` to save memory. However, under certain cases, it can produce incorrect inputs. + For example, gemma2 uses two rmsnorm sequentially with residual in between. The resesidual part needs dY so it cannot be modified in-place. + Therefore, for the patching of RMSNorm in gemma2, we set `in_place` to `False` """ @staticmethod @ensure_contiguous def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None): - """ - X: (B, T, H) or (BxT, H) - W: (H,) - """ if isinstance(X, torch.distributed.tensor.DTensor): X = X.full_tensor() - # Call with default inference_mode=False (Training behavior) + # NOTE: Default inference_mode=False ensures existing training behavior is preserved Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode) ctx.offset = offset