Skip to content

Commit 6e1ec4a

Browse files
authored
[NPU]:Added support for the jsd operator (#1134)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> - Grid-stride loop optimization: efficient multi-row processing with automatic grid size tuning - Memory access optimization: Column-blocked processing with configurable BLOCK_N, Dynamic block size selection based on tensor width - The size of the grid should not exceed the number of NPUs cores, fully leveraging the advantages of NPU. <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <img width="1080" height="602" alt="image" src="https://github.com/user-attachments/assets/58628ed3-5aa5-4c60-903e-6ee75bff4b89" /> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: Atlas 800I A2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent c3dc732 commit 6e1ec4a

2 files changed

Lines changed: 235 additions & 0 deletions

File tree

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
2727
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
2828
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
29+
from liger_kernel.ops.backends._ascend.ops.jsd import LigerJSDFunction
30+
from liger_kernel.ops.backends._ascend.ops.jsd import jsd_backward
31+
from liger_kernel.ops.backends._ascend.ops.jsd import jsd_forward
2932
from liger_kernel.ops.backends._ascend.ops.kl_div import LigerKLDivLossFunction
3033
from liger_kernel.ops.backends._ascend.ops.kl_div import kldiv_backward_triton
3134
from liger_kernel.ops.backends._ascend.ops.kl_div import kldiv_forward_triton
@@ -94,4 +97,7 @@
9497
"LigerSoftmaxFunction",
9598
"softmax_forward",
9699
"softmax_backward",
100+
"LigerJSDFunction",
101+
"jsd_forward",
102+
"jsd_backward",
97103
]
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
from typing import Optional
2+
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
7+
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
8+
from liger_kernel.ops.utils import ensure_contiguous
9+
from liger_kernel.ops.utils import get_npu_core_count
10+
11+
12+
@triton.jit
13+
def _jsd_kernel(
14+
X_ptr, # input in logspace, X = log Q
15+
X_stride,
16+
Y_ptr, # ground truth in logspace, Y = log P
17+
Y_stride,
18+
loss_ptr,
19+
loss_stride,
20+
dX_ptr,
21+
dX_stride,
22+
label_ptr,
23+
beta: tl.constexpr,
24+
n_non_ignore: int,
25+
ignore_index: tl.constexpr,
26+
n_rows: tl.constexpr,
27+
n_cols: tl.constexpr,
28+
BLOCK_SIZE: tl.constexpr,
29+
HAS_LABEL: tl.constexpr,
30+
):
31+
# JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X)
32+
# = sum(P * log P + Q * log Q - 2 * M * log M) / 2
33+
# = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2
34+
# grad_x_i = 0.5 * Q * (X - log_M)
35+
36+
pid = tl.program_id(0)
37+
num_progs = tl.num_programs(0)
38+
39+
# Grid-Stride Loop - each kernel processes multiple rows
40+
for row_idx in range(pid, n_rows, num_progs):
41+
X_row_ptr = X_ptr + row_idx * X_stride
42+
Y_row_ptr = Y_ptr + row_idx * Y_stride
43+
loss_row_ptr = loss_ptr + row_idx * loss_stride
44+
dX_row_ptr = dX_ptr + row_idx * dX_stride
45+
46+
should_skip = False
47+
if HAS_LABEL:
48+
label = tl.load(label_ptr + row_idx)
49+
should_skip = label == ignore_index
50+
51+
if should_skip:
52+
for i in range(0, n_cols, BLOCK_SIZE):
53+
offsets = i + tl.arange(0, BLOCK_SIZE)
54+
mask = offsets < n_cols
55+
tl.store(dX_row_ptr + offsets, 0.0, mask=mask)
56+
tl.store(loss_row_ptr + offsets, 0.0, mask=mask)
57+
else:
58+
for i in range(0, n_cols, BLOCK_SIZE):
59+
offsets = i + tl.arange(0, BLOCK_SIZE)
60+
mask = offsets < n_cols
61+
X = tl.load(X_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
62+
Y = tl.load(Y_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
63+
64+
if beta == 0.0: # forward KL
65+
Y_max = tl.max(Y, axis=0)
66+
Y_shifted = Y - Y_max
67+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
68+
loss = Y_prob * (Y - X)
69+
dX = -Y_prob
70+
elif beta == 1.0: # reverse KL
71+
X_max = tl.max(X, axis=0)
72+
X_shifted = X - X_max
73+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
74+
loss = X_prob * (X - Y)
75+
dX = loss + X_prob
76+
else:
77+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
78+
X_shifted = X - max_val
79+
Y_shifted = Y - max_val
80+
81+
# Pre-compute exp(max_val) since it's used twice
82+
exp_max = tl.exp(max_val)
83+
84+
# Compute exp terms with compensation
85+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
86+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
87+
88+
# Pre-compute common terms
89+
beta_P = beta * P
90+
one_minus_beta_Q = (1 - beta) * Q
91+
M = beta_P + one_minus_beta_Q
92+
log_M = tl.log(M)
93+
94+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
95+
dX = one_minus_beta_Q * (X - log_M)
96+
97+
# Pre-compute scaling factor
98+
scale = 1.0 / n_non_ignore
99+
loss = loss * scale
100+
dX = dX * scale
101+
102+
tl.store(loss_row_ptr + offsets, loss, mask=mask)
103+
tl.store(dX_row_ptr + offsets, dX, mask=mask)
104+
105+
106+
def get_optimal_block_size(total_elements):
107+
"""
108+
Calculate optimal Block Size using compute_default_tiling_strategy
109+
"""
110+
tile_shapes = compute_default_tiling_strategy(
111+
safety_margin=0.9, dtype_size=4, memory_multiplier=8.0, shapes=((total_elements,),), tiling_dims=(0,)
112+
)
113+
114+
if tile_shapes and len(tile_shapes) > 0:
115+
block_size = tile_shapes[0][0]
116+
return block_size
117+
else:
118+
return 2048
119+
120+
121+
def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label):
122+
BT, V = _input.shape
123+
n_rows = BT
124+
BLOCK_SIZE = get_optimal_block_size(V)
125+
126+
# non reduction loss
127+
loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device)
128+
dX = torch.empty_like(_input)
129+
130+
if has_label:
131+
n_non_ignore = (shift_labels != ignore_index).sum().item()
132+
else:
133+
n_non_ignore = BT
134+
135+
# Use NPU core count for grid size
136+
num_cores = get_npu_core_count()
137+
grid_size = min(num_cores, n_rows)
138+
139+
_jsd_kernel[(grid_size,)](
140+
X_ptr=_input,
141+
X_stride=_input.stride(-2),
142+
Y_ptr=target,
143+
Y_stride=target.stride(-2),
144+
loss_ptr=loss,
145+
loss_stride=loss.stride(-2),
146+
dX_ptr=dX,
147+
dX_stride=dX.stride(-2),
148+
label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)),
149+
beta=beta,
150+
n_non_ignore=n_non_ignore,
151+
ignore_index=ignore_index,
152+
n_rows=n_rows,
153+
n_cols=V,
154+
BLOCK_SIZE=BLOCK_SIZE,
155+
HAS_LABEL=has_label,
156+
)
157+
158+
loss = torch.sum(loss)
159+
return loss.to(_input.dtype), dX
160+
161+
162+
def jsd_backward(dX, grad_output):
163+
# If jsd is the last layer, grad_output is 1.0. Skip the mul to save time
164+
if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
165+
return dX
166+
else:
167+
return grad_output * dX
168+
169+
170+
class LigerJSDFunction(torch.autograd.Function):
171+
r"""
172+
This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence.
173+
.. math::
174+
JSD(\beta)(P || Q)
175+
= \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q))
176+
177+
.. note::
178+
As all the other losses in PyTorch, this function expects the first argument,
179+
:attr:`_input`, to be the predictions, the output of the student model, in log-space
180+
and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space.
181+
This differs from the standard mathematical notation :math:`JSD(P || Q)` where
182+
:math:`P` denotes the teacher model and :math:`Q` denotes the student model.
183+
"""
184+
185+
@staticmethod
186+
@ensure_contiguous
187+
def forward(
188+
ctx,
189+
_input: torch.Tensor,
190+
target: torch.Tensor,
191+
shift_labels: Optional[torch.Tensor] = None,
192+
beta: float = 0.5,
193+
ignore_index: int = -100,
194+
) -> torch.Tensor:
195+
"""
196+
Args:
197+
_input (torch.Tensor): predict values with shape (BT, V) in logspace
198+
target (torch.Tensor): ground truth values with shape (BT, V) in logspace
199+
shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1].
200+
beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5`
201+
ignore_index (int): the index to ignore. Default: -100
202+
203+
Returns:
204+
loss (torch.Tensor): generalized JSD
205+
"""
206+
has_label = False
207+
if shift_labels is not None:
208+
assert shift_labels.shape == (_input.shape[0],), (
209+
f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}"
210+
)
211+
shift_labels = shift_labels.contiguous()
212+
has_label = True
213+
214+
loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label)
215+
ctx.save_for_backward(dX)
216+
return loss
217+
218+
@staticmethod
219+
@ensure_contiguous
220+
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
221+
(dX,) = ctx.saved_tensors
222+
dX = jsd_backward(dX, grad_output)
223+
return (
224+
dX,
225+
None,
226+
None,
227+
None,
228+
None,
229+
)

0 commit comments

Comments
 (0)