From 86f5e67f7a3370a479938597101e3fe744bec2fc Mon Sep 17 00:00:00 2001 From: Zhihao Dai Date: Thu, 18 Sep 2025 15:01:20 +0100 Subject: [PATCH 1/6] Add `scale_batch_size()` to `TorchForecastingModel` - A wrapped around Lightning Tuner's method of the same name, `scale_batch_size()` finds a batch size before out-of-memory error. - Options for Tuner method are supported, including `mode`, `steps_per_trial`, `init_val`, and `max_trials`. - Tuner requires a `batch_size` attribute within `LightningDataModule` or model and disallows previous `train_loader` and `val_loader`. - Because of that, I implemented `_CustomDataModule` and `_CustomDataModuleWithVal` to return dataloaders as per `batch_size`. - The previous behaviours of `dataloader_kwargs` are being preserved with the new datamodules. - Update `_setup_for_train()`, `_train()`, `fit_from_dataset()`, `lr_find()` methods to use datamodules instead of direct data loaders. --- .../forecasting/torch_forecasting_model.py | 296 ++++++++++++++---- 1 file changed, 243 insertions(+), 53 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 6afdb3b2ed..1d8beb25b3 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -36,6 +36,7 @@ import pandas as pd import pytorch_lightning as pl import torch +from pytorch_lightning import LightningDataModule from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import ProgressBar from pytorch_lightning.tuner import Tuner @@ -128,6 +129,72 @@ def _get_checkpoint_fname(work_dir, model_name, best=False): return os.path.basename(file_name) +class _CustomDataModule(LightningDataModule): + def __init__( + self, + train_dataset: TorchTrainingDataset, + val_dataset: Optional[TorchTrainingDataset], + batch_size: int, + collate_fn: Callable, + dataloader_kwargs: Optional[dict[str, Any]], + ): + """Custom LightningDataModule to handle train and val dataloaders. + + Parameters + ---------- + train_dataset + Dataset for training. + val_dataset + Dataset for validation. + batch_size + Number of time series (input and output sequences) used in each training pass. + collate_fn + Function to collate samples into a batch. + dataloader_kwargs + Additional keyword arguments for DataLoader. + """ + super().__init__() + self.train_dataset = train_dataset + self.val_dataset = val_dataset + if dataloader_kwargs is None: + dataloader_kwargs = dict() + self.batch_size = dataloader_kwargs.pop("batch_size", batch_size) + self.shuffle = dataloader_kwargs.pop("shuffle", True) + + # setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at + # least one batch no matter the chosen batch size + self.dataloader_kwargs = dict( + { + "pin_memory": True, + "drop_last": False, + "collate_fn": collate_fn, + }, + **dataloader_kwargs, + ) + + def train_dataloader(self): + """Train dataloader.""" + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + **self.dataloader_kwargs, + ) + + +class _CustomDataModuleWithVal(_CustomDataModule): + """Custom LightningDataModule (with validation dataset) to handle train and val dataloaders.""" + + def val_dataloader(self): + """Validation dataloader.""" + return DataLoader( + self.val_dataset, # pyright: ignore[reportArgumentType] + batch_size=self.batch_size, + shuffle=False, + **self.dataloader_kwargs, + ) + + class TorchForecastingModel(GlobalForecastingModel, ABC): @random_method def __init__( @@ -1120,16 +1187,15 @@ def fit_from_dataset( self Fitted model. """ - self._train( - *self._setup_for_train( - train_dataset=train_dataset, - val_dataset=val_dataset, - trainer=trainer, - verbose=verbose, - epochs=epochs, - dataloader_kwargs=dataloader_kwargs, - ) + trainer, model, datamodule = self._setup_for_train( + train_dataset=train_dataset, + val_dataset=val_dataset, + trainer=trainer, + verbose=verbose, + epochs=epochs, + dataloader_kwargs=dataloader_kwargs, ) + self._train(trainer, model, datamodule) return self def _setup_for_train( @@ -1140,9 +1206,9 @@ def _setup_for_train( verbose: Optional[bool] = None, epochs: int = 0, dataloader_kwargs: Optional[dict[str, Any]] = None, - ) -> tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]: + ) -> tuple[pl.Trainer, PLForecastingModule, LightningDataModule]: """This method acts on `TorchTrainingDataset` inputs. It performs sanity checks, and sets up / returns the - trainer, model, and dataset loaders required for training the model with `_train()`. + trainer, model, and datamodule required for training the model with `_train()`. """ self._verify_train_dataset_type(train_dataset) @@ -1245,33 +1311,14 @@ def _setup_for_train( logger=logger, ) - # setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at - # least one batch no matter the chosen batch size - dataloader_kwargs = dict( - { - "batch_size": self.batch_size, - "shuffle": True, - "pin_memory": True, - "drop_last": False, - "collate_fn": self._batch_collate_fn, - }, - **(dataloader_kwargs or dict()), - ) - - train_loader = DataLoader( - train_dataset, - **dataloader_kwargs, - ) - - # prepare validation data - dataloader_kwargs["shuffle"] = False - val_loader = ( - None - if val_dataset is None - else DataLoader( - val_dataset, - **dataloader_kwargs, - ) + # setup datamodule + datamodule_cls = _CustomDataModuleWithVal if val_dataset else _CustomDataModule + datamodule = datamodule_cls( + train_dataset=train_dataset, + val_dataset=val_dataset, + batch_size=self.batch_size, + collate_fn=self._batch_collate_fn, + dataloader_kwargs=dataloader_kwargs, ) # if user wants to train the model for more epochs, ignore the n_epochs parameter @@ -1286,24 +1333,25 @@ def _setup_for_train( f"discouraged. Consider model `{self.__class__.__name__}.load_weights()` to load the weights for " f"fine-tuning." ) - return trainer, model, train_loader, val_loader + return trainer, model, datamodule def _train( self, trainer: pl.Trainer, model: PLForecastingModule, - train_loader: DataLoader, - val_loader: Optional[DataLoader], + datamodule: LightningDataModule, ) -> None: """ Performs the actual training Parameters ---------- - train_loader - the training data loader feeding the training data and targets - val_loader - optionally, a validation set loader + trainer + The PyTorch Lightning Trainer object to use for training + model + The PyTorch Lightning Module to train + datamodule + The PyTorch Lightning DataModule to use for training """ self._fit_called = True @@ -1314,9 +1362,8 @@ def _train( if self._requires_training: trainer.fit( - model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, + model=model, + datamodule=datamodule, ckpt_path=ckpt_path, ) else: @@ -1352,12 +1399,12 @@ def lr_find( A wrapper around PyTorch Lightning's `Tuner.lr_find()`. Performs a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. For more information on PyTorch Lightning's Tuner check out - `this link `_. + `this link `_. It is recommended to increase the number of `epochs` if the tuner did not give satisfactory results. Consider creating a new model object with the suggested learning rate for example using model creation parameters `optimizer_cls`, `optimizer_kwargs`, `lr_scheduler_cls`, and `lr_scheduler_kwargs`. - Example using a :class:`RNNModel`: + Example using a :class:`NBEATSModel`: .. highlight:: python .. code-block:: python @@ -1466,11 +1513,10 @@ def lr_find( max_samples_per_ts=max_samples_per_ts, dataloader_kwargs=dataloader_kwargs, ) - trainer, model, train_loader, val_loader = self._setup_for_train(*params) + trainer, model, datamodule = self._setup_for_train(*params) return Tuner(trainer).lr_find( model, - train_dataloaders=train_loader, - val_dataloaders=val_loader, + datamodule=datamodule, method="fit", min_lr=min_lr, max_lr=max_lr, @@ -1480,6 +1526,150 @@ def lr_find( update_attr=False, ) + @random_method + def scale_batch_size( + self, + series: Union[TimeSeries, Sequence[TimeSeries]], + past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_series: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + sample_weight: Optional[Union[TimeSeries, Sequence[TimeSeries], str]] = None, + val_sample_weight: Optional[ + Union[TimeSeries, Sequence[TimeSeries], str] + ] = None, + trainer: Optional[pl.Trainer] = None, + verbose: Optional[bool] = None, + epochs: int = 0, + max_samples_per_ts: Optional[int] = None, + dataloader_kwargs: Optional[dict[str, Any]] = None, + mode: str = "power", + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + ): + """ + A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Performs a batch size range test to find a + good batch size to use for training. For more information on PyTorch Lightning's Tuner check out + `this link `_. + + Example using a :class:`NBEATSModel`: + + .. highlight:: python + .. code-block:: python + + from darts.datasets import AirPassengersDataset + from darts.models import NBEATSModel + + series = AirPassengersDataset().load().astype("float32") + train, val = series[:-18], series[-18:] + model = NBEATSModel(12, 6, random_state=42) + # run the batch size tuner + batch_size = model.scale_batch_size(series=train, val_series=val) + # create a new model with the optimal batch size + model = NBEATSModel(12, 6, random_state=42, batch_size=batch_size) + # train the new model + model.fit(train, val_series=val, epochs=1) + .. + + Parameters + ---------- + series + A series or sequence of series serving as target (i.e. what the model will be trained to forecast) + past_covariates + Optionally, a series or sequence of series specifying past-observed covariates + future_covariates + Optionally, a series or sequence of series specifying future-known covariates + val_series + Optionally, one or a sequence of validation target series, which will be used to compute the validation + loss throughout training and keep track of the best performing models. + val_past_covariates + Optionally, the past covariates corresponding to the validation series (must match ``covariates``) + val_future_covariates + Optionally, the future covariates corresponding to the validation series (must match ``covariates``) + sample_weight + Optionally, some sample weights to apply to the target `series` labels. They are applied per observation, + per label (each step in `output_chunk_length`), and per component. + If a series or sequence of series, then those weights are used. If the weight series only have a single + component / column, then the weights are applied globally to all components in `series`. Otherwise, for + component-specific weights, the number of components must match those of `series`. + If a string, then the weights are generated using built-in weighting functions. The available options are + `"linear"` or `"exponential"` decay - the further in the past, the lower the weight. The weights are + computed globally based on the length of the longest series in `series`. Then for each series, the weights + are extracted from the end of the global weights. This gives a common time weighting across all series. + val_sample_weight + Same as for `sample_weight` but for the evaluation dataset. + trainer + Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will + override Darts' default trainer. + verbose + Whether to print the progress. Ignored if there is a `ProgressBar` callback in + `pl_trainer_kwargs`. + epochs + If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs`` + was provided to the model constructor. + max_samples_per_ts + Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion + by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily + large number of training samples. This parameter upper-bounds the number of training samples per time + series (taking only the most recent samples in each series). Leaving to None does not apply any + upper bound. + dataloader_kwargs + Optionally, a dictionary of keyword arguments used to create the PyTorch `DataLoader` instances for the + training and validation datasets. For more information on `DataLoader`, check out `this link + `_. + By default, Darts configures parameters ("batch_size", "shuffle", "drop_last", "collate_fn", "pin_memory") + for seamless forecasting. Changing them should be done with care to avoid unexpected behavior. + mode + Search strategy to update batch size after each trial, either 'power' or 'binsearch'. + steps_per_trial + Number of steps to take per trial. + init_val + Initial batch size to try. + max_trials + Maximum number of batch size trials to run. + + Returns + ------- + batch_size + The optimal batch size found by the tuner. + """ + _, params = self._setup_for_fit_from_dataset( + series=series, + past_covariates=past_covariates, + future_covariates=future_covariates, + sample_weight=sample_weight, + val_series=val_series, + val_past_covariates=val_past_covariates, + val_future_covariates=val_future_covariates, + val_sample_weight=val_sample_weight, + trainer=trainer, + verbose=verbose, + epochs=epochs, + max_samples_per_ts=max_samples_per_ts, + dataloader_kwargs=dataloader_kwargs, + ) + trainer, model, datamodule = self._setup_for_train(*params) + batch_size = Tuner(trainer).scale_batch_size( + model=model, + datamodule=datamodule, + mode=mode, + steps_per_trial=steps_per_trial, + init_val=init_val, + max_trials=max_trials, + batch_arg_name="batch_size", + ) + if batch_size is None: + logger.warning( + "Batch size scaling did not find a solution. " + f"Default batch size {self.batch_size} is kept." + ) + else: + logger.info(f"Batch size set to {batch_size}.") + self.batch_size = batch_size + return self.batch_size + @random_method def predict( self, From 4f0bb574d940851125fdc5bee7c4339475f1c362 Mon Sep 17 00:00:00 2001 From: Zhihao Dai Date: Thu, 18 Sep 2025 15:10:16 +0100 Subject: [PATCH 2/6] Add a `scale_batch_size()` test and update dataloader-related tests - Add `test_scale_batch_size` for validating `scale_batch_size()` method. - Update `test_dataloader_kwargs_setup` to validate `datamodule` instead of `train_dataloaders` and `val_dataloaders` due to changes. - Update `helper_check_val_set` used in `test_val_set` to again validate `datamodule`. --- .../test_torch_forecasting_model.py | 62 ++++++++++++++++--- 1 file changed, 54 insertions(+), 8 deletions(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index b9242cc6a8..c798c7ff39 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -19,6 +19,7 @@ import pandas as pd import pytest import torch +from pytorch_lightning import LightningDataModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.tuner.lr_finder import _LRFinder @@ -1656,6 +1657,46 @@ def test_lr_find(self): ) assert scores["worst"] > scores["suggested"] + @pytest.mark.slow + def test_scale_batch_size(self): + def create_series(length: int) -> TimeSeries: + return TimeSeries.from_times_and_values( + times=pd.date_range("20130101", periods=length, freq="s"), + values=np.cos( + np.linspace(0, np.pi * 10, length, dtype=np.float32) + ).reshape(-1, 1), + ) + + def create_model(batch_size: Optional[int] = None) -> RNNModel: + batch_size_kwarg = {"batch_size": batch_size} if batch_size else {} + return RNNModel(12, "RNN", 10, 10, **batch_size_kwarg, **tfm_kwargs) + + series = create_series(10_000) + train_series, val_series = series.split_after(0.8) + + model = create_model() + # find the batch size + res = model.scale_batch_size(series=train_series, val_series=val_series) + assert isinstance(res, int) + assert res == model.batch_size + # verify that learning rate finder bypasses the `fit` logic + assert model.model is None + assert not model._fit_called + # cannot predict with an untrained model + with pytest.raises(ValueError): + model.predict(n=3, series=self.series) + + # check that results are reproducible + model = create_model() + res2 = model.scale_batch_size(series=train_series, val_series=val_series) + assert res == res2 + + # check that batch size could indeed fit in the memory + model = create_model(batch_size=res) + model.fit(train_series, val_series=val_series, epochs=1) + assert model.batch_size == res + assert model.epochs_trained == 1 + def test_encoders(self, tmpdir_fn): series = tg.linear_timeseries(length=10) pc = tg.linear_timeseries(length=12) @@ -1734,12 +1775,12 @@ def test_dataloader_kwargs_setup(self): with patch("pytorch_lightning.Trainer.fit") as fit_patch: model.fit(train_series, val_series=val_series) - assert "train_dataloaders" in fit_patch.call_args.kwargs - assert "val_dataloaders" in fit_patch.call_args.kwargs - train_dl = fit_patch.call_args.kwargs["train_dataloaders"] + datamodule = fit_patch.call_args.kwargs["datamodule"] + assert isinstance(datamodule, LightningDataModule) + train_dl = datamodule.train_dataloader() assert isinstance(train_dl, DataLoader) - val_dl = fit_patch.call_args.kwargs["val_dataloaders"] + val_dl = datamodule.val_dataloader() assert isinstance(val_dl, DataLoader) dl_defaults = { @@ -1759,8 +1800,10 @@ def test_dataloader_kwargs_setup(self): # check that overwriting the dataloader kwargs works dl_custom = dict(dl_defaults, **{"batch_size": 50, "drop_last": True}) model.fit(train_series, val_series=val_series, dataloader_kwargs=dl_custom) - train_dl = fit_patch.call_args.kwargs["train_dataloaders"] - val_dl = fit_patch.call_args.kwargs["val_dataloaders"] + + datamodule = fit_patch.call_args.kwargs["datamodule"] + train_dl = datamodule.train_dataloader() + val_dl = datamodule.val_dataloader() assert all([getattr(train_dl, k) == v for k, v in dl_custom.items()]) assert all([getattr(val_dl, k) == v for k, v in dl_custom.items()]) @@ -1881,8 +1924,11 @@ def helper_check_val_set(self, model_cls, model_kwargs, fit_patch): # fit called only once assert fit_patch.call_count == 1 - train_ds = fit_patch.call_args[1]["train_dataloaders"].dataset - val_dl = fit_patch.call_args[1]["val_dataloaders"] + datamodule = fit_patch.call_args.kwargs["datamodule"] + assert isinstance(datamodule, LightningDataModule) + + train_ds = datamodule.train_dataloader().dataset + val_dl = datamodule.val_dataloader() assert val_dl is not None val_ds = val_dl.dataset From c1fc1b9876fe29368efa6a622f2cf5d5b48079d8 Mon Sep 17 00:00:00 2001 From: Zhihao Dai Date: Thu, 18 Sep 2025 15:54:47 +0100 Subject: [PATCH 3/6] Update CHANGELOG & `scale_batch_size` docstring --- CHANGELOG.md | 1 + darts/models/forecasting/torch_forecasting_model.py | 3 ++- darts/tests/models/forecasting/test_torch_forecasting_model.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 16822cbbce..6d6effdd6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co **Improved** +- Added `scale_batch_size()` method to torch-based models to find the largest batch size. It is recommended to re-initialize the model after calling this method before calling `fit()`. [#2905](https://github.com/unit8co/darts/pull/2905) by [Zhihao Dai](https://github.com/daidahao). - Added mixed precision and 16-bit precision support to `TorchForecastingModel`. Simply specify `{"precision": "bf16-mixed" }` for `pl_trainer_kwargs` to enable mixed precision training. Alternatively, declare a custom `pytorch_lightning.Trainer` with a `"precision"` parameter and pass the trainer to `fit()`. Other precision options such as `"64-true"` and `"16-mixed"` supported by `pytorch_lightning` are also allowed. [#2883](https://github.com/unit8co/darts/pull/2883) by [Zhihao Dai](https://github.com/daidahao). - 🔴 Added future and static covariates support to `BlockRNNModel`. This improvement required changes to the underlying model architecture which means that saved model instances from older Darts versions cannot be loaded any longer. [#2845](https://github.com/unit8co/darts/pull/2845) by [Gabriel Margaria](https://github.com/Jaco-Pastorius). - `from_group_dataframe()` now supports creating `TimeSeries` from **additional DataFrame backends** (Polars, PyArrow, ...). We leverage `narwhals` as the compatibility layer between DataFrame libraries. See their [documentation](https://narwhals-dev.github.io/narwhals/) for all supported backends. [#2766](https://github.com/unit8co/darts/pull/2766) by [He Weilin](https://github.com/cnhwl). diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 1d8beb25b3..e01181f134 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -1551,7 +1551,8 @@ def scale_batch_size( ): """ A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Performs a batch size range test to find a - good batch size to use for training. For more information on PyTorch Lightning's Tuner check out + the largest batch size to use for training. It is recommended to re-initialize the model after calling this + method before fitting, using the found batch size. For more information on PyTorch Lightning's Tuner check out `this link `_. Example using a :class:`NBEATSModel`: diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index c798c7ff39..058e69288b 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1679,7 +1679,7 @@ def create_model(batch_size: Optional[int] = None) -> RNNModel: res = model.scale_batch_size(series=train_series, val_series=val_series) assert isinstance(res, int) assert res == model.batch_size - # verify that learning rate finder bypasses the `fit` logic + # verify that batch size finder bypasses the `fit` logic assert model.model is None assert not model._fit_called # cannot predict with an untrained model From 0f72406fc3f5ce7384bf63f8599493a1360d3e98 Mon Sep 17 00:00:00 2001 From: Zhihao Dai Date: Fri, 19 Sep 2025 13:24:11 +0100 Subject: [PATCH 4/6] Fix a bug in `_setup_for_train` when `val_dataset` is `None` - When `val_dataset` is `None`, `_CustomDataModule` would still need to implement `val_dataloader()` for Lightning to work. - Since `val_dataloader()` can return ANY iterable but not `None` as per Lightning `EVAL_DATALOADERS`, we return an empty list here. - Batch size scaling would not update the model weights, so there is no need to re-initialize the model after scaling. --- .../forecasting/torch_forecasting_model.py | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index e01181f134..67b37ff23f 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -181,14 +181,12 @@ def train_dataloader(self): **self.dataloader_kwargs, ) - -class _CustomDataModuleWithVal(_CustomDataModule): - """Custom LightningDataModule (with validation dataset) to handle train and val dataloaders.""" - def val_dataloader(self): """Validation dataloader.""" + if self.val_dataset is None: + return [] return DataLoader( - self.val_dataset, # pyright: ignore[reportArgumentType] + self.val_dataset, batch_size=self.batch_size, shuffle=False, **self.dataloader_kwargs, @@ -1312,8 +1310,7 @@ def _setup_for_train( ) # setup datamodule - datamodule_cls = _CustomDataModuleWithVal if val_dataset else _CustomDataModule - datamodule = datamodule_cls( + datamodule = _CustomDataModule( train_dataset=train_dataset, val_dataset=val_dataset, batch_size=self.batch_size, @@ -1550,9 +1547,9 @@ def scale_batch_size( max_trials: int = 25, ): """ - A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Performs a batch size range test to find a - the largest batch size to use for training. It is recommended to re-initialize the model after calling this - method before fitting, using the found batch size. For more information on PyTorch Lightning's Tuner check out + A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Performs a batch size scaling test to + find the largest batch size to use for training. The batch size in the model would be updated after this + call. For more information on PyTorch Lightning's Tuner check out `this link `_. Example using a :class:`NBEATSModel`: @@ -1567,10 +1564,8 @@ def scale_batch_size( train, val = series[:-18], series[-18:] model = NBEATSModel(12, 6, random_state=42) # run the batch size tuner - batch_size = model.scale_batch_size(series=train, val_series=val) - # create a new model with the optimal batch size - model = NBEATSModel(12, 6, random_state=42, batch_size=batch_size) - # train the new model + model.scale_batch_size(series=train, val_series=val) + # train the model with the suggested batch size model.fit(train, val_series=val, epochs=1) .. From 9a735106581ff49bbfa8228eef50d618ff218b4f Mon Sep 17 00:00:00 2001 From: Zhihao Dai Date: Fri, 19 Sep 2025 13:28:39 +0100 Subject: [PATCH 5/6] Add tests to validate unchanged model weights after training As per previous commit, batch size scaling would not update model weights. The model can be used for training directly. - Update `test_scale_batch_size()` to NOT re-initialize the model after scaling. - Add `test_scale_batch_size_no_updates()` to validate that the model weights do not change after scaling. --- .../test_torch_forecasting_model.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 058e69288b..18b55f1927 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1667,18 +1667,17 @@ def create_series(length: int) -> TimeSeries: ).reshape(-1, 1), ) - def create_model(batch_size: Optional[int] = None) -> RNNModel: - batch_size_kwarg = {"batch_size": batch_size} if batch_size else {} - return RNNModel(12, "RNN", 10, 10, **batch_size_kwarg, **tfm_kwargs) - series = create_series(10_000) train_series, val_series = series.split_after(0.8) - model = create_model() - # find the batch size + model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs) + # find the largest batch size res = model.scale_batch_size(series=train_series, val_series=val_series) + + # verify results and that batch size is set assert isinstance(res, int) assert res == model.batch_size + # verify that batch size finder bypasses the `fit` logic assert model.model is None assert not model._fit_called @@ -1686,16 +1685,36 @@ def create_model(batch_size: Optional[int] = None) -> RNNModel: with pytest.raises(ValueError): model.predict(n=3, series=self.series) + # check that batch size could indeed fit in the memory + model.fit(train_series, val_series=val_series, epochs=1) + assert model.batch_size == res + assert model.epochs_trained == 1 + # check that results are reproducible - model = create_model() + model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs) res2 = model.scale_batch_size(series=train_series, val_series=val_series) assert res == res2 - # check that batch size could indeed fit in the memory - model = create_model(batch_size=res) - model.fit(train_series, val_series=val_series, epochs=1) - assert model.batch_size == res + @pytest.mark.slow + def test_scale_batch_size_no_updates(self): + model = RNNModel(12, "RNN", 10, 10, **tfm_kwargs) + + # train for 1 epoch with default batch size + model.fit(self.series, epochs=1) assert model.epochs_trained == 1 + # store the predictions after 1 epoch + preds = model.predict(n=3, series=self.series) + + # find the largest batch size, should not change the model weights + res = model.scale_batch_size(series=self.series) + # verify that batch size is set + assert isinstance(res, int) + assert res == model.batch_size + + # verify that weights have not changed after batch size scaling + preds_after = model.predict(n=3, series=self.series) + assert isinstance(preds, TimeSeries) and isinstance(preds_after, TimeSeries) + assert np.isclose(preds.values(), preds_after.values()).all() def test_encoders(self, tmpdir_fn): series = tg.linear_timeseries(length=10) From 7ede1c9186ab346aa679a6d57cf7458a4c9b0454 Mon Sep 17 00:00:00 2001 From: Zhihao Dai Date: Mon, 29 Sep 2025 11:57:15 +0100 Subject: [PATCH 6/6] Fix on `load_best` when merging commits Co-authored-by: Zhihao Dai --- darts/models/forecasting/torch_forecasting_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index a96ef0ae73..a3f8b42c20 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -1199,7 +1199,7 @@ def fit_from_dataset( self Fitted model. """ - trainer, model, datamodule = self._setup_for_train( + trainer, model, datamodule, load_best = self._setup_for_train( train_dataset=train_dataset, val_dataset=val_dataset, trainer=trainer, @@ -1208,7 +1208,7 @@ def fit_from_dataset( dataloader_kwargs=dataloader_kwargs, load_best=load_best, ) - self._train(trainer, model, datamodule) + self._train(trainer, model, datamodule, load_best) return self def _setup_for_train( @@ -1383,7 +1383,7 @@ def _train( ckpt_activated = ckpt_callback is not None and hasattr( ckpt_callback, "best_model_path" ) - if not ckpt_activated or val_loader is None: + if not ckpt_activated or len(datamodule.val_dataloader()) == 0: logger.warning( "Loading the best model will be skipped (`load_best` is ignored), as it requires " "active checkpointing and a validation set to be provided to the current fit method." @@ -1689,7 +1689,7 @@ def scale_batch_size( max_samples_per_ts=max_samples_per_ts, dataloader_kwargs=dataloader_kwargs, ) - trainer, model, datamodule = self._setup_for_train(*params) + trainer, model, datamodule, _ = self._setup_for_train(*params) batch_size = Tuner(trainer).scale_batch_size( model=model, datamodule=datamodule,