Skip to content

Commit d84e8ce

Browse files
authored
Merge pull request #26 from AdityaLab/Sudarshan
Leaderboard Setup
2 parents d90df82 + 378abe4 commit d84e8ce

11 files changed

Lines changed: 367 additions & 165 deletions

File tree

README.md

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,6 @@ avg_loss, trues, preds, histories = tfm.evaluate(val_dataset)
138138

139139
### MOIRAI
140140

141-
Install the package: `pip install git+https://github.com/AdityaLab/Samay.git`.
142-
143141
#### Loading Model
144142

145143
```python
@@ -161,16 +159,11 @@ moirai_model = MoiraiTSModel(repo=repo, config=config)
161159
#### Loading Dataset
162160

163161
```python
164-
data_config = {"name" : "ett",
165-
"path" : "../src/samay/models/moment/data/ETTh1.csv",
166-
"date_col" : "date",
167-
"freq": "h"
168-
}
169162

170-
train_dataset = MoiraiDataset(name=data_config['name'], mode="train", path=data_config['path'], datetime_col=data_config['date_col'], freq=data_config['freq'],
163+
train_dataset = MoiraiDataset(name="ett", mode="train", path="data/ETTh1.csv", datetime_col="date", freq="h",
171164
context_len=config['context_len'], horizon_len=config['horizon_len'])
172165

173-
test_dataset = MoiraiDataset(name=data_config['name'], mode="test", path=data_config['path'], datetime_col=data_config['date_col'], freq=data_config['freq'],
166+
test_dataset = MoiraiDataset(name="ett", mode="test", path="data/ETTh1.csv", datetime_col="date", freq="h",
174167
context_len=config['context_len'], horizon_len=config['horizon_len'])
175168
```
176169

example/moirai.ipynb

Lines changed: 46 additions & 38 deletions
Large diffs are not rendered by default.

leaderboard.py

Lines changed: 114 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,40 @@
22
import sys
33
import numpy as np
44
import pandas as pd
5-
5+
import time
66

77
# src_path = os.path.abspath(os.path.join("src"))
88
# if src_path not in sys.path:
99
# sys.path.insert(0, src_path)
1010

11-
from samay.model import TimesfmModel, MomentModel, ChronosModel, ChronosBoltModel, TinyTimeMixerModel
12-
from samay.dataset import TimesfmDataset, MomentDataset, ChronosDataset, ChronosBoltDataset, TinyTimeMixerDataset
13-
from samay.utils import load_args
14-
from samay.metric import *
11+
from src.samay.model import TimesfmModel, MomentModel, ChronosModel, ChronosBoltModel, TinyTimeMixerModel, MoiraiTSModel
12+
from src.samay.dataset import TimesfmDataset, MomentDataset, ChronosDataset, ChronosBoltDataset, TinyTimeMixerDataset, MoiraiDataset
13+
from src.samay.utils import load_args, get_gifteval_datasets
14+
from src.samay.metric import *
1515

1616

17-
ECON_NAMES = {
18-
"m4_yearly": ["Y"],
19-
"m4_quarterly": ["Q"],
20-
"m4_monthly": ["M"],
21-
"m4_weekly": ["W"],
22-
"m4_daily": ["D"],
23-
"m4_hourly": ["H"],
24-
}
17+
# ECON_NAMES = {
18+
# "m4_yearly": ["Y"],
19+
# "m4_quarterly": ["Q"],
20+
# "m4_monthly": ["M"],
21+
# "m4_weekly": ["W"],
22+
# "m4_daily": ["D"],
23+
# "m4_hourly": ["H"],
24+
# }
2525

26-
SALES_NAMES = {
27-
"car_parts_with_missing": ['M'],
28-
"hierarchical_sales": ['D', 'W'],
29-
"restaurant": ['D'],
30-
}
26+
# SALES_NAMES = {
27+
# "car_parts_with_missing": ['M'],
28+
# "hierarchical_sales": ['D', 'W'],
29+
# "restaurant": ['D'],
30+
# }
31+
32+
start = time.time()
33+
NAMES, filesizes = get_gifteval_datasets("data/gifteval")
34+
end = time.time()
35+
36+
print(f"Time taken to load datasets: {end-start:.2f} seconds")
3137

38+
MODEL_NAMES = ["moirai", "chronos", "chronosbolt", "timesfm", "moment", "ttm"]
3239
MONASH_NAMES = {
3340
# "weather": "1D",
3441
"tourism_yearly": ["1YE"],
@@ -99,7 +106,6 @@
99106
"temperature_rain": 30
100107
}
101108

102-
MODEL_NAMES = ["timesfm", "moment", "chronos", "chronosbolt", "ttm"]
103109
MODEL_CONTEXT_LEN = {
104110
"timesfm": 32,
105111
"moment": 512,
@@ -139,35 +145,37 @@ def calc_pred_and_context_len(freq):
139145

140146

141147
if __name__ == "__main__":
142-
model_name = MODEL_NAMES[1]
143-
# create csv file for leaderboard if not already created
144-
# csv_path = f"leaderboard/{model_name}.csv"
145-
csv_path = f"leaderboard/monash_{model_name}.csv"
146-
msh = False
147-
148-
if not os.path.exists(csv_path):
149-
df = pd.DataFrame(columns=["dataset", "mse", "mae", "mase", "mape", "rmse", "nrmse", "smape", "msis", "nd", "mwsq", "crps"])
150-
df.to_csv(csv_path, index=False)
151-
152-
if model_name == "timesfm":
153-
arg_path = "config/timesfm.json"
154-
args = load_args(arg_path)
155-
elif model_name == "moment":
156-
arg_path = "config/moment_forecast.json"
157-
args = load_args(arg_path)
158-
elif model_name == "chronos":
159-
arg_path = "config/chronos.json"
160-
args = load_args(arg_path)
161-
elif model_name == "ttm":
162-
arg_path = "config/tinytimemixer.json"
163-
args = load_args(arg_path)
164-
165-
# NAMES = ECON_NAMES | SALES_NAMES
166-
NAMES = MONASH_NAMES
167-
msh = True
168-
169-
for dataset_name, freqs in NAMES.items():
170-
for freq in freqs:
148+
149+
for model_name in MODEL_NAMES[4:]:
150+
print(f"Evaluating model: {model_name}")
151+
# create csv file for leaderboard if not already created
152+
csv_path = f"leaderboard/{model_name}.csv"
153+
if not os.path.exists(csv_path):
154+
print(f"Creating leaderboard csv file: {csv_path}")
155+
df = pd.DataFrame(columns=["dataset", "size_in_MB", "eval_time", "mse", "mae", "mase", "mape", "rmse", "nrmse", "smape", "msis", "nd", "mwsq", "crps"])
156+
df.to_csv(csv_path, index=False)
157+
158+
# Load model config
159+
if model_name == "timesfm":
160+
arg_path = "config/timesfm.json"
161+
args = load_args(arg_path)
162+
elif model_name == "moment":
163+
arg_path = "config/moment_forecast.json"
164+
args = load_args(arg_path)
165+
elif model_name == "chronos":
166+
arg_path = "config/chronos.json"
167+
args = load_args(arg_path)
168+
elif model_name == "ttm":
169+
arg_path = "config/tinytimemixer.json"
170+
args = load_args(arg_path)
171+
elif model_name == "moirai":
172+
arg_path = "config/moirai.json"
173+
args = load_args(arg_path)
174+
175+
for fname, freq, fs in filesizes:
176+
print(f"Evaluating {fname} ({freq})")
177+
# Adjust the context and prediction length based on the frequency
178+
171179
# pred_len, context_len = calc_pred_and_context_len(freq)
172180
pred_len, context_len = 96, 512
173181
if msh:
@@ -180,6 +188,17 @@ def calc_pred_and_context_len(freq):
180188
elif model_name == "ttm":
181189
args["config"]["horizon_len"] = pred_len
182190
args["config"]["context_len"] = context_len
191+
elif model_name == "moirai":
192+
args["config"]["horizon_len"] = pred_len
193+
args["config"]["context_len"] = context_len
194+
195+
# Set the dataset path
196+
if len(NAMES.get(fname)) == 1:
197+
dataset_path = f"data/gifteval/{fname}/data.csv"
198+
else:
199+
dataset_path = f"data/gifteval/{fname}/{freq}/data.csv"
200+
201+
# Initialize the model and dataset
183202
if msh:
184203
dataset_path = f"data/monash/{dataset_name}/test/data.csv"
185204
else:
@@ -191,15 +210,23 @@ def calc_pred_and_context_len(freq):
191210
if model_name == "timesfm":
192211
model = TimesfmModel(**args)
193212
dataset = TimesfmDataset(datetime_col='timestamp', path=dataset_path, mode='test', context_len=args["config"]["context_len"], horizon_len=args["config"]["horizon_len"], boundaries=(-1, -1, -1), batchsize=64)
213+
start = time.time()
194214
metrics = model.evaluate(dataset)
215+
end = time.time()
216+
print(f"Size of dataset: {fs:.2f} MB")
217+
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
195218

196219
elif model_name == "moment":
197220
model = MomentModel(**args)
198221
args["config"]["task_name"] = "forecasting"
199222
train_dataset = MomentDataset(datetime_col='timestamp', path=dataset_path, mode='train', horizon_len=args["config"]["forecast_horizon"], normalize=False)
200223
dataset = MomentDataset(datetime_col='timestamp', path=dataset_path, mode='test', horizon_len=args["config"]["forecast_horizon"], normalize=False)
201224
finetuned_model = model.finetune(train_dataset, task_name="forecasting")
225+
start = time.time()
202226
metrics = model.evaluate(dataset, task_name="forecasting")
227+
end = time.time()
228+
print(f"Size of dataset: {fs:.2f} MB")
229+
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
203230
print(metrics)
204231

205232
elif model_name == "chronos":
@@ -208,26 +235,58 @@ def calc_pred_and_context_len(freq):
208235
dataset_config["context_length"] = context_len
209236
dataset_config["prediction_length"] = pred_len
210237
dataset = ChronosDataset(datetime_col='timestamp', path=dataset_path, mode='test', config=dataset_config, batch_size=4)
238+
start = time.time()
211239
metrics = model.evaluate(dataset, horizon_len=dataset_config["prediction_length"], quantile_levels=[0.1, 0.5, 0.9])
240+
end = time.time()
241+
print(f"Size of dataset: {fs:.2f} MB")
242+
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
212243

213244
elif model_name == "chronosbolt":
214245
repo = "amazon/chronos-bolt-small"
215246
model = ChronosBoltModel(repo=repo)
216-
dataset = ChronosBoltDataset(datetime_col='timestamp', path=dataset_path, mode='test', batch_size=16, context_len=context_len, horizon_len=pred_len)
247+
dataset = ChronosBoltDataset(datetime_col='timestamp', path=dataset_path, mode='test', batch_size=8, context_len=context_len, horizon_len=pred_len)
248+
start = time.time()
217249
metrics = model.evaluate(dataset, horizon_len=pred_len, quantile_levels=[0.1, 0.5, 0.9])
250+
end = time.time()
251+
print(f"Size of dataset: {fs:.2f} MB")
252+
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
218253

219254
elif model_name == "ttm":
220255
model = TinyTimeMixerModel(**args)
221256
dataset = TinyTimeMixerDataset(datetime_col='timestamp', path=dataset_path, mode='test', context_len=context_len, horizon_len=pred_len)
257+
start = time.time()
222258
metrics = model.evaluate(dataset)
259+
end = time.time()
260+
print(f"Size of dataset: {fs:.2f} MB")
261+
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
262+
263+
elif model_name == "moirai":
264+
model = MoiraiTSModel(**args)
265+
dataset = MoiraiDataset(name=fname,datetime_col='timestamp', freq=freq,
266+
path=dataset_path, mode='test', context_len=context_len, horizon_len=pred_len)
267+
268+
start = time.time()
269+
metrics = model.evaluate(dataset,leaderboard=True)
270+
end = time.time()
271+
print(f"Size of dataset: {fs:.2f} MB")
272+
print(f"Time taken for evaluation of {fname}: {end-start:.2f} seconds")
273+
274+
print("Evaluation done!")
275+
276+
eval_time = end - start
277+
unit = "s"
278+
if eval_time > 1000: # convert to minutes
279+
eval_time = eval_time / 60
280+
unit = "m"
281+
223282

224283
df = pd.read_csv(csv_path)
225-
if dataset_name in df["dataset"].values:
226-
df.loc[df["dataset"] == dataset_name, list(metrics.keys())] = list(metrics.values())
284+
if fname in df["dataset"].values:
285+
df.loc[df["dataset"] == fname, "size_in_MB"] = round(fs,2)
286+
df.loc[df["dataset"] == fname, "eval_time"] = str(round(eval_time,2)) + unit
287+
df.loc[df["dataset"] == fname, list(metrics.keys())] = list(metrics.values())
227288
else:
228-
new_row = pd.DataFrame([{**{"dataset": dataset_name}, **metrics}])
289+
new_row = pd.DataFrame([{**{"dataset": fname, "size_in_MB":round(fs,2), "eval_time":str(round(eval_time,2)) + unit}, **metrics}])
229290
df = pd.concat([df, new_row], ignore_index=True)
230291

231-
df.to_csv(csv_path, index=False)
232-
233-
292+
df.to_csv(csv_path, index=False)

leaderboard/chronos.csv

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,29 @@
1-
dataset,mse,mae,mase,mape,rmse,nrmse,smape,msis,nd,mwsq,crps
2-
m4_yearly,66252550.56894015,5959.432080232358,1.2466633549052129,116.0693424385617,8139.566976746377,0.0539687506710429,1.7844448819195622,0.2462129957053928,0.9999963526074004,5.684685482926337,62499.61471598859
3-
m4_quarterly,53476536.116586566,5412.361620228974,1.1984195060804983,111.97136705228762,7312.765285210962,0.1437596383536479,1.7840546713665213,0.2660070752214569,0.9999961134397592,5.248531341335546,50733.14054465674
4-
m4_monthly,29150668.200291883,3928.301527841015,1.2465767579479727,110.07550748571124,5399.135875331523,0.0407481952824455,1.7841431576288205,0.2790665130307588,0.9999945622402834,4.109593239536733,30980.31590585716
5-
m4_weekly,13127156.156438977,2593.3999113001714,,119.017248214552,3623.141752186764,0.1177759565389917,1.7429390776203342,0.2179777627700754,0.9999917243130576,4.451479711167246,15381.47052969332
6-
m4_daily,20818955.37703621,3383.3343197166814,1.217461272876409,115.13205622515392,4562.779347835726,0.066050656444343,1.7813443961611557,0.2791981407056162,0.999993502253278,4.0132389355638445,25368.806662947107
7-
m4_hourly,1249343690.1846278,5731.696451628863,0.3711103980506451,87.29192972826681,35346.05621826327,0.0569406128004115,1.774749758507997,0.0854603735323393,0.9999964202687148,3.501781981392865,93630.61016010716
8-
car_parts_with_missing,0.0133717639281966,0.0239443373268813,2394.433732688132,2394.433732688132,0.1156363434573951,11563.634345739518,0.2915361131408352,2394.433732688132,2394.433732688132,0.0239310713895487,0.0133661095415976
9-
hierarchical_sales,1164.580781760996,14.276960861055366,,769.319906400828,34.12595466446317,0.0418723364978402,1.276396955962651,0.1320677145641298,0.999488233832982,0.0162003405008592,0.6419357896665321
10-
restaurant,549.8892230222452,16.931034421372605,1.1714066015882363,186.38452736617225,23.4497169070811,0.03739986687891935,1.7225995714648292,0.25489343812543663,0.9990823009939029,0.015270601154084623,0.35233365494850294
1+
dataset,size_in_MB,eval_time,mse,mae,mase,mape,rmse,nrmse,smape,msis,nd,mwsq,crps
2+
us_births,0.13,1.24s,78187387.24978645,7586.208333326565,,0.7399471304883386,8842.363216345868,0.77496610066575,1.4798932459530707,0.6124445785896484,0.9999999986809264,0.0001059348302209,1.091746157509392
3+
ett1,4.4,1.41s,93.54429120154582,5.897260479994885,,280.6259892554294,9.671829775256894,0.3138980473931072,1.5571668415018334,0.1657898062363476,0.9972802436063892,0.0088134578993685,0.0136036177875246
4+
ett2,4.57,1.41s,325.75832724415693,12.114023534551617,,0.6626575551293206,18.04877633647658,0.4507572695213377,1.325280077867013,0.2904934515797344,0.9999600900592192,1.7213636914300657e-07,4.499342450936781e-06
5+
saugeenday,0.38,1.16s,76.49666646944554,7.243749993527031,,0.7399463236204532,8.746237274934035,0.5333068257235231,1.4798916300056202,0.3977856881034689,0.9999986186078822,1.0245254431979265e-07,1.0366162516862808e-06
6+
solar,33.4,22.88s,106.85474845514084,6.029460579751957,,1173.4230367899731,10.337057050009005,0.1747600220187452,1.1180718304037367,0.1909287710948171,1.0006912266188184,0.0213146983129753,0.1601412125518966
7+
jena_weather,7.18,5.15s,72581.63907926583,88.13806604206415,,1469.4794364536383,269.40979766754185,0.2053960596599586,1.0099967438361368,0.0373644719521297,1.0042994579327431,0.2078406942308367,4.752297940664754
8+
hierarchical_sales,0.9,19.87s,40.39378747612249,2.5460542771161907,,1012.937310835032,6.355610708352305,0.0557509662354617,1.1131198381313607,0.1028032736638465,0.9992906657840404,0.0186013412961492,0.4217582559078034
9+
bizitobs_l2c,1.68,1.41s,0.0025111913971847,0.0033564104287851,,335.641042878519,0.050111789004033,5011.178900403307,0.0175835686284836,335.641042878519,335.64104287851893,0.0033564103824191,0.0025111913624956
10+
M_DENSE,3.7,5.15s,779166.6419101359,418.83873638736617,,538.6328731314599,882.704164434572,0.1915278898875602,1.4314799877574504,0.0814654073003754,0.9999979824084412,0.0063960957854613,0.1128682919134299
11+
covid_deaths,0.27,42.56s,70085795.96780607,1122.6806332936076,,1615.2718101733803,8371.725984992943,0.048042959940963,0.9825853633719572,0.0872468856187782,1.0000105554869307,1.0624762911264325,17230.311071763055
12+
bizitobs_application,0.33,1.19s,676944756.5826368,15653.719594373208,,356.6171618056131,26018.1620523556,0.4628907281477136,1.6665565059982017,0.2422907333723502,0.9999996211460498,10.955474761899952,446608.70469783817
13+
hospital,0.35,118.95s,0.0123168371948544,0.0221001298498491,2210.012984984913,2210.012984984913,0.1109812470413558,11098.124704135587,0.2770874727972745,2210.012984984913,2210.012984984913,0.0220895832741943,0.0123131651946033
14+
car_parts_with_missing,0.58,416.5s,0.014102859489454,0.024796217136795,2479.621713679505,2479.621713679505,0.1187554608826646,11875.546088266468,0.2978937720422596,2479.621713679505,2479.621713679504,0.0247706001864832,0.0140946431523094
15+
electricity,442.39,59.46s,1456831.4934865898,206.09218703957248,,1708.2964658069727,1206.992747901407,0.0394657020554642,0.8131836742326622,0.0847968373652326,1.0000367184379413,0.0707768180220827,60.642085873034695
16+
kdd_cup_2018_with_missing,14.28,42.49s,3851.45839608552,27.90325060336912,,584.0685303780313,62.06011920779334,0.0859439389950892,1.5066524443334202,0.1031155597885626,0.9994176112324632,0.0298904942039703,2.423892971840899
17+
LOOP_SEATTLE,324.08,52.43s,2927.766440554279,45.57446018044582,,503.21599800509233,54.10883883945653,0.7032451063819988,1.4931622460939462,0.6093650608110267,0.999748441598298,0.0281530233428902,1.460833342606207
18+
SZ_TAXI,4.58,25.02s,435.332686351744,14.245694219987042,,343.8599873289237,20.86462763510876,0.2773723487483956,1.484375229500433,0.1972819894514714,0.9990451614924684,0.0166508015743499,0.3800147299225255
19+
restaurant,1.77,126.54s,423.68930848194793,12.718873056687327,1.08509061349554,972.7435288516384,20.583714642453337,0.030494391611125,1.3610320390638275,0.190555117981387,0.9995774903434064,0.0244602911956938,0.4907717847672417
20+
m4_hourly,2.43,64.74s,1103542760.0901096,4911.378860530786,0.3726868082865278,394.8534542483283,33219.61408701356,0.0535150278556501,1.5256166795702857,0.0742795737465515,0.9999971534553816,1.5810111396867963,92641.90190390203
21+
bizitobs_service,3.07,7.55s,3000768.467127385,675.8270123436982,,750.3495003106024,1732.2726307159005,0.1668213240608327,1.2493755082437543,0.061203278752978,0.9999990735489268,0.0288586742841926,45.18272264539291
22+
bitbrains_rnd,63.69,157.59s,2202671.051954235,203.4095291804485,0.5728890817093389,893.1173610062597,1484.1398357143555,0.0677454571473168,1.1162888236525472,0.072972869194764,0.9999826629630684,0.6149925771633011,6452.646928677517
23+
m4_weekly,7.18,57.16s,9606190.083518032,2050.181715607978,,341.01627863171103,3099.385436424136,0.112819795984855,1.4964896778266703,0.2069794960936107,0.999993615699366,1.6910199906082488,8296.014177909163
24+
bitbrains_fast_storage,160.06,389.4s,10569478.96346236,635.655498843825,0.5793226061139789,913.8611399891372,3251.073509390761,0.1477797025787299,1.1804512642110845,0.0094477501887851,0.999994204664338,0.3211663613656945,3115.649937010363
25+
m4_yearly,51.4,59.61m,55308544.29912007,4976.860027115842,1.246930048523901,472.7850860445029,7436.971446705982,0.0493102469580485,1.5357736163939335,0.2107337312526032,0.9999971127251904,5.343212252281464,57858.717003997976
26+
temperature_rain_with_missing,113.99,83.0m,191.1150142258784,5.091763323149936,0.7139944157811877,1413.0003585755453,13.824435403512089,0.0229565512685927,0.8992030209276702,0.0587361533892377,1.001276078825235,0.0206105243297257,0.2241398718823561
27+
m4_quarterly,163.93,62.8m,44005360.43136102,4493.896279483806,1.200773477682749,473.94502409748384,6633.653626121959,0.1304091693170139,1.5372096040725636,0.2291787119412921,0.999996840921094,5.360133396424458,51024.99019915791
28+
m4_daily,316.28,660.71s,17141343.871600132,2792.680250203576,1.2052741825168067,463.65871661860456,4140.210607155164,0.0618773069277595,1.5352677813436568,0.229351670551101,0.9999950900213428,3.0986971854360625,17456.266419993455
29+
m4_monthly,1025.34,124.6m,22621233.346859664,3182.69983324925,1.2572740200870813,467.3309500550832,4756.178439341786,0.035895686331945884,1.536661107605411,0.23811709044229257,0.9999954704061967,3.350789658660369,22262.53123278339

0 commit comments

Comments
 (0)