Skip to content

Commit 54a3888

Browse files
committed
Added '--retrain-model' flag to retrain model if already trained and saved ckpts
1 parent b4f499a commit 54a3888

5 files changed

Lines changed: 48 additions & 13 deletions

File tree

src/thunder/benchmark.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import hydra
44
import logging
55
from omegaconf import DictConfig
6+
import shutil
67
from typing import Callable
78

89
from .datasets.utils import is_dataset_available
@@ -18,6 +19,7 @@ def benchmark(
1819
ckpt_save_all: bool = False,
1920
online_wandb: bool = False,
2021
recomp_embs: bool = False,
22+
retrain_model: bool = False,
2123
**kwargs,
2224
):
2325
"""
@@ -38,6 +40,7 @@ def benchmark(
3840
ckpt_save_all (bool): Whether to save all checkpoints during training. Default is False which means that only the best is saved.
3941
online_wandb (bool): Whether to use online mode for Weights & Biases (wandb) logging. Default is False which means offline mode.
4042
recomp_embs (bool): Whether to recompute embeddings if already saved.
43+
retrain_model (bool): Whether to retrain model if already trained and saved ckpts.
4144
"""
4245
from hydra import compose, initialize
4346
from omegaconf import OmegaConf
@@ -48,6 +51,7 @@ def benchmark(
4851
adaptation_type = "lora" if lora else "frozen"
4952
ckpt_saving = "save_ckpts_all_epochs" if ckpt_save_all else "save_best_ckpt_only"
5053
embedding_recomputing = "recomp_embs" if recomp_embs else "no_recomp_embs"
54+
model_retraining = "retrain_model" if retrain_model else "no_retrain_model"
5155
model_name = model if isinstance(model, str) else None
5256

5357
if model_name and model_name.startswith("custom:"):
@@ -64,6 +68,7 @@ def benchmark(
6468
loading_mode,
6569
wandb_mode,
6670
embedding_recomputing,
71+
model_retraining,
6772
**kwargs,
6873
)
6974

@@ -154,15 +159,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
154159
mode=cfg.wandb.mode,
155160
)
156161
wandb_base_folder = f"|{task_type}| |{adaptation_type}| |{dataset_name}|"
157-
ckpt_folder = os.path.join(
158-
os.environ["THUNDER_BASE_DATA_FOLDER"],
159-
"outputs",
160-
"ckpts",
161-
dataset_name,
162-
model_name,
163-
adaptation_type,
164-
)
165-
os.makedirs(ckpt_folder, exist_ok=True)
166162

167163
# Folder to save results
168164
res_folder = os.path.join(
@@ -177,6 +173,16 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
177173
os.makedirs(res_folder, exist_ok=True)
178174

179175
if task_type in ["linear_probing", "segmentation"]:
176+
# Model checkpoints folder
177+
ckpt_folder = os.path.join(
178+
os.environ["THUNDER_BASE_DATA_FOLDER"],
179+
"outputs",
180+
"ckpts",
181+
dataset_name,
182+
model_name,
183+
adaptation_type,
184+
)
185+
180186
# Criterion
181187
if task_type == "linear_probing":
182188
criterion = torch.nn.CrossEntropyLoss()
@@ -191,6 +197,28 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
191197

192198
# Probe training
193199
if not os.path.exists(os.path.join(ckpt_folder, "best_model.pth")):
200+
logging.info(f"Not found already trained best model. Training a model.")
201+
train_model = True
202+
203+
# Deleting existing folder
204+
if os.path.exists(ckpt_folder):
205+
shutil.rmtree(ckpt_folder)
206+
os.makedirs(ckpt_folder)
207+
else:
208+
model_train_info_str = f"Found already trained best model {os.path.join(ckpt_folder, 'best_model.pth')}."
209+
210+
if cfg.model_retraining.retrain_model:
211+
model_train_info_str += " Deleting saved weights and re-training a model as explictly requested."
212+
shutil.rmtree(ckpt_folder)
213+
os.makedirs(ckpt_folder)
214+
train_model = True
215+
else:
216+
model_train_info_str += " Not re-training a model."
217+
train_model = False
218+
219+
logging.info(model_train_info_str)
220+
221+
if train_model:
194222
best_ckpt_dict = train_probe(
195223
cfg,
196224
data,
@@ -209,10 +237,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
209237
model_cls,
210238
)
211239
else:
212-
logging.info(
213-
f"Found already trained best model {os.path.join(ckpt_folder, 'best_model.pth')}. "
214-
f"Not re-training."
215-
)
216240
best_ckpt_dict = torch.load(
217241
os.path.join(ckpt_folder, "best_model.pth"), weights_only=True
218242
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
retrain_model: False
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
retrain_model: True

src/thunder/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def benchmark(
7777
help="If provided embeddings will be re-computed even if already saved"
7878
),
7979
] = False,
80+
retrain_model: Annotated[
81+
bool,
82+
typer.Option(
83+
help="If provided model will be re-trained even if already trained and saved ckpts"
84+
),
85+
] = False,
8086
kwargs: Annotated[List[str], typer.Argument(help="Additional arguments")] = None,
8187
):
8288
from . import benchmark
@@ -101,6 +107,7 @@ def benchmark(
101107
ckpt_save_all,
102108
online_wandb,
103109
recomp_embs,
110+
retrain_model,
104111
**kwargs,
105112
)
106113

src/thunder/utils/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def get_config(
1313
data_loading_type: Optional[str] = None,
1414
wandb_mode: Optional[str] = None,
1515
embedding_recomputing: Optional[str] = None,
16+
model_retraining: Optional[str] = None,
1617
**kwargs,
1718
) -> DictConfig:
1819
params = {
@@ -24,6 +25,7 @@ def get_config(
2425
"task": task,
2526
"wandb": wandb_mode,
2627
"embedding_recomputing": embedding_recomputing,
28+
"model_retraining": model_retraining,
2729
}
2830

2931
overrides = [f"+{k}={v}" for k, v in params.items() if v is not None]

0 commit comments

Comments
 (0)