-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidate_models.py
More file actions
71 lines (69 loc) · 2.67 KB
/
validate_models.py
File metadata and controls
71 lines (69 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from src.dataset import Dataset
from src.dataset_loader import (
NoWeekLoader,
WeatherLoader,
TrafficLoader,
ElectricityLoader,
)
from src.window import WindowConfig
from src.validate import validate
from src.predictors.ml_predictor import MLPredictor
from src.predictors.dbp_predictor import DBPPredictor
from src.predictors.kf_predictor import KFPredictor
DATASET_DIR = "/home/l.calisti/notebooks/dlds_paper/datasets"
MODELS_DIR = "/home/l.calisti/notebooks/dlds_paper/models"
OUTPUT_DIR = "/home/l.calisti/notebooks/dlds_paper/outputs"
MODELS = ["model3"]
SEEDS = [69] # [42, 69, 911, 2020, 42069]
WS = [5] # [3,5,7, 10, 15]
TS = [1, 7] # [1, 3, 5, 7, 10, 15]
DATASET_NAMES = [
("noweekend/co2_peano_no_weekend.csv", NoWeekLoader()),
("noweekend/pm2p5_peano_no_weekend.csv", NoWeekLoader()),
("noweekend/rad_peano_no_weekend.csv", NoWeekLoader()),
("noweekend/noise_peano_no_weekend.csv", NoWeekLoader()),
# ("external/weather.csv", WeatherLoader("T (degC)")),
# ("external/weather.csv", WeatherLoader("rh (%)")),
# ("external/weather.csv", WeatherLoader("wv (m/s)")),
# ("external/weather.csv", WeatherLoader("SWDR (W/m�)")),
# ("external/traffic.csv", TrafficLoader()),
# ("external/electricity.csv", ElectricityLoader()),
]
for dataset_name, dataset_loader in DATASET_NAMES:
ds = Dataset(
name=dataset_name,
base_path=DATASET_DIR,
loader=dataset_loader,
smooth=None,
)
for seed in SEEDS:
for ws in WS:
for ts in TS:
for model_name in MODELS:
wc = WindowConfig(ws, ts)
ml_predictor = MLPredictor(
model_name, MODELS_DIR, ds.name(), wc, seed
)
validate(
dataset=ds,
predictor=ml_predictor,
window_config=wc,
seed=seed,
output_path=OUTPUT_DIR,
)
# dbp_predictor = DBPPredictor(20)
# validate(
# dataset=ds,
# predictor=dbp_predictor,
# window_config=WindowConfig(20, 1),
# seed=seed,
# output_path=OUTPUT_DIR,
# )
# kf_predictor = KFPredictor(x_size=3)
# validate(
# dataset=ds,
# predictor=kf_predictor,
# window_config=WindowConfig(3, 1),
# seed=seed,
# output_path=OUTPUT_DIR,
# )