-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathrun.py
More file actions
32 lines (23 loc) · 956 Bytes
/
run.py
File metadata and controls
32 lines (23 loc) · 956 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import hydra
import wandb
from omegaconf import DictConfig
from src.train import run_model
from src.utilities.utils import get_logger
log = get_logger(__name__)
if "CONFIG_PATH" in os.environ:
# Split config path and config name from config path (split by last '/')
config_path, config_name = os.environ["CONFIG_PATH"].rsplit("/", 1)
log.info(f"Using config path from environment variable: {os.environ['CONFIG_PATH']}")
else:
config_path = "src/configs/"
config_name = "main_config.yaml"
@hydra.main(config_path=config_path, config_name=config_name, version_base=None)
def main(config: DictConfig) -> float:
"""Run/train model based on the config file configs/main_config.yaml (and any command-line overrides)."""
return run_model(config)
if __name__ == "__main__":
if "WANDB_API_KEY" in os.environ:
wandb.login(key=os.environ["WANDB_API_KEY"])
os.environ["HYDRA_FULL_ERROR"] = "1"
main()