|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | +from triton.language.math import tanh |
| 6 | + |
| 7 | +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy |
| 8 | +from liger_kernel.ops.utils import ensure_contiguous |
| 9 | +from liger_kernel.ops.utils import get_npu_core_count |
| 10 | + |
| 11 | +# ----------------------------------------------------------------------------- |
| 12 | +# Forward Kernel |
| 13 | +# ----------------------------------------------------------------------------- |
| 14 | + |
| 15 | + |
| 16 | +@triton.jit |
| 17 | +def _dyt_fwd_kernel( |
| 18 | + X, |
| 19 | + Y, |
| 20 | + Alpha, |
| 21 | + Gamma, |
| 22 | + Beta, |
| 23 | + HAVE_BETA: tl.constexpr, |
| 24 | + M: tl.constexpr, |
| 25 | + N: tl.constexpr, |
| 26 | + BLOCK_N: tl.constexpr, |
| 27 | +): |
| 28 | + """ |
| 29 | + Forward kernel for DYT: y = tanh(α·x) · γ + β |
| 30 | +
|
| 31 | + Grid: (num_col_blocks, num_row_programs) |
| 32 | + Each program processes multiple rows using grid-stride loop |
| 33 | + """ |
| 34 | + pid_n = tl.program_id(0) |
| 35 | + pid_m = tl.program_id(1) |
| 36 | + num_row_programs = tl.num_programs(1) |
| 37 | + |
| 38 | + col_start = pid_n * BLOCK_N |
| 39 | + col_offsets = col_start + tl.arange(0, BLOCK_N) |
| 40 | + col_mask = col_offsets < N |
| 41 | + |
| 42 | + alpha = tl.load(Alpha).to(tl.float32) |
| 43 | + gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32) |
| 44 | + if HAVE_BETA: |
| 45 | + beta = tl.load(Beta + col_offsets, mask=col_mask, other=0.0).to(tl.float32) |
| 46 | + |
| 47 | + # Grid-stride loop over rows |
| 48 | + for row_idx in range(pid_m, M, num_row_programs): |
| 49 | + row_offset = row_idx * N |
| 50 | + |
| 51 | + x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32) |
| 52 | + |
| 53 | + # Compute: y = tanh(α·x) · γ + β |
| 54 | + tanh_x = tanh(alpha * x) |
| 55 | + y = tanh_x * gamma |
| 56 | + |
| 57 | + if HAVE_BETA: |
| 58 | + y += beta |
| 59 | + |
| 60 | + tl.store(Y + row_offset + col_offsets, y, mask=col_mask) |
| 61 | + |
| 62 | + |
| 63 | +# ----------------------------------------------------------------------------- |
| 64 | +# Backward Kernel |
| 65 | +# ----------------------------------------------------------------------------- |
| 66 | + |
| 67 | + |
| 68 | +@triton.jit |
| 69 | +def _dyt_bwd_kernel( |
| 70 | + DY, |
| 71 | + DX, |
| 72 | + DA, |
| 73 | + DG, |
| 74 | + DB, |
| 75 | + X, |
| 76 | + Alpha, |
| 77 | + Gamma, |
| 78 | + HAVE_BETA: tl.constexpr, |
| 79 | + M: tl.constexpr, |
| 80 | + N: tl.constexpr, |
| 81 | + BLOCK_N: tl.constexpr, |
| 82 | +): |
| 83 | + """ |
| 84 | + Backward kernel for DYT |
| 85 | +
|
| 86 | + Grid: (num_col_blocks, num_row_programs) |
| 87 | + Each program processes multiple rows using grid-stride loop |
| 88 | + Accumulates gradients in local buffers, then stores to global memory |
| 89 | + """ |
| 90 | + pid_n = tl.program_id(0) |
| 91 | + pid_m = tl.program_id(1) |
| 92 | + num_row_programs = tl.num_programs(1) |
| 93 | + |
| 94 | + col_start = pid_n * BLOCK_N |
| 95 | + col_offsets = col_start + tl.arange(0, BLOCK_N) |
| 96 | + col_mask = col_offsets < N |
| 97 | + |
| 98 | + alpha = tl.load(Alpha).to(tl.float32) |
| 99 | + gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32) |
| 100 | + |
| 101 | + da_vec = tl.zeros((BLOCK_N,), dtype=tl.float32) |
| 102 | + dg_acc = tl.zeros((BLOCK_N,), dtype=tl.float32) |
| 103 | + if HAVE_BETA: |
| 104 | + db_acc = tl.zeros((BLOCK_N,), dtype=tl.float32) |
| 105 | + |
| 106 | + # Grid-stride loop over rows |
| 107 | + for row_idx in range(pid_m, M, num_row_programs): |
| 108 | + row_offset = row_idx * N |
| 109 | + |
| 110 | + x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32) |
| 111 | + dy = tl.load(DY + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32) |
| 112 | + |
| 113 | + tanh_x = tanh(alpha * x) |
| 114 | + |
| 115 | + if HAVE_BETA: |
| 116 | + db_acc += dy |
| 117 | + |
| 118 | + dg_acc += dy * tanh_x |
| 119 | + |
| 120 | + # Compute intermediate: tmp = (1 - tanh²) · dy · γ |
| 121 | + tmp = (1.0 - tanh_x * tanh_x) * dy * gamma |
| 122 | + |
| 123 | + # Accumulate dα = Σ(x · tmp) |
| 124 | + da_vec += x * tmp |
| 125 | + |
| 126 | + # Compute dx = α · tmp |
| 127 | + dx = alpha * tmp |
| 128 | + tl.store(DX + row_offset + col_offsets, dx, mask=col_mask) |
| 129 | + |
| 130 | + da_acc = tl.sum(da_vec, 0) |
| 131 | + da_offset = pid_m * triton.cdiv(N, BLOCK_N) + pid_n |
| 132 | + tl.store(DA + da_offset, da_acc) |
| 133 | + |
| 134 | + dg_offset = pid_m * N + col_offsets |
| 135 | + tl.store(DG + dg_offset, dg_acc, mask=col_mask) |
| 136 | + |
| 137 | + if HAVE_BETA: |
| 138 | + db_offset = pid_m * N + col_offsets |
| 139 | + tl.store(DB + db_offset, db_acc, mask=col_mask) |
| 140 | + |
| 141 | + |
| 142 | +def get_optimal_block_size(total_elements, is_backward=False): |
| 143 | + """ |
| 144 | + Calculate optimal Block Size using compute_default_tiling_strategy |
| 145 | + """ |
| 146 | + multiplier = 8.0 if is_backward else 4.0 |
| 147 | + |
| 148 | + tile_shapes = compute_default_tiling_strategy( |
| 149 | + safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,) |
| 150 | + ) |
| 151 | + |
| 152 | + if tile_shapes and len(tile_shapes) > 0: |
| 153 | + block_size = tile_shapes[0][0] |
| 154 | + return block_size |
| 155 | + else: |
| 156 | + return 2048 |
| 157 | + |
| 158 | + |
| 159 | +def _compute_grid_size(n_cols, n_rows, block_n): |
| 160 | + """ |
| 161 | + Compute grid size to avoid launching idle programs |
| 162 | +
|
| 163 | + Args: |
| 164 | + n_cols: Number of columns |
| 165 | + n_rows: Number of rows |
| 166 | + block_n: Block size for column dimension |
| 167 | +
|
| 168 | + Returns: |
| 169 | + (num_col_blocks, num_row_programs) |
| 170 | + """ |
| 171 | + num_cores = get_npu_core_count() |
| 172 | + num_col_blocks = triton.cdiv(n_cols, block_n) |
| 173 | + num_row_blocks = n_rows |
| 174 | + |
| 175 | + num_row_programs = min(max(1, (num_cores // num_col_blocks)), num_row_blocks) |
| 176 | + |
| 177 | + return num_col_blocks, num_row_programs |
| 178 | + |
| 179 | + |
| 180 | +# ----------------------------------------------------------------------------- |
| 181 | +# Python Wrapper Functions |
| 182 | +# ----------------------------------------------------------------------------- |
| 183 | + |
| 184 | + |
| 185 | +def liger_dyt_fwd(x, alpha, gamma, beta): |
| 186 | + """ |
| 187 | + Forward pass of DYT: y = tanh(α·x) · γ + β |
| 188 | +
|
| 189 | + Args: |
| 190 | + x: Input tensor of shape [..., N] |
| 191 | + alpha: Scalar parameter |
| 192 | + gamma: Vector parameter of shape [N] |
| 193 | + beta: Vector parameter of shape [N] (optional) |
| 194 | +
|
| 195 | + Returns: |
| 196 | + y: Output tensor of same shape as x |
| 197 | + """ |
| 198 | + assert x.is_contiguous() |
| 199 | + HAVE_BETA = beta is not None |
| 200 | + |
| 201 | + # Flatten to 2D |
| 202 | + input_shape = x.shape |
| 203 | + x = x.view(-1, input_shape[-1]) |
| 204 | + M, N = x.shape |
| 205 | + |
| 206 | + # Allocate output |
| 207 | + y = torch.empty_like(x) |
| 208 | + |
| 209 | + block_n = get_optimal_block_size(N, is_backward=False) |
| 210 | + |
| 211 | + # Compute grid size |
| 212 | + num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n) |
| 213 | + grid = (num_col_blocks, num_row_programs) |
| 214 | + |
| 215 | + # Launch kernel |
| 216 | + _dyt_fwd_kernel[grid](x, y, alpha, gamma, beta, HAVE_BETA, M, N, BLOCK_N=block_n) |
| 217 | + |
| 218 | + return y.view(input_shape) |
| 219 | + |
| 220 | + |
| 221 | +def liger_dyt_bwd(dy, x, alpha, gamma, beta): |
| 222 | + """ |
| 223 | + Backward pass of DYT |
| 224 | +
|
| 225 | + Args: |
| 226 | + dy: Upstream gradient of shape [..., N] |
| 227 | + x: Input tensor of shape [..., N] |
| 228 | + alpha: Scalar parameter |
| 229 | + gamma: Vector parameter of shape [N] |
| 230 | + beta: Vector parameter of shape [N] (optional) |
| 231 | +
|
| 232 | + Returns: |
| 233 | + dx: Gradient w.r.t. x |
| 234 | + dalpha: Gradient w.r.t. alpha |
| 235 | + dgamma: Gradient w.r.t. gamma |
| 236 | + dbeta: Gradient w.r.t. beta (or None) |
| 237 | + """ |
| 238 | + assert dy.is_contiguous() |
| 239 | + HAVE_BETA = beta is not None |
| 240 | + |
| 241 | + # Flatten to 2D |
| 242 | + input_shape = x.shape |
| 243 | + x = x.view(-1, input_shape[-1]) |
| 244 | + dy = dy.view(-1, input_shape[-1]) |
| 245 | + M, N = x.shape |
| 246 | + |
| 247 | + block_n = get_optimal_block_size(N, is_backward=True) |
| 248 | + |
| 249 | + # Compute grid size |
| 250 | + num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n) |
| 251 | + grid = (num_col_blocks, num_row_programs) |
| 252 | + |
| 253 | + da = torch.zeros(num_row_programs, triton.cdiv(N, block_n), dtype=torch.float32, device=x.device) |
| 254 | + dg = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device) |
| 255 | + db = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None |
| 256 | + dx = torch.empty_like(dy) |
| 257 | + |
| 258 | + _dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, BLOCK_N=block_n) |
| 259 | + |
| 260 | + da = da.sum().to(x.dtype).unsqueeze(0) |
| 261 | + dg = dg.sum(0).to(gamma.dtype) |
| 262 | + db = db.sum(0).to(x.dtype) if HAVE_BETA else None |
| 263 | + |
| 264 | + return dx.view(input_shape), da, dg, db |
| 265 | + |
| 266 | + |
| 267 | +# ----------------------------------------------------------------------------- |
| 268 | +# Autograd Function |
| 269 | +# ----------------------------------------------------------------------------- |
| 270 | + |
| 271 | + |
| 272 | +class LigerDyTFunction(torch.autograd.Function): |
| 273 | + @staticmethod |
| 274 | + @ensure_contiguous |
| 275 | + def forward(ctx, x, alpha, gamma, beta): |
| 276 | + y = liger_dyt_fwd(x, alpha, gamma, beta) |
| 277 | + ctx.save_for_backward(x, alpha, gamma, beta) |
| 278 | + return y |
| 279 | + |
| 280 | + @staticmethod |
| 281 | + @ensure_contiguous |
| 282 | + def backward(ctx, dy): |
| 283 | + x, alpha, gamma, beta = ctx.saved_tensors |
| 284 | + dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta) |
| 285 | + return dx, dalpha, dgamma, dbeta |
0 commit comments