Skip to content

Commit b1aa044

Browse files
committed
log optimize: merge acc/ploss log with base log && add remaining_time log
1 parent 43e29f8 commit b1aa044

1 file changed

Lines changed: 72 additions & 9 deletions

File tree

angelslim/compressor/speculative/train/trainer/eagle3_trainer.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import time
1617
from abc import ABC, abstractmethod
1718
from typing import Dict, List, Optional, Tuple
1819

@@ -44,6 +45,71 @@ def __init__(self, draft_model: nn.Module, length: int, **kwargs):
4445
"""
4546
super().__init__(model=draft_model, **kwargs)
4647
self.length = length
48+
self._train_start_time = None
49+
self._pending_log: dict = (
50+
{}
51+
) # cache acc/ploss log for merging with base Trainer's loss log
52+
self._pending_log_count: int = 0 # accumulated batch count for averaging the cached log
53+
54+
def train(self, *args, **kwargs):
55+
"""Override train method to record training start time for estimating remaining time."""
56+
self._train_start_time = time.time()
57+
return super().train(*args, **kwargs)
58+
59+
def log(self, logs: dict, start_time: Optional[float] = None) -> None:
60+
"""
61+
rewrite log method to merge acc/ploss log with base Trainer's loss log.
62+
"""
63+
if "loss" in logs and self._pending_log:
64+
# merge cached acc/ploss data (average)
65+
count = max(self._pending_log_count, 1)
66+
acc_ploss = {k: f"{round(v / count, 3):.4f}" for k, v in self._pending_log.items()}
67+
merged = {}
68+
69+
# step
70+
if self.state is not None:
71+
global_step = self.state.global_step
72+
max_steps = self.state.max_steps
73+
merged["step"] = f"{global_step:>5}"
74+
75+
# epoch
76+
if "epoch" in logs:
77+
merged["epoch"] = f"{logs['epoch']:.4f}"
78+
79+
# loss
80+
if "loss" in logs:
81+
merged["loss"] = f"{logs['loss']:.6f}"
82+
83+
# grad_norm (6 decimal places)
84+
if "grad_norm" in logs:
85+
merged["grad_norm"] = f"{logs['grad_norm']:.6f}"
86+
87+
# learning_rate (scientific notation, 6 decimal places)
88+
if "learning_rate" in logs:
89+
merged["lr"] = f"{logs['learning_rate']:.6e}"
90+
91+
# acc/ploss
92+
merged.update(acc_ploss)
93+
94+
# remaining_time
95+
if (
96+
self.state is not None
97+
and self._train_start_time is not None
98+
and global_step > 0
99+
and max_steps > 0
100+
):
101+
elapsed = time.time() - self._train_start_time
102+
time_per_step = elapsed / global_step
103+
remaining_seconds = int(time_per_step * (max_steps - global_step))
104+
hours, remainder = divmod(remaining_seconds, 3600)
105+
minutes, seconds = divmod(remainder, 60)
106+
merged["remaining_time"] = f"{hours:02d}h:{minutes:02d}m:{seconds:02d}s"
107+
108+
self._pending_log.clear()
109+
self._pending_log_count = 0
110+
super().log(merged, start_time)
111+
else:
112+
super().log(logs, start_time)
47113

48114
@property
49115
def draft_model(self) -> nn.Module:
@@ -214,15 +280,12 @@ def draft_model_training_time_test(
214280
ploss_weight = [0.8**i for i in range(len(plosses))]
215281
ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))])
216282

217-
log = {f"{log_prefix}/acc_{i}": round(float(acces[i]), 3) for i in range(len(acces))}
218-
log.update(
219-
{
220-
f"{log_prefix}/ploss_{i}": round(float(plosses[i].item()), 3)
221-
for i in range(len(plosses))
222-
}
223-
)
224-
self.log(log)
225-
283+
log = {f"{log_prefix}/acc_{i}": acces[i] for i in range(len(acces))}
284+
log.update({f"{log_prefix}/ploss_{i}": plosses[i].item() for i in range(len(plosses))})
285+
# Cache log for merging with base Trainer's loss log
286+
for k, v in log.items():
287+
self._pending_log[k] = self._pending_log.get(k, 0.0) + v
288+
self._pending_log_count += 1
226289
# Step 9: Return loss
227290
return ploss
228291

0 commit comments

Comments
 (0)