@@ -293,6 +293,7 @@ def train_ensemble(
293293 model_params : dict [str , Any ],
294294 loss_dict : dict [str , Literal ["L1" , "L2" , "CSE" ]],
295295 patience : int = None ,
296+ verbose : bool = False ,
296297) -> None :
297298 """Convenience method to train multiple models in serial.
298299
@@ -313,6 +314,7 @@ def train_ensemble(
313314 to loss functions.
314315 patience (int, optional): Maximum number of epochs without improvement
315316 when early stopping. Defaults to None.
317+ verbose (bool, optional): Whether to show progress bars for each epoch.
316318 """
317319 train_generator = DataLoader (train_set , ** data_params )
318320 print (f"Training on { len (train_set ):,} samples" )
@@ -359,9 +361,7 @@ def train_ensemble(
359361
360362 if log :
361363 writer = SummaryWriter (
362- log_dir = (
363- f"runs/{ model_name } /{ model_name } -r{ j } _{ datetime .now ():%d-%m-%Y_%H-%M-%S} "
364- )
364+ f"runs/{ model_name } /{ model_name } -r{ j } _{ datetime .now ():%d-%m-%Y_%H-%M-%S} "
365365 )
366366 else :
367367 writer = None
@@ -375,7 +375,7 @@ def train_ensemble(
375375 optimizer = None ,
376376 normalizer_dict = normalizer_dict ,
377377 action = "val" ,
378- verbose = True ,
378+ verbose = verbose ,
379379 )
380380
381381 val_score = {}
@@ -727,7 +727,7 @@ def save_results_dict(
727727 for col , data in results_dict [target_name ].items ():
728728
729729 # NOTE we save pre_logits rather than logits due to fact
730- # that with the hetroskedastic setup we want to be able to
730+ # that with the heteroskedastic setup we want to be able to
731731 # sample from the Gaussian distributed pre_logits we parameterise.
732732 if "pre-logits" in col :
733733 for n_ens , y_pre_logit in enumerate (data ):
@@ -760,6 +760,8 @@ def save_results_dict(
760760
761761 file_name = model_name .replace ("/" , "_" )
762762
763+ os .makedirs ("results" , exist_ok = True )
764+
763765 csv_path = f"results/{ file_name } .csv"
764766 df .to_csv (csv_path , index = False )
765767 print (f"\n Saved model predictions to '{ csv_path } '" )
0 commit comments