|
1 | 1 | import json |
2 | 2 | import numpy as np |
3 | | -from omegaconf import DictConfig |
4 | 3 | import os |
5 | 4 | import random |
6 | 5 | import torch |
|
9 | 8 | import h5py |
10 | 9 | from typing import Dict, Any, Optional |
11 | 10 |
|
| 11 | +from typing import Optional |
| 12 | +from omegaconf import DictConfig, OmegaConf |
| 13 | +import logging |
| 14 | + |
| 15 | +def print_task_hyperparams( |
| 16 | + cfg: DictConfig, |
| 17 | + custom_name: Optional[str] = None |
| 18 | +) -> None: |
| 19 | + """ |
| 20 | + Print dataset, model, and only the task-specific hyper-parameters |
| 21 | + from a full Hydra cfg, using classic ANSI colors. |
| 22 | +
|
| 23 | + If `custom_name` is provided, it will be used instead of |
| 24 | + cfg.pretrained_model.model_name. |
| 25 | + """ |
| 26 | + RESET = "\033[0m" |
| 27 | + BOLD = "\033[1m" |
| 28 | + UNDERLINE = "\033[4m" |
| 29 | + BLUE = "\033[34m" |
| 30 | + WHITE = "\033[37m" |
| 31 | + GREEN = "\033[32m" |
| 32 | + RED = "\033[31m" |
| 33 | + |
| 34 | + task = cfg.task.type |
| 35 | + dataset_name = cfg.dataset.dataset_name |
| 36 | + |
| 37 | + # Safe-fetch the model name, falling back to custom_name if given |
| 38 | + if custom_name is not None: |
| 39 | + model_label = custom_name |
| 40 | + else: |
| 41 | + model_label = OmegaConf.select( |
| 42 | + cfg, "pretrained_model.model_name", default="<unknown model>" |
| 43 | + ) |
| 44 | + |
| 45 | + # Choose where hyperparams live |
| 46 | + task_cfg = cfg.task if task not in ["linear_probing", "segmentation"] else cfg.adaptation |
| 47 | + |
| 48 | + sep = "-" * 60 |
| 49 | + |
| 50 | + logging.info(f"\n{BOLD}{BLUE}\U0001F680 Experiment Info{RESET}") |
| 51 | + print(sep) |
| 52 | + print(f"{BLUE}Task :{RESET} {WHITE}{task}{RESET}") |
| 53 | + print(f"{BLUE}Dataset:{RESET} {WHITE}{dataset_name}{RESET}") |
| 54 | + print(f"{BLUE}Model :{RESET} {WHITE}{model_label}{RESET}") |
| 55 | + print(sep) |
| 56 | + |
| 57 | + print(f"\n{BOLD}{BLUE}Hyper-parameters{RESET}\n{sep}") |
| 58 | + |
| 59 | + # Fields to skip |
| 60 | + skip = {"compatible_adaptation_types", "base_embeddings_folder"} |
| 61 | + if task not in ["linear_probing", "segmentation"]: |
| 62 | + skip.add("type") |
| 63 | + |
| 64 | + # Print each hyperparam |
| 65 | + for key, val in task_cfg.items(): |
| 66 | + if key in skip: |
| 67 | + continue |
| 68 | + |
| 69 | + # Special PGD block header |
| 70 | + if task == "adversarial_attack" and key == "attack": |
| 71 | + print(f"\n{RED}{BOLD}PGD attack hyper-parameters{RESET}") |
| 72 | + |
| 73 | + # Nested dicts |
| 74 | + if isinstance(val, (DictConfig, dict)): |
| 75 | + print(f"{BLUE}{BOLD}{key}{RESET}:") |
| 76 | + for subkey, subval in val.items(): |
| 77 | + print(f" {WHITE}{subkey}{RESET}: {subval}") |
| 78 | + else: |
| 79 | + print(f"{BLUE}{key}{RESET}: {WHITE}{val}{RESET}") |
| 80 | + |
| 81 | + print(sep) |
| 82 | + |
| 83 | + |
12 | 84 |
|
13 | 85 | def get_hyperaparams_dict(cfg: DictConfig) -> dict: |
14 | 86 | """ |
|
0 commit comments