@@ -223,6 +223,14 @@ def compute_total_numb_batch(numb_batches, sampler_weights) -> int:
223223 raise ValueError ("Sampler weights must sum to a positive value." )
224224 probs = weights / weight_sum
225225 nbatches = np .asarray (numb_batches , dtype = np .float64 )
226+ if nbatches .ndim != 1 :
227+ raise ValueError ("Number of batches must be 1D." )
228+ if nbatches .size == 0 :
229+ raise ValueError ("Number of batches is empty." )
230+ if not np .all (np .isfinite (nbatches )):
231+ raise ValueError ("Number of batches must be finite." )
232+ if np .any (nbatches < 0.0 ):
233+ raise ValueError ("Number of batches must be non-negative." )
226234 if nbatches .shape [0 ] != probs .shape [0 ]:
227235 raise ValueError ("Number of batches and sampler weights must match." )
228236 valid = probs > 0.0
@@ -239,6 +247,11 @@ def resolve_model_prob(
239247 ) -> np .ndarray :
240248 model_prob = np .zeros (len (model_keys ), dtype = np .float64 )
241249 if model_prob_config :
250+ missing = [k for k in model_keys if k not in model_prob_config ]
251+ if missing :
252+ raise ValueError (
253+ f"training.model_prob must specify all tasks; missing: { missing } "
254+ )
242255 for ii , model_key in enumerate (model_keys ):
243256 if model_key in model_prob_config :
244257 model_prob [ii ] = float (model_prob_config [model_key ])
@@ -437,6 +450,7 @@ def get_lr(lr_params):
437450 ),
438451 )
439452
453+ per_task_total = []
440454 if not self .multi_task :
441455 sampler_weights = to_numpy_array (
442456 self .training_dataloader .batch_sampler .sampler .weights
@@ -446,7 +460,6 @@ def get_lr(lr_params):
446460 sampler_weights ,
447461 )
448462 else :
449- per_task_total = []
450463 for model_key in self .model_keys :
451464 sampler_weights = to_numpy_array (
452465 self .training_dataloader [model_key ].batch_sampler .sampler .weights
@@ -481,22 +494,27 @@ def get_lr(lr_params):
481494 f"training.num_epoch_dict['{ model_key } '] must be positive, got { epoch_value } ."
482495 )
483496 # Compute steps needed for each task to complete its epochs
484- per_task_steps = []
497+ per_task_steps : dict [ str , float ] = {}
485498 for ii , model_key in enumerate (self .model_keys ):
486499 epoch_value = self .num_epoch_dict [model_key ]
487500 if epoch_value is not None :
501+ if self .model_prob [ii ] <= 0.0 :
502+ raise ValueError (
503+ f"training.model_prob['{ model_key } '] must be positive when num_epoch_dict targets it."
504+ )
488505 # steps_i = epoch_i * per_task_total[i] / model_prob[i]
489506 steps_i = epoch_value * per_task_total [ii ] / self .model_prob [ii ]
490- per_task_steps .append (steps_i )
491- self .num_steps = int (np .ceil (np .max (per_task_steps )))
507+ per_task_steps [model_key ] = float (steps_i )
508+ if not per_task_steps :
509+ raise ValueError (
510+ "training.num_epoch_dict must have at least one non-null epoch target."
511+ )
512+ self .num_steps = int (np .ceil (np .max (list (per_task_steps .values ()))))
492513 log .info (
493514 "Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s." ,
494515 self .num_steps ,
495516 self .num_epoch_dict ,
496- {
497- k : int (np .ceil (v ))
498- for k , v in zip (self .model_keys , per_task_steps )
499- },
517+ {k : int (np .ceil (v )) for k , v in per_task_steps .items ()},
500518 )
501519 # === Step 2. Fall back to num_epoch ===
502520 elif self .num_epoch is None :
0 commit comments