Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

from collections.abc import Callable
import inspect
from pathlib import Path
from typing import Any
import warnings

from ase import Atoms, filters, units
from ase import Atoms, constraints, filters, units
from ase.filters import FrechetCellFilter
from ase.io import read
import ase.optimize
Expand Down Expand Up @@ -76,6 +77,10 @@ class GeomOpt(BaseCalculation):
Default is `FrechetCellFilter`.
filter_kwargs
Keyword arguments to pass to filter_class. Default is {}.
constraint_class
Constraint class, or name of class from ase.constraints. Default is None.
constraint_kwargs
Keyword arguments to pass to constraint_class. Default is {}.
optimizer
Optimization function, or name of function from ase.optimize. Default is
`LBFGS`.
Expand Down Expand Up @@ -113,6 +118,8 @@ def __init__(
angle_tolerance: float = -1.0,
filter_class: Callable | str | None = FrechetCellFilter,
filter_kwargs: dict[str, Any] | None = None,
constraint_class: type | str | None = None,
constraint_kwargs: dict[str, Any] | None = None,
optimizer: Callable | str = LBFGS,
opt_kwargs: ASEOptArgs | None = None,
write_results: bool = False,
Expand Down Expand Up @@ -167,6 +174,11 @@ def __init__(
Default is `FrechetCellFilter`.
filter_kwargs
Keyword arguments to pass to filter_class. Default is {}.
constraint_class
Constraint class, or name of class from ase.constraints. Default is
None.
constraint_kwargs
Keyword arguments to construct constraint_class. Default is {}.
optimizer
Optimization function, or name of function from ase.optimize. Default is
`LBFGS`.
Expand All @@ -184,9 +196,21 @@ def __init__(
"filename" keyword is inferred from `file_prefix` if not given.
Default is {}.
"""
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = (
(
read_kwargs,
constraint_kwargs,
filter_kwargs,
opt_kwargs,
write_kwargs,
traj_kwargs,
) = list(
none_to_dict(
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs
read_kwargs,
constraint_kwargs,
filter_kwargs,
opt_kwargs,
write_kwargs,
traj_kwargs,
)
)

Expand All @@ -197,6 +221,8 @@ def __init__(
self.angle_tolerance = angle_tolerance
self.filter_class = filter_class
self.filter_kwargs = filter_kwargs
self.constraint_class = constraint_class
self.constraint_kwargs = constraint_kwargs
self.optimizer = optimizer
self.opt_kwargs = opt_kwargs
self.write_results = write_results
Expand Down Expand Up @@ -277,12 +303,29 @@ def output_files(self) -> None:
"trajectory": self.traj_kwargs.get("filename"),
}

def _set_mandatory_constraint_kwargs(self) -> None:
"""
Inspect constraint class for mandatory arguments.

For now we are just looking for the "atoms" parameter of FixSymmetry
"""
parameters = inspect.signature(self.constraint_class.__init__).parameters
if "atoms" in parameters:
self.constraint_kwargs["atoms"] = self.struct

def set_optimizer(self) -> None:
"""Set optimizer for geometry optimization."""
self._set_functions()
if self.logger:
self.logger.info("Using optimizer: %s", self.optimizer.__name__)

if self.constraint_class is not None:
self._set_mandatory_constraint_kwargs()
self.struct.set_constraint(self.constraint_class(**self.constraint_kwargs))

if self.logger:
self.logger.info("Using constraint: %s", self.constraint_class.__name__)

if self.filter_class is not None:
if "scalar_pressure" in self.filter_kwargs:
self.filter_kwargs["scalar_pressure"] *= units.GPa
Expand All @@ -308,13 +351,21 @@ def set_optimizer(self) -> None:
self.dyn = self.optimizer(self.struct, **self.opt_kwargs)

def _set_functions(self) -> None:
"""Set optimizer and filter."""
"""Set optimizer, constraint and filter functions."""
if isinstance(self.optimizer, str):
try:
self.optimizer = getattr(ase.optimize, self.optimizer)
except AttributeError as e:
raise AttributeError(f"No such optimizer: {self.optimizer}") from e

if self.constraint_class is not None and isinstance(self.constraint_class, str):
try:
self.constraint_class = getattr(constraints, self.constraint_class)
except AttributeError as e:
raise AttributeError(
f"No such constraint: {self.constraint_class}"
) from e

if self.filter_class is not None and isinstance(self.filter_class, str):
try:
self.filter_class = getattr(filters, self.filter_class)
Expand Down
14 changes: 14 additions & 0 deletions janus_core/cli/geomopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def geomopt(
rich_help_panel="Calculation",
),
] = None,
constraint_class: Annotated[
str,
Option(
"--constraint",
help="Name of ASE constraint to attach to atoms.",
rich_help_panel="Calculation",
),
] = None,
pressure: Annotated[
float,
Option(
Expand Down Expand Up @@ -205,6 +213,10 @@ def geomopt(
filter_class
Name of filter from ase.filters to wrap around atoms. If using
--opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`.
constraint_class
Name of constraint class from ase.constraints, to apply constraints
to atoms. Parameters should be included as a "constraint_kwargs" dict
within "minimize_kwargs". Default is None.
pressure
Scalar pressure when optimizing cell geometry, in GPa. Passed to the filter
function if either `opt_cell_lengths` or `opt_cell_fully` is True. Default is
Expand Down Expand Up @@ -283,6 +295,7 @@ def geomopt(
# Check optimized structure path not duplicated
if "filename" in write_kwargs:
raise ValueError("'filename' must be passed through the --out option")

if out:
write_kwargs["filename"] = out

Expand Down Expand Up @@ -323,6 +336,7 @@ def geomopt(
"symmetrize": symmetrize,
"symmetry_tolerance": symmetry_tolerance,
"file_prefix": file_prefix,
"constraint_class": constraint_class,
**opt_cell_fully_dict,
**minimize_kwargs,
"write_results": True,
Expand Down
23 changes: 23 additions & 0 deletions tests/test_geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ase.filters import FrechetCellFilter, UnitCellFilter
from ase.io import read
from numpy.testing import assert_allclose
import pytest

from janus_core.calculations.geom_opt import GeomOpt
Expand Down Expand Up @@ -52,6 +53,28 @@ def test_optimize(arch, struct, expected, kwargs):
)


def test_constrained_optimize():
"""Test optimizing geometry using MACE with ASE constraint."""
single_point = SinglePoint(
struct=DATA_PATH / "H2O.cif",
arch="mace",
model=MODEL_PATH,
)

initial_positions = single_point.struct.positions.copy()

optimizer = GeomOpt(
single_point.struct,
filter_class=None, # No volume opt
constraint_class="FixAtoms",
constraint_kwargs={"indices": [2]},
)
optimizer.run()

assert_allclose(initial_positions[2], optimizer.struct.positions[2])
assert (initial_positions[:2] - optimizer.struct.positions[:2]).any()


def test_saving_struct(tmp_path):
"""Test saving optimized structure."""
results_path = tmp_path / "NaCl.extxyz"
Expand Down
33 changes: 33 additions & 0 deletions tests/test_geomopt_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,39 @@ def test_filter_str_error(tmp_path):
assert isinstance(result.exception, ValueError)


def test_constraint(tmp_path):
"""Test setting constraint."""
results_path = tmp_path / "H2O-opt.extxyz"
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

result = runner.invoke(
app,
[
"geomopt",
"--struct",
DATA_PATH / "H2O.cif",
"--arch",
"mace_mp",
"--out",
results_path,
"--constraint",
"FixAtoms",
"--minimize-kwargs",
"{'constraint_kwargs': {'indices': [2]}}",
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 0
assert_log_contains(
log_path,
includes=["Starting geometry optimization", "Using constraint: FixAtoms"],
)


@pytest.mark.parametrize("read_kwargs", ["{'index': 1}", "{}"])
def test_valid_traj_input(read_kwargs, tmp_path):
"""Test valid trajectory input structure handled."""
Expand Down