Skip to content

Commit 040eb57

Browse files
committed
Added short benchmark description and moved imports to speed up help
1 parent c44df9a commit 040eb57

File tree

5 files changed

+31
-25
lines changed

5 files changed

+31
-25
lines changed

src/thunder/benchmark.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,7 @@
1-
import logging
2-
import os
3-
import shutil
41
from typing import Callable
52

6-
import h5py
7-
import hydra
83
from omegaconf import DictConfig
94

10-
from .datasets.utils import is_dataset_available
11-
from .models.utils import (
12-
is_model_available,
13-
load_custom_dataset_from_file,
14-
load_custom_model_from_file,
15-
)
16-
from .utils.utils import print_task_hyperparams, save_config
17-
185

196
def benchmark(
207
model: str | Callable,
@@ -51,7 +38,14 @@ def benchmark(
5138
from hydra import compose, initialize
5239
from omegaconf import OmegaConf
5340

41+
from .datasets.utils import is_dataset_available
42+
from .models.utils import (
43+
is_model_available,
44+
load_custom_dataset_from_file,
45+
load_custom_model_from_file,
46+
)
5447
from .utils.config import get_config
48+
from .utils.utils import print_task_hyperparams
5549

5650
wandb_mode = "online" if online_wandb else "offline"
5751
adaptation_type = "lora" if lora else "frozen"
@@ -119,6 +113,11 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
119113
:param cfg: config defining the job to run.
120114
"""
121115

116+
import logging
117+
import os
118+
import shutil
119+
120+
import h5py
122121
import numpy as np
123122
import torch
124123
import wandb
@@ -136,7 +135,7 @@ def run_benchmark(cfg: DictConfig, model_cls: Callable = None) -> None:
136135
from .utils.constants import UtilsConstants
137136
from .utils.data import get_data, h5_to_np, load_embeddings
138137
from .utils.dice_loss import multiclass_dice_loss
139-
from .utils.utils import set_seed
138+
from .utils.utils import save_config, set_seed
140139

141140
# Getting device
142141
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

src/thunder/datasets/data_splits.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import logging
21
import os
3-
import random
4-
from collections import defaultdict
52
from collections.abc import Callable
6-
from pathlib import Path
73
from typing import List, Union
84

95

@@ -86,6 +82,8 @@ def generate_splits_for_dataset(dataset_name: str) -> None:
8682
Args:
8783
dataset_name (str): The name of the dataset to generate splits for.
8884
"""
85+
from pathlib import Path
86+
8987
from omegaconf import OmegaConf
9088

9189
from ..utils.constants import DatasetConstants
@@ -277,6 +275,9 @@ def create_few_shot_training_data(
277275
:param data_splits: data splits dictionary.
278276
:param nb_sets_per_nb_shot: number of data sets to create for each number of shots.
279277
"""
278+
import random
279+
from collections import defaultdict
280+
280281
# Creating lebel2 image dict
281282
train_images, train_labels = (
282283
data_splits["train"]["images"],
@@ -313,6 +314,8 @@ def generate_data_splits(
313314
:param dataset_yaml: path to the dataset yaml config file.
314315
:param split_function: function to run to generate the data splits.
315316
"""
317+
import logging
318+
316319
from omegaconf import OmegaConf
317320

318321
dataset_folder = os.path.join(base_folder, dataset_name)

src/thunder/datasets/download.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
import logging
21
import os
3-
from pathlib import Path
42
from typing import List, Union
53

6-
from .data_splits import generate_splits
74
from .dataset import *
85

96

@@ -39,6 +36,8 @@ def download_datasets(datasets: Union[List[str], str], make_splits: bool = False
3936
datasets (List[str] or str): A dataset name string or a List of dataset names to download or one of the following aliases: `all`, `classification`, `segmentation`.
4037
make_splits (bool): Whether to generate data splits for the datasets. Defaults to False.
4138
"""
39+
from .data_splits import generate_splits
40+
4241
if "THUNDER_BASE_DATA_FOLDER" not in os.environ:
4342
raise EnvironmentError(
4443
"Please set base data directory of thunder using `export THUNDER_BASE_DATA_FOLDER=/base/data/directory`"
@@ -105,6 +104,9 @@ def download_datasets(datasets: Union[List[str], str], make_splits: bool = False
105104

106105

107106
def download_dataset(dataset: str):
107+
import logging
108+
from pathlib import Path
109+
108110
root_folder = os.path.join(
109111
os.environ["THUNDER_BASE_DATA_FOLDER"], f"datasets/{dataset}"
110112
)

src/thunder/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def benchmark(
8686
] = False,
8787
kwargs: Annotated[List[str], typer.Argument(help="Additional arguments")] = None,
8888
):
89+
"""Benchmark a model on a dataset for a task."""
8990
from . import benchmark
9091

9192
if "THUNDER_BASE_DATA_FOLDER" not in os.environ:

src/thunder/models/download.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import logging
2-
import os
3-
from pathlib import Path
42
from typing import List, Union
53

6-
from ..utils.utils import wget_download
7-
84
# Configure logging
95
logging.basicConfig(
106
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -185,8 +181,13 @@ def download_models(models: Union[List[str], str]) -> None:
185181

186182

187183
def download_model(model: str) -> None:
184+
import os
185+
from pathlib import Path
186+
188187
from huggingface_hub import hf_hub_download
189188

189+
from ..utils.utils import wget_download
190+
190191
if "THUNDER_BASE_DATA_FOLDER" not in os.environ:
191192
raise EnvironmentError(
192193
"Please set base data directory of thunder using `export THUNDER_BASE_DATA_FOLDER=/base/data/directory`"

0 commit comments

Comments
 (0)