diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f996b6e7a..7081f5628c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** +- Added `scale_batch_size()` method to torch-based models to find the largest batch size. It is recommended to re-initialize the model after calling this method before calling `fit()`. [#2905](https://github.com/unit8co/darts/pull/2905) by [Zhihao Dai](https://github.com/daidahao). - Added hyperparameter `skip_interpolation` to `TFTModel` that will replace 1D interpolation on feature embeddings with linear projection. When `True`, it can greatly increase training and inference efficiency while predictive accuracy remains largely unaffected. [#2898](https://github.com/unit8co/darts/pull/2898) by [Zhihao Dai](https://github.com/daidahao). - Added mixed precision and 16-bit precision support to `TorchForecastingModel`. Simply specify `{"precision": "bf16-mixed" }` for `pl_trainer_kwargs` to enable mixed precision training. Alternatively, declare a custom `pytorch_lightning.Trainer` with a `"precision"` parameter and pass the trainer to `fit()`. Other precision options such as `"64-true"` and `"16-mixed"` supported by `pytorch_lightning` are also allowed. [#2883](https://github.com/unit8co/darts/pull/2883) by [Zhihao Dai](https://github.com/daidahao). - 🔴 Added future and static covariates support to `BlockRNNModel`. This improvement required changes to the underlying model architecture which means that saved model instances from older Darts versions cannot be loaded any longer. [#2845](https://github.com/unit8co/darts/pull/2845) by [Gabriel Margaria](https://github.com/Jaco-Pastorius). diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 982b44429b..a3f8b42c20 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -36,6 +36,7 @@ import pandas as pd import pytorch_lightning as pl import torch +from pytorch_lightning import LightningDataModule from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import ProgressBar from pytorch_lightning.tuner import Tuner @@ -128,6 +129,70 @@ def _get_checkpoint_fname(work_dir, model_name, best=False): return os.path.basename(file_name) +class _CustomDataModule(LightningDataModule): + def __init__( + self, + train_dataset: TorchTrainingDataset, + val_dataset: Optional[TorchTrainingDataset], + batch_size: int, + collate_fn: Callable, + dataloader_kwargs: Optional[dict[str, Any]], + ): + """Custom LightningDataModule to handle train and val dataloaders. + + Parameters + ---------- + train_dataset + Dataset for training. + val_dataset + Dataset for validation. + batch_size + Number of time series (input and output sequences) used in each training pass. + collate_fn + Function to collate samples into a batch. + dataloader_kwargs + Additional keyword arguments for DataLoader. + """ + super().__init__() + self.train_dataset = train_dataset + self.val_dataset = val_dataset + if dataloader_kwargs is None: + dataloader_kwargs = dict() + self.batch_size = dataloader_kwargs.pop("batch_size", batch_size) + self.shuffle = dataloader_kwargs.pop("shuffle", True) + + # setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at + # least one batch no matter the chosen batch size + self.dataloader_kwargs = dict( + { + "pin_memory": True, + "drop_last": False, + "collate_fn": collate_fn, + }, + **dataloader_kwargs, + ) + + def train_dataloader(self): + """Train dataloader.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + **self.dataloader_kwargs, + ) + + def val_dataloader(self): + """Validation dataloader.""" + if self.val_dataset is None: + return [] + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + **self.dataloader_kwargs, + ) + + class TorchForecastingModel(GlobalForecastingModel, ABC): @random_method def __init__( @@ -1134,17 +1199,16 @@ def fit_from_dataset( self Fitted model. """ - self._train( - *self._setup_for_train( - train_dataset=train_dataset, - val_dataset=val_dataset, - trainer=trainer, - verbose=verbose, - epochs=epochs, - dataloader_kwargs=dataloader_kwargs, - load_best=load_best, - ) + trainer, model, datamodule, load_best = self._setup_for_train( + train_dataset=train_dataset, + val_dataset=val_dataset, + trainer=trainer, + verbose=verbose, + epochs=epochs, + dataloader_kwargs=dataloader_kwargs, + load_best=load_best, ) + self._train(trainer, model, datamodule, load_best) return self def _setup_for_train( @@ -1156,9 +1220,9 @@ def _setup_for_train( epochs: int = 0, dataloader_kwargs: Optional[dict[str, Any]] = None, load_best: bool = False, - ) -> tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader], bool]: + ) -> tuple[pl.Trainer, PLForecastingModule, LightningDataModule, bool]: """This method acts on `TorchTrainingDataset` inputs. It performs sanity checks, and sets up / returns the - trainer, model, and dataset loaders required for training the model with `_train()`. + trainer, model, and datamodule required for training the model with `_train()`. """ self._verify_train_dataset_type(train_dataset) @@ -1261,33 +1325,13 @@ def _setup_for_train( logger=logger, ) - # setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at - # least one batch no matter the chosen batch size - dataloader_kwargs = dict( - { - "batch_size": self.batch_size, - "shuffle": True, - "pin_memory": True, - "drop_last": False, - "collate_fn": self._batch_collate_fn, - }, - **(dataloader_kwargs or dict()), - ) - - train_loader = DataLoader( - train_dataset, - **dataloader_kwargs, - ) - - # prepare validation data - dataloader_kwargs["shuffle"] = False - val_loader = ( - None - if val_dataset is None - else DataLoader( - val_dataset, - **dataloader_kwargs, - ) + # setup datamodule + datamodule = _CustomDataModule( + train_dataset=train_dataset, + val_dataset=val_dataset, + batch_size=self.batch_size, + collate_fn=self._batch_collate_fn, + dataloader_kwargs=dataloader_kwargs, ) # if user wants to train the model for more epochs, ignore the n_epochs parameter @@ -1302,14 +1346,13 @@ def _setup_for_train( f"discouraged. Consider model `{self.__class__.__name__}.load_weights()` to load the weights for " f"fine-tuning." ) - return trainer, model, train_loader, val_loader, load_best + return trainer, model, datamodule, load_best def _train( self, trainer: pl.Trainer, model: PLForecastingModule, - train_loader: DataLoader, - val_loader: Optional[DataLoader], + datamodule: LightningDataModule, load_best: bool = False, ) -> None: """ @@ -1317,10 +1360,12 @@ def _train( Parameters ---------- - train_loader - the training data loader feeding the training data and targets - val_loader - optionally, a validation set loader + trainer + The PyTorch Lightning Trainer object to use for training + model + The PyTorch Lightning Module to train + datamodule + The PyTorch Lightning DataModule to use for training load_best Whether to load the best model checkpoint after training. """ @@ -1338,7 +1383,7 @@ def _train( ckpt_activated = ckpt_callback is not None and hasattr( ckpt_callback, "best_model_path" ) - if not ckpt_activated or val_loader is None: + if not ckpt_activated or len(datamodule.val_dataloader()) == 0: logger.warning( "Loading the best model will be skipped (`load_best` is ignored), as it requires " "active checkpointing and a validation set to be provided to the current fit method." @@ -1351,9 +1396,8 @@ def _train( if self._requires_training: trainer.fit( - model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, + model=model, + datamodule=datamodule, ckpt_path=ckpt_path, ) if load_best: @@ -1395,12 +1439,12 @@ def lr_find( A wrapper around PyTorch Lightning's `Tuner.lr_find()`. Performs a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. For more information on PyTorch Lightning's Tuner check out - `this link `_. + `this link `_. It is recommended to increase the number of `epochs` if the tuner did not give satisfactory results. Consider creating a new model object with the suggested learning rate for example using model creation parameters `optimizer_cls`, `optimizer_kwargs`, `lr_scheduler_cls`, and `lr_scheduler_kwargs`. - Example using a :class:`RNNModel`: + Example using a :class:`NBEATSModel`: .. highlight:: python .. code-block:: python @@ -1509,11 +1553,10 @@ def lr_find( max_samples_per_ts=max_samples_per_ts, dataloader_kwargs=dataloader_kwargs, ) - trainer, model, train_loader, val_loader, _ = self._setup_for_train(*params) + trainer, model, datamodule, _ = self._setup_for_train(*params) return Tuner(trainer).lr_find( model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, + datamodule=datamodule, method="fit", min_lr=min_lr, max_lr=max_lr, @@ -1523,6 +1566,149 @@ def lr_find( update_attr=False, ) + @random_method + def scale_batch_size( + self, + series: Union[TimeSeries, Sequence[TimeSeries]], + past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + sample_weight: Optional[Union[TimeSeries, Sequence[TimeSeries], str]] = None, + val_sample_weight: Optional[ + Union[TimeSeries, Sequence[TimeSeries], str] + ] = None, + trainer: Optional[pl.Trainer] = None, + verbose: Optional[bool] = None, + epochs: int = 0, + max_samples_per_ts: Optional[int] = None, + dataloader_kwargs: Optional[dict[str, Any]] = None, + mode: str = "power", + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + ): + """ + A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Performs a batch size scaling test to + find the largest batch size to use for training. The batch size in the model would be updated after this + call. For more information on PyTorch Lightning's Tuner check out + `this link `_. + + Example using a :class:`NBEATSModel`: + + .. highlight:: python + .. code-block:: python + + from darts.datasets import AirPassengersDataset + from darts.models import NBEATSModel + + series = AirPassengersDataset().load().astype("float32") + train, val = series[:-18], series[-18:] + model = NBEATSModel(12, 6, random_state=42) + # run the batch size tuner + model.scale_batch_size(series=train, val_series=val) + # train the model with the suggested batch size + model.fit(train, val_series=val, epochs=1) + .. + + Parameters + ---------- + series + A series or sequence of series serving as target (i.e. what the model will be trained to forecast) + past_covariates + Optionally, a series or sequence of series specifying past-observed covariates + future_covariates + Optionally, a series or sequence of series specifying future-known covariates + val_series + Optionally, one or a sequence of validation target series, which will be used to compute the validation + loss throughout training and keep track of the best performing models. + val_past_covariates + Optionally, the past covariates corresponding to the validation series (must match ``covariates``) + val_future_covariates + Optionally, the future covariates corresponding to the validation series (must match ``covariates``) + sample_weight + Optionally, some sample weights to apply to the target `series` labels. They are applied per observation, + per label (each step in `output_chunk_length`), and per component. + If a series or sequence of series, then those weights are used. If the weight series only have a single + component / column, then the weights are applied globally to all components in `series`. Otherwise, for + component-specific weights, the number of components must match those of `series`. + If a string, then the weights are generated using built-in weighting functions. The available options are + `"linear"` or `"exponential"` decay - the further in the past, the lower the weight. The weights are + computed globally based on the length of the longest series in `series`. Then for each series, the weights + are extracted from the end of the global weights. This gives a common time weighting across all series. + val_sample_weight + Same as for `sample_weight` but for the evaluation dataset. + trainer + Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will + override Darts' default trainer. + verbose + Whether to print the progress. Ignored if there is a `ProgressBar` callback in + `pl_trainer_kwargs`. + epochs + If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs`` + was provided to the model constructor. + max_samples_per_ts + Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion + by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily + large number of training samples. This parameter upper-bounds the number of training samples per time + series (taking only the most recent samples in each series). Leaving to None does not apply any + upper bound. + dataloader_kwargs + Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the + training and validation datasets. For more information on `DataLoader`, check out `this link + `_. + By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory") + for seamless forecasting. Changing them should be done with care to avoid unexpected behavior. + mode + Search strategy to update batch size after each trial, either 'power' or 'binsearch'. + steps_per_trial + Number of steps to take per trial. + init_val + Initial batch size to try. + max_trials + Maximum number of batch size trials to run. + + Returns + ------- + batch_size + The optimal batch size found by the tuner. + """ + _, params = self._setup_for_fit_from_dataset( + series=series, + past_covariates=past_covariates, + future_covariates=future_covariates, + sample_weight=sample_weight, + val_series=val_series, + val_past_covariates=val_past_covariates, + val_future_covariates=val_future_covariates, + val_sample_weight=val_sample_weight, + trainer=trainer, + verbose=verbose, + epochs=epochs, + max_samples_per_ts=max_samples_per_ts, + dataloader_kwargs=dataloader_kwargs, + ) + trainer, model, datamodule, _ = self._setup_for_train(*params) + batch_size = Tuner(trainer).scale_batch_size( + model=model, + datamodule=datamodule, + mode=mode, + steps_per_trial=steps_per_trial, + init_val=init_val, + max_trials=max_trials, + batch_arg_name="batch_size", + ) + if batch_size is None: + logger.warning( + "Batch size scaling did not find a solution. " + f"Default batch size {self.batch_size} is kept." + ) + else: + logger.info(f"Batch size set to {batch_size}.") + self.batch_size = batch_size + return self.batch_size + @random_method def predict( self, diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 5252056ba2..5ece0d26b3 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -21,6 +21,7 @@ import pytest import pytorch_lightning as pl import torch +from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.tuner.lr_finder import _LRFinder @@ -1658,6 +1659,65 @@ def test_lr_find(self): ) assert scores["worst"] > scores["suggested"] + @pytest.mark.slow + def test_scale_batch_size(self): + def create_series(length: int) -> TimeSeries: + return TimeSeries.from_times_and_values( + times=pd.date_range("20130101", periods=length, freq="s"), + values=np.cos( + np.linspace(0, np.pi * 10, length, dtype=np.float32) + ).reshape(-1, 1), + ) + + series = create_series(10_000) + train_series, val_series = series.split_after(0.8) + + model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs) + # find the largest batch size + res = model.scale_batch_size(series=train_series, val_series=val_series) + + # verify results and that batch size is set + assert isinstance(res, int) + assert res == model.batch_size + + # verify that batch size finder bypasses the `fit` logic + assert model.model is None + assert not model._fit_called + # cannot predict with an untrained model + with pytest.raises(ValueError): + model.predict(n=3, series=self.series) + + # check that batch size could indeed fit in the memory + model.fit(train_series, val_series=val_series, epochs=1) + assert model.batch_size == res + assert model.epochs_trained == 1 + + # check that results are reproducible + model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs) + res2 = model.scale_batch_size(series=train_series, val_series=val_series) + assert res == res2 + + @pytest.mark.slow + def test_scale_batch_size_no_updates(self): + model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs) + + # train for 1 epoch with default batch size + model.fit(self.series, epochs=1) + assert model.epochs_trained == 1 + # store the predictions after 1 epoch + preds = model.predict(n=3, series=self.series) + + # find the largest batch size, should not change the model weights + res = model.scale_batch_size(series=self.series) + # verify that batch size is set + assert isinstance(res, int) + assert res == model.batch_size + + # verify that weights have not changed after batch size scaling + preds_after = model.predict(n=3, series=self.series) + assert isinstance(preds, TimeSeries) and isinstance(preds_after, TimeSeries) + assert np.isclose(preds.values(), preds_after.values()).all() + def test_encoders(self, tmpdir_fn): series = tg.linear_timeseries(length=10) pc = tg.linear_timeseries(length=12) @@ -1736,12 +1796,12 @@ def test_dataloader_kwargs_setup(self): with patch("pytorch_lightning.Trainer.fit") as fit_patch: model.fit(train_series, val_series=val_series) - assert "train_dataloaders" in fit_patch.call_args.kwargs - assert "val_dataloaders" in fit_patch.call_args.kwargs - train_dl = fit_patch.call_args.kwargs["train_dataloaders"] + datamodule = fit_patch.call_args.kwargs["datamodule"] + assert isinstance(datamodule, LightningDataModule) + train_dl = datamodule.train_dataloader() assert isinstance(train_dl, DataLoader) - val_dl = fit_patch.call_args.kwargs["val_dataloaders"] + val_dl = datamodule.val_dataloader() assert isinstance(val_dl, DataLoader) dl_defaults = { @@ -1761,8 +1821,10 @@ def test_dataloader_kwargs_setup(self): # check that overwriting the dataloader kwargs works dl_custom = dict(dl_defaults, **{"batch_size": 50, "drop_last": True}) model.fit(train_series, val_series=val_series, dataloader_kwargs=dl_custom) - train_dl = fit_patch.call_args.kwargs["train_dataloaders"] - val_dl = fit_patch.call_args.kwargs["val_dataloaders"] + + datamodule = fit_patch.call_args.kwargs["datamodule"] + train_dl = datamodule.train_dataloader() + val_dl = datamodule.val_dataloader() assert all([getattr(train_dl, k) == v for k, v in dl_custom.items()]) assert all([getattr(val_dl, k) == v for k, v in dl_custom.items()]) @@ -1883,8 +1945,11 @@ def helper_check_val_set(self, model_cls, model_kwargs, fit_patch): # fit called only once assert fit_patch.call_count == 1 - train_ds = fit_patch.call_args[1]["train_dataloaders"].dataset - val_dl = fit_patch.call_args[1]["val_dataloaders"] + datamodule = fit_patch.call_args.kwargs["datamodule"] + assert isinstance(datamodule, LightningDataModule) + + train_ds = datamodule.train_dataloader().dataset + val_dl = datamodule.val_dataloader() assert val_dl is not None val_ds = val_dl.dataset