|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +import triton |
| 5 | +import triton.language as tl |
| 6 | + |
| 7 | +from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy |
| 8 | +from liger_kernel.ops.utils import ensure_contiguous |
| 9 | +from liger_kernel.ops.utils import get_npu_core_count |
| 10 | + |
| 11 | + |
| 12 | +@triton.jit |
| 13 | +def _jsd_kernel( |
| 14 | + X_ptr, # input in logspace, X = log Q |
| 15 | + X_stride, |
| 16 | + Y_ptr, # ground truth in logspace, Y = log P |
| 17 | + Y_stride, |
| 18 | + loss_ptr, |
| 19 | + loss_stride, |
| 20 | + dX_ptr, |
| 21 | + dX_stride, |
| 22 | + label_ptr, |
| 23 | + beta: tl.constexpr, |
| 24 | + n_non_ignore: int, |
| 25 | + ignore_index: tl.constexpr, |
| 26 | + n_rows: tl.constexpr, |
| 27 | + n_cols: tl.constexpr, |
| 28 | + BLOCK_SIZE: tl.constexpr, |
| 29 | + HAS_LABEL: tl.constexpr, |
| 30 | +): |
| 31 | + # JSD(P || Q) = (KL(P || M) + KL(Q || M)) / 2, M = (1/2) * (P + Q) = (1/2) * (e ^ Y + e ^ X) |
| 32 | + # = sum(P * log P + Q * log Q - 2 * M * log M) / 2 |
| 33 | + # = sum(e ^ Y * Y + e ^ X * X - 2 * M * log M) / 2 |
| 34 | + # grad_x_i = 0.5 * Q * (X - log_M) |
| 35 | + |
| 36 | + pid = tl.program_id(0) |
| 37 | + num_progs = tl.num_programs(0) |
| 38 | + |
| 39 | + # Grid-Stride Loop - each kernel processes multiple rows |
| 40 | + for row_idx in range(pid, n_rows, num_progs): |
| 41 | + X_row_ptr = X_ptr + row_idx * X_stride |
| 42 | + Y_row_ptr = Y_ptr + row_idx * Y_stride |
| 43 | + loss_row_ptr = loss_ptr + row_idx * loss_stride |
| 44 | + dX_row_ptr = dX_ptr + row_idx * dX_stride |
| 45 | + |
| 46 | + should_skip = False |
| 47 | + if HAS_LABEL: |
| 48 | + label = tl.load(label_ptr + row_idx) |
| 49 | + should_skip = label == ignore_index |
| 50 | + |
| 51 | + if should_skip: |
| 52 | + for i in range(0, n_cols, BLOCK_SIZE): |
| 53 | + offsets = i + tl.arange(0, BLOCK_SIZE) |
| 54 | + mask = offsets < n_cols |
| 55 | + tl.store(dX_row_ptr + offsets, 0.0, mask=mask) |
| 56 | + tl.store(loss_row_ptr + offsets, 0.0, mask=mask) |
| 57 | + else: |
| 58 | + for i in range(0, n_cols, BLOCK_SIZE): |
| 59 | + offsets = i + tl.arange(0, BLOCK_SIZE) |
| 60 | + mask = offsets < n_cols |
| 61 | + X = tl.load(X_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) |
| 62 | + Y = tl.load(Y_row_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) |
| 63 | + |
| 64 | + if beta == 0.0: # forward KL |
| 65 | + Y_max = tl.max(Y, axis=0) |
| 66 | + Y_shifted = Y - Y_max |
| 67 | + Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift |
| 68 | + loss = Y_prob * (Y - X) |
| 69 | + dX = -Y_prob |
| 70 | + elif beta == 1.0: # reverse KL |
| 71 | + X_max = tl.max(X, axis=0) |
| 72 | + X_shifted = X - X_max |
| 73 | + X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift |
| 74 | + loss = X_prob * (X - Y) |
| 75 | + dX = loss + X_prob |
| 76 | + else: |
| 77 | + max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) |
| 78 | + X_shifted = X - max_val |
| 79 | + Y_shifted = Y - max_val |
| 80 | + |
| 81 | + # Pre-compute exp(max_val) since it's used twice |
| 82 | + exp_max = tl.exp(max_val) |
| 83 | + |
| 84 | + # Compute exp terms with compensation |
| 85 | + Q = tl.exp(X_shifted) * exp_max # = exp(X) |
| 86 | + P = tl.exp(Y_shifted) * exp_max # = exp(Y) |
| 87 | + |
| 88 | + # Pre-compute common terms |
| 89 | + beta_P = beta * P |
| 90 | + one_minus_beta_Q = (1 - beta) * Q |
| 91 | + M = beta_P + one_minus_beta_Q |
| 92 | + log_M = tl.log(M) |
| 93 | + |
| 94 | + loss = beta_P * Y + one_minus_beta_Q * X - M * log_M |
| 95 | + dX = one_minus_beta_Q * (X - log_M) |
| 96 | + |
| 97 | + # Pre-compute scaling factor |
| 98 | + scale = 1.0 / n_non_ignore |
| 99 | + loss = loss * scale |
| 100 | + dX = dX * scale |
| 101 | + |
| 102 | + tl.store(loss_row_ptr + offsets, loss, mask=mask) |
| 103 | + tl.store(dX_row_ptr + offsets, dX, mask=mask) |
| 104 | + |
| 105 | + |
| 106 | +def get_optimal_block_size(total_elements): |
| 107 | + """ |
| 108 | + Calculate optimal Block Size using compute_default_tiling_strategy |
| 109 | + """ |
| 110 | + tile_shapes = compute_default_tiling_strategy( |
| 111 | + safety_margin=0.9, dtype_size=4, memory_multiplier=8.0, shapes=((total_elements,),), tiling_dims=(0,) |
| 112 | + ) |
| 113 | + |
| 114 | + if tile_shapes and len(tile_shapes) > 0: |
| 115 | + block_size = tile_shapes[0][0] |
| 116 | + return block_size |
| 117 | + else: |
| 118 | + return 2048 |
| 119 | + |
| 120 | + |
| 121 | +def jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label): |
| 122 | + BT, V = _input.shape |
| 123 | + n_rows = BT |
| 124 | + BLOCK_SIZE = get_optimal_block_size(V) |
| 125 | + |
| 126 | + # non reduction loss |
| 127 | + loss = torch.zeros(_input.shape, dtype=torch.float32, device=_input.device) |
| 128 | + dX = torch.empty_like(_input) |
| 129 | + |
| 130 | + if has_label: |
| 131 | + n_non_ignore = (shift_labels != ignore_index).sum().item() |
| 132 | + else: |
| 133 | + n_non_ignore = BT |
| 134 | + |
| 135 | + # Use NPU core count for grid size |
| 136 | + num_cores = get_npu_core_count() |
| 137 | + grid_size = min(num_cores, n_rows) |
| 138 | + |
| 139 | + _jsd_kernel[(grid_size,)]( |
| 140 | + X_ptr=_input, |
| 141 | + X_stride=_input.stride(-2), |
| 142 | + Y_ptr=target, |
| 143 | + Y_stride=target.stride(-2), |
| 144 | + loss_ptr=loss, |
| 145 | + loss_stride=loss.stride(-2), |
| 146 | + dX_ptr=dX, |
| 147 | + dX_stride=dX.stride(-2), |
| 148 | + label_ptr=(shift_labels if has_label else torch.empty(1, device=_input.device)), |
| 149 | + beta=beta, |
| 150 | + n_non_ignore=n_non_ignore, |
| 151 | + ignore_index=ignore_index, |
| 152 | + n_rows=n_rows, |
| 153 | + n_cols=V, |
| 154 | + BLOCK_SIZE=BLOCK_SIZE, |
| 155 | + HAS_LABEL=has_label, |
| 156 | + ) |
| 157 | + |
| 158 | + loss = torch.sum(loss) |
| 159 | + return loss.to(_input.dtype), dX |
| 160 | + |
| 161 | + |
| 162 | +def jsd_backward(dX, grad_output): |
| 163 | + # If jsd is the last layer, grad_output is 1.0. Skip the mul to save time |
| 164 | + if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): |
| 165 | + return dX |
| 166 | + else: |
| 167 | + return grad_output * dX |
| 168 | + |
| 169 | + |
| 170 | +class LigerJSDFunction(torch.autograd.Function): |
| 171 | + r""" |
| 172 | + This class implements the forward and backward pass for the generalized Jensen-Shannon Divergence. |
| 173 | + .. math:: |
| 174 | + JSD(\beta)(P || Q) |
| 175 | + = \beta * KLDiv(P || (\beta * P + (1 - \beta) * Q)) + (1 - \beta) * KLDiv(Q || (\beta * P + (1 - \beta) * Q)) |
| 176 | +
|
| 177 | + .. note:: |
| 178 | + As all the other losses in PyTorch, this function expects the first argument, |
| 179 | + :attr:`_input`, to be the predictions, the output of the student model, in log-space |
| 180 | + and the second, :attr:`target`, to be the observations, the output of the teacher model, in log-space. |
| 181 | + This differs from the standard mathematical notation :math:`JSD(P || Q)` where |
| 182 | + :math:`P` denotes the teacher model and :math:`Q` denotes the student model. |
| 183 | + """ |
| 184 | + |
| 185 | + @staticmethod |
| 186 | + @ensure_contiguous |
| 187 | + def forward( |
| 188 | + ctx, |
| 189 | + _input: torch.Tensor, |
| 190 | + target: torch.Tensor, |
| 191 | + shift_labels: Optional[torch.Tensor] = None, |
| 192 | + beta: float = 0.5, |
| 193 | + ignore_index: int = -100, |
| 194 | + ) -> torch.Tensor: |
| 195 | + """ |
| 196 | + Args: |
| 197 | + _input (torch.Tensor): predict values with shape (BT, V) in logspace |
| 198 | + target (torch.Tensor): ground truth values with shape (BT, V) in logspace |
| 199 | + shift_labels (Optional[torch.LongTensor]): indicator of next predicted vocab with shape (BT) where each value is in [0, V-1]. |
| 200 | + beta (float): coefficient beta of generalized JSD in the interval [0, 1]. It implements forward/reverse KL when beta equals 0 and 1 respectively. Default: `0.5` |
| 201 | + ignore_index (int): the index to ignore. Default: -100 |
| 202 | +
|
| 203 | + Returns: |
| 204 | + loss (torch.Tensor): generalized JSD |
| 205 | + """ |
| 206 | + has_label = False |
| 207 | + if shift_labels is not None: |
| 208 | + assert shift_labels.shape == (_input.shape[0],), ( |
| 209 | + f"the shape of shift_labels must be (BT,). Got: {shift_labels.shape}" |
| 210 | + ) |
| 211 | + shift_labels = shift_labels.contiguous() |
| 212 | + has_label = True |
| 213 | + |
| 214 | + loss, dX = jsd_forward(_input, target, shift_labels, beta, ignore_index, has_label) |
| 215 | + ctx.save_for_backward(dX) |
| 216 | + return loss |
| 217 | + |
| 218 | + @staticmethod |
| 219 | + @ensure_contiguous |
| 220 | + def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: |
| 221 | + (dX,) = ctx.saved_tensors |
| 222 | + dX = jsd_backward(dX, grad_output) |
| 223 | + return ( |
| 224 | + dX, |
| 225 | + None, |
| 226 | + None, |
| 227 | + None, |
| 228 | + None, |
| 229 | + ) |
0 commit comments