@@ -139,9 +139,11 @@ def __init__(
139139 else 1
140140 )
141141 self .num_model = len (self .model_keys )
142+ self .model_prob = None
142143
143144 # Iteration config
144- self .num_steps = training_params ["numb_steps" ]
145+ self .num_steps = training_params .get ("numb_steps" )
146+ self .num_epoch = training_params .get ("num_epoch" )
145147 self .disp_file = training_params .get ("disp_file" , "lcurve.out" )
146148 self .disp_freq = training_params .get ("disp_freq" , 1000 )
147149 self .disp_avg = training_params .get ("disp_avg" , False )
@@ -247,6 +249,47 @@ def get_dataloader_and_iter(
247249 valid_numb_batch ,
248250 )
249251
252+ def compute_total_numb_batch (
253+ numb_batches : Iterable [int ],
254+ sampler_weights : np .ndarray ,
255+ ) -> int :
256+ weights = np .asarray (sampler_weights , dtype = np .float64 )
257+ if weights .ndim != 1 :
258+ raise ValueError ("Sampler weights must be 1D." )
259+ if weights .size == 0 :
260+ raise ValueError ("Sampler weights are empty." )
261+ weight_sum = float (np .sum (weights ))
262+ if weight_sum <= 0.0 :
263+ raise ValueError ("Sampler weights must sum to a positive value." )
264+ probs = weights / weight_sum
265+ nbatches = np .asarray (numb_batches , dtype = np .float64 )
266+ if nbatches .shape [0 ] != probs .shape [0 ]:
267+ raise ValueError ("Number of batches and sampler weights must match." )
268+ valid = probs > 0.0
269+ if not np .any (valid ):
270+ raise ValueError (
271+ "Sampler probabilities must contain at least one positive entry."
272+ )
273+ return int (np .ceil (np .max (nbatches [valid ] / probs [valid ])))
274+
275+ def resolve_model_prob (
276+ model_keys : list [str ],
277+ model_prob_config : dict [str , Any ] | None ,
278+ model_training_data : dict [str , DpLoaderSet ],
279+ ) -> np .ndarray :
280+ model_prob = np .zeros (len (model_keys ), dtype = np .float64 )
281+ if model_prob_config is not None :
282+ for ii , model_key in enumerate (model_keys ):
283+ if model_key in model_prob_config :
284+ model_prob [ii ] = float (model_prob_config [model_key ])
285+ else :
286+ for ii , model_key in enumerate (model_keys ):
287+ model_prob [ii ] = float (len (model_training_data [model_key ]))
288+ sum_prob = float (np .sum (model_prob ))
289+ if sum_prob <= 0.0 :
290+ raise ValueError ("Sum of model prob must be larger than 0!" )
291+ return model_prob / sum_prob
292+
250293 def single_model_stat (
251294 _model : Any ,
252295 _data_stat_nbatch : int ,
@@ -430,6 +473,56 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
430473 ),
431474 )
432475
476+ # Resolve training steps
477+ if not self .multi_task :
478+ sampler_weights = to_numpy_array (self .training_dataloader .sampler .weights )
479+ total_numb_batch = compute_total_numb_batch (
480+ training_data .index ,
481+ sampler_weights ,
482+ )
483+ else :
484+ per_task_total = []
485+ for model_key in self .model_keys :
486+ sampler_weights = to_numpy_array (
487+ self .training_dataloader [model_key ].sampler .weights
488+ )
489+ per_task_total .append (
490+ compute_total_numb_batch (
491+ training_data [model_key ].index ,
492+ sampler_weights ,
493+ )
494+ )
495+ self .model_prob = resolve_model_prob (
496+ self .model_keys ,
497+ training_params .get ("model_prob" ),
498+ training_data ,
499+ )
500+ total_numb_batch = int (
501+ np .ceil (np .sum (np .asarray (per_task_total ) * self .model_prob ))
502+ )
503+ if self .num_steps is None :
504+ if self .num_epoch is None :
505+ raise ValueError (
506+ "Either training.numb_steps or training.num_epoch must be set."
507+ )
508+ if self .num_epoch <= 0 :
509+ raise ValueError ("training.num_epoch must be positive." )
510+ if total_numb_batch <= 0 :
511+ raise ValueError ("Total number of training batches must be positive." )
512+ self .num_steps = int (np .ceil (self .num_epoch * total_numb_batch ))
513+ log .info (
514+ "Computed num_steps=%d from num_epoch=%s and total_numb_batch=%d." ,
515+ self .num_steps ,
516+ self .num_epoch ,
517+ total_numb_batch ,
518+ )
519+ elif self .num_epoch is not None :
520+ log .warning (
521+ "Both training.numb_steps and training.num_epoch are set; "
522+ "using numb_steps=%d." ,
523+ self .num_steps ,
524+ )
525+
433526 # Learning rate
434527 self .warmup_steps = training_params .get ("warmup_steps" , 0 )
435528 self .gradient_max_norm = training_params .get ("gradient_max_norm" , 0.0 )
@@ -637,19 +730,12 @@ def single_model_finetune(
637730 )
638731
639732 # Get model prob for multi-task
640- if self .multi_task :
641- self .model_prob = np .array ([0.0 for key in self .model_keys ])
642- if training_params .get ("model_prob" , None ) is not None :
643- model_prob = training_params ["model_prob" ]
644- for ii , model_key in enumerate (self .model_keys ):
645- if model_key in model_prob :
646- self .model_prob [ii ] += float (model_prob [model_key ])
647- else :
648- for ii , model_key in enumerate (self .model_keys ):
649- self .model_prob [ii ] += float (len (self .training_data [model_key ]))
650- sum_prob = np .sum (self .model_prob )
651- assert sum_prob > 0.0 , "Sum of model prob must be larger than 0!"
652- self .model_prob = self .model_prob / sum_prob
733+ if self .multi_task and self .model_prob is None :
734+ self .model_prob = resolve_model_prob (
735+ self .model_keys ,
736+ training_params .get ("model_prob" ),
737+ training_data ,
738+ )
653739
654740 # Multi-task share params
655741 if shared_links is not None :
0 commit comments