@@ -223,6 +223,14 @@ def compute_total_numb_batch(numb_batches, sampler_weights) -> int:
223223 raise ValueError ("Sampler weights must sum to a positive value." )
224224 probs = weights / weight_sum
225225 nbatches = np .asarray (numb_batches , dtype = np .float64 )
226+ if nbatches .ndim != 1 :
227+ raise ValueError ("Number of batches must be 1D." )
228+ if nbatches .size == 0 :
229+ raise ValueError ("Number of batches is empty." )
230+ if not np .all (np .isfinite (nbatches )):
231+ raise ValueError ("Number of batches must be finite." )
232+ if np .any (nbatches < 0.0 ):
233+ raise ValueError ("Number of batches must be non-negative." )
226234 if nbatches .shape [0 ] != probs .shape [0 ]:
227235 raise ValueError ("Number of batches and sampler weights must match." )
228236 valid = probs > 0.0
@@ -239,6 +247,11 @@ def resolve_model_prob(
239247 ) -> np .ndarray :
240248 model_prob = np .zeros (len (model_keys ), dtype = np .float64 )
241249 if model_prob_config :
250+ missing = [k for k in model_keys if k not in model_prob_config ]
251+ if missing :
252+ raise ValueError (
253+ f"training.model_prob must specify all tasks; missing: { missing } "
254+ )
242255 for ii , model_key in enumerate (model_keys ):
243256 if model_key in model_prob_config :
244257 model_prob [ii ] = float (model_prob_config [model_key ])
@@ -434,6 +447,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
434447 ),
435448 )
436449
450+ per_task_total = []
437451 if not self .multi_task :
438452 sampler_weights = to_numpy_array (
439453 self .training_dataloader .batch_sampler .sampler .weights
@@ -443,7 +457,6 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
443457 sampler_weights ,
444458 )
445459 else :
446- per_task_total = []
447460 for model_key in self .model_keys :
448461 sampler_weights = to_numpy_array (
449462 self .training_dataloader [model_key ].batch_sampler .sampler .weights
@@ -478,22 +491,27 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
478491 f"training.num_epoch_dict['{ model_key } '] must be positive, got { epoch_value } ."
479492 )
480493 # Compute steps needed for each task to complete its epochs
481- per_task_steps = []
494+ per_task_steps : dict [ str , float ] = {}
482495 for ii , model_key in enumerate (self .model_keys ):
483496 epoch_value = self .num_epoch_dict [model_key ]
484497 if epoch_value is not None :
498+ if self .model_prob [ii ] <= 0.0 :
499+ raise ValueError (
500+ f"training.model_prob['{ model_key } '] must be positive when num_epoch_dict targets it."
501+ )
485502 # steps_i = epoch_i * per_task_total[i] / model_prob[i]
486503 steps_i = epoch_value * per_task_total [ii ] / self .model_prob [ii ]
487- per_task_steps .append (steps_i )
488- self .num_steps = int (np .ceil (np .max (per_task_steps )))
504+ per_task_steps [model_key ] = float (steps_i )
505+ if not per_task_steps :
506+ raise ValueError (
507+ "training.num_epoch_dict must have at least one non-null epoch target."
508+ )
509+ self .num_steps = int (np .ceil (np .max (list (per_task_steps .values ()))))
489510 log .info (
490511 "Computed num_steps=%d from num_epoch_dict=%s with per-task steps: %s." ,
491512 self .num_steps ,
492513 self .num_epoch_dict ,
493- {
494- k : int (np .ceil (v ))
495- for k , v in zip (self .model_keys , per_task_steps )
496- },
514+ {k : int (np .ceil (v )) for k , v in per_task_steps .items ()},
497515 )
498516 # === Step 2. Fall back to num_epoch ===
499517 elif self .num_epoch is None :
0 commit comments