Skip to content

Commit 5bdc545

Browse files
committed
chunk over seq
1 parent 846dc1f commit 5bdc545

2 files changed

Lines changed: 117 additions & 66 deletions

File tree

src/liger_kernel/chunked_loss/fused_linear_ppo.py

Lines changed: 76 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -5,100 +5,113 @@
55
import torch._dynamo.config
66

77

8-
_SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE = 2048
8+
_SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE = 4096
9+
_SELECTIVE_LOGPROB_SEQ_CHUNK_SIZE = 2048
910

1011

1112
def _maybe_mark_dynamic_dim1(tensor):
1213
if tensor is not None:
1314
torch._dynamo.maybe_mark_dynamic(tensor, 1)
1415

1516

16-
def _selective_logprob_forward(hidden, weight, targets, bias=None, temperature=1.0, vocab_chunk_size=2048):
17-
"""Compute selective log-probabilities by streaming over vocab chunks (cuBLAS per chunk).
17+
def _selective_logprob_forward(hidden, weight, targets, bias=None, temperature=1.0, vocab_chunk_size=8192):
18+
"""Compute selective log-probabilities by streaming over sequence and vocab chunks.
1819
19-
Uses in-place ops and pre-allocated buffers for memory efficiency.
20+
Dual chunking (sequence × vocab) bounds peak temporary memory to
21+
``seq_chunk_size × vocab_chunk_size`` regardless of total N or V.
2022
"""
2123
device = hidden.device
2224
n_rows, _ = hidden.shape
2325
vocab_size, _ = weight.shape
2426
inv_t = 1.0 / temperature
27+
seq_chunk_size = _SELECTIVE_LOGPROB_SEQ_CHUNK_SIZE
28+
29+
logprobs = torch.empty((n_rows,), device=device, dtype=torch.float32)
30+
log_z = torch.empty((n_rows,), device=device, dtype=torch.float32)
31+
32+
for seq_start in range(0, n_rows, seq_chunk_size):
33+
seq_end = min(seq_start + seq_chunk_size, n_rows)
34+
n_chunk = seq_end - seq_start
35+
hidden_chunk = hidden[seq_start:seq_end]
36+
targets_chunk = targets[seq_start:seq_end]
37+
38+
max_old = torch.full((n_chunk,), float("-inf"), device=device, dtype=torch.float32)
39+
sum_exp = torch.zeros((n_chunk,), device=device, dtype=torch.float32)
40+
target_logit = torch.zeros((n_chunk,), device=device, dtype=torch.float32)
41+
row_idx = torch.arange(n_chunk, device=device)
42+
43+
for vocab_start in range(0, vocab_size, vocab_chunk_size):
44+
vocab_end = min(vocab_start + vocab_chunk_size, vocab_size)
45+
weight_chunk = weight[vocab_start:vocab_end]
46+
logits_chunk = (hidden_chunk @ weight_chunk.to(hidden.dtype).t()).float()
47+
if bias is not None:
48+
logits_chunk.add_(bias[vocab_start:vocab_end].to(torch.float32))
49+
logits_chunk.mul_(inv_t)
2550

26-
max_old = torch.full((n_rows,), float("-inf"), device=device, dtype=torch.float32)
27-
sum_exp = torch.zeros((n_rows,), device=device, dtype=torch.float32)
28-
target_logit = torch.zeros((n_rows,), device=device, dtype=torch.float32)
29-
30-
mm_buf = torch.empty((n_rows, vocab_chunk_size), device=device, dtype=hidden.dtype)
31-
logits_buf = torch.empty((n_rows, vocab_chunk_size), device=device, dtype=torch.float32)
32-
row_idx = torch.arange(n_rows, device=device)
33-
34-
for start in range(0, vocab_size, vocab_chunk_size):
35-
end = min(start + vocab_chunk_size, vocab_size)
36-
chunk_width = end - start
37-
weight_chunk = weight[start:end].to(hidden.dtype)
38-
torch.mm(hidden, weight_chunk.t(), out=mm_buf[:, :chunk_width])
39-
logits_chunk = logits_buf[:, :chunk_width]
40-
logits_chunk.copy_(mm_buf[:, :chunk_width])
41-
if bias is not None:
42-
logits_chunk.add_(bias[start:end].to(torch.float32))
43-
logits_chunk.mul_(inv_t)
51+
chunk_max = logits_chunk.amax(dim=-1)
52+
max_new = torch.maximum(max_old, chunk_max)
53+
rescale = torch.exp(max_old - max_new)
54+
chunk_exp = torch.exp(logits_chunk - max_new.unsqueeze(-1))
4455

45-
chunk_max = logits_chunk.amax(dim=-1)
46-
max_new = torch.maximum(max_old, chunk_max)
47-
rescale = torch.exp(max_old - max_new)
48-
chunk_exp = torch.exp(logits_chunk - max_new.unsqueeze(-1))
56+
sum_exp = sum_exp * rescale + chunk_exp.sum(dim=-1)
57+
max_old = max_new
4958

50-
sum_exp = sum_exp * rescale + chunk_exp.sum(dim=-1)
51-
max_old = max_new
59+
in_chunk = (targets_chunk >= vocab_start) & (targets_chunk < vocab_end)
60+
local_idx = torch.clamp(targets_chunk - vocab_start, 0, vocab_end - vocab_start - 1)
61+
target_logit += logits_chunk[row_idx, local_idx] * in_chunk
5262

53-
in_chunk = (targets >= start) & (targets < end)
54-
local_idx = torch.clamp(targets - start, 0, end - start - 1)
55-
target_logit += logits_chunk[row_idx, local_idx] * in_chunk
63+
log_z_chunk = max_old + torch.log(sum_exp)
64+
log_z[seq_start:seq_end] = log_z_chunk
65+
logprobs[seq_start:seq_end] = target_logit - log_z_chunk
5666

57-
log_z = max_old + torch.log(sum_exp)
58-
return target_logit - log_z, log_z
67+
return logprobs, log_z
5968

6069

6170
def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logprobs, temperature, vocab_chunk_size):
62-
"""Vocab-chunked backward for selective logprob (recomputes logits per chunk for memory efficiency)."""
71+
"""Dual-chunked (sequence × vocab) backward for selective logprob.
72+
73+
Recomputes logits per chunk for memory efficiency.
74+
"""
6375
inv_t = 1.0 / temperature
6476
n_rows, _ = hidden.shape
6577
vocab_size = weight.shape[0]
6678
has_bias = bias is not None
79+
seq_chunk_size = _SELECTIVE_LOGPROB_SEQ_CHUNK_SIZE
6780

6881
grad_hidden = torch.zeros(hidden.shape, device=hidden.device, dtype=torch.float32)
6982
grad_weight = torch.zeros(weight.shape, device=weight.device, dtype=torch.float32)
7083
grad_bias = torch.zeros((vocab_size,), device=weight.device, dtype=torch.float32) if has_bias else None
7184

72-
mm_buf = torch.empty((n_rows, vocab_chunk_size), device=hidden.device, dtype=hidden.dtype)
73-
logits_buf = torch.empty((n_rows, vocab_chunk_size), device=hidden.device, dtype=torch.float32)
74-
7585
grad_logprobs = grad_logprobs.to(torch.float32)
76-
row_idx = torch.arange(n_rows, device=hidden.device)
77-
78-
for start in range(0, vocab_size, vocab_chunk_size):
79-
end = min(start + vocab_chunk_size, vocab_size)
80-
chunk_width = end - start
81-
weight_chunk = weight[start:end]
82-
83-
torch.mm(hidden, weight_chunk.t(), out=mm_buf[:, :chunk_width])
84-
logits_chunk = logits_buf[:, :chunk_width]
85-
logits_chunk.copy_(mm_buf[:, :chunk_width])
86-
if has_bias:
87-
logits_chunk.add_(bias[start:end].to(torch.float32))
88-
logits_chunk.mul_(inv_t)
89-
90-
probs = torch.exp(logits_chunk - log_z.unsqueeze(-1))
91-
grad_logits = (-grad_logprobs).unsqueeze(-1) * probs
92-
93-
in_chunk = (targets >= start) & (targets < end)
94-
local_idx = torch.clamp(targets - start, 0, end - start - 1)
95-
grad_logits[row_idx, local_idx] += grad_logprobs * in_chunk
96-
grad_logits.mul_(inv_t)
97-
98-
grad_hidden.add_(grad_logits @ weight_chunk.float())
99-
grad_weight[start:end].add_(grad_logits.t() @ hidden.float())
100-
if has_bias:
101-
grad_bias[start:end].add_(grad_logits.sum(dim=0))
86+
87+
for seq_start in range(0, n_rows, seq_chunk_size):
88+
seq_end = min(seq_start + seq_chunk_size, n_rows)
89+
hidden_chunk = hidden[seq_start:seq_end]
90+
targets_chunk = targets[seq_start:seq_end]
91+
grad_chunk = grad_logprobs[seq_start:seq_end]
92+
logz_chunk = log_z[seq_start:seq_end]
93+
row_idx = torch.arange(seq_end - seq_start, device=hidden.device)
94+
95+
for vocab_start in range(0, vocab_size, vocab_chunk_size):
96+
vocab_end = min(vocab_start + vocab_chunk_size, vocab_size)
97+
weight_chunk = weight[vocab_start:vocab_end]
98+
logits_chunk = (hidden_chunk @ weight_chunk.to(hidden.dtype).t()).float()
99+
if has_bias:
100+
logits_chunk.add_(bias[vocab_start:vocab_end].to(torch.float32))
101+
logits_chunk.mul_(inv_t)
102+
103+
probs = torch.exp(logits_chunk - logz_chunk.unsqueeze(-1))
104+
grad_logits = (-grad_chunk).unsqueeze(-1) * probs
105+
106+
in_chunk = (targets_chunk >= vocab_start) & (targets_chunk < vocab_end)
107+
local_idx = torch.clamp(targets_chunk - vocab_start, 0, vocab_end - vocab_start - 1)
108+
grad_logits[row_idx, local_idx] += grad_chunk * in_chunk
109+
grad_logits.mul_(inv_t)
110+
111+
grad_hidden[seq_start:seq_end].add_(grad_logits @ weight_chunk.float())
112+
grad_weight[vocab_start:vocab_end].add_(grad_logits.t() @ hidden_chunk.float())
113+
if has_bias:
114+
grad_bias[vocab_start:vocab_end].add_(grad_logits.sum(dim=0))
102115

103116
return grad_hidden, grad_weight, grad_bias
104117

test/chunked_loss/test_grpo_loss.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,14 @@ def forward(
320320

321321
@pytest.mark.parametrize("dtype, atol, rtol", [(torch.float32, 1e-5, 1e-5), (torch.bfloat16, 5e-2, 5e-2)])
322322
@pytest.mark.parametrize("bias", [True, False])
323-
def test_selective_chunk_forward_matches_reference(dtype, atol, rtol, bias):
324-
B, T, H, V = 3, 17, 31, 123
323+
@pytest.mark.parametrize(
324+
"B, T, H, V",
325+
[
326+
(3, 17, 31, 123), # small: no chunking exercised
327+
(1, 4096, 256, 5000), # large: exercises both sequence and vocab chunking
328+
],
329+
)
330+
def test_selective_chunk_forward_matches_reference(B, T, H, V, dtype, atol, rtol, bias):
325331
x = torch.randn(B, T, H, device=device, dtype=dtype, requires_grad=True)
326332
weight = torch.randn(V, H, device=device, dtype=dtype, requires_grad=True)
327333
bias_tensor = torch.randn(V, device=device, dtype=dtype, requires_grad=True) if bias else None
@@ -337,6 +343,38 @@ def test_selective_chunk_forward_matches_reference(dtype, atol, rtol, bias):
337343
assert_verbose_allclose(out, ref, atol=atol, rtol=rtol)
338344

339345

346+
@pytest.mark.parametrize("loss_type", ["dapo", "grpo"])
347+
@pytest.mark.parametrize("compiled", [True, False])
348+
def test_correctness_large_seq_exercises_chunking(loss_type, compiled):
349+
"""Test with N > seq_chunk_size and V > vocab_chunk_size to exercise both chunking loops."""
350+
torch.compiler.reset()
351+
B, T, H, V = 1, 4096, 256, 5000
352+
dtype = torch.float32
353+
354+
torch_lm = TorchLMHeadGRPO(H=H, V=V, dtype=dtype, beta=0.04, loss_type=loss_type, use_ref_model=False)
355+
liger_lm = LigerLMHeadGRPO(H=H, V=V, dtype=dtype, beta=0.04, loss_type=loss_type, use_ref_model=False)
356+
357+
torch_lm.lin.weight.data = liger_lm.lin.weight.data = torch.randn(V, H, device=device, dtype=dtype)
358+
359+
_input = torch.randn(B, T, H, device=device, dtype=dtype)
360+
input1 = _input.detach().clone().requires_grad_(True)
361+
input2 = _input.detach().clone().requires_grad_(True)
362+
selected_token_ids = torch.randint(0, V, (B, T), device=device)
363+
attention_mask = torch.ones(B, T, device=device)
364+
attention_mask[:, -64:] = 0
365+
advantages = torch.randn(B, device=device, dtype=dtype)
366+
367+
loss1, _ = torch_lm(input1, selected_token_ids, attention_mask, advantages)
368+
loss2, _ = liger_lm(input2, selected_token_ids, attention_mask, advantages)
369+
370+
assert_verbose_allclose(loss1, loss2, atol=2e-5, rtol=1e-3)
371+
372+
loss1.backward()
373+
loss2.backward()
374+
assert_verbose_allclose(input1.grad, input2.grad, atol=2e-5, rtol=1e-3)
375+
assert_verbose_allclose(torch_lm.lin.weight.grad, liger_lm.lin.weight.grad, atol=2e-5, rtol=1e-3)
376+
377+
340378
@pytest.mark.parametrize(
341379
"B, T, H, V",
342380
[
@@ -348,7 +386,7 @@ def test_selective_chunk_forward_matches_reference(dtype, atol, rtol, bias):
348386
"scalar, dtype, atol, rtol",
349387
[
350388
(1.0, torch.bfloat16, 1e-1, 5e-1),
351-
(1.0, torch.float32, 1e-5, 5e-4),
389+
(1.0, torch.float32, 2e-5, 1e-3),
352390
],
353391
)
354392
@pytest.mark.parametrize("bias", [True, False])

0 commit comments

Comments
 (0)