|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import os |
| 16 | +import time |
16 | 17 | from abc import ABC, abstractmethod |
17 | 18 | from typing import Dict, List, Optional, Tuple |
18 | 19 |
|
@@ -44,6 +45,71 @@ def __init__(self, draft_model: nn.Module, length: int, **kwargs): |
44 | 45 | """ |
45 | 46 | super().__init__(model=draft_model, **kwargs) |
46 | 47 | self.length = length |
| 48 | + self._train_start_time = None |
| 49 | + self._pending_log: dict = ( |
| 50 | + {} |
| 51 | + ) # cache acc/ploss log for merging with base Trainer's loss log |
| 52 | + self._pending_log_count: int = 0 # accumulated batch count for averaging the cached log |
| 53 | + |
| 54 | + def train(self, *args, **kwargs): |
| 55 | + """Override train method to record training start time for estimating remaining time.""" |
| 56 | + self._train_start_time = time.time() |
| 57 | + return super().train(*args, **kwargs) |
| 58 | + |
| 59 | + def log(self, logs: dict, start_time: Optional[float] = None) -> None: |
| 60 | + """ |
| 61 | + rewrite log method to merge acc/ploss log with base Trainer's loss log. |
| 62 | + """ |
| 63 | + if "loss" in logs and self._pending_log: |
| 64 | + # merge cached acc/ploss data (average) |
| 65 | + count = max(self._pending_log_count, 1) |
| 66 | + acc_ploss = {k: f"{round(v / count, 3):.4f}" for k, v in self._pending_log.items()} |
| 67 | + merged = {} |
| 68 | + |
| 69 | + # step |
| 70 | + if self.state is not None: |
| 71 | + global_step = self.state.global_step |
| 72 | + max_steps = self.state.max_steps |
| 73 | + merged["step"] = f"{global_step:>5}" |
| 74 | + |
| 75 | + # epoch |
| 76 | + if "epoch" in logs: |
| 77 | + merged["epoch"] = f"{logs['epoch']:.4f}" |
| 78 | + |
| 79 | + # loss |
| 80 | + if "loss" in logs: |
| 81 | + merged["loss"] = f"{logs['loss']:.6f}" |
| 82 | + |
| 83 | + # grad_norm (6 decimal places) |
| 84 | + if "grad_norm" in logs: |
| 85 | + merged["grad_norm"] = f"{logs['grad_norm']:.6f}" |
| 86 | + |
| 87 | + # learning_rate (scientific notation, 6 decimal places) |
| 88 | + if "learning_rate" in logs: |
| 89 | + merged["lr"] = f"{logs['learning_rate']:.6e}" |
| 90 | + |
| 91 | + # acc/ploss |
| 92 | + merged.update(acc_ploss) |
| 93 | + |
| 94 | + # remaining_time |
| 95 | + if ( |
| 96 | + self.state is not None |
| 97 | + and self._train_start_time is not None |
| 98 | + and global_step > 0 |
| 99 | + and max_steps > 0 |
| 100 | + ): |
| 101 | + elapsed = time.time() - self._train_start_time |
| 102 | + time_per_step = elapsed / global_step |
| 103 | + remaining_seconds = int(time_per_step * (max_steps - global_step)) |
| 104 | + hours, remainder = divmod(remaining_seconds, 3600) |
| 105 | + minutes, seconds = divmod(remainder, 60) |
| 106 | + merged["remaining_time"] = f"{hours:02d}h:{minutes:02d}m:{seconds:02d}s" |
| 107 | + |
| 108 | + self._pending_log.clear() |
| 109 | + self._pending_log_count = 0 |
| 110 | + super().log(merged, start_time) |
| 111 | + else: |
| 112 | + super().log(logs, start_time) |
47 | 113 |
|
48 | 114 | @property |
49 | 115 | def draft_model(self) -> nn.Module: |
@@ -214,15 +280,12 @@ def draft_model_training_time_test( |
214 | 280 | ploss_weight = [0.8**i for i in range(len(plosses))] |
215 | 281 | ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) |
216 | 282 |
|
217 | | - log = {f"{log_prefix}/acc_{i}": round(float(acces[i]), 3) for i in range(len(acces))} |
218 | | - log.update( |
219 | | - { |
220 | | - f"{log_prefix}/ploss_{i}": round(float(plosses[i].item()), 3) |
221 | | - for i in range(len(plosses)) |
222 | | - } |
223 | | - ) |
224 | | - self.log(log) |
225 | | - |
| 283 | + log = {f"{log_prefix}/acc_{i}": acces[i] for i in range(len(acces))} |
| 284 | + log.update({f"{log_prefix}/ploss_{i}": plosses[i].item() for i in range(len(plosses))}) |
| 285 | + # Cache log for merging with base Trainer's loss log |
| 286 | + for k, v in log.items(): |
| 287 | + self._pending_log[k] = self._pending_log.get(k, 0.0) + v |
| 288 | + self._pending_log_count += 1 |
226 | 289 | # Step 9: Return loss |
227 | 290 | return ploss |
228 | 291 |
|
|
0 commit comments