Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests_and_linters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ jobs:
run: |
apt-get update && apt-get install -y \
coreutils \
gcc \
git
git config --system --add safe.directory $GITHUB_WORKSPACE

Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## Release 0.1.8

- Fixing bug in ASE simulation engine to allow for passing Periodic Boundary
Conditions via the box config value.
- Adapting setup line for zero shifts array to use `np.zeros` to guarantee correct
array shape, see issue #36 for reference.

## Release 0.1.7

- Fixing issues with Periodic Boundary Conditions (PBCs) during inference.
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "mlip"
version = "0.1.7"
version = "0.1.8"
description = "Machine Learning Interatomic Potentials in JAX"
license-files = [
"LICENSE"
Expand Down
1 change: 1 addition & 0 deletions src/mlip/data/graph_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
self._reader = reader
self._config = dataset_config
self._dataset_info: Optional[DatasetInfo] = dataset_info
self._graphs: Optional[dict[str, list[jraph.GraphsTuple]]] = None
self._datasets: Optional[dict[str, Optional[GraphDataset]]] = None
# Sanity check when DatasetInfo is passed from the outside
cutoff = self._config.graph_cutoff_angstrom
Expand Down
2 changes: 1 addition & 1 deletion src/mlip/data/helpers/neighborhood.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_neighborhood(
)

# If we are not having PBCs, then use shifts of zero
shifts = senders_unit_shifts if any(pbc) else np.array([[0] * 3] * len(senders))
shifts = senders_unit_shifts if any(pbc) else np.zeros((len(senders), 3))

# See docstring of functions get_edge_relative_vectors() and
# get_edge_vectors() on how these return values are used
Expand Down
5 changes: 4 additions & 1 deletion src/mlip/simulation/ase/ase_simulation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def _initialize(
def _init_box(self) -> None:
"""Update the PBC parameters of the underlying `ase.Atoms`"""
# Pass if atoms already have PBC and cell, best source of truth
if self.atoms.cell is not None and self.atoms.pbc is not None:
if np.any(self.atoms.cell) or np.any(self.atoms.pbc):
logger.warning(
"Ignoring `box` parameter as `atoms` already has PBC configured."
)
return
# Support cubic periodic box from config for Jax-MD consistency.
# To be discouraged once both engines support arbitrary lattices.
Expand Down
19 changes: 19 additions & 0 deletions tests/models_inference/test_batched_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,22 @@ def test_batched_inference_works_correctly(
assert result[0].pressure is None
assert result[0].energy == pytest.approx(-0.11254195, abs=1e-3)
assert result[0].forces[0][0] == pytest.approx(0.04921325, abs=1e-3)


def test_batched_inference_with_graph_without_edges(setup_system_and_mace_model):
atoms, _, _, mace_ff = setup_system_and_mace_model

positions = np.array([[0, 0, 0], [0, 0, 10]])
atomic_numbers = np.array([6, 6])
atoms_without_edges = ase.Atoms(positions=positions, numbers=atomic_numbers)

structures = [atoms_without_edges, atoms]
result = run_batched_inference(structures, mace_ff, batch_size=2)

# For first structure, no energy if no edges
assert result[0].energy == 0.0
assert not result[0].forces.any()

# For second structure, normal result
assert result[1].energy == pytest.approx(-0.11254195, abs=1e-3)
assert result[1].forces[0][0] == pytest.approx(0.04921325, abs=1e-3)
27 changes: 27 additions & 0 deletions tests/simulation/test_ase_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,30 @@ def test_md_can_be_restarted_from_velocities_with_ase_backend(

for i in range(1, 5):
assert not np.allclose(engine.state.velocities[i], velocities_to_restore)


@pytest.mark.parametrize("atoms_cell", [None, np.eye(3) * 10.0])
def test_ase_engine_sets_cell_from_config(
setup_system_and_mace_model, atoms_cell
) -> None:
atoms, _, _, mace_ff = setup_system_and_mace_model
config_box_length = 25.0

_atoms = deepcopy(atoms)
if atoms_cell is None: # If atoms have no cell, use config.box
_atoms.set_cell(None)
_atoms.set_pbc(None)
target_cell = np.eye(3) * config_box_length
else: # If atoms have a cell, ignore config.box
_atoms.set_cell(atoms_cell)
_atoms.set_pbc(True)
target_cell = atoms_cell

md_config = ASESimulationEngine.Config(
simulation_type=SimulationType.MD,
num_steps=1,
box=config_box_length,
)
engine = ASESimulationEngine(_atoms, mace_ff, md_config)
assert (engine.atoms.get_cell() == target_cell).all()
assert engine.atoms.get_pbc().all()
Loading