Skip to content

Commit ae7a128

Browse files
author
Bohdan Bilonoh
committed
WIP: stuck with model.set_predict_parameters
1 parent 932514e commit ae7a128

2 files changed

Lines changed: 163 additions & 79 deletions

File tree

darts/models/forecasting/torch_forecasting_model.py

Lines changed: 152 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import torch
4242
from pytorch_lightning import loggers as pl_loggers
4343
from torch import Tensor
44-
from torch.utils.data import DataLoader
44+
from torch.utils.data import DataLoader, Dataset
4545

4646
from darts.dataprocessing.encoders import SequentialEncoder
4747
from darts.logging import (
@@ -996,28 +996,20 @@ def _setup_for_train(
996996

997997
# Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
998998
# least one batch no matter the chosen batch size
999-
train_loader = DataLoader(
1000-
train_dataset,
1001-
batch_size=self.batch_size,
1002-
shuffle=True,
1003-
num_workers=num_loader_workers,
1004-
pin_memory=True,
1005-
drop_last=False,
1006-
collate_fn=self._batch_collate_fn,
999+
train_loader = self._build_dataloader(
1000+
split="train",
1001+
dataset=train_dataset,
1002+
num_loader_workers=num_loader_workers,
10071003
)
10081004

10091005
# Prepare validation data
10101006
val_loader = (
10111007
None
10121008
if val_dataset is None
1013-
else DataLoader(
1014-
val_dataset,
1015-
batch_size=self.batch_size,
1016-
shuffle=False,
1017-
num_workers=num_loader_workers,
1018-
pin_memory=True,
1019-
drop_last=False,
1020-
collate_fn=self._batch_collate_fn,
1009+
else self._build_dataloader(
1010+
split="val",
1011+
dataset=val_dataset,
1012+
num_loader_workers=num_loader_workers,
10211013
)
10221014
)
10231015

@@ -1210,17 +1202,17 @@ def lr_find(
12101202
def scale_batch_size(
12111203
self,
12121204
series: Union[TimeSeries, Sequence[TimeSeries]],
1213-
val_series: Union[TimeSeries, Sequence[TimeSeries]],
1205+
n: int = 1,
1206+
n_jobs: int = 1,
1207+
roll_size: Optional[int] = None,
1208+
num_samples: int = 1,
1209+
mc_dropout: bool = False,
1210+
predict_likelihood_parameters: bool = False,
12141211
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
12151212
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
1216-
val_past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
1217-
val_future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
12181213
trainer: Optional[pl.Trainer] = None,
12191214
verbose: Optional[bool] = None,
1220-
epochs: int = 0,
1221-
max_samples_per_ts: Optional[int] = None,
1222-
num_loader_workers: int = 0,
1223-
method: Literal["fit", "validate", "test", "predict"] = "fit",
1215+
method: Literal["fit", "predict"] = "fit",
12241216
mode: str = "power",
12251217
steps_per_trial: int = 3,
12261218
init_val: int = 2,
@@ -1236,37 +1228,19 @@ def scale_batch_size(
12361228
----------
12371229
series
12381230
A series or sequence of series serving as target (i.e. what the model will be trained to forecast)
1231+
n
1232+
The number of time steps after the end of the training time series for which to produce predictions.
1233+
Only for the `predict` method.
12391234
past_covariates
12401235
Optionally, a series or sequence of series specifying past-observed covariates
12411236
future_covariates
12421237
Optionally, a series or sequence of series specifying future-known covariates
1243-
val_series
1244-
Optionally, one or a sequence of validation target series, which will be used to compute the validation
1245-
loss throughout training and keep track of the best performing models.
1246-
val_past_covariates
1247-
Optionally, the past covariates corresponding to the validation series (must match ``covariates``)
1248-
val_future_covariates
1249-
Optionally, the future covariates corresponding to the validation series (must match ``covariates``)
12501238
trainer
12511239
Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will
12521240
override Darts' default trainer.
12531241
verbose
12541242
Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
12551243
`pl_trainer_kwargs`.
1256-
epochs
1257-
If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
1258-
was provided to the model constructor.
1259-
max_samples_per_ts
1260-
Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion
1261-
by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily
1262-
large number of training samples. This parameter upper-bounds the number of training samples per time
1263-
series (taking only the most recent samples in each series). Leaving to None does not apply any
1264-
upper bound.
1265-
num_loader_workers
1266-
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
1267-
both for the training and validation loaders (if any).
1268-
A larger number of workers can sometimes increase performance, but can also incur extra overheads
1269-
and increase memory usage, as more batches are loaded in parallel.
12701244
method
12711245
The method to use for scaling the batch size. Can be one of 'fit', 'validate', 'test', or 'predict'.
12721246
mode
@@ -1277,8 +1251,6 @@ def scale_batch_size(
12771251
The initial value to start the search with.
12781252
max_trials
12791253
The maximum number of trials to run.
1280-
batch_arg_name
1281-
The name of the argument to scale in the model. Defaults to 'batch_size'.
12821254
12831255
Returns
12841256
-------
@@ -1289,43 +1261,93 @@ def scale_batch_size(
12891261
series=series,
12901262
past_covariates=past_covariates,
12911263
future_covariates=future_covariates,
1292-
val_series=val_series,
1293-
val_past_covariates=val_past_covariates,
1294-
val_future_covariates=val_future_covariates,
1264+
val_series=series,
1265+
val_past_covariates=past_covariates,
1266+
val_future_covariates=future_covariates,
12951267
trainer=trainer,
12961268
verbose=verbose,
1297-
epochs=epochs,
1298-
max_samples_per_ts=max_samples_per_ts,
1299-
num_loader_workers=num_loader_workers,
13001269
)
13011270
trainer, model, train_loader, val_loader = self._setup_for_train(*params)
13021271

1272+
if method == "predict":
1273+
if roll_size is None:
1274+
roll_size = self.output_chunk_length
1275+
else:
1276+
raise_if_not(
1277+
0 < roll_size <= self.output_chunk_length,
1278+
"`roll_size` must be an integer between 1 and `self.output_chunk_length`.",
1279+
)
1280+
predict_dataset = self._build_inference_dataset(
1281+
target=series,
1282+
n=n,
1283+
past_covariates=past_covariates,
1284+
future_covariates=future_covariates,
1285+
stride=0,
1286+
bounds=None,
1287+
)
1288+
model.set_predict_parameters(
1289+
n=n,
1290+
num_samples=num_samples,
1291+
roll_size=roll_size,
1292+
batch_size=1,
1293+
n_jobs=n_jobs,
1294+
predict_likelihood_parameters=predict_likelihood_parameters,
1295+
mc_dropout=mc_dropout,
1296+
)
1297+
1298+
build_dataloader = self._build_dataloader
1299+
13031300
class DataModule(pl.LightningDataModule):
13041301
def __init__(self, batch_size):
13051302
super().__init__()
13061303
self.save_hyperparameters()
1307-
self.batch_size = batch_size
1304+
self._batch_size = batch_size
1305+
1306+
@property
1307+
def batch_size(self):
1308+
return self._batch_size
1309+
1310+
@batch_size.setter
1311+
def batch_size(self, batch_size):
1312+
model.set_predict_parameters(
1313+
n=n,
1314+
num_samples=num_samples,
1315+
roll_size=roll_size,
1316+
batch_size=batch_size,
1317+
n_jobs=n_jobs,
1318+
predict_likelihood_parameters=predict_likelihood_parameters,
1319+
mc_dropout=mc_dropout,
1320+
)
1321+
self._batch_size = batch_size
13081322

13091323
def train_dataloader(self):
1310-
return DataLoader(
1311-
train_loader.dataset,
1324+
return build_dataloader(
1325+
split="train",
1326+
dataset=train_loader.dataset,
13121327
batch_size=self.batch_size,
1313-
shuffle=True,
1314-
num_workers=train_loader.num_workers,
1315-
pin_memory=True,
1316-
drop_last=False,
1317-
collate_fn=train_loader.collate_fn,
13181328
)
13191329

13201330
def val_dataloader(self):
1321-
return DataLoader(
1322-
val_loader.dataset,
1331+
return build_dataloader(
1332+
split="val",
1333+
dataset=val_loader.dataset,
1334+
batch_size=self.batch_size,
1335+
)
1336+
1337+
def predict_dataloader(self):
1338+
model.set_predict_parameters(
1339+
n=n,
1340+
num_samples=num_samples,
1341+
roll_size=roll_size,
1342+
batch_size=self._batch_size,
1343+
n_jobs=n_jobs,
1344+
predict_likelihood_parameters=predict_likelihood_parameters,
1345+
mc_dropout=mc_dropout,
1346+
)
1347+
return build_dataloader(
1348+
split="predict",
1349+
dataset=predict_dataset,
13231350
batch_size=self.batch_size,
1324-
shuffle=False,
1325-
num_workers=val_loader.num_workers,
1326-
pin_memory=True,
1327-
drop_last=False,
1328-
collate_fn=val_loader.collate_fn,
13291351
)
13301352

13311353
return Tuner(trainer).scale_batch_size(
@@ -1619,14 +1641,11 @@ def predict_from_dataset(
16191641
mc_dropout=mc_dropout,
16201642
)
16211643

1622-
pred_loader = DataLoader(
1623-
input_series_dataset,
1644+
pred_loader = self._build_dataloader(
1645+
split="predict",
1646+
dataset=input_series_dataset,
1647+
num_loader_workers=num_loader_workers,
16241648
batch_size=batch_size,
1625-
shuffle=False,
1626-
num_workers=num_loader_workers,
1627-
pin_memory=True,
1628-
drop_last=False,
1629-
collate_fn=self._batch_collate_fn,
16301649
)
16311650

16321651
# set up trainer. use user supplied trainer or create a new trainer from scratch
@@ -2377,6 +2396,64 @@ def _check_ckpt_parameters(self, tfm_save):
23772396

23782397
raise_log(ValueError("\n".join(msg)), logger)
23792398

2399+
def _build_dataloader(
2400+
self,
2401+
split: Literal["train", "val", "predict"],
2402+
dataset: Dataset,
2403+
batch_size: Optional[int] = None,
2404+
num_loader_workers: int = 0,
2405+
) -> DataLoader:
2406+
"""
2407+
Builds a PyTorch DataLoader from a given dataset.
2408+
2409+
Parameters
2410+
----------
2411+
split
2412+
The split for which the DataLoader is built. Can be "train", "val" or "predict".
2413+
dataset
2414+
The dataset from which to build the DataLoader.
2415+
batch_size
2416+
The batch size for the DataLoader. If not specified, the model's default batch size is used.
2417+
num_loader_workers
2418+
The number of workers for the DataLoader. Default is 0.
2419+
"""
2420+
2421+
if batch_size is None:
2422+
batch_size = self.batch_size
2423+
2424+
if split == "train":
2425+
return DataLoader(
2426+
dataset=dataset,
2427+
batch_size=batch_size,
2428+
shuffle=True,
2429+
num_workers=num_loader_workers,
2430+
pin_memory=True,
2431+
drop_last=False,
2432+
collate_fn=self._batch_collate_fn,
2433+
)
2434+
2435+
if split == "val":
2436+
return DataLoader(
2437+
dataset=dataset,
2438+
batch_size=batch_size,
2439+
shuffle=False,
2440+
num_workers=num_loader_workers,
2441+
pin_memory=True,
2442+
drop_last=False,
2443+
collate_fn=self._batch_collate_fn,
2444+
)
2445+
2446+
if split == "predict":
2447+
return DataLoader(
2448+
dataset=dataset,
2449+
batch_size=batch_size,
2450+
shuffle=False,
2451+
num_workers=num_loader_workers,
2452+
pin_memory=True,
2453+
drop_last=False,
2454+
collate_fn=self._batch_collate_fn,
2455+
)
2456+
23802457
def __getstate__(self):
23812458
# do not pickle the PyTorch LightningModule, and Trainer
23822459
return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE}

darts/tests/models/forecasting/test_torch_forecasting_model.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,17 +1402,24 @@ def test_lr_find(self):
14021402
)
14031403
assert scores["worst"] > scores["suggested"]
14041404

1405-
@pytest.mark.slow
14061405
def test_scale_batch_size(self):
1407-
train_series, val_series = self.series[:-40], self.series[-40:]
1406+
train_series, predict_series = self.series[:-40], self.series[-40:]
14081407
model = RNNModel(12, "RNN", 10, 10, random_state=42, batch_size=1, **tfm_kwargs)
14091408
# find the batch size
14101409
init_batch_size = model.batch_size
14111410
batch_size = model.scale_batch_size(
14121411
series=train_series,
1413-
val_series=val_series,
1414-
epochs=50,
14151412
init_val=init_batch_size,
1413+
method="fit",
1414+
)
1415+
assert isinstance(batch_size, int)
1416+
assert batch_size != init_batch_size
1417+
1418+
batch_size = model.scale_batch_size(
1419+
series=predict_series,
1420+
init_val=init_batch_size,
1421+
method="predict",
1422+
n=10,
14161423
)
14171424
assert isinstance(batch_size, int)
14181425
assert batch_size != init_batch_size

0 commit comments

Comments
 (0)