Skip to content

Commit 3e41e15

Browse files
Print configuration and hyperaparameters
Display configuration and hyperparameters Remove double import and change colors Fixing custom models parameters print
1 parent 025a275 commit 3e41e15

6 files changed

Lines changed: 82 additions & 8 deletions

File tree

docs/custom_config.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,5 @@ Here is a non exhaustive list of the parameters that you may want to override pe
7979
### Transformation invariance
8080
| Name | Type | Description |
8181
|------|---|---|
82-
| task.pre_comp_emb_batch_size | int | Batch size for precomputing the embeddings. |
82+
| task.transformation_invariance_batch_size | int | Batch size for the transformations. |
8383
| task.nb_images | int | Number of images to use. |

src/thunder/benchmark.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .datasets.utils import is_dataset_available
1010
from .models.utils import is_model_available, load_custom_model_from_file
11-
11+
from .utils.utils import print_task_hyperparams
1212

1313
def benchmark(
1414
model: str | Callable,
@@ -46,17 +46,19 @@ def benchmark(
4646
from omegaconf import OmegaConf
4747

4848
from .utils.config import get_config
49-
49+
5050
wandb_mode = "online" if online_wandb else "offline"
5151
adaptation_type = "lora" if lora else "frozen"
5252
ckpt_saving = "save_ckpts_all_epochs" if ckpt_save_all else "save_best_ckpt_only"
5353
embedding_recomputing = "recomp_embs" if recomp_embs else "no_recomp_embs"
5454
model_retraining = "retrain_model" if retrain_model else "no_retrain_model"
5555
model_name = model if isinstance(model, str) else None
56+
custom_name = None
5657

5758
if model_name and model_name.startswith("custom:"):
5859
model = load_custom_model_from_file(model_name.split(":")[1])
5960
model_name = None
61+
custom_name = model.name
6062

6163
# Get Config
6264
cfg = get_config(
@@ -72,6 +74,8 @@ def benchmark(
7274
**kwargs,
7375
)
7476

77+
print_task_hyperparams(cfg, custom_name=custom_name)
78+
7579
if not is_dataset_available(dataset):
7680
from . import download_datasets
7781

@@ -159,7 +163,6 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
159163
mode=cfg.wandb.mode,
160164
)
161165
wandb_base_folder = f"|{task_type}| |{adaptation_type}| |{dataset_name}|"
162-
163166
# Folder to save results
164167
res_folder = os.path.join(
165168
os.environ["THUNDER_BASE_DATA_FOLDER"],

src/thunder/config/task/adversarial_attack.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
type: adversarial_attack
22
compatible_adaptation_types: ["frozen"]
33
base_embeddings_folder: ${oc.env:THUNDER_BASE_DATA_FOLDER}/embeddings/
4-
pre_comp_emb_batch_size: 128
54
attack_batch_size: 8
65
nb_attack_images: 10000
76
# ----------------------------------------------------------------------
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
type: transformation_invariance
22
compatible_adaptation_types: ["frozen"]
33
base_embeddings_folder: ${oc.env:THUNDER_BASE_DATA_FOLDER}/embeddings/
4-
pre_comp_emb_batch_size: 128
4+
transformation_invariance_batch_size: 64
55
nb_images: 1000

src/thunder/tasks/transformation_invariance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def transformation_invariance(
143143

144144
dataloader = DataLoader(
145145
subset_dataset,
146-
batch_size=cfg.adaptation.batch_size,
146+
batch_size=cfg.task.transformation_invariance_batch_size,
147147
shuffle=False,
148148
num_workers=cfg.adaptation.num_workers,
149149
generator=generator,

src/thunder/utils/utils.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import json
22
import numpy as np
3-
from omegaconf import DictConfig
43
import os
54
import random
65
import torch
@@ -9,6 +8,79 @@
98
import h5py
109
from typing import Dict, Any, Optional
1110

11+
from typing import Optional
12+
from omegaconf import DictConfig, OmegaConf
13+
import logging
14+
15+
def print_task_hyperparams(
16+
cfg: DictConfig,
17+
custom_name: Optional[str] = None
18+
) -> None:
19+
"""
20+
Print dataset, model, and only the task-specific hyper-parameters
21+
from a full Hydra cfg, using classic ANSI colors.
22+
23+
If `custom_name` is provided, it will be used instead of
24+
cfg.pretrained_model.model_name.
25+
"""
26+
RESET = "\033[0m"
27+
BOLD = "\033[1m"
28+
UNDERLINE = "\033[4m"
29+
BLUE = "\033[34m"
30+
WHITE = "\033[37m"
31+
GREEN = "\033[32m"
32+
RED = "\033[31m"
33+
34+
task = cfg.task.type
35+
dataset_name = cfg.dataset.dataset_name
36+
37+
# Safe-fetch the model name, falling back to custom_name if given
38+
if custom_name is not None:
39+
model_label = custom_name
40+
else:
41+
model_label = OmegaConf.select(
42+
cfg, "pretrained_model.model_name", default="<unknown model>"
43+
)
44+
45+
# Choose where hyperparams live
46+
task_cfg = cfg.task if task not in ["linear_probing", "segmentation"] else cfg.adaptation
47+
48+
sep = "-" * 60
49+
50+
logging.info(f"\n{BOLD}{BLUE}\U0001F680 Experiment Info{RESET}")
51+
print(sep)
52+
print(f"{BLUE}Task :{RESET} {WHITE}{task}{RESET}")
53+
print(f"{BLUE}Dataset:{RESET} {WHITE}{dataset_name}{RESET}")
54+
print(f"{BLUE}Model :{RESET} {WHITE}{model_label}{RESET}")
55+
print(sep)
56+
57+
print(f"\n{BOLD}{BLUE}Hyper-parameters{RESET}\n{sep}")
58+
59+
# Fields to skip
60+
skip = {"compatible_adaptation_types", "base_embeddings_folder"}
61+
if task not in ["linear_probing", "segmentation"]:
62+
skip.add("type")
63+
64+
# Print each hyperparam
65+
for key, val in task_cfg.items():
66+
if key in skip:
67+
continue
68+
69+
# Special PGD block header
70+
if task == "adversarial_attack" and key == "attack":
71+
print(f"\n{RED}{BOLD}PGD attack hyper-parameters{RESET}")
72+
73+
# Nested dicts
74+
if isinstance(val, (DictConfig, dict)):
75+
print(f"{BLUE}{BOLD}{key}{RESET}:")
76+
for subkey, subval in val.items():
77+
print(f" {WHITE}{subkey}{RESET}: {subval}")
78+
else:
79+
print(f"{BLUE}{key}{RESET}: {WHITE}{val}{RESET}")
80+
81+
print(sep)
82+
83+
1284

1385
def get_hyperaparams_dict(cfg: DictConfig) -> dict:
1486
"""

0 commit comments

Comments
 (0)