Skip to content

Commit 13b4340

Browse files
committed
format
1 parent 174ff39 commit 13b4340

1 file changed

Lines changed: 9 additions & 11 deletions

File tree

scripts/train_eagle3.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,16 @@
4646
)
4747
from specforge.optimizer import BF16Optimizer
4848
from specforge.tracker import Tracker, create_tracker, get_tracker_class
49+
from specforge.utils import create_draft_config_from_target, get_last_checkpoint
50+
from specforge.utils import maybe_fetch_remote_config as _maybe_fetch_remote_config
4951
from specforge.utils import (
50-
create_draft_config_from_target,
51-
get_last_checkpoint,
52-
maybe_fetch_remote_config as _maybe_fetch_remote_config,
5352
print_args_with_dots,
5453
print_on_rank0,
5554
print_with_rank,
5655
rank_0_priority,
57-
resolve_local_model_path as _resolve_local_model_path,
58-
safe_conversations_generator,
5956
)
57+
from specforge.utils import resolve_local_model_path as _resolve_local_model_path
58+
from specforge.utils import safe_conversations_generator
6059

6160

6261
def print_cuda_memory_debug(label: str) -> None:
@@ -1165,7 +1164,9 @@ def update_training_metric_accumulators(
11651164
nonlocal metric_loss_weighted_sums, metric_loss_denom_sums
11661165
with torch.no_grad():
11671166
if metric_correct_sums is None:
1168-
metric_correct_sums = [correct.detach().clone() for correct in acc_corrects]
1167+
metric_correct_sums = [
1168+
correct.detach().clone() for correct in acc_corrects
1169+
]
11691170
metric_denom_sums = [denom.detach().clone() for denom in acc_denoms]
11701171
metric_loss_weighted_sums = [
11711172
loss.detach() * denom.detach()
@@ -1315,8 +1316,7 @@ def maybe_evaluate(epoch: int) -> None:
13151316
eval_acces[i] + [eval_acc[i]] for i in range(len(eval_acc))
13161317
]
13171318
eval_plosses = [
1318-
eval_plosses[i] + [eval_ploss[i]]
1319-
for i in range(len(eval_ploss))
1319+
eval_plosses[i] + [eval_ploss[i]] for i in range(len(eval_ploss))
13201320
]
13211321

13221322
eval_acces = [torch.stack(acc).mean() for acc in eval_acces]
@@ -1467,9 +1467,7 @@ def fill_prefetch_queue(pending_current: int) -> None:
14671467
):
14681468
global_step += 1
14691469
maybe_profile_step()
1470-
if finish_eagle3_training_step(
1471-
epoch, progress_bar, *outputs
1472-
):
1470+
if finish_eagle3_training_step(epoch, progress_bar, *outputs):
14731471
break
14741472
else:
14751473
if train_one_eagle3_batch(epoch, progress_bar, data):

0 commit comments

Comments
 (0)