Skip to content

Commit cab4413

Browse files
committed
[DFlash] align train recipe with paper
1 parent 471afa9 commit cab4413

1 file changed

Lines changed: 38 additions & 5 deletions

File tree

examples/discrete_diffusion/train_dflash.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,13 @@ class TrainConfig:
5252
weight_decay: float
5353
lr_scheduler: str
5454
lr_warmup_steps: int
55+
lr_warmup_ratio: float
56+
max_grad_norm: float
5557

5658
max_length: int
5759
block_size: int
5860
mask_token: str
61+
loss_decay_gamma: float
5962

6063

6164
def parse_args() -> TrainConfig:
@@ -75,18 +78,32 @@ def parse_args() -> TrainConfig:
7578

7679
parser.add_argument("--per_device_train_batch_size", type=int, default=2)
7780
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
78-
parser.add_argument("--learning_rate", type=float, default=2e-5)
81+
parser.add_argument("--learning_rate", type=float, default=6e-4)
7982
parser.add_argument("--weight_decay", type=float, default=0.0)
8083
parser.add_argument(
8184
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
8285
)
83-
parser.add_argument("--lr_warmup_steps", type=int, default=100)
86+
parser.add_argument(
87+
"--lr_warmup_steps",
88+
type=int,
89+
default=0,
90+
help="Absolute warmup steps. Ignored when --lr_warmup_ratio > 0 (default).",
91+
)
92+
parser.add_argument("--lr_warmup_ratio", type=float, default=0.04)
93+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
8494

85-
parser.add_argument("--max_length", type=int, default=512)
95+
parser.add_argument("--max_length", type=int, default=3072)
8696
parser.add_argument(
8797
"--block_size", type=int, default=0, help="Override draft block size (0 uses the model config)."
8898
)
8999
parser.add_argument("--mask_token", type=str, default="<|MASK|>")
100+
parser.add_argument(
101+
"--loss_decay_gamma",
102+
type=float,
103+
default=0.0,
104+
help="Per-position loss decay γ for w_k = exp(-(k-1)/γ). 0 selects the paper default for the "
105+
"draft block size (γ=7 for block 16, γ=5 for block 10, γ=4 for block 8, else block_size/2).",
106+
)
90107

91108
args = parser.parse_args()
92109
return TrainConfig(**vars(args))
@@ -177,6 +194,14 @@ def main():
177194
if block_size < 2:
178195
raise ValueError("`block_size` must be at least 2 for DFlash training.")
179196

197+
# Eq. 4 in the DFlash paper: w_k = exp(-(k-1)/γ) over predicted positions k=1..block_size-1.
198+
# Defaults from Appendix A.1.
199+
if cfg.loss_decay_gamma > 0.0:
200+
loss_gamma = float(cfg.loss_decay_gamma)
201+
else:
202+
loss_gamma = {16: 7.0, 10: 5.0, 8: 4.0}.get(block_size, max(2.0, block_size / 2.0))
203+
pos_weights = torch.exp(-torch.arange(block_size - 1, dtype=torch.float32) / loss_gamma)
204+
180205
layer_ids = getattr(draft_model, "target_layer_ids", None)
181206
if layer_ids is None:
182207
cfg_draft = getattr(draft_model, "config", None)
@@ -208,10 +233,14 @@ def main():
208233
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
209234
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
210235

236+
if cfg.lr_warmup_ratio > 0.0:
237+
num_warmup_steps = int(cfg.lr_warmup_ratio * cfg.max_train_steps)
238+
else:
239+
num_warmup_steps = cfg.lr_warmup_steps
211240
lr_scheduler = get_scheduler(
212241
name=cfg.lr_scheduler,
213242
optimizer=optimizer,
214-
num_warmup_steps=cfg.lr_warmup_steps,
243+
num_warmup_steps=num_warmup_steps,
215244
num_training_steps=cfg.max_train_steps,
216245
)
217246

@@ -220,6 +249,7 @@ def main():
220249
)
221250
input_embeddings = get_target_input_embeddings(target_model)
222251
output_embeddings = get_target_output_embeddings(target_model)
252+
pos_weights = pos_weights.to(accelerator.device)
223253

224254
global_step = 0
225255
draft_model.train()
@@ -279,9 +309,12 @@ def main():
279309
vocab_size = logits.shape[-1]
280310
loss = F.cross_entropy(logits.view(-1, vocab_size), block_targets.reshape(-1), reduction="none")
281311
loss = loss.view(block_targets.shape[0], -1)
282-
loss = (loss * block_mask.to(loss.dtype)).sum() / block_mask.sum().clamp_min(1)
312+
weights = pos_weights.to(loss.dtype)[None, :].expand_as(loss) * block_mask.to(loss.dtype)
313+
loss = (loss * weights).sum() / weights.sum().clamp_min(1)
283314

284315
accelerator.backward(loss)
316+
if accelerator.sync_gradients and cfg.max_grad_norm > 0:
317+
accelerator.clip_grad_norm_(draft_model.parameters(), cfg.max_grad_norm)
285318
optimizer.step()
286319
lr_scheduler.step()
287320
optimizer.zero_grad(set_to_none=True)

0 commit comments

Comments
 (0)