forked from ML-GSAI/LLaDA
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathgenerate.py
More file actions
139 lines (115 loc) · 4.67 KB
/
generate.py
File metadata and controls
139 lines (115 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import numpy as np
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
def add_gumbel_noise(logits, temperature):
"""
The Gumbel max is a method for sampling categorical distributions.
According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
Thus, we use float64.
"""
if temperature == 0:
return logits
logits = logits.to(torch.float64)
noise = torch.rand_like(logits, dtype=torch.float64)
gumbel_noise = (-torch.log(noise)) ** temperature
return logits.exp() / gumbel_noise
def get_num_transfer_tokens(mask_index, steps):
"""
In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
the expected number of tokens transitioned at each step should be consistent.
This function is designed to precompute the number of tokens that need to be transitioned at each step.
"""
mask_num = mask_index.sum(dim=1, keepdim=True)
base = mask_num // steps
remainder = mask_num % steps
num_transfer_tokens = (
torch.zeros(
mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64
)
+ base
)
for i in range(mask_num.size(0)):
num_transfer_tokens[i, : remainder[i]] += 1
return num_transfer_tokens
@torch.no_grad()
def generate(
model,
prompt,
steps=128,
gen_length=128,
block_length=128,
temperature=0.0,
cfg_scale=0.0,
remasking="low_confidence",
mask_id=126336,
verbose=False,
tokenizer=None,
):
"""
Args:
model: Mask predictor.
prompt: A tensor of shape (1, L).
steps: Sampling steps, less than or equal to gen_length.
gen_length: Generated answer length.
block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
temperature: Categorical distribution sampling temperature.
cfg_scale: Unsupervised classifier-free guidance scale.
remasking: Remasking strategy. 'low_confidence' or 'random'.
mask_id: The toke id of [MASK] is 126336.
"""
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(
model.device
)
x[:, : prompt.shape[1]] = prompt.clone()
prompt_index = x != mask_id
assert gen_length % block_length == 0
num_blocks = gen_length // block_length
assert steps % num_blocks == 0
steps = steps // num_blocks
if verbose and tokenizer:
print(tokenizer.decode(*x))
for num_block in range(num_blocks):
block_mask_index = (
x[
:,
prompt.shape[1] + num_block * block_length : prompt.shape[1]
+ (num_block + 1) * block_length :,
]
== mask_id
)
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
for i in range(steps):
mask_index = x == mask_id
if cfg_scale > 0.0:
un_x = x.clone()
un_x[prompt_index] = mask_id
x_ = torch.cat([x, un_x], dim=0)
logits = model(x_).logits
logits, un_logits = torch.chunk(logits, 2, dim=0)
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
else:
logits = model(x).logits
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
if remasking == "low_confidence":
p = F.softmax(logits.to(torch.float64), dim=-1)
x0_p = torch.squeeze(
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
) # b, l
elif remasking == "random":
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
else:
raise NotImplementedError(remasking)
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length :] = -np.inf
x0 = torch.where(mask_index, x0, x)
confidence = torch.where(mask_index, x0_p, -np.inf)
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
for j in range(confidence.shape[0]):
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
transfer_index[j, select_index] = True
x[transfer_index] = x0[transfer_index]
if verbose and tokenizer:
print(tokenizer.decode(*x))
return x