diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 49a3e147..9b370780 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -3,11 +3,12 @@ from __future__ import annotations from collections.abc import Callable +import inspect from pathlib import Path from typing import Any import warnings -from ase import Atoms, filters, units +from ase import Atoms, constraints, filters, units from ase.filters import FrechetCellFilter from ase.io import read import ase.optimize @@ -76,6 +77,10 @@ class GeomOpt(BaseCalculation): Default is `FrechetCellFilter`. filter_kwargs Keyword arguments to pass to filter_class. Default is {}. + constraint_class + Constraint class, or name of class from ase.constraints. Default is None. + constraint_kwargs + Keyword arguments to pass to constraint_class. Default is {}. optimizer Optimization function, or name of function from ase.optimize. Default is `LBFGS`. @@ -113,6 +118,8 @@ def __init__( angle_tolerance: float = -1.0, filter_class: Callable | str | None = FrechetCellFilter, filter_kwargs: dict[str, Any] | None = None, + constraint_class: type | str | None = None, + constraint_kwargs: dict[str, Any] | None = None, optimizer: Callable | str = LBFGS, opt_kwargs: ASEOptArgs | None = None, write_results: bool = False, @@ -167,6 +174,11 @@ def __init__( Default is `FrechetCellFilter`. filter_kwargs Keyword arguments to pass to filter_class. Default is {}. + constraint_class + Constraint class, or name of class from ase.constraints. Default is + None. + constraint_kwargs + Keyword arguments to construct constraint_class. Default is {}. optimizer Optimization function, or name of function from ase.optimize. Default is `LBFGS`. @@ -184,9 +196,21 @@ def __init__( "filename" keyword is inferred from `file_prefix` if not given. Default is {}. """ - read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = ( + ( + read_kwargs, + constraint_kwargs, + filter_kwargs, + opt_kwargs, + write_kwargs, + traj_kwargs, + ) = list( none_to_dict( - read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs + read_kwargs, + constraint_kwargs, + filter_kwargs, + opt_kwargs, + write_kwargs, + traj_kwargs, ) ) @@ -197,6 +221,8 @@ def __init__( self.angle_tolerance = angle_tolerance self.filter_class = filter_class self.filter_kwargs = filter_kwargs + self.constraint_class = constraint_class + self.constraint_kwargs = constraint_kwargs self.optimizer = optimizer self.opt_kwargs = opt_kwargs self.write_results = write_results @@ -277,12 +303,29 @@ def output_files(self) -> None: "trajectory": self.traj_kwargs.get("filename"), } + def _set_mandatory_constraint_kwargs(self) -> None: + """ + Inspect constraint class for mandatory arguments. + + For now we are just looking for the "atoms" parameter of FixSymmetry + """ + parameters = inspect.signature(self.constraint_class.__init__).parameters + if "atoms" in parameters: + self.constraint_kwargs["atoms"] = self.struct + def set_optimizer(self) -> None: """Set optimizer for geometry optimization.""" self._set_functions() if self.logger: self.logger.info("Using optimizer: %s", self.optimizer.__name__) + if self.constraint_class is not None: + self._set_mandatory_constraint_kwargs() + self.struct.set_constraint(self.constraint_class(**self.constraint_kwargs)) + + if self.logger: + self.logger.info("Using constraint: %s", self.constraint_class.__name__) + if self.filter_class is not None: if "scalar_pressure" in self.filter_kwargs: self.filter_kwargs["scalar_pressure"] *= units.GPa @@ -308,13 +351,21 @@ def set_optimizer(self) -> None: self.dyn = self.optimizer(self.struct, **self.opt_kwargs) def _set_functions(self) -> None: - """Set optimizer and filter.""" + """Set optimizer, constraint and filter functions.""" if isinstance(self.optimizer, str): try: self.optimizer = getattr(ase.optimize, self.optimizer) except AttributeError as e: raise AttributeError(f"No such optimizer: {self.optimizer}") from e + if self.constraint_class is not None and isinstance(self.constraint_class, str): + try: + self.constraint_class = getattr(constraints, self.constraint_class) + except AttributeError as e: + raise AttributeError( + f"No such constraint: {self.constraint_class}" + ) from e + if self.filter_class is not None and isinstance(self.filter_class, str): try: self.filter_class = getattr(filters, self.filter_class) diff --git a/janus_core/cli/geomopt.py b/janus_core/cli/geomopt.py index 78ad3350..12c11d3f 100644 --- a/janus_core/cli/geomopt.py +++ b/janus_core/cli/geomopt.py @@ -124,6 +124,14 @@ def geomopt( rich_help_panel="Calculation", ), ] = None, + constraint_class: Annotated[ + str, + Option( + "--constraint", + help="Name of ASE constraint to attach to atoms.", + rich_help_panel="Calculation", + ), + ] = None, pressure: Annotated[ float, Option( @@ -205,6 +213,10 @@ def geomopt( filter_class Name of filter from ase.filters to wrap around atoms. If using --opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`. + constraint_class + Name of constraint class from ase.constraints, to apply constraints + to atoms. Parameters should be included as a "constraint_kwargs" dict + within "minimize_kwargs". Default is None. pressure Scalar pressure when optimizing cell geometry, in GPa. Passed to the filter function if either `opt_cell_lengths` or `opt_cell_fully` is True. Default is @@ -283,6 +295,7 @@ def geomopt( # Check optimized structure path not duplicated if "filename" in write_kwargs: raise ValueError("'filename' must be passed through the --out option") + if out: write_kwargs["filename"] = out @@ -323,6 +336,7 @@ def geomopt( "symmetrize": symmetrize, "symmetry_tolerance": symmetry_tolerance, "file_prefix": file_prefix, + "constraint_class": constraint_class, **opt_cell_fully_dict, **minimize_kwargs, "write_results": True, diff --git a/tests/test_geom_opt.py b/tests/test_geom_opt.py index 67fa314a..8d3d6cb4 100644 --- a/tests/test_geom_opt.py +++ b/tests/test_geom_opt.py @@ -6,6 +6,7 @@ from ase.filters import FrechetCellFilter, UnitCellFilter from ase.io import read +from numpy.testing import assert_allclose import pytest from janus_core.calculations.geom_opt import GeomOpt @@ -52,6 +53,28 @@ def test_optimize(arch, struct, expected, kwargs): ) +def test_constrained_optimize(): + """Test optimizing geometry using MACE with ASE constraint.""" + single_point = SinglePoint( + struct=DATA_PATH / "H2O.cif", + arch="mace", + model=MODEL_PATH, + ) + + initial_positions = single_point.struct.positions.copy() + + optimizer = GeomOpt( + single_point.struct, + filter_class=None, # No volume opt + constraint_class="FixAtoms", + constraint_kwargs={"indices": [2]}, + ) + optimizer.run() + + assert_allclose(initial_positions[2], optimizer.struct.positions[2]) + assert (initial_positions[:2] - optimizer.struct.positions[:2]).any() + + def test_saving_struct(tmp_path): """Test saving optimized structure.""" results_path = tmp_path / "NaCl.extxyz" diff --git a/tests/test_geomopt_cli.py b/tests/test_geomopt_cli.py index 1ef21e4b..494b6754 100644 --- a/tests/test_geomopt_cli.py +++ b/tests/test_geomopt_cli.py @@ -603,6 +603,39 @@ def test_filter_str_error(tmp_path): assert isinstance(result.exception, ValueError) +def test_constraint(tmp_path): + """Test setting constraint.""" + results_path = tmp_path / "H2O-opt.extxyz" + log_path = tmp_path / "test.log" + summary_path = tmp_path / "summary.yml" + + result = runner.invoke( + app, + [ + "geomopt", + "--struct", + DATA_PATH / "H2O.cif", + "--arch", + "mace_mp", + "--out", + results_path, + "--constraint", + "FixAtoms", + "--minimize-kwargs", + "{'constraint_kwargs': {'indices': [2]}}", + "--log", + log_path, + "--summary", + summary_path, + ], + ) + assert result.exit_code == 0 + assert_log_contains( + log_path, + includes=["Starting geometry optimization", "Using constraint: FixAtoms"], + ) + + @pytest.mark.parametrize("read_kwargs", ["{'index': 1}", "{}"]) def test_valid_traj_input(read_kwargs, tmp_path): """Test valid trajectory input structure handled."""