@@ -317,6 +317,55 @@ def resolve_model_prob(
317317 raise ValueError ("Sum of model prob must be larger than 0!" )
318318 return model_prob / sum_prob
319319
320+ def resolve_model_prob_from_epochs (
321+ model_keys : list [str ],
322+ num_epoch_dict_config : dict [str , Any ],
323+ per_task_total : np .ndarray ,
324+ ) -> tuple [np .ndarray , int , dict [str , float ]]:
325+ if not num_epoch_dict_config :
326+ raise ValueError (
327+ "training.num_epoch_dict must be set for multi-task epochs."
328+ )
329+ missing = [k for k in model_keys if k not in num_epoch_dict_config ]
330+ if missing :
331+ raise ValueError (
332+ "training.num_epoch_dict must specify all tasks; "
333+ f"missing: { missing } "
334+ )
335+ epoch_targets = np .zeros (len (model_keys ), dtype = np .float64 )
336+ for ii , model_key in enumerate (model_keys ):
337+ epoch_value = num_epoch_dict_config [model_key ]
338+ if epoch_value is None :
339+ raise ValueError (
340+ f"training.num_epoch_dict['{ model_key } '] must be positive."
341+ )
342+ epoch_value = float (epoch_value )
343+ if not np .isfinite (epoch_value ) or epoch_value <= 0.0 :
344+ raise ValueError (
345+ f"training.num_epoch_dict['{ model_key } '] must be positive, got { epoch_value } ."
346+ )
347+ epoch_targets [ii ] = epoch_value
348+ per_task_total = np .asarray (per_task_total , dtype = np .float64 )
349+ if per_task_total .ndim != 1 :
350+ raise ValueError ("Per-task total batches must be 1D." )
351+ if per_task_total .shape [0 ] != epoch_targets .shape [0 ]:
352+ raise ValueError ("Per-task totals and epoch targets must match." )
353+ if not np .all (np .isfinite (per_task_total )):
354+ raise ValueError ("Per-task total batches must be finite." )
355+ if np .any (per_task_total <= 0.0 ):
356+ raise ValueError ("Per-task total batches must be positive." )
357+ per_task_steps = per_task_total * epoch_targets
358+ total_target_steps = float (np .sum (per_task_steps ))
359+ if total_target_steps <= 0.0 :
360+ raise ValueError ("Sum of target steps must be positive." )
361+ model_prob = per_task_steps / total_target_steps
362+ num_steps = int (np .ceil (total_target_steps ))
363+ per_task_steps_map = {
364+ model_key : float (per_task_steps [ii ])
365+ for ii , model_key in enumerate (model_keys )
366+ }
367+ return model_prob , num_steps , per_task_steps_map
368+
320369 def single_model_stat (
321370 _model : Any ,
322371 _data_stat_nbatch : int ,
@@ -508,6 +557,24 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
508557 training_data .index ,
509558 sampler_weights ,
510559 )
560+ if self .num_steps is None :
561+ if self .num_epoch is None :
562+ raise ValueError (
563+ "Either training.numb_steps or training.num_epoch must be set."
564+ )
565+ if self .num_epoch <= 0 :
566+ raise ValueError ("training.num_epoch must be positive." )
567+ if total_numb_batch <= 0 :
568+ raise ValueError (
569+ "Total number of training batches must be positive."
570+ )
571+ self .num_steps = int (np .ceil (self .num_epoch * total_numb_batch ))
572+ log .info (
573+ "Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d." ,
574+ self .num_steps ,
575+ self .num_epoch ,
576+ total_numb_batch ,
577+ )
511578 else :
512579 for model_key in self .model_keys :
513580 sampler_weights = to_numpy_array (
@@ -519,80 +586,35 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
519586 sampler_weights ,
520587 )
521588 )
522- self .model_prob = resolve_model_prob (
523- self .model_keys ,
524- training_params .get ("model_prob" ),
525- training_data ,
526- )
527- total_numb_batch = int (
528- np .ceil (np .sum (np .asarray (per_task_total ) * self .model_prob ))
529- )
530- if self .num_steps is None :
531- # === Step 1. Check num_epoch_dict first (multi-task only) ===
532- if self .multi_task and self .num_epoch_dict :
533- missing = [k for k in self .model_keys if k not in self .num_epoch_dict ]
534- if missing :
535- raise ValueError (
536- f"training.num_epoch_dict must specify all tasks; missing: { missing } "
537- )
538- # Validate epoch values
539- for model_key in self .model_keys :
540- epoch_value = self .num_epoch_dict [model_key ]
541- if epoch_value is not None and epoch_value <= 0 :
542- raise ValueError (
543- f"training.num_epoch_dict['{ model_key } '] must be positive, got { epoch_value } ."
544- )
545- # Compute steps needed for each task to complete its epochs
546- per_task_steps : dict [str , float ] = {}
547- for ii , model_key in enumerate (self .model_keys ):
548- epoch_value = self .num_epoch_dict [model_key ]
549- if epoch_value is not None :
550- if self .model_prob [ii ] <= 0.0 :
551- raise ValueError (
552- f"training.model_prob['{ model_key } '] must be positive when num_epoch_dict targets it."
553- )
554- # steps_i = epoch_i * per_task_total[i] / model_prob[i]
555- steps_i = epoch_value * per_task_total [ii ] / self .model_prob [ii ]
556- per_task_steps [model_key ] = float (steps_i )
557- if not per_task_steps :
558- raise ValueError (
559- "training.num_epoch_dict must have at least one non-null epoch target."
560- )
561- self .num_steps = int (np .ceil (np .max (list (per_task_steps .values ()))))
589+ if self .num_epoch_dict :
590+ (
591+ self .model_prob ,
592+ self .num_steps ,
593+ per_task_steps ,
594+ ) = resolve_model_prob_from_epochs (
595+ self .model_keys ,
596+ self .num_epoch_dict ,
597+ np .asarray (per_task_total , dtype = np .float64 ),
598+ )
562599 log .info (
563- "Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s." ,
600+ "Computed model_prob=%s and num_steps=%d from num_epoch_dict=%s "
601+ "with per-task target steps: %s." ,
602+ self .model_prob ,
564603 self .num_steps ,
565604 self .num_epoch_dict ,
566605 {k : int (np .ceil (v )) for k , v in per_task_steps .items ()},
567606 )
568- # === Step 2. Fall back to num_epoch ===
569- elif self .num_epoch is None :
570- raise ValueError (
571- "Either training.numb_steps, training.num_epoch, or "
572- "training.num_epoch_dict (multi-task only) must be set."
573- )
574607 else :
575- if self .num_epoch <= 0 :
576- raise ValueError ("training.num_epoch must be positive." )
577- if total_numb_batch <= 0 :
608+ if self .num_steps is None :
578609 raise ValueError (
579- "Total number of training batches must be positive."
610+ "Either training.numb_steps (multi-task only) or "
611+ "training.num_epoch_dict must be set."
580612 )
581- self .num_steps = int (np .ceil (self .num_epoch * total_numb_batch ))
582- log .info (
583- "Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d." ,
584- self .num_steps ,
585- self .num_epoch ,
586- total_numb_batch ,
613+ self .model_prob = resolve_model_prob (
614+ self .model_keys ,
615+ training_params .get ("model_prob" ),
616+ training_data ,
587617 )
588- elif self .num_epoch is not None or (
589- self .multi_task and self .num_epoch_dict is not None
590- ):
591- log .warning (
592- "Both training.numb_steps and training.num_epoch (or num_epoch_dict) are set; "
593- "using numb_steps=%d." ,
594- self .num_steps ,
595- )
596618
597619 # Learning rate
598620 warmup_steps = training_params .get ("warmup_steps" , None )
0 commit comments