-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathinit_sweep.py
More file actions
84 lines (71 loc) · 2.94 KB
/
init_sweep.py
File metadata and controls
84 lines (71 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
from typing import Any, Dict, Union
import argparse
import json
import yaml
from wandb import agent, sweep
def check_sweep_configs(configs: list[dict[str, Any]]) -> None:
for config in configs:
assert (
len(config["project_name"].split("-")) >= 4
), "Project name must be of the form <Function>-<Backbone>-<Dataset>-<Set-Type>"
get_config(config["config_path"])
def get_config(config_path: str) -> Dict[str, Any]:
assert os.path.isfile(config_path), f"Config file not found at {config_path}"
assert config_path.endswith(".yml") or config_path.endswith(".yaml"), "Config file must be YAML"
config_dict: Dict[str, Union[str, int, float, bool]] = dict()
with open(config_path, "r") as file:
config_dict = yaml.safe_load(file)
return config_dict
def run_sweep(project_name: str, entity: str, config_path: str, parameters: Dict[str, Dict[str, Any]], sweep_name: str = "") -> None:
if sweep_name == "":
sweep_name = project_name
sweep_config = {
"program": "./train.py", # Note: not the sweep file, but the training script
"name": sweep_name,
"method": "grid", # Specify the search method (random search in this case)
"metric": {
"goal": "maximize",
"name": "aggregated/cxlkfold/val/embeddings/knn5/accuracy",
}, # Specify the metric to optimize
"parameters": parameters,
"command": ["${interpreter}", "${program}", "${args}", "--config_path", config_path],
}
sweep_id = sweep(sweep=sweep_config, project=project_name, entity=entity)
# Print the sweep ID directly
print(f"SWEEP_PATH={entity}/{project_name}/{sweep_id}")
agent(sweep_id)
if __name__ == "__main__":
sweeps = [
{
"project_name": "Embedding-EfficientNetRWM-CXL-OpenSet",
"entity": "gorillas",
"config_path": "./cfgs/efficientnet_rw_m_cxl.yml",
"parameters": {
"weight_decay": {
"values": [
0.7,
0.5,
0.1,
]
},
"dropout_p": {"values": [0.5, 0.3, 0.1]},
"start_lr": {"values": [1e-3, 1e-4, 1e-5]},
"end_lr": {"values": [1e-5, 1e-6, 1e-7]},
},
},
]
parser = argparse.ArgumentParser()
parser.add_argument("--sweep_config_file", type=str, default=None, help="Path to sweep configs directory")
args = parser.parse_args()
if args.sweep_config_file:
sweeps = json.load(open(args.sweep_config_file, "r"))
check_sweep_configs(sweeps)
for current_sweep in sweeps:
print(f"Running sweep: {current_sweep['project_name']}")
try:
run_sweep(**current_sweep) # type: ignore
except Exception as e:
print(f"Error running sweep: {current_sweep['project_name']}")
print(e)
continue