|
46 | 46 | ) |
47 | 47 | from specforge.optimizer import BF16Optimizer |
48 | 48 | from specforge.tracker import Tracker, create_tracker, get_tracker_class |
| 49 | +from specforge.utils import create_draft_config_from_target, get_last_checkpoint |
| 50 | +from specforge.utils import maybe_fetch_remote_config as _maybe_fetch_remote_config |
49 | 51 | from specforge.utils import ( |
50 | | - create_draft_config_from_target, |
51 | | - get_last_checkpoint, |
52 | | - maybe_fetch_remote_config as _maybe_fetch_remote_config, |
53 | 52 | print_args_with_dots, |
54 | 53 | print_on_rank0, |
55 | 54 | print_with_rank, |
56 | 55 | rank_0_priority, |
57 | | - resolve_local_model_path as _resolve_local_model_path, |
58 | | - safe_conversations_generator, |
59 | 56 | ) |
| 57 | +from specforge.utils import resolve_local_model_path as _resolve_local_model_path |
| 58 | +from specforge.utils import safe_conversations_generator |
60 | 59 |
|
61 | 60 |
|
62 | 61 | def print_cuda_memory_debug(label: str) -> None: |
@@ -1165,7 +1164,9 @@ def update_training_metric_accumulators( |
1165 | 1164 | nonlocal metric_loss_weighted_sums, metric_loss_denom_sums |
1166 | 1165 | with torch.no_grad(): |
1167 | 1166 | if metric_correct_sums is None: |
1168 | | - metric_correct_sums = [correct.detach().clone() for correct in acc_corrects] |
| 1167 | + metric_correct_sums = [ |
| 1168 | + correct.detach().clone() for correct in acc_corrects |
| 1169 | + ] |
1169 | 1170 | metric_denom_sums = [denom.detach().clone() for denom in acc_denoms] |
1170 | 1171 | metric_loss_weighted_sums = [ |
1171 | 1172 | loss.detach() * denom.detach() |
@@ -1315,8 +1316,7 @@ def maybe_evaluate(epoch: int) -> None: |
1315 | 1316 | eval_acces[i] + [eval_acc[i]] for i in range(len(eval_acc)) |
1316 | 1317 | ] |
1317 | 1318 | eval_plosses = [ |
1318 | | - eval_plosses[i] + [eval_ploss[i]] |
1319 | | - for i in range(len(eval_ploss)) |
| 1319 | + eval_plosses[i] + [eval_ploss[i]] for i in range(len(eval_ploss)) |
1320 | 1320 | ] |
1321 | 1321 |
|
1322 | 1322 | eval_acces = [torch.stack(acc).mean() for acc in eval_acces] |
@@ -1467,9 +1467,7 @@ def fill_prefetch_queue(pending_current: int) -> None: |
1467 | 1467 | ): |
1468 | 1468 | global_step += 1 |
1469 | 1469 | maybe_profile_step() |
1470 | | - if finish_eagle3_training_step( |
1471 | | - epoch, progress_bar, *outputs |
1472 | | - ): |
| 1470 | + if finish_eagle3_training_step(epoch, progress_bar, *outputs): |
1473 | 1471 | break |
1474 | 1472 | else: |
1475 | 1473 | if train_one_eagle3_batch(epoch, progress_bar, data): |
|
0 commit comments