@@ -102,6 +102,7 @@ def __init__(
102102 self ._inference_pipeline = inference_pipeline
103103
104104 self ._global_step = 0
105+ self ._data_epoch = 0
105106 self ._wandb_run = None
106107
107108 set_seed (config .seed )
@@ -671,30 +672,6 @@ def _training_step(self, batch: dict[str, Tensor]) -> Tensor:
671672
672673 return total_loss
673674
674- def _compute_distillation_loss (
675- self , student_pred : Tensor , teacher_pred : Tensor , loss_mask : Tensor
676- ) -> Tensor :
677- loss_type = self ._config .distillation .distillation_loss_type
678-
679- if loss_type == "mse" :
680- loss = torch .nn .functional .mse_loss (student_pred , teacher_pred , reduction = "none" )
681- elif loss_type == "cosine" :
682- s_flat = student_pred .flatten (start_dim = 2 )
683- t_flat = teacher_pred .flatten (start_dim = 2 )
684- cos_sim = torch .nn .functional .cosine_similarity (s_flat , t_flat , dim = - 1 )
685- loss = 1.0 - cos_sim # [B, T]
686- else :
687- raise ValueError (f"Unknown distillation loss type: { loss_type } " )
688-
689- if loss_mask is not None and loss_mask .numel () > 0 :
690- # Expand mask to match loss dimensions
691- while loss_mask .dim () < loss .dim ():
692- loss_mask = loss_mask .unsqueeze (- 1 )
693- mask = loss_mask .float ()
694- loss = loss .mul (mask ).div (mask .mean ())
695-
696- return loss .mean ()
697-
698675 def _compute_layer_distillation_loss (self ) -> Tensor :
699676 """Compute distillation loss across hooked intermediate layers."""
700677 assert self ._student_extractor is not None
@@ -1078,10 +1055,12 @@ def train(self) -> dict:
10781055 f"batch_size={ cfg .optimization .batch_size } "
10791056 )
10801057
1058+ start_micro = self ._global_step * grad_accum
1059+ total_micro = total_steps * grad_accum
10811060 pbar = tqdm (
1082- range (self . _global_step , total_steps * grad_accum ),
1083- initial = self . _global_step * grad_accum ,
1084- total = total_steps * grad_accum ,
1061+ range (start_micro , total_micro ),
1062+ initial = start_micro ,
1063+ total = total_micro ,
10851064 desc = "Training" ,
10861065 disable = not _is_global_rank0 (),
10871066 )
@@ -1091,6 +1070,10 @@ def train(self) -> dict:
10911070 try :
10921071 batch = next (data_iter )
10931072 except StopIteration :
1073+ self ._data_epoch += 1
1074+ sampler = getattr (self ._dataloader , "sampler" , None )
1075+ if sampler is not None and hasattr (sampler , "set_epoch" ):
1076+ sampler .set_epoch (self ._data_epoch )
10941077 data_iter = iter (self ._dataloader )
10951078 batch = next (data_iter )
10961079
0 commit comments