diff --git a/src/queens/schedulers/__init__.py b/src/queens/schedulers/__init__.py
index 1bbcf70b7..71bd58715 100644
--- a/src/queens/schedulers/__init__.py
+++ b/src/queens/schedulers/__init__.py
@@ -21,8 +21,10 @@
from queens.utils.imports import extract_type_checking_imports, import_class_from_class_module_map
if TYPE_CHECKING:
+ from queens.schedulers._cluster_base import _BaseCluster
from queens.schedulers._scheduler import Scheduler
from queens.schedulers.cluster import Cluster
+ from queens.schedulers.cluster_local import ClusterLocal
from queens.schedulers.local import Local
from queens.schedulers.pool import Pool
diff --git a/src/queens/schedulers/_cluster_base.py b/src/queens/schedulers/_cluster_base.py
new file mode 100644
index 000000000..ae04ad37d
--- /dev/null
+++ b/src/queens/schedulers/_cluster_base.py
@@ -0,0 +1,257 @@
+#
+# SPDX-License-Identifier: LGPL-3.0-or-later
+# Copyright (c) 2024-2025, QUEENS contributors.
+#
+# This file is part of QUEENS.
+#
+# QUEENS is free software: you can redistribute it and/or modify it under the terms of the GNU
+# Lesser General Public License as published by the Free Software Foundation, either version 3 of
+# the License, or (at your option) any later version. QUEENS is distributed in the hope that it will
+# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You
+# should have received a copy of the GNU Lesser General Public License along with QUEENS. If not,
+# see .
+#
+"""Base class for QUEENS cluster schedulers."""
+
+import logging
+from abc import abstractmethod
+from datetime import timedelta
+
+from dask_jobqueue import PBSCluster, SLURMCluster
+
+from queens.schedulers._dask import Dask
+from queens.utils.logger_settings import log_init_args
+from queens.utils.valid_options import get_option
+
+_logger = logging.getLogger(__name__)
+
+VALID_WORKLOAD_MANAGERS = {
+ "slurm": {
+ "dask_cluster_cls": SLURMCluster,
+ "job_extra_directives": lambda nodes, cores: f"--ntasks={nodes * cores}",
+ "job_directives_skip": [
+ "#SBATCH -n 1",
+ "#SBATCH --mem=",
+ "#SBATCH --cpus-per-task=",
+ ],
+ },
+ "pbs": {
+ "dask_cluster_cls": PBSCluster,
+ "job_extra_directives": lambda nodes, cores: f"-l nodes={nodes}:ppn={cores}",
+ "job_directives_skip": ["#PBS -l select"],
+ },
+}
+
+
+def timedelta_to_str(timedelta_obj):
+ """Format a timedelta object to str.
+
+ This function seems unnecessarily complicated, but unfortunately the datetime library does not
+ support this formatting for timedeltas. Returns the format HH:MM:SS.
+
+ Args:
+ timedelta_obj (datetime.timedelta): Timedelta object to format
+
+ Returns:
+ str: String of the timedelta object
+ """
+ # Time in seconds
+ time_in_seconds = int(timedelta_obj.total_seconds())
+ (minutes, seconds) = divmod(time_in_seconds, 60)
+ (hours, minutes) = divmod(minutes, 60)
+ return f"{hours:02}:{minutes:02}:{seconds:02}"
+
+
+def _initialize_dask_cluster(
+ logger, dask_cluster_cls, dask_cluster_kwargs, dask_cluster_adapt_kwargs, experiment_dir
+):
+ """Initialize a Dask cluster.
+
+ Start dask cluster, adapt it to the requested worker settings, and
+ write jobscript.
+ """
+ logger.info("Starting dask cluster of type: %s", dask_cluster_cls)
+ logger.debug("Dask cluster kwargs:")
+ logger.debug(dask_cluster_kwargs)
+ cluster = dask_cluster_cls(**dask_cluster_kwargs)
+
+ logger.info("Adapting dask cluster settings")
+ logger.debug("Dask cluster adapt kwargs:")
+ logger.debug(dask_cluster_adapt_kwargs)
+ cluster.adapt(**dask_cluster_adapt_kwargs)
+
+ logger.info("Dask cluster info:")
+ logger.info(cluster)
+
+ dask_jobscript = experiment_dir / "dask_jobscript.sh"
+ logger.info("Writing dask jobscript to:")
+ logger.info(dask_jobscript)
+ dask_jobscript.write_text(str(cluster.job_script()))
+
+ return cluster
+
+
+class _BaseCluster(Dask):
+ """Abstract base class for QUEENS cluster schedulers."""
+
+ @log_init_args
+ def __init__(
+ self,
+ experiment_name,
+ workload_manager,
+ walltime,
+ num_jobs=1,
+ min_jobs=0,
+ num_procs=1,
+ num_nodes=1,
+ queue=None,
+ cluster_internal_address=None,
+ restart_workers=False,
+ allowed_failures=5,
+ verbose=True,
+ experiment_base_dir=None,
+ overwrite_existing_experiment=False,
+ job_script_prologue=None,
+ ):
+ """Init method for the abstract cluster scheduler.
+
+ The total number of cores per job is given by num_procs*num_nodes.
+
+ Args:
+ experiment_name (str): Name of the current experiment
+ workload_manager (str): Workload manager ("pbs" or "slurm")
+ walltime (str): Walltime for each worker job. Format (hh:mm:ss)
+ num_jobs (int, opt): Maximum number of parallel jobs
+ min_jobs (int, opt): Minimum number of active workers for the cluster
+ num_procs (int, opt): Number of processors per job per node
+ num_nodes (int, opt): Number of cluster nodes per job
+ queue (str, opt): Destination queue for each worker job
+ cluster_internal_address (str, opt): Internal address of cluster
+ restart_workers (bool): If True, restart workers after each finished job. For larger
+ jobs (>1min) this should be set to True in most cases.
+ allowed_failures (int): Number of allowed failures for a task before an error is raised
+ verbose (bool, opt): Verbosity of evaluations. Defaults to True.
+ experiment_base_dir (str, Path): Base directory for the simulation outputs
+ overwrite_existing_experiment (bool): If True, overwrite experiment directory if it
+ exists already. If False, prompt user for confirmation before overwriting.
+ job_script_prologue (list, opt): List of commands to be executed before starting a
+ worker.
+ """
+ self.workload_manager = workload_manager
+ self.walltime = walltime
+ self.min_jobs = min_jobs
+ self.num_nodes = num_nodes
+ self.queue = queue
+ self.cluster_internal_address = cluster_internal_address
+ self.allowed_failures = allowed_failures
+ self.job_script_prologue = job_script_prologue
+
+ experiment_dir = self._get_experiment_dir(
+ experiment_name, experiment_base_dir, overwrite_existing_experiment
+ )
+
+ _logger.debug("experiment directory: %s", experiment_dir)
+
+ super().__init__(
+ experiment_name=experiment_name,
+ experiment_dir=experiment_dir,
+ num_jobs=num_jobs,
+ num_procs=num_procs,
+ restart_workers=restart_workers,
+ verbose=verbose,
+ )
+
+ @abstractmethod
+ def _get_experiment_dir(
+ self, experiment_name, experiment_base_dir, overwrite_existing_experiment
+ ):
+ """Get experiment directory."""
+
+ @abstractmethod
+ def _start_cluster(self, dask_cluster_kwargs, dask_cluster_adapt_kwargs):
+ """Start cluster and return connected client and dashboard port."""
+
+ @abstractmethod
+ def copy_files_to_experiment_dir(self, paths):
+ """Copy file to experiment directory.
+
+ Args:
+ paths (Path, list): Paths to files or directories that should be copied to experiment
+ directory
+ """
+
+ def _start_cluster_and_connect_client(self):
+ """Start a Dask cluster and a client that connects to it.
+
+ Returns:
+ client (Client): Dask client that is connected to and submits computations to a Dask
+ cluster.
+ """
+ # collect all settings for the dask cluster
+ dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager)
+ job_extra_directives = dask_cluster_options["job_extra_directives"](
+ self.num_nodes, self.num_procs
+ )
+ job_directives_skip = dask_cluster_options["job_directives_skip"]
+ if self.queue is None:
+ job_directives_skip.append("#SBATCH -p")
+
+ hours, minutes, seconds = map(int, self.walltime.split(":"))
+ walltime_delta = timedelta(hours=hours, minutes=minutes, seconds=seconds)
+
+ # Increase jobqueue walltime by 5 minutes to kill dask workers in time
+ increased_walltime = timedelta_to_str(walltime_delta + timedelta(minutes=5))
+
+ # dask worker lifetime = walltime - 3m +/- 2m
+ worker_lifetime = str(int((walltime_delta + timedelta(minutes=2)).total_seconds())) + "s"
+
+ dask_cluster_kwargs = {
+ "job_name": self.experiment_name,
+ "queue": self.queue,
+ "memory": "10TB",
+ "walltime": increased_walltime,
+ "log_directory": str(self.experiment_dir),
+ "job_directives_skip": job_directives_skip,
+ "job_extra_directives": [job_extra_directives],
+ "worker_extra_args": ["--lifetime", worker_lifetime, "--lifetime-stagger", "2m"],
+ "job_script_prologue": self.job_script_prologue,
+ # keep this hardcoded to 1, the number of threads for the mpi run is handled by
+ # job_extra_directives. Note that the number of workers is not the number of
+ # parallel simulations!
+ "cores": 1,
+ "processes": 1,
+ "n_workers": 1,
+ }
+ dask_cluster_adapt_kwargs = {
+ "minimum_jobs": self.min_jobs,
+ "maximum_jobs": self.num_jobs,
+ }
+
+ # start dask cluster
+ client, dashboard_port = self._start_cluster(dask_cluster_kwargs, dask_cluster_adapt_kwargs)
+
+ _logger.debug("Submitting dummy job to check basic functionality of client.")
+ client.submit(lambda: "Dummy job").result(timeout=180)
+ _logger.debug("Dummy job was successful.")
+ _logger.info(
+ "To view the Dask dashboard open this link in your browser: "
+ "http://localhost:%i/status",
+ dashboard_port,
+ )
+ return client
+
+ def restart_worker(self, worker):
+ """Restart a worker.
+
+ This method retires a dask worker. The Client.adapt method of dask takes cares of submitting
+ new workers subsequently.
+
+ Args:
+ worker (str, tuple): Worker to restart. This can be a worker address, name, or a both.
+ """
+ self.client.retire_workers(workers=list(worker))
+
+ @staticmethod
+ def delete_experiment_dir_if_empty(_):
+ """The remote experiment directory will never be empty, so pass."""
diff --git a/src/queens/schedulers/cluster.py b/src/queens/schedulers/cluster.py
index b00cadb6a..59b303d42 100644
--- a/src/queens/schedulers/cluster.py
+++ b/src/queens/schedulers/cluster.py
@@ -16,60 +16,21 @@
import logging
import time
-from datetime import timedelta
from pathlib import Path
from typing import Sequence
from dask.distributed import Client
-from dask_jobqueue import PBSCluster, SLURMCluster
-from queens.schedulers._dask import Dask
+from queens.schedulers._cluster_base import _BaseCluster
from queens.utils.config_directories import experiment_directory # Do not change this import!
from queens.utils.config_directories import create_directory
from queens.utils.logger_settings import log_init_args
-from queens.utils.valid_options import get_option
_logger = logging.getLogger(__name__)
-VALID_WORKLOAD_MANAGERS = {
- "slurm": {
- "dask_cluster_cls": SLURMCluster,
- "job_extra_directives": lambda nodes, cores: f"--ntasks={nodes * cores}",
- "job_directives_skip": [
- "#SBATCH -n 1",
- "#SBATCH --mem=",
- "#SBATCH --cpus-per-task=",
- ],
- },
- "pbs": {
- "dask_cluster_cls": PBSCluster,
- "job_extra_directives": lambda nodes, cores: f"-l nodes={nodes}:ppn={cores}",
- "job_directives_skip": ["#PBS -l select"],
- },
-}
-
-def timedelta_to_str(timedelta_obj):
- """Format a timedelta object to str.
-
- This function seems unnecessarily complicated, but unfortunately the datetime library does not
- support this formatting for timedeltas. Returns the format HH:MM:SS.
-
- Args:
- timedelta_obj (datetime.timedelta): Timedelta object to format
-
- Returns:
- str: String of the timedelta object
- """
- # Time in seconds
- time_in_seconds = int(timedelta_obj.total_seconds())
- (minutes, seconds) = divmod(time_in_seconds, 60)
- (hours, minutes) = divmod(minutes, 60)
- return f"{hours:02}:{minutes:02}:{seconds:02}"
-
-
-class Cluster(Dask):
- """Cluster scheduler for QUEENS."""
+class Cluster(_BaseCluster):
+ """Cluster (remote) scheduler for QUEENS."""
@log_init_args
def __init__(
@@ -91,7 +52,7 @@ def __init__(
overwrite_existing_experiment=False,
job_script_prologue=None,
):
- """Init method for the cluster scheduler.
+ """Init method for the remote cluster scheduler.
The total number of cores per job is given by num_procs*num_nodes.
@@ -122,37 +83,25 @@ def __init__(
# sync remote source code with local state
self.remote_connection.sync_remote_repository()
- self.workload_manager = workload_manager
- self.walltime = walltime
- self.min_jobs = min_jobs
- self.num_nodes = num_nodes
- self.queue = queue
- self.cluster_internal_address = cluster_internal_address
- self.allowed_failures = allowed_failures
- self.job_script_prologue = job_script_prologue
-
- # get the path of the experiment directory on remote host
- experiment_dir = self.remote_experiment_dir(
- experiment_name, experiment_base_dir, overwrite_existing_experiment
- )
-
- _logger.debug(
- "experiment directory on %s@%s: %s",
- self.remote_connection.user,
- self.remote_connection.host,
- experiment_dir,
- )
-
super().__init__(
experiment_name=experiment_name,
- experiment_dir=experiment_dir,
+ workload_manager=workload_manager,
+ walltime=walltime,
num_jobs=num_jobs,
+ min_jobs=min_jobs,
num_procs=num_procs,
+ num_nodes=num_nodes,
+ queue=queue,
+ cluster_internal_address=cluster_internal_address,
restart_workers=restart_workers,
+ allowed_failures=allowed_failures,
verbose=verbose,
+ experiment_base_dir=experiment_base_dir,
+ overwrite_existing_experiment=overwrite_existing_experiment,
+ job_script_prologue=job_script_prologue,
)
- def remote_experiment_dir(
+ def _get_experiment_dir(
self, experiment_name, experiment_base_dir, overwrite_existing_experiment
):
"""Get experiment directory on remote host.
@@ -175,31 +124,13 @@ def remote_experiment_dir(
return experiment_dir
- def _start_cluster_and_connect_client(self):
- """Start a Dask cluster and a client that connects to it.
+ def _start_cluster(self, dask_cluster_kwargs, dask_cluster_adapt_kwargs):
+ """Start a Dask cluster and connect a client on remote host.
Returns:
client (Client): Dask client that is connected to and submits computations to a Dask
cluster.
"""
- # collect all settings for the dask cluster
- dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager)
- job_extra_directives = dask_cluster_options["job_extra_directives"](
- self.num_nodes, self.num_procs
- )
- job_directives_skip = dask_cluster_options["job_directives_skip"]
- if self.queue is None:
- job_directives_skip.append("#SBATCH -p")
-
- hours, minutes, seconds = map(int, self.walltime.split(":"))
- walltime_delta = timedelta(hours=hours, minutes=minutes, seconds=seconds)
-
- # Increase jobqueue walltime by 5 minutes to kill dask workers in time
- increased_walltime = timedelta_to_str(walltime_delta + timedelta(minutes=5))
-
- # dask worker lifetime = walltime - 3m +/- 2m
- worker_lifetime = str(int((walltime_delta + timedelta(minutes=2)).total_seconds())) + "s"
-
local_port, remote_port = self.remote_connection.open_port_forwarding()
local_port_dashboard, remote_port_dashboard = self.remote_connection.open_port_forwarding()
@@ -210,30 +141,9 @@ def _start_cluster_and_connect_client(self):
}
if self.cluster_internal_address is not None:
scheduler_options["contact_address"] = f"{self.cluster_internal_address}:{remote_port}"
- dask_cluster_kwargs = {
- "job_name": self.experiment_name,
- "queue": self.queue,
- "memory": "10TB",
- "scheduler_options": scheduler_options,
- "walltime": increased_walltime,
- "log_directory": str(self.experiment_dir),
- "job_directives_skip": job_directives_skip,
- "job_extra_directives": [job_extra_directives],
- "worker_extra_args": ["--lifetime", worker_lifetime, "--lifetime-stagger", "2m"],
- "job_script_prologue": self.job_script_prologue,
- # keep this hardcoded to 1, the number of threads for the mpi run is handled by
- # job_extra_directives. Note that the number of workers is not the number of
- # parallel simulations!
- "cores": 1,
- "processes": 1,
- "n_workers": 1,
- }
- dask_cluster_adapt_kwargs = {
- "minimum_jobs": self.min_jobs,
- "maximum_jobs": self.num_jobs,
- }
- # actually start the dask cluster on remote host
+ dask_cluster_kwargs["scheduler_options"] = scheduler_options
+
stdout, stderr = self.remote_connection.start_cluster(
self.workload_manager,
dask_cluster_kwargs,
@@ -255,26 +165,7 @@ def _start_cluster_and_connect_client(self):
) from exc
time.sleep(1)
- _logger.debug("Submitting dummy job to check basic functionality of client.")
- client.submit(lambda: "Dummy job").result(timeout=180)
- _logger.debug("Dummy job was successful.")
- _logger.info(
- "To view the Dask dashboard open this link in your browser: "
- "http://localhost:%i/status",
- local_port_dashboard,
- )
- return client
-
- def restart_worker(self, worker):
- """Restart a worker.
-
- This method retires a dask worker. The Client.adapt method of dask takes cares of submitting
- new workers subsequently.
-
- Args:
- worker (str, tuple): Worker to restart. This can be a worker address, name, or a both.
- """
- self.client.retire_workers(workers=list(worker))
+ return client, local_port_dashboard
def copy_files_to_experiment_dir(self, paths):
"""Copy file to experiment directory.
@@ -316,7 +207,3 @@ def copy_files_from_experiment_dir(
self.remote_connection.copy_from_remote(
self.experiment_dir, destination, verbose, exclude, filters
)
-
- @staticmethod
- def delete_experiment_dir_if_empty(_):
- """The remote experiment directory will never be empty, so pass."""
diff --git a/src/queens/schedulers/cluster_local.py b/src/queens/schedulers/cluster_local.py
new file mode 100644
index 000000000..999b4952d
--- /dev/null
+++ b/src/queens/schedulers/cluster_local.py
@@ -0,0 +1,99 @@
+#
+# SPDX-License-Identifier: LGPL-3.0-or-later
+# Copyright (c) 2024-2025, QUEENS contributors.
+#
+# This file is part of QUEENS.
+#
+# QUEENS is free software: you can redistribute it and/or modify it under the terms of the GNU
+# Lesser General Public License as published by the Free Software Foundation, either version 3 of
+# the License, or (at your option) any later version. QUEENS is distributed in the hope that it will
+# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You
+# should have received a copy of the GNU Lesser General Public License along with QUEENS. If not,
+# see .
+#
+"""Cluster scheduler for QUEENS runs."""
+
+import logging
+import time
+
+from dask.distributed import Client
+
+from queens.schedulers._cluster_base import (
+ VALID_WORKLOAD_MANAGERS,
+ _BaseCluster,
+ _initialize_dask_cluster,
+)
+from queens.schedulers._scheduler import Scheduler
+from queens.utils.remote_operations import get_port
+from queens.utils.valid_options import get_option
+
+_logger = logging.getLogger(__name__)
+
+
+class ClusterLocal(_BaseCluster):
+ """Cluster (local) scheduler for QUEENS.
+
+ Can be used to schedule jobs to a cluster scheduler with local
+ access i.e. without a network connection.
+ """
+
+ def _get_experiment_dir(
+ self, experiment_name, experiment_base_dir, overwrite_existing_experiment
+ ):
+ """Get local experiment directory."""
+ return Scheduler.local_experiment_dir(
+ self, experiment_name, experiment_base_dir, overwrite_existing_experiment
+ )
+
+ def _start_cluster(self, dask_cluster_kwargs, dask_cluster_adapt_kwargs):
+ """Start a Dask cluster and connect a client locally."""
+ # collect all settings for the dask cluster
+ dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager)
+ dask_cluster_cls = dask_cluster_options["dask_cluster_cls"]
+
+ remote_port = get_port()
+ local_port_dashboard = get_port()
+ remote_port_dashboard = get_port()
+
+ scheduler_options = {
+ "port": remote_port,
+ "dashboard_address": remote_port_dashboard,
+ "allowed_failures": self.allowed_failures,
+ }
+ if self.cluster_internal_address:
+ scheduler_options["contact_address"] = f"{self.cluster_internal_address}:{remote_port}"
+
+ dask_cluster_kwargs["scheduler_options"] = scheduler_options
+
+ try:
+ cluster = _initialize_dask_cluster( # pylint: disable=duplicate-code
+ _logger,
+ dask_cluster_cls,
+ dask_cluster_kwargs,
+ dask_cluster_adapt_kwargs,
+ self.experiment_dir,
+ )
+ except Exception as e:
+ raise RuntimeError() from e
+
+ for i in range(20, 0, -1): # 20 tries to connect
+ _logger.debug("Trying to connect to Dask Cluster: try #%d", i)
+ try:
+ client = Client(cluster)
+ break
+ except OSError as exc:
+ if i == 1:
+ raise OSError() from exc
+ time.sleep(1)
+
+ return client, local_port_dashboard
+
+ def copy_files_to_experiment_dir(self, paths):
+ """Copy file to experiment directory.
+
+ Args:
+ paths (Path, list): paths to files or directories that should be copied to experiment
+ directory
+ """
+ return Scheduler.copy_files_to_experiment_dir(self, paths)
diff --git a/src/queens/utils/start_dask_cluster.py b/src/queens/utils/start_dask_cluster.py
index b9105b8f8..06238a649 100644
--- a/src/queens/utils/start_dask_cluster.py
+++ b/src/queens/utils/start_dask_cluster.py
@@ -22,7 +22,7 @@
from pathlib import Path
from typing import Sequence
-from queens.schedulers.cluster import VALID_WORKLOAD_MANAGERS
+from queens.schedulers._cluster_base import VALID_WORKLOAD_MANAGERS, _initialize_dask_cluster
from queens.utils.logger_settings import setup_basic_logging
from queens.utils.valid_options import get_option
@@ -81,24 +81,13 @@ def parse_arguments(unparsed_args: Sequence[str]) -> argparse.Namespace:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
- _logger.info("Starting dask cluster of type: %s", dask_cluster_cls)
- _logger.debug("Dask cluster kwargs:")
- _logger.debug(dask_cluster_kwargs)
- cluster = dask_cluster_cls(**dask_cluster_kwargs)
-
- _logger.info("Adapting dask cluster settings")
- _logger.debug("Dask cluster adapt kwargs:")
- _logger.debug(dask_cluster_adapt_kwargs)
- cluster.adapt(**dask_cluster_adapt_kwargs)
-
- _logger.info("Dask cluster info:")
- _logger.info(cluster)
-
- dask_jobscript = experiment_dir / "dask_jobscript.sh"
- _logger.info("Writing dask jobscript to:")
- _logger.info(dask_jobscript)
- dask_jobscript.write_text(str(cluster.job_script()))
-
+ cluster = _initialize_dask_cluster(
+ _logger,
+ dask_cluster_cls,
+ dask_cluster_kwargs,
+ dask_cluster_adapt_kwargs,
+ experiment_dir,
+ )
loop.run_forever()
except KeyboardInterrupt:
_logger.info("Caught keyboard interrupt")