|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import os |
| 16 | +import time |
16 | 17 | from abc import ABC, abstractmethod |
17 | 18 | from typing import Dict, List, Optional, Tuple |
18 | 19 |
|
@@ -44,6 +45,67 @@ def __init__(self, draft_model: nn.Module, length: int, **kwargs): |
44 | 45 | """ |
45 | 46 | super().__init__(model=draft_model, **kwargs) |
46 | 47 | self.length = length |
| 48 | + self._train_start_time = None |
| 49 | + self._pending_log: dict = ( |
| 50 | + {} |
| 51 | + ) # cache acc/ploss log for merging with base Trainer's loss log |
| 52 | + self._pending_log_count: int = 0 # accumulated batch count for averaging the cached log |
| 53 | + |
| 54 | + def train(self, *args, **kwargs): |
| 55 | + """Override train method to record training start time for estimating remaining time.""" |
| 56 | + self._train_start_time = time.time() |
| 57 | + return super().train(*args, **kwargs) |
| 58 | + |
| 59 | + def log(self, logs: dict, start_time: Optional[float] = None) -> None: |
| 60 | + """ |
| 61 | + rewrite log method to merge acc/ploss log with base Trainer's loss log. |
| 62 | + """ |
| 63 | + if "loss" in logs and self._pending_log: |
| 64 | + # merge cached acc/ploss data (average) |
| 65 | + count = max(self._pending_log_count, 1) |
| 66 | + acc_ploss = {k: v / count for k, v in self._pending_log.items()} |
| 67 | + merged = {} |
| 68 | + |
| 69 | + # step |
| 70 | + max_steps = 0 |
| 71 | + if self.state is not None: |
| 72 | + global_step = self.state.global_step |
| 73 | + max_steps = self.state.max_steps |
| 74 | + merged["step"] = global_step |
| 75 | + |
| 76 | + # epoch |
| 77 | + if "epoch" in logs: |
| 78 | + merged["epoch"] = logs["epoch"] |
| 79 | + if "loss" in logs: |
| 80 | + merged["loss"] = logs["loss"] |
| 81 | + if "grad_norm" in logs: |
| 82 | + merged["grad_norm"] = logs["grad_norm"] |
| 83 | + |
| 84 | + if "learning_rate" in logs: |
| 85 | + merged["lr"] = logs["learning_rate"] |
| 86 | + |
| 87 | + # acc/ploss |
| 88 | + merged.update(acc_ploss) |
| 89 | + |
| 90 | + # remaining_time |
| 91 | + if ( |
| 92 | + self.state is not None |
| 93 | + and self._train_start_time is not None |
| 94 | + and global_step > 0 |
| 95 | + and max_steps > 0 |
| 96 | + ): |
| 97 | + elapsed = time.time() - self._train_start_time |
| 98 | + time_per_step = elapsed / global_step |
| 99 | + remaining_seconds = int(time_per_step * (max_steps - global_step)) |
| 100 | + hours, remainder = divmod(remaining_seconds, 3600) |
| 101 | + minutes, seconds = divmod(remainder, 60) |
| 102 | + merged["remaining_time"] = f"{hours:02d}h:{minutes:02d}m:{seconds:02d}s" |
| 103 | + |
| 104 | + self._pending_log.clear() |
| 105 | + self._pending_log_count = 0 |
| 106 | + super().log(merged, start_time) |
| 107 | + else: |
| 108 | + super().log(logs, start_time) |
47 | 109 |
|
48 | 110 | @property |
49 | 111 | def draft_model(self) -> nn.Module: |
@@ -131,7 +193,11 @@ def prepare_attention_mask_and_position_ids( |
131 | 193 | position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device) |
132 | 194 | position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
133 | 195 | else: |
134 | | - position_ids = position_ids.view(-1, seq_length).long() |
| 196 | + if position_ids.ndim == 3: |
| 197 | + # MRoPE format: (3, batch, seq_len), keep as-is |
| 198 | + position_ids = position_ids.long() |
| 199 | + else: |
| 200 | + position_ids = position_ids.view(-1, seq_length).long() |
135 | 201 |
|
136 | 202 | if attention_mask is None: |
137 | 203 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.bool, device=device) |
@@ -210,15 +276,12 @@ def draft_model_training_time_test( |
210 | 276 | ploss_weight = [0.8**i for i in range(len(plosses))] |
211 | 277 | ploss = sum([ploss_weight[i] * plosses[i] for i in range(len(plosses))]) |
212 | 278 |
|
213 | | - log = {f"{log_prefix}/acc_{i}": round(float(acces[i]), 3) for i in range(len(acces))} |
214 | | - log.update( |
215 | | - { |
216 | | - f"{log_prefix}/ploss_{i}": round(float(plosses[i].item()), 3) |
217 | | - for i in range(len(plosses)) |
218 | | - } |
219 | | - ) |
220 | | - self.log(log) |
221 | | - |
| 279 | + log = {f"{log_prefix}/acc_{i}": acces[i] for i in range(len(acces))} |
| 280 | + log.update({f"{log_prefix}/ploss_{i}": plosses[i].item() for i in range(len(plosses))}) |
| 281 | + # Cache log for merging with base Trainer's loss log |
| 282 | + for k, v in log.items(): |
| 283 | + self._pending_log[k] = self._pending_log.get(k, 0.0) + v |
| 284 | + self._pending_log_count += 1 |
222 | 285 | # Step 9: Return loss |
223 | 286 | return ploss |
224 | 287 |
|
|
0 commit comments