-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_models_meta.py
More file actions
95 lines (77 loc) · 3.01 KB
/
get_models_meta.py
File metadata and controls
95 lines (77 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
import pandas as pd
import wandb
from my_utilities.config_reader import read_config
BASE_PATH = "/exploiting_model_multiplicity/models"
def _form_path_to_save(project_name, dataset_name):
path_to_save = os.path.join(BASE_PATH, dataset_name, project_name, "metadata/")
if not os.path.exists(path_to_save):
os.makedirs(path_to_save)
return path_to_save
def _save_models_metadata(models_metadata: dict, path_to_save):
meta = [v for v in models_metadata.values()]
df = pd.DataFrame(meta, index=models_metadata.keys())
df.to_csv(path_to_save + "all_models.csv", index=True)
def _get_candidates(
runs,
path_to_save,
filter_runs_name: list = [],
) -> dict:
"""
Returns dictionary consists of multimodel candidates from each experiment
within one project
Inputs:
- api : instance of WandB API Class to connect to WandB Cloud
- project_name: name of wandb project
- metric: ML metric name that will be used to retrive
the most suatable candidates
- metric_decreasing: select model whether metric should be lower of higher
defaul = True
- top_n_artefacts: define how many candidates to retrive from one experiment
within the project
default = 1
- filter_runs_name: define list of experiments what must be used
default = []
Outputs:
- all_models: dict that containes only models that satisfies
passed parameters.
Keys of the dict are the models (artifacts) names
Values of the dict are the mertic values
"""
models_metadata = {}
if filter_runs_name:
# filter runs
runs = [run for run in runs if run.name in filter_runs_name]
for run in runs:
print("run selected:", run.name)
for artifact in run.logged_artifacts():
# from artefact can get https://docs.wandb.ai/ref/python/artifact#docusaurus_skipToContent_fallback
if artifact.type == "model":
# selecting top models within one run
models_metadata[artifact.name.split(":")[0]] = artifact.metadata
_save_models_metadata(models_metadata, path_to_save)
def get_models():
key = ""
api = wandb.Api(api_key=key, timeout=15)
parser = argparse.ArgumentParser(
description="Getting trained modes from the experiment"
)
parser.add_argument(
"--config",
help="Config/Experiment name",
type=str,
default=None,
)
args = parser.parse_args()
config = read_config(args.config)
project_name = config.get("project")
runs = api.runs(project_name)
dataset_name = config.get("parameters").get("dataset").get("value")
path_to_save = _form_path_to_save(project_name, dataset_name)
_get_candidates(
runs=runs,
path_to_save=path_to_save,
filter_runs_name=[],
)
get_models()