Commit f0513d7
feat: Add CUDA cross-entropy loss forward+backward kernel
Cross-entropy loss with numerically stable logsumexp:
- Forward: per-sample loss via max-stabilized logsumexp
- Backward: softmax(logits) - one_hot(labels), scaled by grad_output
- One block per row, shared memory reduction for max and sum
- Supports ignore_index (-100 convention)
- Supports large vocabularies (tested up to 100K)
- fp16 and bf16 via C++ templates
Autograd wrapper computes mean loss with proper ignore_index
handling and per-sample gradient scaling.
10 tests pass covering forward correctness, gradient correctness,
ignore_index handling, large vocab, and both dtypes.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>1 parent 0cc1e5b commit f0513d7
File tree
6 files changed
+470
-2
lines changed- bitsandbytes
- autograd
- backends/cuda
- csrc
- tests
6 files changed
+470
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
782 | 782 | | |
783 | 783 | | |
784 | 784 | | |
| 785 | + | |
| 786 | + | |
| 787 | + | |
| 788 | + | |
| 789 | + | |
| 790 | + | |
| 791 | + | |
| 792 | + | |
| 793 | + | |
| 794 | + | |
| 795 | + | |
| 796 | + | |
| 797 | + | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
| 801 | + | |
| 802 | + | |
| 803 | + | |
| 804 | + | |
| 805 | + | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
| 817 | + | |
| 818 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
145 | 145 | | |
146 | 146 | | |
147 | 147 | | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1372 | 1372 | | |
1373 | 1373 | | |
1374 | 1374 | | |
| 1375 | + | |
| 1376 | + | |
| 1377 | + | |
| 1378 | + | |
| 1379 | + | |
| 1380 | + | |
| 1381 | + | |
| 1382 | + | |
| 1383 | + | |
| 1384 | + | |
| 1385 | + | |
| 1386 | + | |
| 1387 | + | |
| 1388 | + | |
| 1389 | + | |
| 1390 | + | |
| 1391 | + | |
| 1392 | + | |
| 1393 | + | |
| 1394 | + | |
| 1395 | + | |
| 1396 | + | |
| 1397 | + | |
| 1398 | + | |
| 1399 | + | |
| 1400 | + | |
| 1401 | + | |
| 1402 | + | |
| 1403 | + | |
| 1404 | + | |
| 1405 | + | |
| 1406 | + | |
| 1407 | + | |
| 1408 | + | |
| 1409 | + | |
| 1410 | + | |
| 1411 | + | |
| 1412 | + | |
| 1413 | + | |
| 1414 | + | |
| 1415 | + | |
| 1416 | + | |
| 1417 | + | |
| 1418 | + | |
| 1419 | + | |
| 1420 | + | |
| 1421 | + | |
| 1422 | + | |
| 1423 | + | |
| 1424 | + | |
| 1425 | + | |
| 1426 | + | |
| 1427 | + | |
| 1428 | + | |
| 1429 | + | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
| 1433 | + | |
| 1434 | + | |
| 1435 | + | |
| 1436 | + | |
| 1437 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3263 | 3263 | | |
3264 | 3264 | | |
3265 | 3265 | | |
| 3266 | + | |
| 3267 | + | |
| 3268 | + | |
| 3269 | + | |
| 3270 | + | |
| 3271 | + | |
| 3272 | + | |
| 3273 | + | |
| 3274 | + | |
| 3275 | + | |
| 3276 | + | |
| 3277 | + | |
| 3278 | + | |
| 3279 | + | |
| 3280 | + | |
| 3281 | + | |
| 3282 | + | |
| 3283 | + | |
| 3284 | + | |
| 3285 | + | |
| 3286 | + | |
| 3287 | + | |
| 3288 | + | |
| 3289 | + | |
| 3290 | + | |
| 3291 | + | |
| 3292 | + | |
| 3293 | + | |
| 3294 | + | |
| 3295 | + | |
| 3296 | + | |
| 3297 | + | |
| 3298 | + | |
| 3299 | + | |
| 3300 | + | |
| 3301 | + | |
| 3302 | + | |
| 3303 | + | |
| 3304 | + | |
| 3305 | + | |
| 3306 | + | |
| 3307 | + | |
| 3308 | + | |
| 3309 | + | |
| 3310 | + | |
| 3311 | + | |
| 3312 | + | |
| 3313 | + | |
| 3314 | + | |
| 3315 | + | |
| 3316 | + | |
| 3317 | + | |
| 3318 | + | |
| 3319 | + | |
| 3320 | + | |
| 3321 | + | |
| 3322 | + | |
| 3323 | + | |
| 3324 | + | |
| 3325 | + | |
| 3326 | + | |
| 3327 | + | |
| 3328 | + | |
| 3329 | + | |
| 3330 | + | |
| 3331 | + | |
| 3332 | + | |
| 3333 | + | |
| 3334 | + | |
| 3335 | + | |
| 3336 | + | |
| 3337 | + | |
| 3338 | + | |
| 3339 | + | |
| 3340 | + | |
| 3341 | + | |
| 3342 | + | |
| 3343 | + | |
| 3344 | + | |
| 3345 | + | |
| 3346 | + | |
| 3347 | + | |
| 3348 | + | |
| 3349 | + | |
| 3350 | + | |
| 3351 | + | |
| 3352 | + | |
| 3353 | + | |
| 3354 | + | |
| 3355 | + | |
| 3356 | + | |
| 3357 | + | |
| 3358 | + | |
| 3359 | + | |
| 3360 | + | |
| 3361 | + | |
| 3362 | + | |
| 3363 | + | |
| 3364 | + | |
| 3365 | + | |
| 3366 | + | |
| 3367 | + | |
| 3368 | + | |
| 3369 | + | |
| 3370 | + | |
| 3371 | + | |
| 3372 | + | |
| 3373 | + | |
| 3374 | + | |
| 3375 | + | |
| 3376 | + | |
| 3377 | + | |
| 3378 | + | |
| 3379 | + | |
| 3380 | + | |
| 3381 | + | |
| 3382 | + | |
| 3383 | + | |
| 3384 | + | |
| 3385 | + | |
| 3386 | + | |
| 3387 | + | |
| 3388 | + | |
| 3389 | + | |
| 3390 | + | |
| 3391 | + | |
| 3392 | + | |
| 3393 | + | |
| 3394 | + | |
| 3395 | + | |
| 3396 | + | |
| 3397 | + | |
| 3398 | + | |
| 3399 | + | |
| 3400 | + | |
| 3401 | + | |
| 3402 | + | |
| 3403 | + | |
| 3404 | + | |
| 3405 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
652 | 652 | | |
653 | 653 | | |
654 | 654 | | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
655 | 678 | | |
656 | 679 | | |
657 | 680 | | |
| |||
1400 | 1423 | | |
1401 | 1424 | | |
1402 | 1425 | | |
| 1426 | + | |
| 1427 | + | |
| 1428 | + | |
| 1429 | + | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
| 1433 | + | |
| 1434 | + | |
| 1435 | + | |
| 1436 | + | |
| 1437 | + | |
| 1438 | + | |
| 1439 | + | |
| 1440 | + | |
| 1441 | + | |
| 1442 | + | |
| 1443 | + | |
| 1444 | + | |
1403 | 1445 | | |
1404 | 1446 | | |
0 commit comments