4141import torch
4242from pytorch_lightning import loggers as pl_loggers
4343from torch import Tensor
44- from torch .utils .data import DataLoader
44+ from torch .utils .data import DataLoader , Dataset
4545
4646from darts .dataprocessing .encoders import SequentialEncoder
4747from darts .logging import (
@@ -996,28 +996,20 @@ def _setup_for_train(
996996
997997 # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
998998 # least one batch no matter the chosen batch size
999- train_loader = DataLoader (
1000- train_dataset ,
1001- batch_size = self .batch_size ,
1002- shuffle = True ,
1003- num_workers = num_loader_workers ,
1004- pin_memory = True ,
1005- drop_last = False ,
1006- collate_fn = self ._batch_collate_fn ,
999+ train_loader = self ._build_dataloader (
1000+ split = "train" ,
1001+ dataset = train_dataset ,
1002+ num_loader_workers = num_loader_workers ,
10071003 )
10081004
10091005 # Prepare validation data
10101006 val_loader = (
10111007 None
10121008 if val_dataset is None
1013- else DataLoader (
1014- val_dataset ,
1015- batch_size = self .batch_size ,
1016- shuffle = False ,
1017- num_workers = num_loader_workers ,
1018- pin_memory = True ,
1019- drop_last = False ,
1020- collate_fn = self ._batch_collate_fn ,
1009+ else self ._build_dataloader (
1010+ split = "val" ,
1011+ dataset = val_dataset ,
1012+ num_loader_workers = num_loader_workers ,
10211013 )
10221014 )
10231015
@@ -1210,17 +1202,17 @@ def lr_find(
12101202 def scale_batch_size (
12111203 self ,
12121204 series : Union [TimeSeries , Sequence [TimeSeries ]],
1213- val_series : Union [TimeSeries , Sequence [TimeSeries ]],
1205+ n : int = 1 ,
1206+ n_jobs : int = 1 ,
1207+ roll_size : Optional [int ] = None ,
1208+ num_samples : int = 1 ,
1209+ mc_dropout : bool = False ,
1210+ predict_likelihood_parameters : bool = False ,
12141211 past_covariates : Optional [Union [TimeSeries , Sequence [TimeSeries ]]] = None ,
12151212 future_covariates : Optional [Union [TimeSeries , Sequence [TimeSeries ]]] = None ,
1216- val_past_covariates : Optional [Union [TimeSeries , Sequence [TimeSeries ]]] = None ,
1217- val_future_covariates : Optional [Union [TimeSeries , Sequence [TimeSeries ]]] = None ,
12181213 trainer : Optional [pl .Trainer ] = None ,
12191214 verbose : Optional [bool ] = None ,
1220- epochs : int = 0 ,
1221- max_samples_per_ts : Optional [int ] = None ,
1222- num_loader_workers : int = 0 ,
1223- method : Literal ["fit" , "validate" , "test" , "predict" ] = "fit" ,
1215+ method : Literal ["fit" , "predict" ] = "fit" ,
12241216 mode : str = "power" ,
12251217 steps_per_trial : int = 3 ,
12261218 init_val : int = 2 ,
@@ -1236,37 +1228,19 @@ def scale_batch_size(
12361228 ----------
12371229 series
12381230 A series or sequence of series serving as target (i.e. what the model will be trained to forecast)
1231+ n
1232+ The number of time steps after the end of the training time series for which to produce predictions.
1233+ Only for the `predict` method.
12391234 past_covariates
12401235 Optionally, a series or sequence of series specifying past-observed covariates
12411236 future_covariates
12421237 Optionally, a series or sequence of series specifying future-known covariates
1243- val_series
1244- Optionally, one or a sequence of validation target series, which will be used to compute the validation
1245- loss throughout training and keep track of the best performing models.
1246- val_past_covariates
1247- Optionally, the past covariates corresponding to the validation series (must match ``covariates``)
1248- val_future_covariates
1249- Optionally, the future covariates corresponding to the validation series (must match ``covariates``)
12501238 trainer
12511239 Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will
12521240 override Darts' default trainer.
12531241 verbose
12541242 Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
12551243 `pl_trainer_kwargs`.
1256- epochs
1257- If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs``
1258- was provided to the model constructor.
1259- max_samples_per_ts
1260- Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion
1261- by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily
1262- large number of training samples. This parameter upper-bounds the number of training samples per time
1263- series (taking only the most recent samples in each series). Leaving to None does not apply any
1264- upper bound.
1265- num_loader_workers
1266- Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
1267- both for the training and validation loaders (if any).
1268- A larger number of workers can sometimes increase performance, but can also incur extra overheads
1269- and increase memory usage, as more batches are loaded in parallel.
12701244 method
12711245 The method to use for scaling the batch size. Can be one of 'fit', 'validate', 'test', or 'predict'.
12721246 mode
@@ -1277,8 +1251,6 @@ def scale_batch_size(
12771251 The initial value to start the search with.
12781252 max_trials
12791253 The maximum number of trials to run.
1280- batch_arg_name
1281- The name of the argument to scale in the model. Defaults to 'batch_size'.
12821254
12831255 Returns
12841256 -------
@@ -1289,43 +1261,93 @@ def scale_batch_size(
12891261 series = series ,
12901262 past_covariates = past_covariates ,
12911263 future_covariates = future_covariates ,
1292- val_series = val_series ,
1293- val_past_covariates = val_past_covariates ,
1294- val_future_covariates = val_future_covariates ,
1264+ val_series = series ,
1265+ val_past_covariates = past_covariates ,
1266+ val_future_covariates = future_covariates ,
12951267 trainer = trainer ,
12961268 verbose = verbose ,
1297- epochs = epochs ,
1298- max_samples_per_ts = max_samples_per_ts ,
1299- num_loader_workers = num_loader_workers ,
13001269 )
13011270 trainer , model , train_loader , val_loader = self ._setup_for_train (* params )
13021271
1272+ if method == "predict" :
1273+ if roll_size is None :
1274+ roll_size = self .output_chunk_length
1275+ else :
1276+ raise_if_not (
1277+ 0 < roll_size <= self .output_chunk_length ,
1278+ "`roll_size` must be an integer between 1 and `self.output_chunk_length`." ,
1279+ )
1280+ predict_dataset = self ._build_inference_dataset (
1281+ target = series ,
1282+ n = n ,
1283+ past_covariates = past_covariates ,
1284+ future_covariates = future_covariates ,
1285+ stride = 0 ,
1286+ bounds = None ,
1287+ )
1288+ model .set_predict_parameters (
1289+ n = n ,
1290+ num_samples = num_samples ,
1291+ roll_size = roll_size ,
1292+ batch_size = 1 ,
1293+ n_jobs = n_jobs ,
1294+ predict_likelihood_parameters = predict_likelihood_parameters ,
1295+ mc_dropout = mc_dropout ,
1296+ )
1297+
1298+ build_dataloader = self ._build_dataloader
1299+
13031300 class DataModule (pl .LightningDataModule ):
13041301 def __init__ (self , batch_size ):
13051302 super ().__init__ ()
13061303 self .save_hyperparameters ()
1307- self .batch_size = batch_size
1304+ self ._batch_size = batch_size
1305+
1306+ @property
1307+ def batch_size (self ):
1308+ return self ._batch_size
1309+
1310+ @batch_size .setter
1311+ def batch_size (self , batch_size ):
1312+ model .set_predict_parameters (
1313+ n = n ,
1314+ num_samples = num_samples ,
1315+ roll_size = roll_size ,
1316+ batch_size = batch_size ,
1317+ n_jobs = n_jobs ,
1318+ predict_likelihood_parameters = predict_likelihood_parameters ,
1319+ mc_dropout = mc_dropout ,
1320+ )
1321+ self ._batch_size = batch_size
13081322
13091323 def train_dataloader (self ):
1310- return DataLoader (
1311- train_loader .dataset ,
1324+ return build_dataloader (
1325+ split = "train" ,
1326+ dataset = train_loader .dataset ,
13121327 batch_size = self .batch_size ,
1313- shuffle = True ,
1314- num_workers = train_loader .num_workers ,
1315- pin_memory = True ,
1316- drop_last = False ,
1317- collate_fn = train_loader .collate_fn ,
13181328 )
13191329
13201330 def val_dataloader (self ):
1321- return DataLoader (
1322- val_loader .dataset ,
1331+ return build_dataloader (
1332+ split = "val" ,
1333+ dataset = val_loader .dataset ,
1334+ batch_size = self .batch_size ,
1335+ )
1336+
1337+ def predict_dataloader (self ):
1338+ model .set_predict_parameters (
1339+ n = n ,
1340+ num_samples = num_samples ,
1341+ roll_size = roll_size ,
1342+ batch_size = self ._batch_size ,
1343+ n_jobs = n_jobs ,
1344+ predict_likelihood_parameters = predict_likelihood_parameters ,
1345+ mc_dropout = mc_dropout ,
1346+ )
1347+ return build_dataloader (
1348+ split = "predict" ,
1349+ dataset = predict_dataset ,
13231350 batch_size = self .batch_size ,
1324- shuffle = False ,
1325- num_workers = val_loader .num_workers ,
1326- pin_memory = True ,
1327- drop_last = False ,
1328- collate_fn = val_loader .collate_fn ,
13291351 )
13301352
13311353 return Tuner (trainer ).scale_batch_size (
@@ -1619,14 +1641,11 @@ def predict_from_dataset(
16191641 mc_dropout = mc_dropout ,
16201642 )
16211643
1622- pred_loader = DataLoader (
1623- input_series_dataset ,
1644+ pred_loader = self ._build_dataloader (
1645+ split = "predict" ,
1646+ dataset = input_series_dataset ,
1647+ num_loader_workers = num_loader_workers ,
16241648 batch_size = batch_size ,
1625- shuffle = False ,
1626- num_workers = num_loader_workers ,
1627- pin_memory = True ,
1628- drop_last = False ,
1629- collate_fn = self ._batch_collate_fn ,
16301649 )
16311650
16321651 # set up trainer. use user supplied trainer or create a new trainer from scratch
@@ -2377,6 +2396,64 @@ def _check_ckpt_parameters(self, tfm_save):
23772396
23782397 raise_log (ValueError ("\n " .join (msg )), logger )
23792398
2399+ def _build_dataloader (
2400+ self ,
2401+ split : Literal ["train" , "val" , "predict" ],
2402+ dataset : Dataset ,
2403+ batch_size : Optional [int ] = None ,
2404+ num_loader_workers : int = 0 ,
2405+ ) -> DataLoader :
2406+ """
2407+ Builds a PyTorch DataLoader from a given dataset.
2408+
2409+ Parameters
2410+ ----------
2411+ split
2412+ The split for which the DataLoader is built. Can be "train", "val" or "predict".
2413+ dataset
2414+ The dataset from which to build the DataLoader.
2415+ batch_size
2416+ The batch size for the DataLoader. If not specified, the model's default batch size is used.
2417+ num_loader_workers
2418+ The number of workers for the DataLoader. Default is 0.
2419+ """
2420+
2421+ if batch_size is None :
2422+ batch_size = self .batch_size
2423+
2424+ if split == "train" :
2425+ return DataLoader (
2426+ dataset = dataset ,
2427+ batch_size = batch_size ,
2428+ shuffle = True ,
2429+ num_workers = num_loader_workers ,
2430+ pin_memory = True ,
2431+ drop_last = False ,
2432+ collate_fn = self ._batch_collate_fn ,
2433+ )
2434+
2435+ if split == "val" :
2436+ return DataLoader (
2437+ dataset = dataset ,
2438+ batch_size = batch_size ,
2439+ shuffle = False ,
2440+ num_workers = num_loader_workers ,
2441+ pin_memory = True ,
2442+ drop_last = False ,
2443+ collate_fn = self ._batch_collate_fn ,
2444+ )
2445+
2446+ if split == "predict" :
2447+ return DataLoader (
2448+ dataset = dataset ,
2449+ batch_size = batch_size ,
2450+ shuffle = False ,
2451+ num_workers = num_loader_workers ,
2452+ pin_memory = True ,
2453+ drop_last = False ,
2454+ collate_fn = self ._batch_collate_fn ,
2455+ )
2456+
23802457 def __getstate__ (self ):
23812458 # do not pickle the PyTorch LightningModule, and Trainer
23822459 return {k : v for k , v in self .__dict__ .items () if k not in TFM_ATTRS_NO_PICKLE }
0 commit comments