33import hydra
44import logging
55from omegaconf import DictConfig
6+ import shutil
67from typing import Callable
78
89from .datasets .utils import is_dataset_available
@@ -18,6 +19,7 @@ def benchmark(
1819 ckpt_save_all : bool = False ,
1920 online_wandb : bool = False ,
2021 recomp_embs : bool = False ,
22+ retrain_model : bool = False ,
2123 ** kwargs ,
2224):
2325 """
@@ -38,6 +40,7 @@ def benchmark(
3840 ckpt_save_all (bool): Whether to save all checkpoints during training. Default is False which means that only the best is saved.
3941 online_wandb (bool): Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.
4042 recomp_embs (bool): Whether to recompute embeddings if already saved.
43+ retrain_model (bool): Whether to retrain model if already trained and saved ckpts.
4144 """
4245 from hydra import compose , initialize
4346 from omegaconf import OmegaConf
@@ -48,6 +51,7 @@ def benchmark(
4851 adaptation_type = "lora" if lora else "frozen"
4952 ckpt_saving = "save_ckpts_all_epochs" if ckpt_save_all else "save_best_ckpt_only"
5053 embedding_recomputing = "recomp_embs" if recomp_embs else "no_recomp_embs"
54+ model_retraining = "retrain_model" if retrain_model else "no_retrain_model"
5155 model_name = model if isinstance (model , str ) else None
5256
5357 if model_name and model_name .startswith ("custom:" ):
@@ -64,6 +68,7 @@ def benchmark(
6468 loading_mode ,
6569 wandb_mode ,
6670 embedding_recomputing ,
71+ model_retraining ,
6772 ** kwargs ,
6873 )
6974
@@ -154,15 +159,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
154159 mode = cfg .wandb .mode ,
155160 )
156161 wandb_base_folder = f"|{ task_type } | |{ adaptation_type } | |{ dataset_name } |"
157- ckpt_folder = os .path .join (
158- os .environ ["THUNDER_BASE_DATA_FOLDER" ],
159- "outputs" ,
160- "ckpts" ,
161- dataset_name ,
162- model_name ,
163- adaptation_type ,
164- )
165- os .makedirs (ckpt_folder , exist_ok = True )
166162
167163 # Folder to save results
168164 res_folder = os .path .join (
@@ -177,6 +173,16 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
177173 os .makedirs (res_folder , exist_ok = True )
178174
179175 if task_type in ["linear_probing" , "segmentation" ]:
176+ # Model checkpoints folder
177+ ckpt_folder = os .path .join (
178+ os .environ ["THUNDER_BASE_DATA_FOLDER" ],
179+ "outputs" ,
180+ "ckpts" ,
181+ dataset_name ,
182+ model_name ,
183+ adaptation_type ,
184+ )
185+
180186 # Criterion
181187 if task_type == "linear_probing" :
182188 criterion = torch .nn .CrossEntropyLoss ()
@@ -191,6 +197,28 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
191197
192198 # Probe training
193199 if not os .path .exists (os .path .join (ckpt_folder , "best_model.pth" )):
200+ logging .info (f"Not found already trained best model. Training a model." )
201+ train_model = True
202+
203+ # Deleting existing folder
204+ if os .path .exists (ckpt_folder ):
205+ shutil .rmtree (ckpt_folder )
206+ os .makedirs (ckpt_folder )
207+ else :
208+ model_train_info_str = f"Found already trained best model { os .path .join (ckpt_folder , 'best_model.pth' )} ."
209+
210+ if cfg .model_retraining .retrain_model :
211+ model_train_info_str += " Deleting saved weights and re-training a model as explictly requested."
212+ shutil .rmtree (ckpt_folder )
213+ os .makedirs (ckpt_folder )
214+ train_model = True
215+ else :
216+ model_train_info_str += " Not re-training a model."
217+ train_model = False
218+
219+ logging .info (model_train_info_str )
220+
221+ if train_model :
194222 best_ckpt_dict = train_probe (
195223 cfg ,
196224 data ,
@@ -209,10 +237,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
209237 model_cls ,
210238 )
211239 else :
212- logging .info (
213- f"Found already trained best model { os .path .join (ckpt_folder , 'best_model.pth' )} . "
214- f"Not re-training."
215- )
216240 best_ckpt_dict = torch .load (
217241 os .path .join (ckpt_folder , "best_model.pth" ), weights_only = True
218242 )
0 commit comments