diff --git a/docs/source/user_guide/benchmarks/electrolytes.rst b/docs/source/user_guide/benchmarks/electrolytes.rst new file mode 100644 index 000000000..950613148 --- /dev/null +++ b/docs/source/user_guide/benchmarks/electrolytes.rst @@ -0,0 +1,65 @@ +============ +Electrolytes +============ + +SSE-MD +====== + +Summary +------- + +Performance in predicting the structural dynamics of 49 solid-state electrolyte (SSE) +systems via long-timescale molecular dynamics. Systems include Li, Cs, Cu, and +Na-containing ionic conductors spanning a temperature range of 300–1300 K. Radial +distribution functions (RDFs) computed from model trajectories are compared against +*ab initio* molecular dynamics (AIMD) reference data. + + +Metrics +------- + +1. RDF Score + +Minimum RDF similarity score across all systems and element pairs. + +For each system, a 1 ns NVT molecular dynamics trajectory is generated using +a Nosé-Hoover chain thermostat at the target temperature. After discarding an +equilibration period of 5 ps, pairwise radial distribution functions (RDFs) are +computed for all unique element pair combinations. Each RDF is compared to the +corresponding AIMD reference using the normalised mean absolute error metric +described in Schran et al. (2021): + +.. math:: + + \epsilon = \frac{\sum |g_\mathrm{model}(r) - g_\mathrm{ref}(r)|}{\sum g_\mathrm{model}(r) + \sum g_\mathrm{ref}(r)} + +The per-pair score is defined as :math:`1 - \epsilon`, and the per-system score +is the minimum score across all element pairs. The reported RDF Score is the +minimum across all systems. A score of 1.0 indicates perfect agreement with the +reference data. + +* C. Schran, F. L. Thiemann, P. Rowe, E. A. Müller, O. Marsalek, A. Michaelides, + "Machine learning potentials for complex aqueous systems made simple", + Proceedings of the National Academy of Sciences 118, e2110077118 (2021). + +Computational cost +------------------ + +High: tests are likely to take hours to days to run on GPU. + + +Data availability +----------------- + +Input structures: + +* 49 solid-state electrolyte structures comprising Li, Cs, Cu, and Na-containing + ionic conductors at temperatures between 300 K and 1300 K. + +Reference data: + +* AIMD reference RDFs computed with VASP using the PBE functional from + +* López, C., Rurali, R. & Cazorla, C. How Concerted Are Ionic Hops in +Inorganic Solid-State Electrolytes? J. Am. Chem. Soc. 146, 8269–8279 (2024). + diff --git a/docs/source/user_guide/benchmarks/index.rst b/docs/source/user_guide/benchmarks/index.rst index ad3c82f96..9975876a4 100644 --- a/docs/source/user_guide/benchmarks/index.rst +++ b/docs/source/user_guide/benchmarks/index.rst @@ -14,5 +14,6 @@ Benchmarks bulk_crystal lanthanides non_covalent_interactions + electrolytes tm_complexes conformers diff --git a/ml_peg/analysis/electrolytes/SSEMD/analyse_SSEMD.py b/ml_peg/analysis/electrolytes/SSEMD/analyse_SSEMD.py new file mode 100644 index 000000000..d1def8256 --- /dev/null +++ b/ml_peg/analysis/electrolytes/SSEMD/analyse_SSEMD.py @@ -0,0 +1,496 @@ +"""Analyse SSE-MD benchmark.""" + +from __future__ import annotations + +import itertools +import math +import os +from pathlib import Path +import pickle + +from ase import Atoms, io +from MDAnalysis import Universe +import numpy as np +import pytest + +from ml_peg.analysis.utils.decorators import build_table, plot_parity +from ml_peg.analysis.utils.utils import build_d3_name_map, load_metrics_config +from ml_peg.app import APP_ROOT +from ml_peg.calcs import CALCS_ROOT +from ml_peg.calcs.electrolytes.SSEMD.calc_SSEMD import ( + DELTA_T_FS, + FRAME_FREQUENCY, + N_EQUI_FRAMES, +) +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +MODELS = get_model_names(current_models) +D3_MODEL_NAMES = build_d3_name_map(MODELS) +CALC_PATH = CALCS_ROOT / "electrolytes" / "SSEMD" / "outputs" +OUT_PATH = APP_ROOT / "electrolytes" / "SSEMD" + +METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml") +DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, DEFAULT_WEIGHTS = load_metrics_config( + METRICS_CONFIG_PATH +) + +BIN_SIZE: float = 0.05 # Angstrom + + +def get_system_names() -> list[str]: + """ + Get list of SSE-MD system names from trajectory outputs. + + Returns + ------- + list[str] + List of system names derived from trajectory file names. + """ + system_names = [] + for model_name in MODELS: + model_dir = CALC_PATH / model_name + if model_dir.exists(): + traj_files = sorted(model_dir.glob("*.traj")) + if traj_files: + for traj_file in traj_files: + # Strip model name suffix to recover system name + system_name = traj_file.stem.removesuffix(f"_{model_name}") + system_names.append(system_name) + break + return system_names + + +def ase2mda(atoms: list[Atoms], time_between_frames: float) -> Universe: + """ + Convert an ASE trajectory to an MDAnalysis Universe. + + Parameters + ---------- + atoms + List of ASE Atoms frames. + time_between_frames + Time between consecutive frames in fs. + + Returns + ------- + Universe + MDAnalysis Universe with loaded coordinates. + """ + universe = Universe.empty(n_atoms=len(atoms[0]), trajectory=True) + universe.add_TopologyAttr("name", atoms[0].get_chemical_symbols()) + universe.add_TopologyAttr("type", atoms[0].get_chemical_symbols()) + universe.add_TopologyAttr("masses", atoms[0].get_masses()) + coordinates = np.asarray([np.asarray(frame.positions) for frame in atoms]) + universe.load_new( + coordinates, + dimensions=np.asarray(atoms[0].cell.cellpar()), + dt=time_between_frames * 0.001, + ) + return universe + + +def get_rmax_from_cell(lattice_vectors: np.ndarray) -> float: + """ + Compute the maximum RDF cutoff distance from lattice vectors. + + Parameters + ---------- + lattice_vectors + 3x3 array of cell lattice vectors. + + Returns + ------- + float + Half the minimum lattice image distance. + """ + min_dist = np.inf + for n in itertools.product(range(-2, 2 + 1), repeat=3): + if n == (0, 0, 0): + continue + R = np.array(n) @ lattice_vectors + dist = np.linalg.norm(R) + if dist < min_dist: + min_dist = dist + return min_dist * 0.5 + + +def get_element_pairs(species_in_system: list) -> list: + """ + Return unique sorted element pair combinations. + + Parameters + ---------- + species_in_system + Sorted list of unique element symbols. + + Returns + ------- + list + List of ``[element_a, element_b]`` pairs with ``a <= b``. + """ + element_combs = list(itertools.product(species_in_system, repeat=2)) + return [list(comb) for comb in element_combs if comb[0] <= comb[1]] + + +def compute_rdf( + traj: Universe, + rmax: float, + elements: list, + bin_size: float, + cell: np.ndarray, +) -> tuple[list, list]: + """ + Compute the radial distribution function for a given element pair. + + Parameters + ---------- + traj + MDAnalysis Universe trajectory. + rmax + Maximum distance cutoff. + elements + Two-element list ``[element_a, element_b]``. + bin_size + Histogram bin width in Angstrom. + cell + 3x3 cell matrix. + + Returns + ------- + tuple[list, list] + ``(bin_centres, rdf_values)`` + """ + nbins = int(np.ceil(rmax / bin_size)) + edges = np.arange(0.0, nbins + 1) * bin_size + edges[-1] = rmax + + rdf = np.zeros(nbins, dtype=float) + + ag1 = traj.select_atoms(f"name {elements[0]}") + ag2 = traj.select_atoms(f"name {elements[1]}") + vol_cum = 0 + n_frames = 0 + for frame in traj.trajectory: + n_frames += 1 + r_ij = np.asarray(ag2.positions[None, :, :]) - np.asarray( + ag1.positions[:, None, :], + ) + s_ij = r_ij @ np.linalg.inv(cell) + s_ij_mic = (s_ij + 0.5) % 1.0 - 0.5 + r_ij_mic = s_ij_mic @ cell + + dist = np.linalg.norm(r_ij_mic, axis=2).flatten() + mask = (dist > 0.0) & (dist < rmax) + + counts, _ = np.histogram(dist[mask], bins=edges) + rdf += counts + vol_cum += frame.volume + + bins = 0.5 * (edges[1:] + edges[:-1]) + shell_volumes = 4 / 3 * math.pi * np.diff(np.power(edges, 3)) + density = ag2.n_atoms / (vol_cum / n_frames) + n_id_gas = density * shell_volumes + norm = n_id_gas * ag1.n_atoms * n_frames + rdf /= norm + + return list(bins), list(rdf) + + +def compute_rdfs_all( + traj: Universe, + rmax: float, + element_pairs: list, + bin_size: float, + cell: np.ndarray, +) -> dict: + """ + Compute RDFs for all element pairs in a trajectory. + + Parameters + ---------- + traj + MDAnalysis Universe trajectory. + rmax + Maximum distance cutoff. + element_pairs + List of ``[element_a, element_b]`` pairs. + bin_size + Histogram bin width in Angstrom. + cell + 3x3 cell matrix. + + Returns + ------- + dict + Mapping ``"A-B" -> (bin_centres, rdf_values)``. + """ + rdfs = {} + for ele in element_pairs: + rdfs["-".join(ele)] = compute_rdf(traj, rmax, ele, bin_size, cell) + return rdfs + + +def metric_pnas( + rdf_ref: dict, model_rdf: dict +) -> dict[str, float]: + """Compute normalised MAEs relative to reference RDF data. + + Given two sets of RDFs returns the mean absolute error + per element pair. + + Inspired by and partially taken from: + https://github.com/MarsalekGroup/aml/blob/main/aml/score/util.py + + C. Schran, F. L. Thiemann, P. Rowe, E. A. Müller, O. Marsalek, + A. Michaelides, "Machine learning potentials for complex aqueous + systems made simple", PNAS 118, e2110077118 (2021), + 10.1073/pnas.2110077118 + + Parameters + ---------- + rdf_ref + Reference RDFs ``{pair: (bins, values)}``. + model_rdf + Model RDFs ``{pair: (bins, values)}``. + + Returns + ------- + dict[str, float] + Normalised MAE per element pair. + """ + error: dict[str, float] = {} + for name, data in rdf_ref.items(): + ref_vals = np.asarray(data[1][:-1]) + mod_vals = np.asarray(model_rdf[name][1][:-1]) + diff = ref_vals - mod_vals + mae_val = np.sum(np.absolute(diff)) / ( + np.sum(ref_vals) + np.sum(mod_vals) + ) + error[name] = float(mae_val) + return error + + +def compute_rdf_score(g_aimd: dict, g_model: dict) -> float: + """Compute RDF similarity score using PNAS metric. + + Returns the minimum ``(1 - error)`` across all element pairs for a + single system. A score of 1.0 indicates perfect agreement. + + Parameters + ---------- + g_aimd + Reference RDFs for one system. + g_model + Model RDFs for one system. + + Returns + ------- + float + Minimum RDF similarity score across element pairs. + """ + errors = metric_pnas(g_aimd, g_model) + scores = [1.0 - err for err in errors.values()] + return float(np.min(scores)) + + +def load_reference_rdfs() -> dict[str, dict]: + """Load AIMD reference RDFs for all systems from ``rdf_aimd.pkl`` files. + + Extracts the SSEs_data zip (same approach as ``calc_SSEMD.py``) and + walks the directory tree to find ``rdf_aimd.pkl`` files alongside each + POSCAR. Each pickle contains a dict mapping element pair labels to + ``(bins, rdf_values)`` tuples. + + Returns + ------- + dict[str, dict] + Mapping of ``system_name -> {pair_label: (bins, rdf_values)}``. + """ + from ml_peg.calcs.utils.utils import extract_zip + + scratch_dir = Path(os.getenv("SCRATCH", ".")) + data_dir = ( + extract_zip(filename=(scratch_dir / ".cache" / "ml-peg" / "SSEs_data.zip")) + / "SSEs_data" + ) + + ref_rdfs: dict[str, dict] = {} + for pkl_file in sorted(data_dir.rglob("rdf_aimd.pkl")): + temp_dir = pkl_file.parent + compound_dir = temp_dir.parent.parent + system_name = ( + f"{compound_dir.name}_{temp_dir.parent.name}_{temp_dir.name}" + ) + + with open(pkl_file, "rb") as f: + rdf_data = pickle.load(f) # noqa: S301 + + ref_rdfs[system_name] = rdf_data + + return ref_rdfs + + +def compute_model_rdfs(model_name: str) -> dict[str, dict]: + """Compute RDFs from a model's MD trajectory outputs. + + Reads the saved ``.traj`` files produced by ``calc_SSEMD.py``, skips + equilibration frames, subsamples, and computes RDFs for every element + pair in each system. + + Parameters + ---------- + model_name + Name of the MLIP model. + + Returns + ------- + dict[str, dict] + Mapping of ``system_name -> {pair_label: (bins, rdf_values)}``. + """ + model_dir = CALC_PATH / model_name + if not model_dir.exists(): + return {} + + system_rdfs: dict[str, dict] = {} + traj_files = sorted(model_dir.glob("*.traj")) + + for traj_file in traj_files: + system_name = traj_file.stem.removesuffix(f"_{model_name}") + + # Read trajectory, skip equilibration and subsample + ase_traj = io.read(str(traj_file), index=f"{N_EQUI_FRAMES}:") + + # if not ase_traj: + # continue + + time_between_frames = DELTA_T_FS * FRAME_FREQUENCY + mda_traj = ase2mda(ase_traj, time_between_frames) + cell = np.array(ase_traj[0].cell) + rmax = get_rmax_from_cell(cell) + element_pairs = get_element_pairs( + sorted(set(ase_traj[0].get_chemical_symbols())) + ) + rdfs = compute_rdfs_all(mda_traj, rmax, element_pairs, BIN_SIZE, cell) + system_rdfs[system_name] = rdfs + + return system_rdfs + + +@pytest.fixture +@plot_parity( + filename=OUT_PATH / "figure_ssemd_scores.json", + title="SSE-MD Scores", + x_label="Predicted RDF score", + y_label="Reference RDF score (ideal = 1)", + hoverdata={ + "System": get_system_names(), + }, +) +def rdf_scores() -> dict[str, list]: + """ + Get per-system RDF similarity scores for all models. + + Computes RDFs from model trajectories and compares them against + AIMD reference data (or a pseudo-reference from the first available + model while the reference loader is a placeholder). + + Returns + ------- + dict[str, list] + Dictionary with ``"ref"`` key (ideal scores of 1.0) and one key + per model containing per-system RDF scores. + """ + system_names = get_system_names() + results: dict[str, list] = {"ref": [1.0] * len(system_names)} | { + mlip: [] for mlip in MODELS + } + + # Pre-compute all model RDFs so each trajectory is read only once + all_model_rdfs: dict[str, dict] = {} + for model_name in MODELS: + rdfs = compute_model_rdfs(model_name) + if rdfs: + all_model_rdfs[model_name] = rdfs + + # Load AIMD reference RDFs + ref_rdfs = load_reference_rdfs() + + # Score each model against the reference + for model_name in MODELS: + model_rdfs = all_model_rdfs.get(model_name, {}) + + for system_name in system_names: + if system_name in model_rdfs and system_name in ref_rdfs: + score = compute_rdf_score( + ref_rdfs[system_name], model_rdfs[system_name] + ) + results[model_name].append(score) + else: + results[model_name].append(None) + + return results + + +@pytest.fixture +def ssemd_errors(rdf_scores: dict[str, list]) -> dict[str, float | None]: + """ + Compute mean RDF score for each model across all systems. + + Parameters + ---------- + rdf_scores + Per-system RDF scores for every model. + + Returns + ------- + dict[str, float | None] + Mean RDF score per model, or ``None`` if no data available. + """ + results: dict[str, float | None] = {} + for model_name in MODELS: + scores = rdf_scores.get(model_name, []) + valid = [s for s in scores if s is not None] + if valid: + results[model_name] = float(np.mean(valid)) + else: + results[model_name] = None + return results + + +@pytest.fixture +@build_table( + filename=OUT_PATH / "ssemd_metrics_table.json", + metric_tooltips=DEFAULT_TOOLTIPS, + thresholds=DEFAULT_THRESHOLDS, + mlip_name_map=D3_MODEL_NAMES, +) +def metrics(ssemd_errors: dict[str, float | None]) -> dict[str, dict]: + """ + Get all SSE-MD metrics. + + Parameters + ---------- + ssemd_errors + Mean RDF scores for all models. + + Returns + ------- + dict[str, dict] + Metric names and values for all models. + """ + return { + "RDF Score": ssemd_errors, + } + + +def test_ssemd(metrics: dict[str, dict]) -> None: + """ + Run SSE-MD test. + + Parameters + ---------- + metrics + All SSE-MD metrics. + """ + return diff --git a/ml_peg/analysis/electrolytes/SSEMD/metrics.yml b/ml_peg/analysis/electrolytes/SSEMD/metrics.yml new file mode 100644 index 000000000..09c694fe1 --- /dev/null +++ b/ml_peg/analysis/electrolytes/SSEMD/metrics.yml @@ -0,0 +1,7 @@ +metrics: + RDF Score: + good: 1.0 + bad: 0.0 + unit: null + tooltip: "Minimum RDF similarity score across all systems and element pairs" + level_of_theory: PBE \ No newline at end of file diff --git a/ml_peg/app/electrolytes/SSEMD/app_SSEMD.py b/ml_peg/app/electrolytes/SSEMD/app_SSEMD.py new file mode 100644 index 000000000..5c1bd9e0a --- /dev/null +++ b/ml_peg/app/electrolytes/SSEMD/app_SSEMD.py @@ -0,0 +1,74 @@ +"""Run SSE-MD app.""" + +from __future__ import annotations + +from dash import Dash +from dash.html import Div + +from ml_peg.app import APP_ROOT +from ml_peg.app.base_app import BaseApp +from ml_peg.app.utils.build_callbacks import plot_from_table_column +from ml_peg.app.utils.load import read_plot +from ml_peg.models.get_models import get_model_names +from ml_peg.models.models import current_models + +# Get all models +MODELS = get_model_names(current_models) +BENCHMARK_NAME = "SSE-MD Scores" +DOCS_URL = ( + "https://ddmms.github.io/ml-peg/user_guide/benchmarks/electrolytes.html#sse-md" +) +DATA_PATH = APP_ROOT / "electrolytes" / "SSEMD" + + +class SSEMDApp(BaseApp): + """SSE-MD benchmark app layout and callbacks.""" + + def register_callbacks(self) -> None: + """Register callbacks to app.""" + scatter = read_plot( + DATA_PATH / "figure_ssemd_scores.json", + id=f"{BENCHMARK_NAME}-figure", + ) + + plot_from_table_column( + table_id=self.table_id, + plot_id=f"{BENCHMARK_NAME}-figure-placeholder", + column_to_plot={"RDF Score": scatter}, + ) + + +def get_app() -> SSEMDApp: + """ + Get SSE-MD benchmark app layout and callback registration. + + Returns + ------- + SSEMDApp + Benchmark layout and callback registration. + """ + return SSEMDApp( + name=BENCHMARK_NAME, + description=( + "RDF similarity scores for solid-state electrolyte systems, " + "comparing MLIP MD trajectories against AIMD reference data." + ), + docs_url=DOCS_URL, + table_path=DATA_PATH / "ssemd_metrics_table.json", + extra_components=[ + Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), + ], + ) + + +if __name__ == "__main__": + # Create Dash app + full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent) + + # Construct layout and register callbacks + ssemd_app = get_app() + full_app.layout = ssemd_app.layout + ssemd_app.register_callbacks() + + # Run app + full_app.run(port=8056, debug=True) diff --git a/ml_peg/app/electrolytes/electrolytes.yml b/ml_peg/app/electrolytes/electrolytes.yml new file mode 100644 index 000000000..1d505f84c --- /dev/null +++ b/ml_peg/app/electrolytes/electrolytes.yml @@ -0,0 +1,3 @@ +title: Electrolytes +description: Solid-state electrolyte properties. + diff --git a/ml_peg/calcs/electrolytes/SSEMD/calc_SSEMD.py b/ml_peg/calcs/electrolytes/SSEMD/calc_SSEMD.py new file mode 100644 index 000000000..6f4ba65a4 --- /dev/null +++ b/ml_peg/calcs/electrolytes/SSEMD/calc_SSEMD.py @@ -0,0 +1,148 @@ +"""Run calculations for SSE RDF benchmark tests.""" + +from __future__ import annotations + +from collections.abc import Generator +from copy import copy +import os +from pathlib import Path +from typing import Any + +from ase import Atoms, io, units +from ase.calculators.calculator import Calculator +from ase.io import Trajectory +from ase.md.nose_hoover_chain import NoseHooverChainNVT +from ase.md.velocitydistribution import ( + MaxwellBoltzmannDistribution, + Stationary, + ZeroRotation, +) +import numpy as np +import pytest + +# from ml_peg.calcs.utils.utils import download_s3_data +from ml_peg.models.get_models import load_models +from ml_peg.models.models import current_models + +MODELS: dict[str, Any] = load_models(models=current_models) + +OUT_PATH: Path = Path(__file__).parent / "outputs" + +# Benchmark parameters +TOTAL_TIME_NS: float = 1.0 # ns +DELTA_T_FS: float = 0.5 +SEED: int = 0 +FRAME_FREQUENCY: int = 15 +NSTEPS: int = int(TOTAL_TIME_NS * 1e6 / DELTA_T_FS) + +EQUI_TIME_NS: float = 0.005 # 5 ps +N_EQUI_STEPS: int = int(EQUI_TIME_NS * 1e6 / DELTA_T_FS) +N_EQUI_FRAMES: int = N_EQUI_STEPS // FRAME_FREQUENCY +TCHAIN: int = 10 + + +def get_systems(data_dir: Path) -> Generator[tuple[Path, float, str], None, None]: + """ + Discover all SSE RDF systems from the extracted data directory. + + Walks the directory tree looking for POSCAR files under the structure + ``{system}/stoichiometric/{temperature}K/POSCAR``. + + Parameters + ---------- + data_dir + Path to the top-level SSEs_data directory. + + Returns + ------- + Generator[tuple[Path, float, str], None, None] + Generator yielding (poscar_dir, temperature, system_name) for each system. + """ + for poscar_file in sorted(data_dir.rglob(pattern="POSCAR")): + temp_dir: Path = poscar_file.parent + compound_dir: Path = temp_dir.parent.parent + + temperature = float(temp_dir.name.rstrip("K")) + system_name = f"{compound_dir.name}_{temp_dir.parent.name}_{temp_dir.name}" + yield temp_dir, temperature, system_name + + +@pytest.mark.parametrize(argnames="mlip", argvalues=MODELS.items()) +def test_ssemd_benchmark(mlip: tuple[str, Any]) -> None: + """ + Run SSE RDF benchmark test. + + Runs NVT molecular dynamics using a Nosé-Hoover chain thermostat + for each system. + + Parameters + ---------- + mlip + Name of model and model to get calculator. + """ + model_name, model = mlip + calc: Calculator = model.get_calculator() + + timestep: float = DELTA_T_FS * units.fs + tdamp: float = 100 * timestep + + data_dir = ( + download_s3_data( + key="inputs/electrolytes/SSE/SSE.zip", + filename="SSE.zip", + ) + / "SSE" + ) + + # TODO: Check if it is possible to parallelize over systems + for poscar_dir, temperature, system_name in get_systems(data_dir=data_dir): + poscar_file: Path = poscar_dir / "POSCAR" + atoms_initial: Atoms | list[Atoms] = io.read( + filename=poscar_file, format="vasp" + ) + + atoms: Atoms = atoms_initial.copy() # type: ignore[assignment] + atoms.calc = copy(calc) + + rng = np.random.RandomState(seed=SEED) + MaxwellBoltzmannDistribution( + atoms, temperature_K=temperature, force_temp=True, rng=rng + ) + Stationary(atoms) + ZeroRotation(atoms) + + file_name = f"{system_name}_{model_name}" + + # Write output directory + write_dir: Path = OUT_PATH / model_name + write_dir.mkdir(parents=True, exist_ok=True) + + log_path: Path = write_dir / f"{file_name}.log" + traj_path: Path = write_dir / f"{file_name}.traj" + + md_nvt = NoseHooverChainNVT( + atoms=atoms, + timestep=timestep, + temperature_K=temperature, + tdamp=tdamp, + tchain=TCHAIN, + logfile=str(log_path), + ) + + traj = Trajectory(filename=str(traj_path), mode="w", atoms=atoms) + md_nvt.attach(function=traj.write, interval=FRAME_FREQUENCY) + + md_nvt.run(steps=NSTEPS) + + # Read trajectory, skip equilibration frames and subsample + ase_traj: Atoms | list[Atoms] = io.read( + filename=str(traj_path), index=f"{N_EQUI_FRAMES}:" + ) + + # Store metadata on each frame + for frame in ase_traj: + frame.info["system"] = system_name + frame.info["temperature"] = temperature + frame.info["delta_t"] = DELTA_T_FS + frame.info["nsteps"] = NSTEPS +