Skip to content

Commit c3dc732

Browse files
authored
[NPU]:Added support for the dyt operator (#1124)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> - Grid-stride loop optimization: efficient multi-row processing with automatic grid size tuning - Memory access optimization: Column-blocked processing with configurable BLOCK_N, Dynamic block size selection (1024-2048) based on tensor width <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <img width="1135" height="610" alt="image" src="https://github.com/user-attachments/assets/7191daae-e912-41dc-98f2-1f130c3ec86e" /> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: Atlas 800I A2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 8e50f2a commit c3dc732

2 files changed

Lines changed: 291 additions & 0 deletions

File tree

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
If __all__ is not defined, all public symbols will be auto-discovered.
1515
"""
1616

17+
from liger_kernel.ops.backends._ascend.ops.embedding import LigerDyTFunction
1718
from liger_kernel.ops.backends._ascend.ops.embedding import LigerEmbeddingFunction
1819
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_backward
1920
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_forward
21+
from liger_kernel.ops.backends._ascend.ops.embedding import liger_dyt_bwd
22+
from liger_kernel.ops.backends._ascend.ops.embedding import liger_dyt_fwd
2023
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
2124
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_backward
2225
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_forward
@@ -79,6 +82,9 @@
7982
"LigerLlama4RopeFunction",
8083
"llama4_rope_forward",
8184
"llama4_rope_backward",
85+
"LigerDyTFunction",
86+
"liger_dyt_fwd",
87+
"liger_dyt_bwd",
8288
"LigerKLDivLossFunction",
8389
"kldiv_forward_triton",
8490
"kldiv_backward_triton",
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
from triton.language.math import tanh
6+
7+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
8+
from liger_kernel.ops.utils import ensure_contiguous
9+
from liger_kernel.ops.utils import get_npu_core_count
10+
11+
# -----------------------------------------------------------------------------
12+
# Forward Kernel
13+
# -----------------------------------------------------------------------------
14+
15+
16+
@triton.jit
17+
def _dyt_fwd_kernel(
18+
X,
19+
Y,
20+
Alpha,
21+
Gamma,
22+
Beta,
23+
HAVE_BETA: tl.constexpr,
24+
M: tl.constexpr,
25+
N: tl.constexpr,
26+
BLOCK_N: tl.constexpr,
27+
):
28+
"""
29+
Forward kernel for DYT: y = tanh(α·x) · γ + β
30+
31+
Grid: (num_col_blocks, num_row_programs)
32+
Each program processes multiple rows using grid-stride loop
33+
"""
34+
pid_n = tl.program_id(0)
35+
pid_m = tl.program_id(1)
36+
num_row_programs = tl.num_programs(1)
37+
38+
col_start = pid_n * BLOCK_N
39+
col_offsets = col_start + tl.arange(0, BLOCK_N)
40+
col_mask = col_offsets < N
41+
42+
alpha = tl.load(Alpha).to(tl.float32)
43+
gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
44+
if HAVE_BETA:
45+
beta = tl.load(Beta + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
46+
47+
# Grid-stride loop over rows
48+
for row_idx in range(pid_m, M, num_row_programs):
49+
row_offset = row_idx * N
50+
51+
x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
52+
53+
# Compute: y = tanh(α·x) · γ + β
54+
tanh_x = tanh(alpha * x)
55+
y = tanh_x * gamma
56+
57+
if HAVE_BETA:
58+
y += beta
59+
60+
tl.store(Y + row_offset + col_offsets, y, mask=col_mask)
61+
62+
63+
# -----------------------------------------------------------------------------
64+
# Backward Kernel
65+
# -----------------------------------------------------------------------------
66+
67+
68+
@triton.jit
69+
def _dyt_bwd_kernel(
70+
DY,
71+
DX,
72+
DA,
73+
DG,
74+
DB,
75+
X,
76+
Alpha,
77+
Gamma,
78+
HAVE_BETA: tl.constexpr,
79+
M: tl.constexpr,
80+
N: tl.constexpr,
81+
BLOCK_N: tl.constexpr,
82+
):
83+
"""
84+
Backward kernel for DYT
85+
86+
Grid: (num_col_blocks, num_row_programs)
87+
Each program processes multiple rows using grid-stride loop
88+
Accumulates gradients in local buffers, then stores to global memory
89+
"""
90+
pid_n = tl.program_id(0)
91+
pid_m = tl.program_id(1)
92+
num_row_programs = tl.num_programs(1)
93+
94+
col_start = pid_n * BLOCK_N
95+
col_offsets = col_start + tl.arange(0, BLOCK_N)
96+
col_mask = col_offsets < N
97+
98+
alpha = tl.load(Alpha).to(tl.float32)
99+
gamma = tl.load(Gamma + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
100+
101+
da_vec = tl.zeros((BLOCK_N,), dtype=tl.float32)
102+
dg_acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
103+
if HAVE_BETA:
104+
db_acc = tl.zeros((BLOCK_N,), dtype=tl.float32)
105+
106+
# Grid-stride loop over rows
107+
for row_idx in range(pid_m, M, num_row_programs):
108+
row_offset = row_idx * N
109+
110+
x = tl.load(X + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
111+
dy = tl.load(DY + row_offset + col_offsets, mask=col_mask, other=0.0).to(tl.float32)
112+
113+
tanh_x = tanh(alpha * x)
114+
115+
if HAVE_BETA:
116+
db_acc += dy
117+
118+
dg_acc += dy * tanh_x
119+
120+
# Compute intermediate: tmp = (1 - tanh²) · dy · γ
121+
tmp = (1.0 - tanh_x * tanh_x) * dy * gamma
122+
123+
# Accumulate dα = Σ(x · tmp)
124+
da_vec += x * tmp
125+
126+
# Compute dx = α · tmp
127+
dx = alpha * tmp
128+
tl.store(DX + row_offset + col_offsets, dx, mask=col_mask)
129+
130+
da_acc = tl.sum(da_vec, 0)
131+
da_offset = pid_m * triton.cdiv(N, BLOCK_N) + pid_n
132+
tl.store(DA + da_offset, da_acc)
133+
134+
dg_offset = pid_m * N + col_offsets
135+
tl.store(DG + dg_offset, dg_acc, mask=col_mask)
136+
137+
if HAVE_BETA:
138+
db_offset = pid_m * N + col_offsets
139+
tl.store(DB + db_offset, db_acc, mask=col_mask)
140+
141+
142+
def get_optimal_block_size(total_elements, is_backward=False):
143+
"""
144+
Calculate optimal Block Size using compute_default_tiling_strategy
145+
"""
146+
multiplier = 8.0 if is_backward else 4.0
147+
148+
tile_shapes = compute_default_tiling_strategy(
149+
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((total_elements,),), tiling_dims=(0,)
150+
)
151+
152+
if tile_shapes and len(tile_shapes) > 0:
153+
block_size = tile_shapes[0][0]
154+
return block_size
155+
else:
156+
return 2048
157+
158+
159+
def _compute_grid_size(n_cols, n_rows, block_n):
160+
"""
161+
Compute grid size to avoid launching idle programs
162+
163+
Args:
164+
n_cols: Number of columns
165+
n_rows: Number of rows
166+
block_n: Block size for column dimension
167+
168+
Returns:
169+
(num_col_blocks, num_row_programs)
170+
"""
171+
num_cores = get_npu_core_count()
172+
num_col_blocks = triton.cdiv(n_cols, block_n)
173+
num_row_blocks = n_rows
174+
175+
num_row_programs = min(max(1, (num_cores // num_col_blocks)), num_row_blocks)
176+
177+
return num_col_blocks, num_row_programs
178+
179+
180+
# -----------------------------------------------------------------------------
181+
# Python Wrapper Functions
182+
# -----------------------------------------------------------------------------
183+
184+
185+
def liger_dyt_fwd(x, alpha, gamma, beta):
186+
"""
187+
Forward pass of DYT: y = tanh(α·x) · γ + β
188+
189+
Args:
190+
x: Input tensor of shape [..., N]
191+
alpha: Scalar parameter
192+
gamma: Vector parameter of shape [N]
193+
beta: Vector parameter of shape [N] (optional)
194+
195+
Returns:
196+
y: Output tensor of same shape as x
197+
"""
198+
assert x.is_contiguous()
199+
HAVE_BETA = beta is not None
200+
201+
# Flatten to 2D
202+
input_shape = x.shape
203+
x = x.view(-1, input_shape[-1])
204+
M, N = x.shape
205+
206+
# Allocate output
207+
y = torch.empty_like(x)
208+
209+
block_n = get_optimal_block_size(N, is_backward=False)
210+
211+
# Compute grid size
212+
num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n)
213+
grid = (num_col_blocks, num_row_programs)
214+
215+
# Launch kernel
216+
_dyt_fwd_kernel[grid](x, y, alpha, gamma, beta, HAVE_BETA, M, N, BLOCK_N=block_n)
217+
218+
return y.view(input_shape)
219+
220+
221+
def liger_dyt_bwd(dy, x, alpha, gamma, beta):
222+
"""
223+
Backward pass of DYT
224+
225+
Args:
226+
dy: Upstream gradient of shape [..., N]
227+
x: Input tensor of shape [..., N]
228+
alpha: Scalar parameter
229+
gamma: Vector parameter of shape [N]
230+
beta: Vector parameter of shape [N] (optional)
231+
232+
Returns:
233+
dx: Gradient w.r.t. x
234+
dalpha: Gradient w.r.t. alpha
235+
dgamma: Gradient w.r.t. gamma
236+
dbeta: Gradient w.r.t. beta (or None)
237+
"""
238+
assert dy.is_contiguous()
239+
HAVE_BETA = beta is not None
240+
241+
# Flatten to 2D
242+
input_shape = x.shape
243+
x = x.view(-1, input_shape[-1])
244+
dy = dy.view(-1, input_shape[-1])
245+
M, N = x.shape
246+
247+
block_n = get_optimal_block_size(N, is_backward=True)
248+
249+
# Compute grid size
250+
num_col_blocks, num_row_programs = _compute_grid_size(N, M, block_n)
251+
grid = (num_col_blocks, num_row_programs)
252+
253+
da = torch.zeros(num_row_programs, triton.cdiv(N, block_n), dtype=torch.float32, device=x.device)
254+
dg = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device)
255+
db = torch.empty(num_row_programs, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None
256+
dx = torch.empty_like(dy)
257+
258+
_dyt_bwd_kernel[grid](dy, dx, da, dg, db, x, alpha, gamma, HAVE_BETA, M, N, BLOCK_N=block_n)
259+
260+
da = da.sum().to(x.dtype).unsqueeze(0)
261+
dg = dg.sum(0).to(gamma.dtype)
262+
db = db.sum(0).to(x.dtype) if HAVE_BETA else None
263+
264+
return dx.view(input_shape), da, dg, db
265+
266+
267+
# -----------------------------------------------------------------------------
268+
# Autograd Function
269+
# -----------------------------------------------------------------------------
270+
271+
272+
class LigerDyTFunction(torch.autograd.Function):
273+
@staticmethod
274+
@ensure_contiguous
275+
def forward(ctx, x, alpha, gamma, beta):
276+
y = liger_dyt_fwd(x, alpha, gamma, beta)
277+
ctx.save_for_backward(x, alpha, gamma, beta)
278+
return y
279+
280+
@staticmethod
281+
@ensure_contiguous
282+
def backward(ctx, dy):
283+
x, alpha, gamma, beta = ctx.saved_tensors
284+
dx, dalpha, dgamma, dbeta = liger_dyt_bwd(dy, x, alpha, gamma, beta)
285+
return dx, dalpha, dgamma, dbeta

0 commit comments

Comments
 (0)