|
5 | 5 | import torch._dynamo.config |
6 | 6 |
|
7 | 7 |
|
8 | | -_SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE = 2048 |
| 8 | +_SELECTIVE_LOGPROB_VOCAB_CHUNK_SIZE = 4096 |
| 9 | +_SELECTIVE_LOGPROB_SEQ_CHUNK_SIZE = 2048 |
9 | 10 |
|
10 | 11 |
|
11 | 12 | def _maybe_mark_dynamic_dim1(tensor): |
12 | 13 | if tensor is not None: |
13 | 14 | torch._dynamo.maybe_mark_dynamic(tensor, 1) |
14 | 15 |
|
15 | 16 |
|
16 | | -def _selective_logprob_forward(hidden, weight, targets, bias=None, temperature=1.0, vocab_chunk_size=2048): |
17 | | - """Compute selective log-probabilities by streaming over vocab chunks (cuBLAS per chunk). |
| 17 | +def _selective_logprob_forward(hidden, weight, targets, bias=None, temperature=1.0, vocab_chunk_size=8192): |
| 18 | + """Compute selective log-probabilities by streaming over sequence and vocab chunks. |
18 | 19 |
|
19 | | - Uses in-place ops and pre-allocated buffers for memory efficiency. |
| 20 | + Dual chunking (sequence × vocab) bounds peak temporary memory to |
| 21 | + ``seq_chunk_size × vocab_chunk_size`` regardless of total N or V. |
20 | 22 | """ |
21 | 23 | device = hidden.device |
22 | 24 | n_rows, _ = hidden.shape |
23 | 25 | vocab_size, _ = weight.shape |
24 | 26 | inv_t = 1.0 / temperature |
| 27 | + seq_chunk_size = _SELECTIVE_LOGPROB_SEQ_CHUNK_SIZE |
| 28 | + |
| 29 | + logprobs = torch.empty((n_rows,), device=device, dtype=torch.float32) |
| 30 | + log_z = torch.empty((n_rows,), device=device, dtype=torch.float32) |
| 31 | + |
| 32 | + for seq_start in range(0, n_rows, seq_chunk_size): |
| 33 | + seq_end = min(seq_start + seq_chunk_size, n_rows) |
| 34 | + n_chunk = seq_end - seq_start |
| 35 | + hidden_chunk = hidden[seq_start:seq_end] |
| 36 | + targets_chunk = targets[seq_start:seq_end] |
| 37 | + |
| 38 | + max_old = torch.full((n_chunk,), float("-inf"), device=device, dtype=torch.float32) |
| 39 | + sum_exp = torch.zeros((n_chunk,), device=device, dtype=torch.float32) |
| 40 | + target_logit = torch.zeros((n_chunk,), device=device, dtype=torch.float32) |
| 41 | + row_idx = torch.arange(n_chunk, device=device) |
| 42 | + |
| 43 | + for vocab_start in range(0, vocab_size, vocab_chunk_size): |
| 44 | + vocab_end = min(vocab_start + vocab_chunk_size, vocab_size) |
| 45 | + weight_chunk = weight[vocab_start:vocab_end] |
| 46 | + logits_chunk = (hidden_chunk @ weight_chunk.to(hidden.dtype).t()).float() |
| 47 | + if bias is not None: |
| 48 | + logits_chunk.add_(bias[vocab_start:vocab_end].to(torch.float32)) |
| 49 | + logits_chunk.mul_(inv_t) |
25 | 50 |
|
26 | | - max_old = torch.full((n_rows,), float("-inf"), device=device, dtype=torch.float32) |
27 | | - sum_exp = torch.zeros((n_rows,), device=device, dtype=torch.float32) |
28 | | - target_logit = torch.zeros((n_rows,), device=device, dtype=torch.float32) |
29 | | - |
30 | | - mm_buf = torch.empty((n_rows, vocab_chunk_size), device=device, dtype=hidden.dtype) |
31 | | - logits_buf = torch.empty((n_rows, vocab_chunk_size), device=device, dtype=torch.float32) |
32 | | - row_idx = torch.arange(n_rows, device=device) |
33 | | - |
34 | | - for start in range(0, vocab_size, vocab_chunk_size): |
35 | | - end = min(start + vocab_chunk_size, vocab_size) |
36 | | - chunk_width = end - start |
37 | | - weight_chunk = weight[start:end].to(hidden.dtype) |
38 | | - torch.mm(hidden, weight_chunk.t(), out=mm_buf[:, :chunk_width]) |
39 | | - logits_chunk = logits_buf[:, :chunk_width] |
40 | | - logits_chunk.copy_(mm_buf[:, :chunk_width]) |
41 | | - if bias is not None: |
42 | | - logits_chunk.add_(bias[start:end].to(torch.float32)) |
43 | | - logits_chunk.mul_(inv_t) |
| 51 | + chunk_max = logits_chunk.amax(dim=-1) |
| 52 | + max_new = torch.maximum(max_old, chunk_max) |
| 53 | + rescale = torch.exp(max_old - max_new) |
| 54 | + chunk_exp = torch.exp(logits_chunk - max_new.unsqueeze(-1)) |
44 | 55 |
|
45 | | - chunk_max = logits_chunk.amax(dim=-1) |
46 | | - max_new = torch.maximum(max_old, chunk_max) |
47 | | - rescale = torch.exp(max_old - max_new) |
48 | | - chunk_exp = torch.exp(logits_chunk - max_new.unsqueeze(-1)) |
| 56 | + sum_exp = sum_exp * rescale + chunk_exp.sum(dim=-1) |
| 57 | + max_old = max_new |
49 | 58 |
|
50 | | - sum_exp = sum_exp * rescale + chunk_exp.sum(dim=-1) |
51 | | - max_old = max_new |
| 59 | + in_chunk = (targets_chunk >= vocab_start) & (targets_chunk < vocab_end) |
| 60 | + local_idx = torch.clamp(targets_chunk - vocab_start, 0, vocab_end - vocab_start - 1) |
| 61 | + target_logit += logits_chunk[row_idx, local_idx] * in_chunk |
52 | 62 |
|
53 | | - in_chunk = (targets >= start) & (targets < end) |
54 | | - local_idx = torch.clamp(targets - start, 0, end - start - 1) |
55 | | - target_logit += logits_chunk[row_idx, local_idx] * in_chunk |
| 63 | + log_z_chunk = max_old + torch.log(sum_exp) |
| 64 | + log_z[seq_start:seq_end] = log_z_chunk |
| 65 | + logprobs[seq_start:seq_end] = target_logit - log_z_chunk |
56 | 66 |
|
57 | | - log_z = max_old + torch.log(sum_exp) |
58 | | - return target_logit - log_z, log_z |
| 67 | + return logprobs, log_z |
59 | 68 |
|
60 | 69 |
|
61 | 70 | def _selective_logprob_backward(hidden, weight, targets, bias, log_z, grad_logprobs, temperature, vocab_chunk_size): |
62 | | - """Vocab-chunked backward for selective logprob (recomputes logits per chunk for memory efficiency).""" |
| 71 | + """Dual-chunked (sequence × vocab) backward for selective logprob. |
| 72 | +
|
| 73 | + Recomputes logits per chunk for memory efficiency. |
| 74 | + """ |
63 | 75 | inv_t = 1.0 / temperature |
64 | 76 | n_rows, _ = hidden.shape |
65 | 77 | vocab_size = weight.shape[0] |
66 | 78 | has_bias = bias is not None |
| 79 | + seq_chunk_size = _SELECTIVE_LOGPROB_SEQ_CHUNK_SIZE |
67 | 80 |
|
68 | 81 | grad_hidden = torch.zeros(hidden.shape, device=hidden.device, dtype=torch.float32) |
69 | 82 | grad_weight = torch.zeros(weight.shape, device=weight.device, dtype=torch.float32) |
70 | 83 | grad_bias = torch.zeros((vocab_size,), device=weight.device, dtype=torch.float32) if has_bias else None |
71 | 84 |
|
72 | | - mm_buf = torch.empty((n_rows, vocab_chunk_size), device=hidden.device, dtype=hidden.dtype) |
73 | | - logits_buf = torch.empty((n_rows, vocab_chunk_size), device=hidden.device, dtype=torch.float32) |
74 | | - |
75 | 85 | grad_logprobs = grad_logprobs.to(torch.float32) |
76 | | - row_idx = torch.arange(n_rows, device=hidden.device) |
77 | | - |
78 | | - for start in range(0, vocab_size, vocab_chunk_size): |
79 | | - end = min(start + vocab_chunk_size, vocab_size) |
80 | | - chunk_width = end - start |
81 | | - weight_chunk = weight[start:end] |
82 | | - |
83 | | - torch.mm(hidden, weight_chunk.t(), out=mm_buf[:, :chunk_width]) |
84 | | - logits_chunk = logits_buf[:, :chunk_width] |
85 | | - logits_chunk.copy_(mm_buf[:, :chunk_width]) |
86 | | - if has_bias: |
87 | | - logits_chunk.add_(bias[start:end].to(torch.float32)) |
88 | | - logits_chunk.mul_(inv_t) |
89 | | - |
90 | | - probs = torch.exp(logits_chunk - log_z.unsqueeze(-1)) |
91 | | - grad_logits = (-grad_logprobs).unsqueeze(-1) * probs |
92 | | - |
93 | | - in_chunk = (targets >= start) & (targets < end) |
94 | | - local_idx = torch.clamp(targets - start, 0, end - start - 1) |
95 | | - grad_logits[row_idx, local_idx] += grad_logprobs * in_chunk |
96 | | - grad_logits.mul_(inv_t) |
97 | | - |
98 | | - grad_hidden.add_(grad_logits @ weight_chunk.float()) |
99 | | - grad_weight[start:end].add_(grad_logits.t() @ hidden.float()) |
100 | | - if has_bias: |
101 | | - grad_bias[start:end].add_(grad_logits.sum(dim=0)) |
| 86 | + |
| 87 | + for seq_start in range(0, n_rows, seq_chunk_size): |
| 88 | + seq_end = min(seq_start + seq_chunk_size, n_rows) |
| 89 | + hidden_chunk = hidden[seq_start:seq_end] |
| 90 | + targets_chunk = targets[seq_start:seq_end] |
| 91 | + grad_chunk = grad_logprobs[seq_start:seq_end] |
| 92 | + logz_chunk = log_z[seq_start:seq_end] |
| 93 | + row_idx = torch.arange(seq_end - seq_start, device=hidden.device) |
| 94 | + |
| 95 | + for vocab_start in range(0, vocab_size, vocab_chunk_size): |
| 96 | + vocab_end = min(vocab_start + vocab_chunk_size, vocab_size) |
| 97 | + weight_chunk = weight[vocab_start:vocab_end] |
| 98 | + logits_chunk = (hidden_chunk @ weight_chunk.to(hidden.dtype).t()).float() |
| 99 | + if has_bias: |
| 100 | + logits_chunk.add_(bias[vocab_start:vocab_end].to(torch.float32)) |
| 101 | + logits_chunk.mul_(inv_t) |
| 102 | + |
| 103 | + probs = torch.exp(logits_chunk - logz_chunk.unsqueeze(-1)) |
| 104 | + grad_logits = (-grad_chunk).unsqueeze(-1) * probs |
| 105 | + |
| 106 | + in_chunk = (targets_chunk >= vocab_start) & (targets_chunk < vocab_end) |
| 107 | + local_idx = torch.clamp(targets_chunk - vocab_start, 0, vocab_end - vocab_start - 1) |
| 108 | + grad_logits[row_idx, local_idx] += grad_chunk * in_chunk |
| 109 | + grad_logits.mul_(inv_t) |
| 110 | + |
| 111 | + grad_hidden[seq_start:seq_end].add_(grad_logits @ weight_chunk.float()) |
| 112 | + grad_weight[vocab_start:vocab_end].add_(grad_logits.t() @ hidden_chunk.float()) |
| 113 | + if has_bias: |
| 114 | + grad_bias[vocab_start:vocab_end].add_(grad_logits.sum(dim=0)) |
102 | 115 |
|
103 | 116 | return grad_hidden, grad_weight, grad_bias |
104 | 117 |
|
|
0 commit comments