Skip to content

Commit b2b4448

Browse files
authored
Merge pull request #30 from AdityaLab/Shiduo
logic for using all data and choice for short horizon
2 parents 6264ac2 + db254bf commit b2b4448

3 files changed

Lines changed: 53 additions & 22 deletions

File tree

leaderboard.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
import pandas as pd
55
import time
66

7-
# src_path = os.path.abspath(os.path.join("src"))
8-
# if src_path not in sys.path:
9-
# sys.path.insert(0, src_path)
7+
src_path = os.path.abspath(os.path.join("src"))
8+
if src_path not in sys.path:
9+
sys.path.insert(0, src_path)
1010

11-
from src.samay.model import TimesfmModel, MomentModel, ChronosModel, ChronosBoltModel, TinyTimeMixerModel, MoiraiTSModel
12-
from src.samay.dataset import TimesfmDataset, MomentDataset, ChronosDataset, ChronosBoltDataset, TinyTimeMixerDataset, MoiraiDataset
13-
from src.samay.utils import load_args, get_gifteval_datasets
14-
from src.samay.metric import *
11+
from samay.model import TimesfmModel, MomentModel, ChronosModel, ChronosBoltModel, TinyTimeMixerModel, MoiraiTSModel
12+
from samay.dataset import TimesfmDataset, MomentDataset, ChronosDataset, ChronosBoltDataset, TinyTimeMixerDataset, MoiraiDataset
13+
from samay.utils import load_args, get_gifteval_datasets
14+
from samay.metric import *
1515

1616

1717
# ECON_NAMES = {
@@ -146,7 +146,7 @@ def calc_pred_and_context_len(freq):
146146

147147
if __name__ == "__main__":
148148

149-
for model_name in MODEL_NAMES[4:]:
149+
for model_name in MODEL_NAMES[3:]:
150150
print(f"Evaluating model: {model_name}")
151151
# create csv file for leaderboard if not already created
152152
csv_path = f"leaderboard/{model_name}.csv"
@@ -173,6 +173,10 @@ def calc_pred_and_context_len(freq):
173173
args = load_args(arg_path)
174174

175175
for fname, freq, fs in filesizes:
176+
if fname != "solar":
177+
continue
178+
elif freq != "W":
179+
continue
176180
print(f"Evaluating {fname} ({freq})")
177181
# Adjust the context and prediction length based on the frequency
178182

@@ -201,6 +205,7 @@ def calc_pred_and_context_len(freq):
201205
dataset = TimesfmDataset(datetime_col='timestamp', path=dataset_path, mode='test', context_len=args["config"]["context_len"], horizon_len=args["config"]["horizon_len"], boundaries=(-1, -1, -1), batchsize=64)
202206
start = time.time()
203207
metrics = model.evaluate(dataset)
208+
print("Metrics: ", metrics)
204209
end = time.time()
205210
print(f"Size of dataset: {fs:.2f} MB")
206211
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
@@ -209,7 +214,7 @@ def calc_pred_and_context_len(freq):
209214
model = MomentModel(**args)
210215
args["config"]["task_name"] = "forecasting"
211216
train_dataset = MomentDataset(datetime_col='timestamp', path=dataset_path, mode='train', horizon_len=args["config"]["forecast_horizon"], normalize=False)
212-
dataset = MomentDataset(datetime_col='timestamp', path=dataset_path, mode='test', horizon_len=args["config"]["forecast_horizon"], normalize=False)
217+
dataset = MomentDataset(datetime_col='timestamp', path=dataset_path, mode='test', horizon_len=args["config"]["forecast_horizon"], normalize=False, boundaries=[-1, -1, -1])
213218
finetuned_model = model.finetune(train_dataset, task_name="forecasting")
214219
start = time.time()
215220
metrics = model.evaluate(dataset, task_name="forecasting")
@@ -223,7 +228,7 @@ def calc_pred_and_context_len(freq):
223228
dataset_config = load_args("config/chronos_dataset.json")
224229
dataset_config["context_length"] = context_len
225230
dataset_config["prediction_length"] = pred_len
226-
dataset = ChronosDataset(datetime_col='timestamp', path=dataset_path, mode='test', config=dataset_config, batch_size=4)
231+
dataset = ChronosDataset(datetime_col='timestamp', path=dataset_path, mode='test', config=dataset_config, batch_size=4, boundaries=[-1, -1, -1])
227232
start = time.time()
228233
metrics = model.evaluate(dataset, horizon_len=dataset_config["prediction_length"], quantile_levels=[0.1, 0.5, 0.9])
229234
end = time.time()
@@ -233,7 +238,7 @@ def calc_pred_and_context_len(freq):
233238
elif model_name == "chronosbolt":
234239
repo = "amazon/chronos-bolt-small"
235240
model = ChronosBoltModel(repo=repo)
236-
dataset = ChronosBoltDataset(datetime_col='timestamp', path=dataset_path, mode='test', batch_size=8, context_len=context_len, horizon_len=pred_len)
241+
dataset = ChronosBoltDataset(datetime_col='timestamp', path=dataset_path, mode='test', batch_size=8, context_len=context_len, horizon_len=pred_len, boundaries=[-1, -1, -1])
237242
start = time.time()
238243
metrics = model.evaluate(dataset, horizon_len=pred_len, quantile_levels=[0.1, 0.5, 0.9])
239244
end = time.time()
@@ -242,7 +247,7 @@ def calc_pred_and_context_len(freq):
242247

243248
elif model_name == "ttm":
244249
model = TinyTimeMixerModel(**args)
245-
dataset = TinyTimeMixerDataset(datetime_col='timestamp', path=dataset_path, mode='test', context_len=context_len, horizon_len=pred_len)
250+
dataset = TinyTimeMixerDataset(datetime_col='timestamp', path=dataset_path, mode='test', context_len=context_len, horizon_len=pred_len, boundaries=[-1, -1, -1])
246251
start = time.time()
247252
metrics = model.evaluate(dataset)
248253
end = time.time()

leaderboard/timesfm.csv

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
dataset,mse,mae,mase,mape,rmse,nrmse,smape,msis,nd,mwsq,crps
2-
m4_yearly,9324.314453125,1.453155755996704,,129.92918395996094,96.56249237060548,0.003338028634445,1.9404819011688232,1825546.875,0.1487463746984999,736.5278930664062,4763464.50894722
3-
m4_quarterly,11540.2001953125,4.373758316040039,,0.000894644006621,107.42532348632812,0.0021157962591371,0.0008130415226332,0.0147697096690535,0.0007122407249761,8877.486328125,76965652.61745518
4-
m4_monthly,21410.37890625,6.217271327972412,0.3291291296482086,0.0012720649829134,146.32286071777344,0.0019095096563116,0.0012352424673736,0.019323231652379,0.001287188373541,12769.849609375,105929602.0931511
5-
m4_weekly,561843.9375,169.09071350097656,0.0666100904345512,0.0394658930599689,749.5625,0.0146108341018089,0.0377030260860919,0.026721965521574,0.0343743559651665,258700.15625,2611755055.7616825
6-
m4_daily,72798.4453125,27.041128158569336,0.0343014523386955,0.0043948837555944,269.8118591308594,0.006309027242851,0.0045300694182515,0.0084697818383574,0.0044873143451785,40192.64453125,424549963.43306565
7-
m4_hourly,5090942.0,191.70616149902344,,0.3146225810050964,2256.3115234375,0.0032095561059994,0.0889025703072547,0.1215216591954231,0.0266055212389293,663349.375,118992351217.515
8-
car_parts_with_missing,1.4118587970733645,0.4692889750003814,,11805.2060546875,1.1882166862487793,0.029705409729867,1.828437566757202,0.1179555356502533,1.0383541264832732,0.2255326509475708,0.730486241146584
9-
hierarchical_sales,223.04559326171875,3.070805549621582,,1988.4501953125,14.934711456298828,0.0273529508841928,1.5463078022003174,0.0621227174997329,0.4439457604474497,8.156706809997559,1187.0533112340966
10-
restaurant,23.778711318969727,0.7549307346343994,,122.93232727050781,4.876341819763184,0.018611990968104102,1.880303978919983,0.04096284881234169,0.5990635938742331,0.7695077061653137,48.720557867116376
1+
dataset,mse,mae,mase,mape,rmse,nrmse,smape,msis,nd,mwsq,crps,size_in_MB,eval_time
2+
m4_yearly,9324.314453125,1.453155755996704,,129.92918395996094,96.56249237060548,0.003338028634445,1.9404819011688232,1825546.875,0.1487463746984999,736.5278930664062,4763464.50894722,,
3+
m4_quarterly,11540.2001953125,4.373758316040039,,0.000894644006621,107.42532348632812,0.0021157962591371,0.0008130415226332,0.0147697096690535,0.0007122407249761,8877.486328125,76965652.61745518,,
4+
m4_monthly,21410.37890625,6.217271327972412,0.3291291296482086,0.0012720649829134,146.32286071777344,0.0019095096563116,0.0012352424673736,0.019323231652379,0.001287188373541,12769.849609375,105929602.0931511,,
5+
m4_weekly,561843.9375,169.09071350097656,0.0666100904345512,0.0394658930599689,749.5625,0.0146108341018089,0.0377030260860919,0.026721965521574,0.0343743559651665,258700.15625,2611755055.7616825,,
6+
m4_daily,72798.4453125,27.041128158569336,0.0343014523386955,0.0043948837555944,269.8118591308594,0.006309027242851,0.0045300694182515,0.0084697818383574,0.0044873143451785,40192.64453125,424549963.43306565,,
7+
m4_hourly,5090942.0,191.70616149902344,,0.3146225810050964,2256.3115234375,0.0032095561059994,0.0889025703072547,0.1215216591954231,0.0266055212389293,663349.375,118992351217.515,,
8+
car_parts_with_missing,1.4118587970733645,0.4692889750003814,,11805.2060546875,1.1882166862487793,0.029705409729867,1.828437566757202,0.1179555356502533,1.0383541264832732,0.2255326509475708,0.730486241146584,,
9+
hierarchical_sales,223.04559326171875,3.070805549621582,,1988.4501953125,14.934711456298828,0.0273529508841928,1.5463078022003174,0.0621227174997329,0.4439457604474497,8.156706809997559,1187.0533112340966,,
10+
restaurant,23.778711318969727,0.7549307346343994,,122.9323272705078,4.876341819763184,0.0186119909681041,1.880303978919983,0.0409628488123416,0.5990635938742331,0.7695077061653137,48.720557867116376,,
11+
solar,243211.953125,140.67047119140625,0.10394947230815887,168.28846740722656,493.1652526855469,0.027132100801376485,1.5959948301315308,0.037582240998744965,0.1559242251909539,68889.2734375,313517743.6982181,0.06,3.3s

src/samay/dataset.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def __init__(self, name=None,
170170
self.boundaries = [0, 0, len(self.data) - 1]
171171
else:
172172
self.boundaries = boundaries
173+
self.horizon_len = min(self.horizon_len, int(0.3*len(self.data)+1))
173174
self.ts_cols = [col for col in self.data.columns if col != self.datetime_col]
174175
tfdtl = TimeSeriesdata(
175176
data_path=self.data_path,
@@ -301,6 +302,12 @@ def _read_data(self):
301302
if self.boundaries[2] == 0:
302303
self.boundaries[2] = int(len(self.df) - 1)
303304

305+
if self.boundaries == [-1, -1, -1]:
306+
# use all data for training
307+
self.boundaries = [0, 0, len(self.df) - 1]
308+
309+
self.horizon_len = min(self.horizon_len, int(0.3*len(self.df)+1))
310+
304311
self.n_channels = self.df.shape[1] - 1
305312
self.num_chunks = (self.n_channels + self.max_col_num - 1) // self.max_col_num
306313

@@ -445,6 +452,12 @@ def _read_data(self):
445452
if self.boundaries[2] == 0:
446453
self.boundaries[2] = int(len(self.df) - 1)
447454

455+
if self.boundaries == [-1, -1, -1]:
456+
# use all data for training
457+
self.boundaries = [0, 0, len(self.df) - 1]
458+
459+
self.horizon_len = min(self.horizon_len, int(0.3*len(self.df)+1))
460+
448461
self.n_channels = self.df.shape[1] - 1
449462
self.num_chunks = (self.n_channels + self.max_col_num - 1) // self.max_col_num
450463

@@ -566,6 +579,12 @@ def _read_data(self):
566579
if self.boundaries[2] == 0:
567580
self.boundaries[2] = int(len(self.df) - 1)
568581

582+
if self.boundaries == [-1, -1, -1]:
583+
# use all data for training
584+
self.boundaries = [0, 0, len(self.df) - 1]
585+
586+
self.forecast_horizon = min(self.forecast_horizon, int(0.3*len(self.df)+1))
587+
569588
if self.task_name == 'detection':
570589
self.n_channels = 1
571590
else:
@@ -748,6 +767,12 @@ def _read_data(self):
748767
if self.boundaries[2] == 0:
749768
self.boundaries[2] = int(len(self.df) - 1)
750769

770+
if self.boundaries == [-1, -1, -1]:
771+
# use all data for training
772+
self.boundaries = [0, 0, len(self.df) - 1]
773+
774+
self.horizon_len = min(self.horizon_len, int(0.3*len(self.df)+1))
775+
751776
self.n_channels = self.df.shape[1] - 1
752777
self.num_chunks = (self.n_channels + self.max_col_num - 1) // self.max_col_num
753778

0 commit comments

Comments
 (0)