Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ def parse_args():
training_group.add_argument("--accumulation-steps", type=int, default=1)
training_group.add_argument("--seed", type=int, default=42)
training_group.add_argument("--resume", action="store_true")
training_group.add_argument(
"--log-grad-norm",
action="store_true",
help="Log train/pre_clip_grad_norm during training.",
)

output_group = parser.add_argument_group("output")
output_group.add_argument("--output-dir", type=str, required=True)
Expand Down Expand Up @@ -328,6 +333,9 @@ def record_metrics(

if mode == "train" and optimizer is not None:
logdict["train/lr"] = optimizer.get_learning_rate()
grad_norm = optimizer.get_last_grad_norm()
if grad_norm is not None:
logdict["train/pre_clip_grad_norm"] = grad_norm

logdict[f"{mode}/loss"] = loss
logdict[f"{mode}/accuracy"] = accuracy
Expand Down Expand Up @@ -458,6 +466,7 @@ def main():
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
total_steps=total_steps,
log_grad_norm=args.log_grad_norm,
)

if resume_state is not None:
Expand Down
11 changes: 10 additions & 1 deletion scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,11 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
)
training_group.add_argument("--seed", type=int, default=0)
training_group.add_argument("--draft-accumulation-steps", type=int, default=1)
training_group.add_argument(
"--log-grad-norm",
action="store_true",
help="Log train/pre_clip_grad_norm during training.",
)

# data processing type
optimization_group = parser.add_argument_group("optimization")
Expand All @@ -181,7 +186,7 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
"--attention-backend",
type=str,
default="flex_attention",
help="The attention backend for the draft model",
help="The attention backend for the draft model (e.g. flex_attention, fa, usp)",
)

# other args
Expand Down Expand Up @@ -693,6 +698,9 @@ def record_metrcs(

if mode == "train" and optimizer is not None:
logdict["train/lr"] = optimizer.get_learning_rate()
grad_norm = optimizer.get_last_grad_norm()
if grad_norm is not None:
logdict["train/pre_clip_grad_norm"] = grad_norm

accuracies = torch.stack(accuracies)
plosses = torch.stack(plosses)
Expand Down Expand Up @@ -826,6 +834,7 @@ def main():
max_grad_norm=args.max_grad_norm,
warmup_ratio=args.warmup_ratio,
total_steps=args.total_steps,
log_grad_norm=args.log_grad_norm,
)
print_with_rank("Initialized optimizer and scheduler")

Expand Down
Loading
Loading