diff --git a/src/liger_kernel/ops/rms_norm.py b/src/liger_kernel/ops/rms_norm.py index e5cab72ea..69936fe20 100644 --- a/src/liger_kernel/ops/rms_norm.py +++ b/src/liger_kernel/ops/rms_norm.py @@ -57,10 +57,10 @@ def _rms_norm_forward_kernel( casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out elementwise_affine: tl.constexpr, BLOCK_SIZE: tl.constexpr, + STORE_RSTD: tl.constexpr, ): """ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) - Reference: 1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html 2. https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/rms_layernorm.py#L22 @@ -98,9 +98,9 @@ def _rms_norm_forward_kernel( rstd = rsqrt(mean_square + eps) # We can save time by caching rms with minimal memory overhead - # because rms is much smaller compared to X_row, as rms is for each row. # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). - tl.store(rstd_base, rstd) + if STORE_RSTD: + tl.store(rstd_base, rstd) X_row = X_row * rstd @@ -230,6 +230,7 @@ def _block_rms_norm_forward_kernel( elementwise_affine: tl.constexpr, BLOCK_SIZE: tl.constexpr, BLOCK_ROW: tl.constexpr, + STORE_RSTD: tl.constexpr, ): """ y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N) @@ -272,9 +273,9 @@ def _block_rms_norm_forward_kernel( rstd = rsqrt(mean_square + eps) # We can save time by caching rms with minimal memory overhead - # because rms is much smaller compared to X_row, as rms is for each row. # However, on the computation side, it can save 4 operations (*, sum, /, sqrt). - tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask) + if STORE_RSTD: + tl.store(RSTD_ptr + row_idx * RSTD_row_stride, rstd, row_mask) X_row = X_row * rstd[:, None] @@ -406,7 +407,10 @@ def _block_rms_norm_backward_kernel( } -def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): +def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode, inference_mode=False): + """ + Added inference_mode argument. If True, skips RSTD calculation storage. + """ if not isinstance(casting_mode, int): assert casting_mode in _str_to_casting_mode, f"Invalid casting mode: {casting_mode}" casting_mode = _str_to_casting_mode[casting_mode] @@ -420,10 +424,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): BLOCK_SIZE, num_warps = calculate_settings(n_cols) Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device) - # RSTD is to cache rstd for each row - # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode - rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype - RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + + # RSTD hum tabhi banayenge agar inference mode NAHI hai + if not inference_mode: + rstd_dtype = torch.float32 if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value) else X.dtype + RSTD = torch.empty(n_rows, dtype=rstd_dtype, device=X.device) + else: + # Dummy tensor (0-size) to satisfy pointer arguments, won't be written to. + RSTD = torch.empty(0, device=X.device) if W is not None: # Check constraints. @@ -447,13 +455,14 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): W, W.stride(0) if elementwise_affine else 0, RSTD, - RSTD.stride(0), + RSTD.stride(0) if not inference_mode else 0, n_cols, eps, offset, casting_mode, elementwise_affine=elementwise_affine, BLOCK_SIZE=BLOCK_SIZE, + STORE_RSTD=(not inference_mode), num_warps=num_warps, **kernel_args, # XPU-specific optimization ) @@ -468,7 +477,7 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): W, W.stride(0) if elementwise_affine else 0, RSTD, - RSTD.stride(0), + RSTD.stride(0) if not inference_mode else 0, n_rows, n_cols, eps, @@ -476,9 +485,15 @@ def rms_norm_forward(X, W, eps, offset, casting_mode, row_mode): casting_mode, elementwise_affine=elementwise_affine, BLOCK_SIZE=BLOCK_SIZE, + STORE_RSTD=(not inference_mode), num_warps=num_warps, **kernel_args, # XPU-specific optimization ) + + # Return Logic Update + if inference_mode: + return Y.view(*shape), X, None, BLOCK_SIZE, num_warps, casting_mode + return Y.view(*shape), X, RSTD, BLOCK_SIZE, num_warps, casting_mode @@ -494,7 +509,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp elif X.device.type == "xpu": sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count elif X.device.type == "npu": - sm_count = get_npu_core_count() + sm_count = get_npu_multi_processor_count() if W is not None: # fp32 for numerical stability especially. @@ -605,18 +620,12 @@ class LigerRMSNormFunction(torch.autograd.Function): @staticmethod @ensure_contiguous def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row_mode=None): - """ - X: (B, T, H) or (BxT, H) - W: (H,) - """ if isinstance(X, torch.distributed.tensor.DTensor): - # Input tensor is output of a tensor parallel module and - # needs to be gathered to a local tensor to compute - # RMSE layer norm on each TP worker. - # TODO: support CP. X = X.full_tensor() + # NOTE: Default inference_mode=False ensures existing training behavior is preserved Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = rms_norm_forward(X, W, eps, offset, casting_mode, row_mode) + ctx.offset = offset ctx.casting_mode = casting_mode ctx.in_place = in_place @@ -633,9 +642,6 @@ def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama", in_place=True, row @staticmethod @ensure_contiguous def backward(ctx, dY): - """ - Y: (B, T, H) or (BxT, H) - """ if ctx.elementwise_affine: X, W, RSTD = ctx.saved_tensors else: @@ -643,12 +649,9 @@ def backward(ctx, dY): W = None if isinstance(dY, torch.distributed.tensor.DTensor): - # Gradients are output of a tensor parallel module and - # needs to be gathered to a local tensor for computing RMSE layer. - # TODO: support CP. dY = dY.full_tensor() dX, dW = rms_norm_backward( dY, X, W, RSTD, ctx.offset, ctx.casting_mode, ctx.BLOCK_SIZE, ctx.num_warps, ctx.in_place, ctx.row_mode ) - return dX, dW, None, None, None, None, None + return dX, dW, None, None, None, None, None \ No newline at end of file