diff --git a/src/queens/schedulers/__init__.py b/src/queens/schedulers/__init__.py index 1bbcf70b7..71bd58715 100644 --- a/src/queens/schedulers/__init__.py +++ b/src/queens/schedulers/__init__.py @@ -21,8 +21,10 @@ from queens.utils.imports import extract_type_checking_imports, import_class_from_class_module_map if TYPE_CHECKING: + from queens.schedulers._cluster_base import _BaseCluster from queens.schedulers._scheduler import Scheduler from queens.schedulers.cluster import Cluster + from queens.schedulers.cluster_local import ClusterLocal from queens.schedulers.local import Local from queens.schedulers.pool import Pool diff --git a/src/queens/schedulers/_cluster_base.py b/src/queens/schedulers/_cluster_base.py new file mode 100644 index 000000000..ae04ad37d --- /dev/null +++ b/src/queens/schedulers/_cluster_base.py @@ -0,0 +1,257 @@ +# +# SPDX-License-Identifier: LGPL-3.0-or-later +# Copyright (c) 2024-2025, QUEENS contributors. +# +# This file is part of QUEENS. +# +# QUEENS is free software: you can redistribute it and/or modify it under the terms of the GNU +# Lesser General Public License as published by the Free Software Foundation, either version 3 of +# the License, or (at your option) any later version. QUEENS is distributed in the hope that it will +# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You +# should have received a copy of the GNU Lesser General Public License along with QUEENS. If not, +# see . +# +"""Base class for QUEENS cluster schedulers.""" + +import logging +from abc import abstractmethod +from datetime import timedelta + +from dask_jobqueue import PBSCluster, SLURMCluster + +from queens.schedulers._dask import Dask +from queens.utils.logger_settings import log_init_args +from queens.utils.valid_options import get_option + +_logger = logging.getLogger(__name__) + +VALID_WORKLOAD_MANAGERS = { + "slurm": { + "dask_cluster_cls": SLURMCluster, + "job_extra_directives": lambda nodes, cores: f"--ntasks={nodes * cores}", + "job_directives_skip": [ + "#SBATCH -n 1", + "#SBATCH --mem=", + "#SBATCH --cpus-per-task=", + ], + }, + "pbs": { + "dask_cluster_cls": PBSCluster, + "job_extra_directives": lambda nodes, cores: f"-l nodes={nodes}:ppn={cores}", + "job_directives_skip": ["#PBS -l select"], + }, +} + + +def timedelta_to_str(timedelta_obj): + """Format a timedelta object to str. + + This function seems unnecessarily complicated, but unfortunately the datetime library does not + support this formatting for timedeltas. Returns the format HH:MM:SS. + + Args: + timedelta_obj (datetime.timedelta): Timedelta object to format + + Returns: + str: String of the timedelta object + """ + # Time in seconds + time_in_seconds = int(timedelta_obj.total_seconds()) + (minutes, seconds) = divmod(time_in_seconds, 60) + (hours, minutes) = divmod(minutes, 60) + return f"{hours:02}:{minutes:02}:{seconds:02}" + + +def _initialize_dask_cluster( + logger, dask_cluster_cls, dask_cluster_kwargs, dask_cluster_adapt_kwargs, experiment_dir +): + """Initialize a Dask cluster. + + Start dask cluster, adapt it to the requested worker settings, and + write jobscript. + """ + logger.info("Starting dask cluster of type: %s", dask_cluster_cls) + logger.debug("Dask cluster kwargs:") + logger.debug(dask_cluster_kwargs) + cluster = dask_cluster_cls(**dask_cluster_kwargs) + + logger.info("Adapting dask cluster settings") + logger.debug("Dask cluster adapt kwargs:") + logger.debug(dask_cluster_adapt_kwargs) + cluster.adapt(**dask_cluster_adapt_kwargs) + + logger.info("Dask cluster info:") + logger.info(cluster) + + dask_jobscript = experiment_dir / "dask_jobscript.sh" + logger.info("Writing dask jobscript to:") + logger.info(dask_jobscript) + dask_jobscript.write_text(str(cluster.job_script())) + + return cluster + + +class _BaseCluster(Dask): + """Abstract base class for QUEENS cluster schedulers.""" + + @log_init_args + def __init__( + self, + experiment_name, + workload_manager, + walltime, + num_jobs=1, + min_jobs=0, + num_procs=1, + num_nodes=1, + queue=None, + cluster_internal_address=None, + restart_workers=False, + allowed_failures=5, + verbose=True, + experiment_base_dir=None, + overwrite_existing_experiment=False, + job_script_prologue=None, + ): + """Init method for the abstract cluster scheduler. + + The total number of cores per job is given by num_procs*num_nodes. + + Args: + experiment_name (str): Name of the current experiment + workload_manager (str): Workload manager ("pbs" or "slurm") + walltime (str): Walltime for each worker job. Format (hh:mm:ss) + num_jobs (int, opt): Maximum number of parallel jobs + min_jobs (int, opt): Minimum number of active workers for the cluster + num_procs (int, opt): Number of processors per job per node + num_nodes (int, opt): Number of cluster nodes per job + queue (str, opt): Destination queue for each worker job + cluster_internal_address (str, opt): Internal address of cluster + restart_workers (bool): If True, restart workers after each finished job. For larger + jobs (>1min) this should be set to True in most cases. + allowed_failures (int): Number of allowed failures for a task before an error is raised + verbose (bool, opt): Verbosity of evaluations. Defaults to True. + experiment_base_dir (str, Path): Base directory for the simulation outputs + overwrite_existing_experiment (bool): If True, overwrite experiment directory if it + exists already. If False, prompt user for confirmation before overwriting. + job_script_prologue (list, opt): List of commands to be executed before starting a + worker. + """ + self.workload_manager = workload_manager + self.walltime = walltime + self.min_jobs = min_jobs + self.num_nodes = num_nodes + self.queue = queue + self.cluster_internal_address = cluster_internal_address + self.allowed_failures = allowed_failures + self.job_script_prologue = job_script_prologue + + experiment_dir = self._get_experiment_dir( + experiment_name, experiment_base_dir, overwrite_existing_experiment + ) + + _logger.debug("experiment directory: %s", experiment_dir) + + super().__init__( + experiment_name=experiment_name, + experiment_dir=experiment_dir, + num_jobs=num_jobs, + num_procs=num_procs, + restart_workers=restart_workers, + verbose=verbose, + ) + + @abstractmethod + def _get_experiment_dir( + self, experiment_name, experiment_base_dir, overwrite_existing_experiment + ): + """Get experiment directory.""" + + @abstractmethod + def _start_cluster(self, dask_cluster_kwargs, dask_cluster_adapt_kwargs): + """Start cluster and return connected client and dashboard port.""" + + @abstractmethod + def copy_files_to_experiment_dir(self, paths): + """Copy file to experiment directory. + + Args: + paths (Path, list): Paths to files or directories that should be copied to experiment + directory + """ + + def _start_cluster_and_connect_client(self): + """Start a Dask cluster and a client that connects to it. + + Returns: + client (Client): Dask client that is connected to and submits computations to a Dask + cluster. + """ + # collect all settings for the dask cluster + dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager) + job_extra_directives = dask_cluster_options["job_extra_directives"]( + self.num_nodes, self.num_procs + ) + job_directives_skip = dask_cluster_options["job_directives_skip"] + if self.queue is None: + job_directives_skip.append("#SBATCH -p") + + hours, minutes, seconds = map(int, self.walltime.split(":")) + walltime_delta = timedelta(hours=hours, minutes=minutes, seconds=seconds) + + # Increase jobqueue walltime by 5 minutes to kill dask workers in time + increased_walltime = timedelta_to_str(walltime_delta + timedelta(minutes=5)) + + # dask worker lifetime = walltime - 3m +/- 2m + worker_lifetime = str(int((walltime_delta + timedelta(minutes=2)).total_seconds())) + "s" + + dask_cluster_kwargs = { + "job_name": self.experiment_name, + "queue": self.queue, + "memory": "10TB", + "walltime": increased_walltime, + "log_directory": str(self.experiment_dir), + "job_directives_skip": job_directives_skip, + "job_extra_directives": [job_extra_directives], + "worker_extra_args": ["--lifetime", worker_lifetime, "--lifetime-stagger", "2m"], + "job_script_prologue": self.job_script_prologue, + # keep this hardcoded to 1, the number of threads for the mpi run is handled by + # job_extra_directives. Note that the number of workers is not the number of + # parallel simulations! + "cores": 1, + "processes": 1, + "n_workers": 1, + } + dask_cluster_adapt_kwargs = { + "minimum_jobs": self.min_jobs, + "maximum_jobs": self.num_jobs, + } + + # start dask cluster + client, dashboard_port = self._start_cluster(dask_cluster_kwargs, dask_cluster_adapt_kwargs) + + _logger.debug("Submitting dummy job to check basic functionality of client.") + client.submit(lambda: "Dummy job").result(timeout=180) + _logger.debug("Dummy job was successful.") + _logger.info( + "To view the Dask dashboard open this link in your browser: " + "http://localhost:%i/status", + dashboard_port, + ) + return client + + def restart_worker(self, worker): + """Restart a worker. + + This method retires a dask worker. The Client.adapt method of dask takes cares of submitting + new workers subsequently. + + Args: + worker (str, tuple): Worker to restart. This can be a worker address, name, or a both. + """ + self.client.retire_workers(workers=list(worker)) + + @staticmethod + def delete_experiment_dir_if_empty(_): + """The remote experiment directory will never be empty, so pass.""" diff --git a/src/queens/schedulers/cluster.py b/src/queens/schedulers/cluster.py index b00cadb6a..59b303d42 100644 --- a/src/queens/schedulers/cluster.py +++ b/src/queens/schedulers/cluster.py @@ -16,60 +16,21 @@ import logging import time -from datetime import timedelta from pathlib import Path from typing import Sequence from dask.distributed import Client -from dask_jobqueue import PBSCluster, SLURMCluster -from queens.schedulers._dask import Dask +from queens.schedulers._cluster_base import _BaseCluster from queens.utils.config_directories import experiment_directory # Do not change this import! from queens.utils.config_directories import create_directory from queens.utils.logger_settings import log_init_args -from queens.utils.valid_options import get_option _logger = logging.getLogger(__name__) -VALID_WORKLOAD_MANAGERS = { - "slurm": { - "dask_cluster_cls": SLURMCluster, - "job_extra_directives": lambda nodes, cores: f"--ntasks={nodes * cores}", - "job_directives_skip": [ - "#SBATCH -n 1", - "#SBATCH --mem=", - "#SBATCH --cpus-per-task=", - ], - }, - "pbs": { - "dask_cluster_cls": PBSCluster, - "job_extra_directives": lambda nodes, cores: f"-l nodes={nodes}:ppn={cores}", - "job_directives_skip": ["#PBS -l select"], - }, -} - -def timedelta_to_str(timedelta_obj): - """Format a timedelta object to str. - - This function seems unnecessarily complicated, but unfortunately the datetime library does not - support this formatting for timedeltas. Returns the format HH:MM:SS. - - Args: - timedelta_obj (datetime.timedelta): Timedelta object to format - - Returns: - str: String of the timedelta object - """ - # Time in seconds - time_in_seconds = int(timedelta_obj.total_seconds()) - (minutes, seconds) = divmod(time_in_seconds, 60) - (hours, minutes) = divmod(minutes, 60) - return f"{hours:02}:{minutes:02}:{seconds:02}" - - -class Cluster(Dask): - """Cluster scheduler for QUEENS.""" +class Cluster(_BaseCluster): + """Cluster (remote) scheduler for QUEENS.""" @log_init_args def __init__( @@ -91,7 +52,7 @@ def __init__( overwrite_existing_experiment=False, job_script_prologue=None, ): - """Init method for the cluster scheduler. + """Init method for the remote cluster scheduler. The total number of cores per job is given by num_procs*num_nodes. @@ -122,37 +83,25 @@ def __init__( # sync remote source code with local state self.remote_connection.sync_remote_repository() - self.workload_manager = workload_manager - self.walltime = walltime - self.min_jobs = min_jobs - self.num_nodes = num_nodes - self.queue = queue - self.cluster_internal_address = cluster_internal_address - self.allowed_failures = allowed_failures - self.job_script_prologue = job_script_prologue - - # get the path of the experiment directory on remote host - experiment_dir = self.remote_experiment_dir( - experiment_name, experiment_base_dir, overwrite_existing_experiment - ) - - _logger.debug( - "experiment directory on %s@%s: %s", - self.remote_connection.user, - self.remote_connection.host, - experiment_dir, - ) - super().__init__( experiment_name=experiment_name, - experiment_dir=experiment_dir, + workload_manager=workload_manager, + walltime=walltime, num_jobs=num_jobs, + min_jobs=min_jobs, num_procs=num_procs, + num_nodes=num_nodes, + queue=queue, + cluster_internal_address=cluster_internal_address, restart_workers=restart_workers, + allowed_failures=allowed_failures, verbose=verbose, + experiment_base_dir=experiment_base_dir, + overwrite_existing_experiment=overwrite_existing_experiment, + job_script_prologue=job_script_prologue, ) - def remote_experiment_dir( + def _get_experiment_dir( self, experiment_name, experiment_base_dir, overwrite_existing_experiment ): """Get experiment directory on remote host. @@ -175,31 +124,13 @@ def remote_experiment_dir( return experiment_dir - def _start_cluster_and_connect_client(self): - """Start a Dask cluster and a client that connects to it. + def _start_cluster(self, dask_cluster_kwargs, dask_cluster_adapt_kwargs): + """Start a Dask cluster and connect a client on remote host. Returns: client (Client): Dask client that is connected to and submits computations to a Dask cluster. """ - # collect all settings for the dask cluster - dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager) - job_extra_directives = dask_cluster_options["job_extra_directives"]( - self.num_nodes, self.num_procs - ) - job_directives_skip = dask_cluster_options["job_directives_skip"] - if self.queue is None: - job_directives_skip.append("#SBATCH -p") - - hours, minutes, seconds = map(int, self.walltime.split(":")) - walltime_delta = timedelta(hours=hours, minutes=minutes, seconds=seconds) - - # Increase jobqueue walltime by 5 minutes to kill dask workers in time - increased_walltime = timedelta_to_str(walltime_delta + timedelta(minutes=5)) - - # dask worker lifetime = walltime - 3m +/- 2m - worker_lifetime = str(int((walltime_delta + timedelta(minutes=2)).total_seconds())) + "s" - local_port, remote_port = self.remote_connection.open_port_forwarding() local_port_dashboard, remote_port_dashboard = self.remote_connection.open_port_forwarding() @@ -210,30 +141,9 @@ def _start_cluster_and_connect_client(self): } if self.cluster_internal_address is not None: scheduler_options["contact_address"] = f"{self.cluster_internal_address}:{remote_port}" - dask_cluster_kwargs = { - "job_name": self.experiment_name, - "queue": self.queue, - "memory": "10TB", - "scheduler_options": scheduler_options, - "walltime": increased_walltime, - "log_directory": str(self.experiment_dir), - "job_directives_skip": job_directives_skip, - "job_extra_directives": [job_extra_directives], - "worker_extra_args": ["--lifetime", worker_lifetime, "--lifetime-stagger", "2m"], - "job_script_prologue": self.job_script_prologue, - # keep this hardcoded to 1, the number of threads for the mpi run is handled by - # job_extra_directives. Note that the number of workers is not the number of - # parallel simulations! - "cores": 1, - "processes": 1, - "n_workers": 1, - } - dask_cluster_adapt_kwargs = { - "minimum_jobs": self.min_jobs, - "maximum_jobs": self.num_jobs, - } - # actually start the dask cluster on remote host + dask_cluster_kwargs["scheduler_options"] = scheduler_options + stdout, stderr = self.remote_connection.start_cluster( self.workload_manager, dask_cluster_kwargs, @@ -255,26 +165,7 @@ def _start_cluster_and_connect_client(self): ) from exc time.sleep(1) - _logger.debug("Submitting dummy job to check basic functionality of client.") - client.submit(lambda: "Dummy job").result(timeout=180) - _logger.debug("Dummy job was successful.") - _logger.info( - "To view the Dask dashboard open this link in your browser: " - "http://localhost:%i/status", - local_port_dashboard, - ) - return client - - def restart_worker(self, worker): - """Restart a worker. - - This method retires a dask worker. The Client.adapt method of dask takes cares of submitting - new workers subsequently. - - Args: - worker (str, tuple): Worker to restart. This can be a worker address, name, or a both. - """ - self.client.retire_workers(workers=list(worker)) + return client, local_port_dashboard def copy_files_to_experiment_dir(self, paths): """Copy file to experiment directory. @@ -316,7 +207,3 @@ def copy_files_from_experiment_dir( self.remote_connection.copy_from_remote( self.experiment_dir, destination, verbose, exclude, filters ) - - @staticmethod - def delete_experiment_dir_if_empty(_): - """The remote experiment directory will never be empty, so pass.""" diff --git a/src/queens/schedulers/cluster_local.py b/src/queens/schedulers/cluster_local.py new file mode 100644 index 000000000..999b4952d --- /dev/null +++ b/src/queens/schedulers/cluster_local.py @@ -0,0 +1,99 @@ +# +# SPDX-License-Identifier: LGPL-3.0-or-later +# Copyright (c) 2024-2025, QUEENS contributors. +# +# This file is part of QUEENS. +# +# QUEENS is free software: you can redistribute it and/or modify it under the terms of the GNU +# Lesser General Public License as published by the Free Software Foundation, either version 3 of +# the License, or (at your option) any later version. QUEENS is distributed in the hope that it will +# be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You +# should have received a copy of the GNU Lesser General Public License along with QUEENS. If not, +# see . +# +"""Cluster scheduler for QUEENS runs.""" + +import logging +import time + +from dask.distributed import Client + +from queens.schedulers._cluster_base import ( + VALID_WORKLOAD_MANAGERS, + _BaseCluster, + _initialize_dask_cluster, +) +from queens.schedulers._scheduler import Scheduler +from queens.utils.remote_operations import get_port +from queens.utils.valid_options import get_option + +_logger = logging.getLogger(__name__) + + +class ClusterLocal(_BaseCluster): + """Cluster (local) scheduler for QUEENS. + + Can be used to schedule jobs to a cluster scheduler with local + access i.e. without a network connection. + """ + + def _get_experiment_dir( + self, experiment_name, experiment_base_dir, overwrite_existing_experiment + ): + """Get local experiment directory.""" + return Scheduler.local_experiment_dir( + self, experiment_name, experiment_base_dir, overwrite_existing_experiment + ) + + def _start_cluster(self, dask_cluster_kwargs, dask_cluster_adapt_kwargs): + """Start a Dask cluster and connect a client locally.""" + # collect all settings for the dask cluster + dask_cluster_options = get_option(VALID_WORKLOAD_MANAGERS, self.workload_manager) + dask_cluster_cls = dask_cluster_options["dask_cluster_cls"] + + remote_port = get_port() + local_port_dashboard = get_port() + remote_port_dashboard = get_port() + + scheduler_options = { + "port": remote_port, + "dashboard_address": remote_port_dashboard, + "allowed_failures": self.allowed_failures, + } + if self.cluster_internal_address: + scheduler_options["contact_address"] = f"{self.cluster_internal_address}:{remote_port}" + + dask_cluster_kwargs["scheduler_options"] = scheduler_options + + try: + cluster = _initialize_dask_cluster( # pylint: disable=duplicate-code + _logger, + dask_cluster_cls, + dask_cluster_kwargs, + dask_cluster_adapt_kwargs, + self.experiment_dir, + ) + except Exception as e: + raise RuntimeError() from e + + for i in range(20, 0, -1): # 20 tries to connect + _logger.debug("Trying to connect to Dask Cluster: try #%d", i) + try: + client = Client(cluster) + break + except OSError as exc: + if i == 1: + raise OSError() from exc + time.sleep(1) + + return client, local_port_dashboard + + def copy_files_to_experiment_dir(self, paths): + """Copy file to experiment directory. + + Args: + paths (Path, list): paths to files or directories that should be copied to experiment + directory + """ + return Scheduler.copy_files_to_experiment_dir(self, paths) diff --git a/src/queens/utils/start_dask_cluster.py b/src/queens/utils/start_dask_cluster.py index b9105b8f8..06238a649 100644 --- a/src/queens/utils/start_dask_cluster.py +++ b/src/queens/utils/start_dask_cluster.py @@ -22,7 +22,7 @@ from pathlib import Path from typing import Sequence -from queens.schedulers.cluster import VALID_WORKLOAD_MANAGERS +from queens.schedulers._cluster_base import VALID_WORKLOAD_MANAGERS, _initialize_dask_cluster from queens.utils.logger_settings import setup_basic_logging from queens.utils.valid_options import get_option @@ -81,24 +81,13 @@ def parse_arguments(unparsed_args: Sequence[str]) -> argparse.Namespace: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - _logger.info("Starting dask cluster of type: %s", dask_cluster_cls) - _logger.debug("Dask cluster kwargs:") - _logger.debug(dask_cluster_kwargs) - cluster = dask_cluster_cls(**dask_cluster_kwargs) - - _logger.info("Adapting dask cluster settings") - _logger.debug("Dask cluster adapt kwargs:") - _logger.debug(dask_cluster_adapt_kwargs) - cluster.adapt(**dask_cluster_adapt_kwargs) - - _logger.info("Dask cluster info:") - _logger.info(cluster) - - dask_jobscript = experiment_dir / "dask_jobscript.sh" - _logger.info("Writing dask jobscript to:") - _logger.info(dask_jobscript) - dask_jobscript.write_text(str(cluster.job_script())) - + cluster = _initialize_dask_cluster( + _logger, + dask_cluster_cls, + dask_cluster_kwargs, + dask_cluster_adapt_kwargs, + experiment_dir, + ) loop.run_forever() except KeyboardInterrupt: _logger.info("Caught keyboard interrupt")