Skip to content

Commit f0513d7

Browse files
TimDettmersclaude
andcommitted
feat: Add CUDA cross-entropy loss forward+backward kernel
Cross-entropy loss with numerically stable logsumexp: - Forward: per-sample loss via max-stabilized logsumexp - Backward: softmax(logits) - one_hot(labels), scaled by grad_output - One block per row, shared memory reduction for max and sum - Supports ignore_index (-100 convention) - Supports large vocabularies (tested up to 100K) - fp16 and bf16 via C++ templates Autograd wrapper computes mean loss with proper ignore_index handling and per-sample gradient scaling. 10 tests pass covering forward correctness, gradient correctness, ignore_index handling, large vocab, and both dtypes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 0cc1e5b commit f0513d7

File tree

6 files changed

+470
-2
lines changed

6 files changed

+470
-2
lines changed

bitsandbytes/_ops.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,3 +782,37 @@ def _(
782782
@register_fake("bitsandbytes::rope_forward")
783783
def _(q: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, n_heads: int) -> None:
784784
pass
785+
786+
787+
# Cross-Entropy Loss forward: per-sample loss + logsumexp for backward
788+
torch.library.define(
789+
"bitsandbytes::cross_entropy_forward",
790+
"(Tensor logits, Tensor labels, int ignore_index) -> (Tensor, Tensor)",
791+
)
792+
793+
794+
@register_fake("bitsandbytes::cross_entropy_forward")
795+
def _(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int) -> tuple[torch.Tensor, torch.Tensor]:
796+
torch._check(logits.dim() == 2, lambda: "logits must be 2D [N, V]")
797+
N = logits.shape[0]
798+
losses = torch.empty(N, device=logits.device, dtype=torch.float32)
799+
logsumexp = torch.empty(N, device=logits.device, dtype=torch.float32)
800+
return losses, logsumexp
801+
802+
803+
# Cross-Entropy Loss backward: grad_logits from grad_output
804+
torch.library.define(
805+
"bitsandbytes::cross_entropy_backward",
806+
"(Tensor logits, Tensor labels, Tensor grad_output, Tensor logsumexp, int ignore_index) -> Tensor",
807+
)
808+
809+
810+
@register_fake("bitsandbytes::cross_entropy_backward")
811+
def _(
812+
logits: torch.Tensor,
813+
labels: torch.Tensor,
814+
grad_output: torch.Tensor,
815+
logsumexp: torch.Tensor,
816+
ignore_index: int,
817+
) -> torch.Tensor:
818+
return torch.empty_like(logits)

bitsandbytes/autograd/training_kernels.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,76 @@ def rope(
145145
Rotated query tensor (same shape).
146146
"""
147147
return RoPEFunction.apply(q, cos_cache, sin_cache, n_heads)
148+
149+
150+
class CrossEntropyFunction(torch.autograd.Function):
151+
"""Cross-entropy loss using CUDA kernel.
152+
153+
Forward: loss = -log_softmax(logits)[label] per row
154+
Backward: grad_logits = (softmax(logits) - one_hot(label)) * grad_output
155+
156+
Stores logsumexp from forward for efficient backward (avoids recomputing).
157+
"""
158+
159+
@staticmethod
160+
def forward(ctx, logits, labels, ignore_index=-100):
161+
# Flatten to 2D for the CUDA kernel
162+
orig_shape = logits.shape
163+
logits_2d = logits.reshape(-1, logits.shape[-1]).contiguous()
164+
labels_flat = labels.reshape(-1)
165+
166+
losses, logsumexp = torch.ops.bitsandbytes.cross_entropy_forward(
167+
logits_2d, labels_flat, ignore_index,
168+
)
169+
170+
ctx.save_for_backward(logits_2d, labels_flat, logsumexp)
171+
ctx.ignore_index = ignore_index
172+
173+
# Compute mean loss (ignoring padding)
174+
valid_mask = labels_flat != ignore_index
175+
n_valid = valid_mask.sum()
176+
if n_valid > 0:
177+
mean_loss = losses[valid_mask].sum() / n_valid.float()
178+
else:
179+
mean_loss = losses.sum() * 0.0 # zero but with grad
180+
181+
return mean_loss
182+
183+
@staticmethod
184+
def backward(ctx, grad_output):
185+
logits_2d, labels_flat, logsumexp = ctx.saved_tensors
186+
187+
# Expand scalar grad_output to per-sample
188+
N = logits_2d.shape[0]
189+
valid_mask = labels_flat != ctx.ignore_index
190+
n_valid = valid_mask.sum()
191+
192+
grad_per_sample = torch.zeros(N, device=logits_2d.device, dtype=torch.float32)
193+
if n_valid > 0:
194+
grad_per_sample[valid_mask] = grad_output.float() / n_valid.float()
195+
196+
grad_logits = torch.ops.bitsandbytes.cross_entropy_backward(
197+
logits_2d, labels_flat, grad_per_sample, logsumexp, ctx.ignore_index,
198+
)
199+
200+
return grad_logits, None, None
201+
202+
203+
def cross_entropy(
204+
logits: torch.Tensor,
205+
labels: torch.Tensor,
206+
ignore_index: int = -100,
207+
) -> torch.Tensor:
208+
"""Cross-entropy loss with CUDA kernel (autograd support).
209+
210+
Uses a numerically stable logsumexp-based implementation.
211+
212+
Args:
213+
logits: Logit tensor (*, vocab_size), fp16 or bf16.
214+
labels: Label tensor (*), int64.
215+
ignore_index: Label value to ignore (default: -100).
216+
217+
Returns:
218+
Scalar mean loss.
219+
"""
220+
return CrossEntropyFunction.apply(logits, labels, ignore_index)

bitsandbytes/backends/cuda/ops.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,3 +1372,66 @@ def _(
13721372
ct.c_int(n_heads),
13731373
ct.c_int(head_dim),
13741374
)
1375+
1376+
1377+
@register_kernel("bitsandbytes::cross_entropy_forward", "cuda")
1378+
def _(
1379+
logits: torch.Tensor,
1380+
labels: torch.Tensor,
1381+
ignore_index: int,
1382+
) -> tuple[torch.Tensor, torch.Tensor]:
1383+
torch._check(logits.dim() == 2, lambda: "logits must be 2D [N, V]")
1384+
torch._check(logits.is_contiguous(), lambda: "logits must be contiguous")
1385+
torch._check(
1386+
logits.dtype in (torch.float16, torch.bfloat16),
1387+
lambda: f"cross_entropy supports float16/bfloat16, got {logits.dtype}",
1388+
)
1389+
1390+
N, V = logits.shape
1391+
losses = torch.empty(N, device=logits.device, dtype=torch.float32)
1392+
logsumexp = torch.empty(N, device=logits.device, dtype=torch.float32)
1393+
dtype_suffix = "fp16" if logits.dtype == torch.float16 else "bf16"
1394+
1395+
with _cuda_device_of(logits):
1396+
fn = getattr(lib, f"ccross_entropy_forward_{dtype_suffix}_c")
1397+
fn(
1398+
get_ptr(logits),
1399+
get_ptr(labels),
1400+
get_ptr(losses),
1401+
get_ptr(logsumexp),
1402+
ct.c_int(N),
1403+
ct.c_int(V),
1404+
ct.c_int(ignore_index),
1405+
)
1406+
1407+
return losses, logsumexp
1408+
1409+
1410+
@register_kernel("bitsandbytes::cross_entropy_backward", "cuda")
1411+
def _(
1412+
logits: torch.Tensor,
1413+
labels: torch.Tensor,
1414+
grad_output: torch.Tensor,
1415+
logsumexp: torch.Tensor,
1416+
ignore_index: int,
1417+
) -> torch.Tensor:
1418+
torch._check(logits.is_contiguous(), lambda: "logits must be contiguous")
1419+
1420+
N, V = logits.shape
1421+
grad_logits = torch.empty_like(logits)
1422+
dtype_suffix = "fp16" if logits.dtype == torch.float16 else "bf16"
1423+
1424+
with _cuda_device_of(logits):
1425+
fn = getattr(lib, f"ccross_entropy_backward_{dtype_suffix}_c")
1426+
fn(
1427+
get_ptr(logits),
1428+
get_ptr(labels),
1429+
get_ptr(grad_output),
1430+
get_ptr(logsumexp),
1431+
get_ptr(grad_logits),
1432+
ct.c_int(N),
1433+
ct.c_int(V),
1434+
ct.c_int(ignore_index),
1435+
)
1436+
1437+
return grad_logits

csrc/ops.cu

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3263,3 +3263,143 @@ void rope_forward(T* q, const T* cos_cache, const T* sin_cache,
32633263

32643264
template void rope_forward<half>(half*, const half*, const half*, int, int, int);
32653265
template void rope_forward<__nv_bfloat16>(__nv_bfloat16*, const __nv_bfloat16*, const __nv_bfloat16*, int, int, int);
3266+
3267+
// ---------- Cross-Entropy Loss forward+backward ----------
3268+
// Forward: loss = -log_softmax(logits)[label] per row
3269+
// = -(logits[label] - logsumexp(logits))
3270+
// Uses chunked logsumexp for numerical stability with large vocab.
3271+
// Backward: grad_logits = (softmax(logits) - one_hot(label)) * grad_output
3272+
// One thread block per row (sample in the batch).
3273+
3274+
template <typename T, int BLOCK_SIZE>
3275+
__global__ void kCrossEntropyForward(
3276+
const T* __restrict__ logits, // [N, V]
3277+
const long* __restrict__ labels, // [N]
3278+
float* __restrict__ losses, // [N]
3279+
float* __restrict__ logsumexp_out, // [N] stored for backward
3280+
int N, int V, int ignore_index
3281+
) {
3282+
int row = blockIdx.x;
3283+
if (row >= N) return;
3284+
3285+
long label = labels[row];
3286+
if (label == ignore_index) {
3287+
losses[row] = 0.0f;
3288+
logsumexp_out[row] = 0.0f;
3289+
return;
3290+
}
3291+
3292+
const T* logits_row = logits + row * V;
3293+
3294+
// Phase 1: find max for numerical stability
3295+
__shared__ float shared[BLOCK_SIZE];
3296+
float thread_max = -1e30f;
3297+
for (int i = threadIdx.x; i < V; i += BLOCK_SIZE) {
3298+
float val = float(logits_row[i]);
3299+
thread_max = fmaxf(thread_max, val);
3300+
}
3301+
shared[threadIdx.x] = thread_max;
3302+
__syncthreads();
3303+
3304+
for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
3305+
if (threadIdx.x < s) {
3306+
shared[threadIdx.x] = fmaxf(shared[threadIdx.x], shared[threadIdx.x + s]);
3307+
}
3308+
__syncthreads();
3309+
}
3310+
float row_max = shared[0];
3311+
3312+
// Phase 2: compute sum(exp(x - max))
3313+
float thread_sum = 0.0f;
3314+
for (int i = threadIdx.x; i < V; i += BLOCK_SIZE) {
3315+
thread_sum += expf(float(logits_row[i]) - row_max);
3316+
}
3317+
shared[threadIdx.x] = thread_sum;
3318+
__syncthreads();
3319+
3320+
for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
3321+
if (threadIdx.x < s) {
3322+
shared[threadIdx.x] += shared[threadIdx.x + s];
3323+
}
3324+
__syncthreads();
3325+
}
3326+
float sum_exp = shared[0];
3327+
3328+
float lse = row_max + logf(sum_exp);
3329+
float logit_label = float(logits_row[label]);
3330+
float loss = -(logit_label - lse);
3331+
3332+
if (threadIdx.x == 0) {
3333+
losses[row] = loss;
3334+
logsumexp_out[row] = lse;
3335+
}
3336+
}
3337+
3338+
template <typename T, int BLOCK_SIZE>
3339+
__global__ void kCrossEntropyBackward(
3340+
const T* __restrict__ logits, // [N, V]
3341+
const long* __restrict__ labels, // [N]
3342+
const float* __restrict__ grad_output, // [N] scalar per sample
3343+
const float* __restrict__ logsumexp, // [N] from forward
3344+
T* __restrict__ grad_logits, // [N, V]
3345+
int N, int V, int ignore_index
3346+
) {
3347+
int row = blockIdx.x;
3348+
if (row >= N) return;
3349+
3350+
long label = labels[row];
3351+
const T* logits_row = logits + row * V;
3352+
T* grad_row = grad_logits + row * V;
3353+
float go = grad_output[row];
3354+
3355+
if (label == ignore_index) {
3356+
for (int i = threadIdx.x; i < V; i += BLOCK_SIZE) {
3357+
grad_row[i] = T(0.0f);
3358+
}
3359+
return;
3360+
}
3361+
3362+
float lse = logsumexp[row];
3363+
3364+
for (int i = threadIdx.x; i < V; i += BLOCK_SIZE) {
3365+
float softmax_i = expf(float(logits_row[i]) - lse);
3366+
float grad_i = softmax_i;
3367+
if (i == label) {
3368+
grad_i -= 1.0f;
3369+
}
3370+
grad_row[i] = T(grad_i * go);
3371+
}
3372+
}
3373+
3374+
// C wrapper functions
3375+
template <typename T>
3376+
void cross_entropy_forward(const T* logits, const long* labels, float* losses, float* logsumexp,
3377+
int N, int V, int ignore_index) {
3378+
if (V <= 256) {
3379+
kCrossEntropyForward<T, 256><<<N, 256>>>(logits, labels, losses, logsumexp, N, V, ignore_index);
3380+
} else if (V <= 512) {
3381+
kCrossEntropyForward<T, 512><<<N, 512>>>(logits, labels, losses, logsumexp, N, V, ignore_index);
3382+
} else {
3383+
kCrossEntropyForward<T, 1024><<<N, 1024>>>(logits, labels, losses, logsumexp, N, V, ignore_index);
3384+
}
3385+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
3386+
}
3387+
3388+
template <typename T>
3389+
void cross_entropy_backward(const T* logits, const long* labels, const float* grad_output,
3390+
const float* logsumexp, T* grad_logits,
3391+
int N, int V, int ignore_index) {
3392+
if (V <= 256) {
3393+
kCrossEntropyBackward<T, 256><<<N, 256>>>(logits, labels, grad_output, logsumexp, grad_logits, N, V, ignore_index);
3394+
} else if (V <= 512) {
3395+
kCrossEntropyBackward<T, 512><<<N, 512>>>(logits, labels, grad_output, logsumexp, grad_logits, N, V, ignore_index);
3396+
} else {
3397+
kCrossEntropyBackward<T, 1024><<<N, 1024>>>(logits, labels, grad_output, logsumexp, grad_logits, N, V, ignore_index);
3398+
}
3399+
CUDA_CHECK_RETURN(cudaPeekAtLastError());
3400+
}
3401+
3402+
template void cross_entropy_forward<half>(const half*, const long*, float*, float*, int, int, int);
3403+
template void cross_entropy_forward<__nv_bfloat16>(const __nv_bfloat16*, const long*, float*, float*, int, int, int);
3404+
template void cross_entropy_backward<half>(const half*, const long*, const float*, const float*, half*, int, int, int);
3405+
template void cross_entropy_backward<__nv_bfloat16>(const __nv_bfloat16*, const long*, const float*, const float*, __nv_bfloat16*, int, int, int);

csrc/pythonInterface.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,29 @@ void crope_forward_bf16(__nv_bfloat16* q, const __nv_bfloat16* cos_cache, const
652652
rope_forward<__nv_bfloat16>(q, cos_cache, sin_cache, total_tokens, n_heads, head_dim);
653653
}
654654

655+
// Cross-entropy loss forward declarations
656+
template <typename T> void cross_entropy_forward(const T*, const long*, float*, float*, int, int, int);
657+
template <typename T> void cross_entropy_backward(const T*, const long*, const float*, const float*, T*, int, int, int);
658+
659+
void ccross_entropy_forward_fp16(const half* logits, const long* labels, float* losses,
660+
float* logsumexp, int N, int V, int ignore_index) {
661+
cross_entropy_forward<half>(logits, labels, losses, logsumexp, N, V, ignore_index);
662+
}
663+
void ccross_entropy_backward_fp16(const half* logits, const long* labels, const float* grad_output,
664+
const float* logsumexp, half* grad_logits,
665+
int N, int V, int ignore_index) {
666+
cross_entropy_backward<half>(logits, labels, grad_output, logsumexp, grad_logits, N, V, ignore_index);
667+
}
668+
void ccross_entropy_forward_bf16(const __nv_bfloat16* logits, const long* labels, float* losses,
669+
float* logsumexp, int N, int V, int ignore_index) {
670+
cross_entropy_forward<__nv_bfloat16>(logits, labels, losses, logsumexp, N, V, ignore_index);
671+
}
672+
void ccross_entropy_backward_bf16(const __nv_bfloat16* logits, const long* labels, const float* grad_output,
673+
const float* logsumexp, __nv_bfloat16* grad_logits,
674+
int N, int V, int ignore_index) {
675+
cross_entropy_backward<__nv_bfloat16>(logits, labels, grad_output, logsumexp, grad_logits, N, V, ignore_index);
676+
}
677+
655678
#endif // BUILD_CUDA || BUILD_HIP (kbit unmangled)
656679

657680
extern "C" {
@@ -1400,5 +1423,24 @@ void crope_forward_bf16_c(__nv_bfloat16* q, const __nv_bfloat16* cos_cache, cons
14001423
crope_forward_bf16(q, cos_cache, sin_cache, total_tokens, n_heads, head_dim);
14011424
}
14021425

1426+
void ccross_entropy_forward_fp16_c(const half* logits, const long* labels, float* losses,
1427+
float* logsumexp, int N, int V, int ignore_index) {
1428+
ccross_entropy_forward_fp16(logits, labels, losses, logsumexp, N, V, ignore_index);
1429+
}
1430+
void ccross_entropy_backward_fp16_c(const half* logits, const long* labels, const float* grad_output,
1431+
const float* logsumexp, half* grad_logits,
1432+
int N, int V, int ignore_index) {
1433+
ccross_entropy_backward_fp16(logits, labels, grad_output, logsumexp, grad_logits, N, V, ignore_index);
1434+
}
1435+
void ccross_entropy_forward_bf16_c(const __nv_bfloat16* logits, const long* labels, float* losses,
1436+
float* logsumexp, int N, int V, int ignore_index) {
1437+
ccross_entropy_forward_bf16(logits, labels, losses, logsumexp, N, V, ignore_index);
1438+
}
1439+
void ccross_entropy_backward_bf16_c(const __nv_bfloat16* logits, const long* labels, const float* grad_output,
1440+
const float* logsumexp, __nv_bfloat16* grad_logits,
1441+
int N, int V, int ignore_index) {
1442+
ccross_entropy_backward_bf16(logits, labels, grad_output, logsumexp, grad_logits, N, V, ignore_index);
1443+
}
1444+
14031445
#endif
14041446
}

0 commit comments

Comments
 (0)