From d7fe9c272396cc78d1fc19e9b9818e38185c4016 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 30 Sep 2024 10:35:33 +0000 Subject: [PATCH 1/8] Allow ASE constraint to be used in geomopt - Separate parameter from filter_func - kwargs come from minimize_args, same as filter args - Check if the constraint needs "atoms" and insert if so --- janus_core/calculations/geom_opt.py | 46 +++++++++++++++++++++++++---- janus_core/cli/geomopt.py | 14 +++++++-- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 4a5f5aeb..8099b585 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -3,11 +3,13 @@ from __future__ import annotations from collections.abc import Callable +import inspect from pathlib import Path -from typing import Any +from typing import Any, Callable + import warnings -from ase import Atoms, filters, units +from ase import Atoms, constraints, filters, units from ase.filters import FrechetCellFilter from ase.io import read import ase.optimize @@ -80,6 +82,10 @@ class GeomOpt(BaseCalculation): Deprecated. Please use `filter_class`. filter_kwargs Keyword arguments to pass to filter_class. Default is {}. + constraint_func + Constraint function, or name of function from ase.constraints. Default is None. + constraint_kwargs + Keyword arguments to pass to constraint_func. Default is {}. optimizer Optimization function, or name of function from ase.optimize. Default is `LBFGS`. @@ -121,6 +127,8 @@ def __init__( filter_kwargs: dict[str, Any] | None = None, optimizer: Callable | str = LBFGS, opt_kwargs: ASEOptArgs | None = None, + constraint_func: Callable | str | None = None, + constraint_kwargs: dict[str, Any] | None = None, write_results: bool = False, write_kwargs: OutputKwargs | None = None, write_traj: bool = False, @@ -194,9 +202,9 @@ def __init__( "filename" keyword is inferred from `file_prefix` if not given. Default is {}. """ - read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = ( + read_kwargs, constraint_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = ( none_to_dict( - read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs + (read_kwargs, constraint_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs) ) ) @@ -211,6 +219,8 @@ def __init__( self.angle_tolerance = angle_tolerance self.filter_class = filter_class self.filter_kwargs = filter_kwargs + self.constraint_func = constraint_func + self.constraint_kwargs = constraint_kwargs self.optimizer = optimizer self.opt_kwargs = opt_kwargs self.write_results = write_results @@ -301,12 +311,30 @@ def output_files(self) -> None: "trajectory": self.traj_kwargs.get("filename"), } + def _get_constraint_args(self, constraint_class: object) -> list[Any]: + """Inspect constraint class for mandatory arguments + + For now we are just looking for the "atoms" parameter of FixSymmetry + """ + parameters = inspect.signature(constraint_class.__init__).parameters + if "atoms" in parameters: + return [self.struct] + + return [] + def set_optimizer(self) -> None: """Set optimizer for geometry optimization.""" self._set_functions() if self.logger: self.logger.info("Using optimizer: %s", self.optimizer.__name__) + if self.constraint_func is not None: + constraint_args = self._get_constraint_args(self.constraint_func) + self.struct.set_constraint(self.constraint_func(*constraint_args, **self.constraint_kwargs)) + + if self.logger: + self.logger.info("Using constraint: %s", self.constraint_func.__name__) + if self.filter_class is not None: if "scalar_pressure" in self.filter_kwargs: self.filter_kwargs["scalar_pressure"] *= units.GPa @@ -332,14 +360,20 @@ def set_optimizer(self) -> None: self.dyn = self.optimizer(self.struct, **self.opt_kwargs) def _set_functions(self) -> None: - """Set optimizer and filter.""" + """Set optimizer, constraint and filter functions.""" if isinstance(self.optimizer, str): try: self.optimizer = getattr(ase.optimize, self.optimizer) except AttributeError as e: raise AttributeError(f"No such optimizer: {self.optimizer}") from e - if self.filter_class is not None and isinstance(self.filter_class, str): + if self.constraint_func is not None and isinstance(self.constraint_func, str): + try: + self.constraint_func = getattr(constraints, self.constraint_func) + except AttributeError as e: + raise AttributeError(f"No such constraint: {self.constraint_func}") from e + + if self.filter_class is not None and isinstance(self.filter_func, str): try: self.filter_class = getattr(filters, self.filter_class) except AttributeError as e: diff --git a/janus_core/cli/geomopt.py b/janus_core/cli/geomopt.py index eac3c105..3cc9f3f2 100644 --- a/janus_core/cli/geomopt.py +++ b/janus_core/cli/geomopt.py @@ -132,6 +132,12 @@ def geomopt( rich_help_panel="Calculation", callback=deprecated_option, hidden=True, + ) + ] = None, + constraint_func: Annotated[ + str, + Option( + help="Name of ASE constraint function to use." ), ] = None, pressure: Annotated[ @@ -216,8 +222,10 @@ def geomopt( filter_class Name of filter from ase.filters to wrap around atoms. If using --opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`. - filter_func - Deprecated. Please use `filter_class`. + constraint_func + Name of constraint function from ase.constraints, to apply constraints + to atoms. Parameters should be included as a "constraint_kwargs" dict + within "minimize_kwargs". Default is None pressure Scalar pressure when optimizing cell geometry, in GPa. Passed to the filter function if either `opt_cell_lengths` or `opt_cell_fully` is True. Default is @@ -298,6 +306,7 @@ def geomopt( # Check optimized structure path not duplicated if "filename" in write_kwargs: raise ValueError("'filename' must be passed through the --out option") + if out: write_kwargs["filename"] = out @@ -344,6 +353,7 @@ def geomopt( "symmetrize": symmetrize, "symmetry_tolerance": symmetry_tolerance, "file_prefix": file_prefix, + "constraint_func": constraint_func, **opt_cell_fully_dict, **minimize_kwargs, "write_results": True, From 4ce4e3ac8b013909ae47c51a5561bc9fb9b649eb Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 30 Sep 2024 10:42:33 +0000 Subject: [PATCH 2/8] Tidy up a bit --- janus_core/calculations/geom_opt.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 8099b585..e2b19377 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -311,16 +311,14 @@ def output_files(self) -> None: "trajectory": self.traj_kwargs.get("filename"), } - def _get_constraint_args(self, constraint_class: object) -> list[Any]: + def _set_mandatory_constraint_kwargs(self) -> None: """Inspect constraint class for mandatory arguments For now we are just looking for the "atoms" parameter of FixSymmetry """ - parameters = inspect.signature(constraint_class.__init__).parameters + parameters = inspect.signature(self.constraint_func.__init__).parameters if "atoms" in parameters: - return [self.struct] - - return [] + self.constraint_kwargs["atoms"] = self.struct def set_optimizer(self) -> None: """Set optimizer for geometry optimization.""" @@ -329,8 +327,8 @@ def set_optimizer(self) -> None: self.logger.info("Using optimizer: %s", self.optimizer.__name__) if self.constraint_func is not None: - constraint_args = self._get_constraint_args(self.constraint_func) - self.struct.set_constraint(self.constraint_func(*constraint_args, **self.constraint_kwargs)) + self._set_mandatory_constraint_kwargs() + self.struct.set_constraint(self.constraint_func(**self.constraint_kwargs)) if self.logger: self.logger.info("Using constraint: %s", self.constraint_func.__name__) From 38f9f65e69f4963d9beecf464fa44efe2a8cb91f Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 2 Jun 2025 12:38:37 +0000 Subject: [PATCH 3/8] Fix test failures after rebase --- janus_core/calculations/geom_opt.py | 39 +++++++++++++++++++++-------- janus_core/cli/geomopt.py | 8 +++--- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index e2b19377..1b752d2f 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -5,8 +5,7 @@ from collections.abc import Callable import inspect from pathlib import Path -from typing import Any, Callable - +from typing import Any import warnings from ase import Atoms, constraints, filters, units @@ -125,10 +124,10 @@ def __init__( filter_class: Callable | str | None = FrechetCellFilter, filter_func: Callable | str | None = None, filter_kwargs: dict[str, Any] | None = None, - optimizer: Callable | str = LBFGS, - opt_kwargs: ASEOptArgs | None = None, constraint_func: Callable | str | None = None, constraint_kwargs: dict[str, Any] | None = None, + optimizer: Callable | str = LBFGS, + opt_kwargs: ASEOptArgs | None = None, write_results: bool = False, write_kwargs: OutputKwargs | None = None, write_traj: bool = False, @@ -185,6 +184,11 @@ def __init__( Deprecated. Please use `filter_class`. filter_kwargs Keyword arguments to pass to filter_class. Default is {}. + constraint_func + Constraint function, or name of function from ase.constraints. Default is + None. + constraint_kwargs + Keyword arguments to pass to constraint_func. Default is {}. optimizer Optimization function, or name of function from ase.optimize. Default is `LBFGS`. @@ -202,9 +206,21 @@ def __init__( "filename" keyword is inferred from `file_prefix` if not given. Default is {}. """ - read_kwargs, constraint_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = ( + ( + read_kwargs, + constraint_kwargs, + filter_kwargs, + opt_kwargs, + write_kwargs, + traj_kwargs, + ) = list( none_to_dict( - (read_kwargs, constraint_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs) + read_kwargs, + constraint_kwargs, + filter_kwargs, + opt_kwargs, + write_kwargs, + traj_kwargs, ) ) @@ -220,7 +236,7 @@ def __init__( self.filter_class = filter_class self.filter_kwargs = filter_kwargs self.constraint_func = constraint_func - self.constraint_kwargs = constraint_kwargs + self.constraint_kwargs = constraint_kwargs self.optimizer = optimizer self.opt_kwargs = opt_kwargs self.write_results = write_results @@ -312,7 +328,8 @@ def output_files(self) -> None: } def _set_mandatory_constraint_kwargs(self) -> None: - """Inspect constraint class for mandatory arguments + """ + Inspect constraint class for mandatory arguments. For now we are just looking for the "atoms" parameter of FixSymmetry """ @@ -369,9 +386,11 @@ def _set_functions(self) -> None: try: self.constraint_func = getattr(constraints, self.constraint_func) except AttributeError as e: - raise AttributeError(f"No such constraint: {self.constraint_func}") from e + raise AttributeError( + f"No such constraint: {self.constraint_func}" + ) from e - if self.filter_class is not None and isinstance(self.filter_func, str): + if self.filter_class is not None and isinstance(self.filter_class, str): try: self.filter_class = getattr(filters, self.filter_class) except AttributeError as e: diff --git a/janus_core/cli/geomopt.py b/janus_core/cli/geomopt.py index 3cc9f3f2..7f84960f 100644 --- a/janus_core/cli/geomopt.py +++ b/janus_core/cli/geomopt.py @@ -136,9 +136,7 @@ def geomopt( ] = None, constraint_func: Annotated[ str, - Option( - help="Name of ASE constraint function to use." - ), + Option(help="Name of ASE constraint function to use."), ] = None, pressure: Annotated[ float, @@ -222,10 +220,12 @@ def geomopt( filter_class Name of filter from ase.filters to wrap around atoms. If using --opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`. + filter_func + Deprecated. Please use --filter. constraint_func Name of constraint function from ase.constraints, to apply constraints to atoms. Parameters should be included as a "constraint_kwargs" dict - within "minimize_kwargs". Default is None + within "minimize_kwargs". Default is None. pressure Scalar pressure when optimizing cell geometry, in GPa. Passed to the filter function if either `opt_cell_lengths` or `opt_cell_fully` is True. Default is From 79b8ef2099b6a2810edd667787b0cd1d901238ba Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Mon, 2 Jun 2025 14:30:18 +0000 Subject: [PATCH 4/8] Revert unintended docstring change. --- janus_core/cli/geomopt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/janus_core/cli/geomopt.py b/janus_core/cli/geomopt.py index 7f84960f..d7d5ac29 100644 --- a/janus_core/cli/geomopt.py +++ b/janus_core/cli/geomopt.py @@ -221,7 +221,7 @@ def geomopt( Name of filter from ase.filters to wrap around atoms. If using --opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`. filter_func - Deprecated. Please use --filter. + Deprecated. Please use `--filter_class`. constraint_func Name of constraint function from ase.constraints, to apply constraints to atoms. Parameters should be included as a "constraint_kwargs" dict From c0d500ecfcf5d57d8a5265ca4f2e3e094365ee6b Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Tue, 9 Dec 2025 16:15:09 +0000 Subject: [PATCH 5/8] constraint_func -> constraint_class, ruff format Constraints are classes, use correct name and tweak some docstrings. --- janus_core/calculations/geom_opt.py | 30 ++++++++++++++--------------- janus_core/cli/geomopt.py | 16 +++++++++------ 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 1b752d2f..633b4e33 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -81,10 +81,10 @@ class GeomOpt(BaseCalculation): Deprecated. Please use `filter_class`. filter_kwargs Keyword arguments to pass to filter_class. Default is {}. - constraint_func - Constraint function, or name of function from ase.constraints. Default is None. + constraint_class + Constraint class, or name of class from ase.constraints. Default is None. constraint_kwargs - Keyword arguments to pass to constraint_func. Default is {}. + Keyword arguments to pass to constraint_class. Default is {}. optimizer Optimization function, or name of function from ase.optimize. Default is `LBFGS`. @@ -124,7 +124,7 @@ def __init__( filter_class: Callable | str | None = FrechetCellFilter, filter_func: Callable | str | None = None, filter_kwargs: dict[str, Any] | None = None, - constraint_func: Callable | str | None = None, + constraint_class: type | str | None = None, constraint_kwargs: dict[str, Any] | None = None, optimizer: Callable | str = LBFGS, opt_kwargs: ASEOptArgs | None = None, @@ -184,11 +184,11 @@ def __init__( Deprecated. Please use `filter_class`. filter_kwargs Keyword arguments to pass to filter_class. Default is {}. - constraint_func - Constraint function, or name of function from ase.constraints. Default is + constraint_class + Constraint class, or name of class from ase.constraints. Default is None. constraint_kwargs - Keyword arguments to pass to constraint_func. Default is {}. + Keyword arguments to construct constraint_class. Default is {}. optimizer Optimization function, or name of function from ase.optimize. Default is `LBFGS`. @@ -235,7 +235,7 @@ def __init__( self.angle_tolerance = angle_tolerance self.filter_class = filter_class self.filter_kwargs = filter_kwargs - self.constraint_func = constraint_func + self.constraint_class = constraint_class self.constraint_kwargs = constraint_kwargs self.optimizer = optimizer self.opt_kwargs = opt_kwargs @@ -333,7 +333,7 @@ def _set_mandatory_constraint_kwargs(self) -> None: For now we are just looking for the "atoms" parameter of FixSymmetry """ - parameters = inspect.signature(self.constraint_func.__init__).parameters + parameters = inspect.signature(self.constraint_class.__init__).parameters if "atoms" in parameters: self.constraint_kwargs["atoms"] = self.struct @@ -343,12 +343,12 @@ def set_optimizer(self) -> None: if self.logger: self.logger.info("Using optimizer: %s", self.optimizer.__name__) - if self.constraint_func is not None: + if self.constraint_class is not None: self._set_mandatory_constraint_kwargs() - self.struct.set_constraint(self.constraint_func(**self.constraint_kwargs)) + self.struct.set_constraint(self.constraint_class(**self.constraint_kwargs)) if self.logger: - self.logger.info("Using constraint: %s", self.constraint_func.__name__) + self.logger.info("Using constraint: %s", self.constraint_class.__name__) if self.filter_class is not None: if "scalar_pressure" in self.filter_kwargs: @@ -382,12 +382,12 @@ def _set_functions(self) -> None: except AttributeError as e: raise AttributeError(f"No such optimizer: {self.optimizer}") from e - if self.constraint_func is not None and isinstance(self.constraint_func, str): + if self.constraint_class is not None and isinstance(self.constraint_class, str): try: - self.constraint_func = getattr(constraints, self.constraint_func) + self.constraint_class = getattr(constraints, self.constraint_class) except AttributeError as e: raise AttributeError( - f"No such constraint: {self.constraint_func}" + f"No such constraint: {self.constraint_class}" ) from e if self.filter_class is not None and isinstance(self.filter_class, str): diff --git a/janus_core/cli/geomopt.py b/janus_core/cli/geomopt.py index d7d5ac29..d920f473 100644 --- a/janus_core/cli/geomopt.py +++ b/janus_core/cli/geomopt.py @@ -132,11 +132,15 @@ def geomopt( rich_help_panel="Calculation", callback=deprecated_option, hidden=True, - ) + ), ] = None, - constraint_func: Annotated[ + constraint_class: Annotated[ str, - Option(help="Name of ASE constraint function to use."), + Option( + "--constraint", + help="Name of ASE constraint to attach to atoms.", + rich_help_panel="Calculation", + ), ] = None, pressure: Annotated[ float, @@ -222,8 +226,8 @@ def geomopt( --opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`. filter_func Deprecated. Please use `--filter_class`. - constraint_func - Name of constraint function from ase.constraints, to apply constraints + constraint_class + Name of constraint class from ase.constraints, to apply constraints to atoms. Parameters should be included as a "constraint_kwargs" dict within "minimize_kwargs". Default is None. pressure @@ -353,7 +357,7 @@ def geomopt( "symmetrize": symmetrize, "symmetry_tolerance": symmetry_tolerance, "file_prefix": file_prefix, - "constraint_func": constraint_func, + "constraint_class": constraint_class, **opt_cell_fully_dict, **minimize_kwargs, "write_results": True, From 7669d889a639b70474e080612869900d46d51862 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 10 Dec 2025 11:09:10 +0000 Subject: [PATCH 6/8] Add unit test for geom_opt python API with ASE constraint Test FixAtoms as this requires additional parameter. Make sure the fixed atoms stays where it is and the others move! --- tests/test_geom_opt.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_geom_opt.py b/tests/test_geom_opt.py index fad926ef..cf80db84 100644 --- a/tests/test_geom_opt.py +++ b/tests/test_geom_opt.py @@ -6,6 +6,7 @@ from ase.filters import FrechetCellFilter, UnitCellFilter from ase.io import read +from numpy.testing import assert_allclose import pytest from janus_core.calculations.geom_opt import GeomOpt @@ -52,6 +53,28 @@ def test_optimize(arch, struct, expected, kwargs): ) +def test_constrained_optimize(): + """Test optimizing geometry using MACE with ASE constraint.""" + single_point = SinglePoint( + struct=DATA_PATH / "H2O.cif", + arch="mace", + model=MODEL_PATH, + ) + + initial_positions = single_point.struct.positions.copy() + + optimizer = GeomOpt( + single_point.struct, + filter_class=None, # No volume opt + constraint_class="FixAtoms", + constraint_kwargs={"indices": [2]}, + ) + optimizer.run() + + assert_allclose(initial_positions[2], optimizer.struct.positions[2]) + assert (initial_positions[:2] - optimizer.struct.positions[:2]).any() + + def test_saving_struct(tmp_path): """Test saving optimized structure.""" results_path = tmp_path / "NaCl.extxyz" From a169a363b1d1d839d4b0ece2d2fb68833878b651 Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 10 Dec 2025 12:20:48 +0000 Subject: [PATCH 7/8] Add unit test for geom_opt CLI with constraint --- tests/test_geomopt_cli.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_geomopt_cli.py b/tests/test_geomopt_cli.py index 07c1249c..c2e67588 100644 --- a/tests/test_geomopt_cli.py +++ b/tests/test_geomopt_cli.py @@ -603,6 +603,39 @@ def test_filter_str_error(tmp_path): assert isinstance(result.exception, ValueError) +def test_constraint(tmp_path): + """Test setting constraint.""" + results_path = tmp_path / "H2O-opt.extxyz" + log_path = tmp_path / "test.log" + summary_path = tmp_path / "summary.yml" + + result = runner.invoke( + app, + [ + "geomopt", + "--struct", + DATA_PATH / "H2O.cif", + "--arch", + "mace_mp", + "--out", + results_path, + "--constraint", + "FixAtoms", + "--minimize-kwargs", + "{'constraint_kwargs': {'indices': [2]}}", + "--log", + log_path, + "--summary", + summary_path, + ], + ) + assert result.exit_code == 0 + assert_log_contains( + log_path, + includes=["Starting geometry optimization", "Using constraint: FixAtoms"], + ) + + @pytest.mark.parametrize("read_kwargs", ["{'index': 1}", "{}"]) def test_valid_traj_input(read_kwargs, tmp_path): """Test valid trajectory input structure handled.""" From a7f7315eeacb970540213a2343f982f9985e93ca Mon Sep 17 00:00:00 2001 From: "Adam J. Jackson" Date: Wed, 10 Dec 2025 12:35:54 +0000 Subject: [PATCH 8/8] Remove filter_func (again) for consistency with main --- janus_core/cli/geomopt.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/janus_core/cli/geomopt.py b/janus_core/cli/geomopt.py index 191c98aa..12c11d3f 100644 --- a/janus_core/cli/geomopt.py +++ b/janus_core/cli/geomopt.py @@ -124,15 +124,6 @@ def geomopt( rich_help_panel="Calculation", ), ] = None, - filter_func: Annotated[ - str | None, - Option( - help="Deprecated. Please use --filter", - rich_help_panel="Calculation", - callback=deprecated_option, - hidden=True, - ), - ] = None, constraint_class: Annotated[ str, Option( @@ -222,8 +213,6 @@ def geomopt( filter_class Name of filter from ase.filters to wrap around atoms. If using --opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`. - filter_func - Deprecated. Please use `--filter_class`. constraint_class Name of constraint class from ase.constraints, to apply constraints to atoms. Parameters should be included as a "constraint_kwargs" dict