Skip to content

Commit 0cc1e5b

Browse files
TimDettmersclaude
andcommitted
feat: Add CUDA SwiGLU, RMSNorm, and RoPE training kernels
CUDA kernels for three core training operations: - SwiGLU forward+backward: h = silu(gate) * up, element-wise - RMSNorm forward+backward: y = x * rsqrt(mean(x²) + eps) * w with Gemma variant (add_unit_offset: w + 1). One block per row, shared memory reduction. Forward stores rrms for backward. - RoPE forward (backward reuses with -sin): in-place rotary position embedding. Supports arbitrary head_dim. All kernels support fp16 and bf16 via C++ templates. Includes extern "C" wrappers, torch.library op registrations, backend dispatch, and torch.autograd.Function wrappers. 24 tests pass covering forward correctness against PyTorch reference, backward gradient correctness, Gemma variant, various sizes, and autograd integration. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 463908b commit 0cc1e5b

File tree

6 files changed

+1000
-0
lines changed

6 files changed

+1000
-0
lines changed

bitsandbytes/_ops.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,3 +705,80 @@ def _(
705705
)
706706
total_M = A_concat.shape[0]
707707
return torch.empty(total_M, N, device=A_concat.device, dtype=A_concat.dtype)
708+
709+
710+
# ============================================================================
711+
# Training Kernels: SwiGLU, RMSNorm, RoPE
712+
# ============================================================================
713+
714+
# SwiGLU forward: h = silu(gate) * up
715+
torch.library.define(
716+
"bitsandbytes::swiglu_forward",
717+
"(Tensor gate, Tensor up) -> Tensor",
718+
)
719+
720+
721+
@register_fake("bitsandbytes::swiglu_forward")
722+
def _(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
723+
torch._check(gate.shape == up.shape, lambda: "gate and up must have same shape")
724+
return torch.empty_like(gate)
725+
726+
727+
# SwiGLU backward: (grad_gate, grad_up) from grad_h
728+
torch.library.define(
729+
"bitsandbytes::swiglu_backward",
730+
"(Tensor grad_h, Tensor gate, Tensor up) -> (Tensor, Tensor)",
731+
)
732+
733+
734+
@register_fake("bitsandbytes::swiglu_backward")
735+
def _(grad_h: torch.Tensor, gate: torch.Tensor, up: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
736+
return torch.empty_like(gate), torch.empty_like(up)
737+
738+
739+
# RMSNorm forward: y = x * rsqrt(mean(x^2) + eps) * w, also returns rrms
740+
torch.library.define(
741+
"bitsandbytes::rmsnorm_forward",
742+
"(Tensor x, Tensor w, float eps, bool add_unit_offset) -> (Tensor, Tensor)",
743+
)
744+
745+
746+
@register_fake("bitsandbytes::rmsnorm_forward")
747+
def _(x: torch.Tensor, w: torch.Tensor, eps: float, add_unit_offset: bool) -> tuple[torch.Tensor, torch.Tensor]:
748+
torch._check(x.dim() == 2, lambda: "x must be 2D [rows, cols]")
749+
rows = x.shape[0]
750+
out = torch.empty_like(x)
751+
rrms = torch.empty(rows, device=x.device, dtype=torch.float32)
752+
return out, rrms
753+
754+
755+
# RMSNorm backward: (grad_x, grad_w) from grad_out
756+
torch.library.define(
757+
"bitsandbytes::rmsnorm_backward",
758+
"(Tensor grad_out, Tensor x, Tensor w, Tensor rrms, bool add_unit_offset) -> (Tensor, Tensor)",
759+
)
760+
761+
762+
@register_fake("bitsandbytes::rmsnorm_backward")
763+
def _(
764+
grad_out: torch.Tensor,
765+
x: torch.Tensor,
766+
w: torch.Tensor,
767+
rrms: torch.Tensor,
768+
add_unit_offset: bool,
769+
) -> tuple[torch.Tensor, torch.Tensor]:
770+
grad_x = torch.empty_like(x)
771+
grad_w = torch.empty(x.shape[1], device=x.device, dtype=torch.float32)
772+
return grad_x, grad_w
773+
774+
775+
# RoPE forward (in-place): applies rotary embeddings to Q (or Q+K)
776+
torch.library.define(
777+
"bitsandbytes::rope_forward",
778+
"(Tensor(a!) q, Tensor cos_cache, Tensor sin_cache, int n_heads) -> ()",
779+
)
780+
781+
782+
@register_fake("bitsandbytes::rope_forward")
783+
def _(q: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, n_heads: int) -> None:
784+
pass
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""torch.autograd.Function wrappers for CUDA training kernels.
2+
3+
Wraps the low-level CUDA ops (SwiGLU, RMSNorm, RoPE) into autograd-aware
4+
functions that can be used directly in PyTorch training.
5+
"""
6+
7+
import torch
8+
9+
10+
class SwiGLUFunction(torch.autograd.Function):
11+
"""SwiGLU activation: h = silu(gate) * up.
12+
13+
Forward: h = (gate * sigmoid(gate)) * up
14+
Backward: grad_gate = grad_h * up * sigmoid(gate) * (1 + gate * (1 - sigmoid(gate)))
15+
grad_up = grad_h * silu(gate)
16+
"""
17+
18+
@staticmethod
19+
def forward(ctx, gate, up):
20+
ctx.save_for_backward(gate, up)
21+
return torch.ops.bitsandbytes.swiglu_forward(gate, up)
22+
23+
@staticmethod
24+
def backward(ctx, grad_h):
25+
gate, up = ctx.saved_tensors
26+
grad_gate, grad_up = torch.ops.bitsandbytes.swiglu_backward(
27+
grad_h.contiguous(), gate, up,
28+
)
29+
return grad_gate, grad_up
30+
31+
32+
def swiglu(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
33+
"""SwiGLU activation with autograd support.
34+
35+
Args:
36+
gate: Gate tensor (any shape, fp16 or bf16).
37+
up: Up tensor (same shape as gate).
38+
39+
Returns:
40+
silu(gate) * up
41+
"""
42+
return SwiGLUFunction.apply(gate, up)
43+
44+
45+
class RMSNormFunction(torch.autograd.Function):
46+
"""RMS normalization: y = x * rsqrt(mean(x^2) + eps) * w.
47+
48+
Supports Gemma variant with ``add_unit_offset=True`` (uses w + 1).
49+
"""
50+
51+
@staticmethod
52+
def forward(ctx, x, w, eps=1e-6, add_unit_offset=False):
53+
# Flatten to 2D for the CUDA kernel
54+
orig_shape = x.shape
55+
x_2d = x.reshape(-1, x.shape[-1]).contiguous()
56+
57+
out_2d, rrms = torch.ops.bitsandbytes.rmsnorm_forward(
58+
x_2d, w, eps, add_unit_offset,
59+
)
60+
61+
ctx.save_for_backward(x_2d, w, rrms)
62+
ctx.add_unit_offset = add_unit_offset
63+
ctx.orig_shape = orig_shape
64+
65+
return out_2d.reshape(orig_shape)
66+
67+
@staticmethod
68+
def backward(ctx, grad_out):
69+
x_2d, w, rrms = ctx.saved_tensors
70+
grad_out_2d = grad_out.reshape(x_2d.shape).contiguous()
71+
72+
grad_x_2d, grad_w = torch.ops.bitsandbytes.rmsnorm_backward(
73+
grad_out_2d, x_2d, w, rrms, ctx.add_unit_offset,
74+
)
75+
76+
grad_x = grad_x_2d.reshape(ctx.orig_shape)
77+
return grad_x, grad_w.to(w.dtype), None, None
78+
79+
80+
def rmsnorm(
81+
x: torch.Tensor,
82+
w: torch.Tensor,
83+
eps: float = 1e-6,
84+
add_unit_offset: bool = False,
85+
) -> torch.Tensor:
86+
"""RMS normalization with autograd support.
87+
88+
Args:
89+
x: Input tensor (*, hidden_size), fp16 or bf16.
90+
w: Weight tensor (hidden_size,).
91+
eps: Epsilon for numerical stability.
92+
add_unit_offset: If True, uses (w + 1) instead of w (Gemma convention).
93+
94+
Returns:
95+
Normalized tensor of same shape as x.
96+
"""
97+
return RMSNormFunction.apply(x, w, eps, add_unit_offset)
98+
99+
100+
class RoPEFunction(torch.autograd.Function):
101+
"""Rotary Position Embedding (in-place).
102+
103+
Forward: q[..., :half] = q[..., :half] * cos - q[..., half:] * sin
104+
q[..., half:] = q[..., half:] * cos + q[..., :half] * sin
105+
Backward: same operation with sin negated.
106+
"""
107+
108+
@staticmethod
109+
def forward(ctx, q, cos_cache, sin_cache, n_heads):
110+
# q: [total_tokens, n_heads, head_dim]
111+
ctx.save_for_backward(cos_cache, sin_cache)
112+
ctx.n_heads = n_heads
113+
114+
q_out = q.clone()
115+
torch.ops.bitsandbytes.rope_forward(q_out, cos_cache, sin_cache, n_heads)
116+
return q_out
117+
118+
@staticmethod
119+
def backward(ctx, grad_q):
120+
cos_cache, sin_cache = ctx.saved_tensors
121+
122+
# Backward of RoPE is the same operation with sin negated
123+
grad_q_out = grad_q.clone()
124+
torch.ops.bitsandbytes.rope_forward(
125+
grad_q_out, cos_cache, -sin_cache, ctx.n_heads,
126+
)
127+
return grad_q_out, None, None, None
128+
129+
130+
def rope(
131+
q: torch.Tensor,
132+
cos_cache: torch.Tensor,
133+
sin_cache: torch.Tensor,
134+
n_heads: int,
135+
) -> torch.Tensor:
136+
"""Apply Rotary Position Embedding with autograd support.
137+
138+
Args:
139+
q: Query tensor [total_tokens, n_heads, head_dim], fp16 or bf16.
140+
cos_cache: Cosine cache [total_tokens, head_dim/2].
141+
sin_cache: Sine cache [total_tokens, head_dim/2].
142+
n_heads: Number of attention heads.
143+
144+
Returns:
145+
Rotated query tensor (same shape).
146+
"""
147+
return RoPEFunction.apply(q, cos_cache, sin_cache, n_heads)

bitsandbytes/backends/cuda/ops.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,3 +1231,144 @@ def _(
12311231
)
12321232

12331233
return C_concat
1234+
1235+
1236+
# ============================================================================
1237+
# Training Kernels: SwiGLU, RMSNorm, RoPE
1238+
# ============================================================================
1239+
1240+
1241+
@register_kernel("bitsandbytes::swiglu_forward", "cuda")
1242+
def _(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
1243+
torch._check(gate.shape == up.shape, lambda: "gate and up must have same shape")
1244+
torch._check(gate.is_contiguous(), lambda: "gate must be contiguous")
1245+
torch._check(up.is_contiguous(), lambda: "up must be contiguous")
1246+
torch._check(
1247+
gate.dtype in (torch.float16, torch.bfloat16),
1248+
lambda: f"swiglu supports float16/bfloat16, got {gate.dtype}",
1249+
)
1250+
1251+
out = torch.empty_like(gate)
1252+
n = gate.numel()
1253+
dtype_suffix = "fp16" if gate.dtype == torch.float16 else "bf16"
1254+
1255+
with _cuda_device_of(gate):
1256+
fn = getattr(lib, f"cswiglu_forward_{dtype_suffix}_c")
1257+
fn(get_ptr(gate), get_ptr(up), get_ptr(out), ct.c_int(n))
1258+
1259+
return out
1260+
1261+
1262+
@register_kernel("bitsandbytes::swiglu_backward", "cuda")
1263+
def _(grad_h: torch.Tensor, gate: torch.Tensor, up: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
1264+
torch._check(grad_h.is_contiguous(), lambda: "grad_h must be contiguous")
1265+
torch._check(gate.is_contiguous(), lambda: "gate must be contiguous")
1266+
torch._check(up.is_contiguous(), lambda: "up must be contiguous")
1267+
1268+
grad_gate = torch.empty_like(gate)
1269+
grad_up = torch.empty_like(up)
1270+
n = gate.numel()
1271+
dtype_suffix = "fp16" if gate.dtype == torch.float16 else "bf16"
1272+
1273+
with _cuda_device_of(gate):
1274+
fn = getattr(lib, f"cswiglu_backward_{dtype_suffix}_c")
1275+
fn(get_ptr(grad_h), get_ptr(gate), get_ptr(up), get_ptr(grad_gate), get_ptr(grad_up), ct.c_int(n))
1276+
1277+
return grad_gate, grad_up
1278+
1279+
1280+
@register_kernel("bitsandbytes::rmsnorm_forward", "cuda")
1281+
def _(
1282+
x: torch.Tensor,
1283+
w: torch.Tensor,
1284+
eps: float,
1285+
add_unit_offset: bool,
1286+
) -> tuple[torch.Tensor, torch.Tensor]:
1287+
torch._check(x.dim() == 2, lambda: "x must be 2D [rows, cols]")
1288+
torch._check(x.is_contiguous(), lambda: "x must be contiguous")
1289+
torch._check(
1290+
x.dtype in (torch.float16, torch.bfloat16),
1291+
lambda: f"rmsnorm supports float16/bfloat16, got {x.dtype}",
1292+
)
1293+
1294+
rows, cols = x.shape
1295+
out = torch.empty_like(x)
1296+
rrms = torch.empty(rows, device=x.device, dtype=torch.float32)
1297+
dtype_suffix = "fp16" if x.dtype == torch.float16 else "bf16"
1298+
1299+
with _cuda_device_of(x):
1300+
fn = getattr(lib, f"crmsnorm_forward_{dtype_suffix}_c")
1301+
fn(
1302+
get_ptr(x),
1303+
get_ptr(w),
1304+
get_ptr(out),
1305+
get_ptr(rrms),
1306+
ct.c_int(rows),
1307+
ct.c_int(cols),
1308+
ct.c_float(eps),
1309+
ct.c_bool(add_unit_offset),
1310+
)
1311+
1312+
return out, rrms
1313+
1314+
1315+
@register_kernel("bitsandbytes::rmsnorm_backward", "cuda")
1316+
def _(
1317+
grad_out: torch.Tensor,
1318+
x: torch.Tensor,
1319+
w: torch.Tensor,
1320+
rrms: torch.Tensor,
1321+
add_unit_offset: bool,
1322+
) -> tuple[torch.Tensor, torch.Tensor]:
1323+
torch._check(grad_out.is_contiguous(), lambda: "grad_out must be contiguous")
1324+
torch._check(x.is_contiguous(), lambda: "x must be contiguous")
1325+
1326+
rows, cols = x.shape
1327+
grad_x = torch.empty_like(x)
1328+
grad_w = torch.zeros(cols, device=x.device, dtype=torch.float32)
1329+
dtype_suffix = "fp16" if x.dtype == torch.float16 else "bf16"
1330+
1331+
with _cuda_device_of(x):
1332+
fn = getattr(lib, f"crmsnorm_backward_{dtype_suffix}_c")
1333+
fn(
1334+
get_ptr(grad_out),
1335+
get_ptr(x),
1336+
get_ptr(w),
1337+
get_ptr(rrms),
1338+
get_ptr(grad_x),
1339+
get_ptr(grad_w),
1340+
ct.c_int(rows),
1341+
ct.c_int(cols),
1342+
ct.c_bool(add_unit_offset),
1343+
)
1344+
1345+
return grad_x, grad_w
1346+
1347+
1348+
@register_kernel("bitsandbytes::rope_forward", "cuda")
1349+
def _(
1350+
q: torch.Tensor,
1351+
cos_cache: torch.Tensor,
1352+
sin_cache: torch.Tensor,
1353+
n_heads: int,
1354+
) -> None:
1355+
torch._check(q.is_contiguous(), lambda: "q must be contiguous")
1356+
torch._check(
1357+
q.dtype in (torch.float16, torch.bfloat16),
1358+
lambda: f"rope supports float16/bfloat16, got {q.dtype}",
1359+
)
1360+
1361+
total_tokens = q.shape[0]
1362+
head_dim = q.shape[-1]
1363+
dtype_suffix = "fp16" if q.dtype == torch.float16 else "bf16"
1364+
1365+
with _cuda_device_of(q):
1366+
fn = getattr(lib, f"crope_forward_{dtype_suffix}_c")
1367+
fn(
1368+
get_ptr(q),
1369+
get_ptr(cos_cache),
1370+
get_ptr(sin_cache),
1371+
ct.c_int(total_tokens),
1372+
ct.c_int(n_heads),
1373+
ct.c_int(head_dim),
1374+
)

0 commit comments

Comments
 (0)