Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
8fac8e9
Initial ox_state calc directory
Jan 28, 2026
81be364
Test calc script for aqueous iron chloride MD
Feb 5, 2026
44513fe
Core setup of Fe_oxidation_states benchmarks - fix to remove unecessa…
Feb 6, 2026
9cc0d98
Fix metrics, highlight ref on the plots, fix analysis, add rdf calcul…
Feb 13, 2026
f7a82ef
Apply suggestion from @ElliottKasoar
PKourtis Feb 16, 2026
c4bcb21
Applied PR suggestion on highighted_range for the plot scatter decora…
Feb 16, 2026
7adf1dd
Revert pyproject.toml to upstream version
Feb 18, 2026
3eef358
Cleaned up model declaration in the app.py
PKourtis Feb 19, 2026
5efad90
Each model has its own directory within outputs and the rdf tests get…
PKourtis Feb 19, 2026
996a551
Updated metrics level of theory to Experimental
PKourtis Feb 19, 2026
60469b6
Updated analysis to match outputs/model_name data directory pattern
PKourtis Feb 19, 2026
336ddb7
Added download from S3 bucket function for the input data
PKourtis Feb 19, 2026
fa78ba4
Added yes/no units
PKourtis Feb 19, 2026
9b0bed8
Fixed plot_scatter highlighted range title and plot title mixup
PKourtis Feb 19, 2026
9ac8f24
Download MD starting structures from S3 bucket and save outputs in th…
PKourtis Feb 19, 2026
bf5e038
Added model name to the scatter title
PKourtis Feb 19, 2026
52c17ff
Apply pre-commit
ElliottKasoar Mar 11, 2026
f8a4a87
Delete calc data
ElliottKasoar Mar 11, 2026
d12f970
Remove output files
ElliottKasoar Mar 12, 2026
0a23d84
Add docs link
ElliottKasoar Mar 12, 2026
c81facd
Changed dynamics function to NVT for clarity and added Elliot's minor…
PKourtis Apr 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Analyse aqueous Iron Chloride oxidation states."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from ml_peg.analysis.utils.decorators import build_table, plot_scatter
from ml_peg.analysis.utils.utils import load_metrics_config
from ml_peg.app import APP_ROOT
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

MODELS = get_model_names(current_models)

CALC_PATH = CALCS_ROOT / "physicality" / "oxidation_states" / "outputs"
OUT_PATH = APP_ROOT / "data" / "physicality" / "oxidation_states"

METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml")
DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, _ = load_metrics_config(METRICS_CONFIG_PATH)

IRON_SALTS = ["Fe2Cl", "Fe3Cl"]
TESTS = ["Fe-O RDF Peak Split", "Peak Within Experimental Ref"]
REF_PEAK_RANGE = {
"Fe<sup>+2</sup><br>Ref": [2.0, 2.2],
"Fe<sup>+3</sup><br>Ref": [1.9, 2.0],
}


def get_rdf_results(
model: str,
) -> dict[str, tuple[list[float], list[float]]]:
"""
Get a model's Fe-O RDFs for the aqueous Fe2Cl and Fe3Cl MD.

Parameters
----------
model
Name of MLIP.

Returns
-------
results
RDF Radii and intensities for the aqueous Fe2Cl and Fe3Cl systems.
"""
results = {salt: [] for salt in IRON_SALTS}

model_calc_path = CALC_PATH / model

for salt in IRON_SALTS:
rdf_file = model_calc_path / f"O-Fe_{salt}_{model}.rdf"

fe_o_rdf = np.loadtxt(rdf_file)
r = list(fe_o_rdf[:, 0])
g_r = list(fe_o_rdf[:, 1])

results[salt].append(r)
results[salt].append(g_r)

return results


def plot_rdfs(model: str, results: dict[str, tuple[list[float], list[float]]]) -> None:
"""
Plot Fe-O RDFs.

Parameters
----------
model
Name of MLIP.
results
RDF Radii and intensities for the aqueous Fe2Cl and Fe3Cl systems.
"""

@plot_scatter(
filename=OUT_PATH / f"Fe-O_{model}_RDF_scatter.json",
title=f"<b>{model} MD</b>",
x_label="r [Å]",
y_label="Fe-O G(r)",
show_line=True,
show_markers=False,
highlight_range=REF_PEAK_RANGE,
)
def plot_result() -> dict[str, tuple[list[float], list[float]]]:
"""
Plot the RDFs.

Returns
-------
model_results
Dictionary of model Fe-O RDFs for the aqueous Fe2Cl and Fe3Cl systems.
"""
return results

plot_result()


@pytest.fixture
def get_oxidation_states_passfail() -> dict[str, dict]:
"""
Test whether model RDF peaks are split and they fall within the reference range.

Returns
-------
oxidation_states_passfail
Dictionary of pass fail per model.
"""
oxidation_state_passfail = {test: {} for test in TESTS}

fe_2_ref = [2.0, 2.2]
fe_3_ref = [1.9, 2.0]

for model in MODELS:
peak_position = {}
results = get_rdf_results(model)
plot_rdfs(model, results)

for salt in IRON_SALTS:
r = results[salt][0]
g_r = results[salt][1]
peak_position[salt] = r[g_r.index(max(g_r))]

peak_difference = abs(peak_position["Fe2Cl"] - peak_position["Fe3Cl"])

oxidation_state_passfail["Fe-O RDF Peak Split"][model] = 0.0
oxidation_state_passfail["Peak Within Experimental Ref"][model] = 0.0

if peak_difference > 0.07:
oxidation_state_passfail["Fe-O RDF Peak Split"][model] = 1.0

if fe_2_ref[0] <= peak_position["Fe2Cl"] <= fe_2_ref[1]:
oxidation_state_passfail["Peak Within Experimental Ref"][model] += 0.5

if fe_3_ref[0] <= peak_position["Fe3Cl"] <= fe_3_ref[1]:
oxidation_state_passfail["Peak Within Experimental Ref"][model] += 0.5

return oxidation_state_passfail


@pytest.fixture
@build_table(
filename=OUT_PATH / "oxidation_states_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
)
def oxidation_states_passfail_metrics(
get_oxidation_states_passfail: dict[str, dict],
) -> dict[str, dict]:
"""
Get all oxidation states pass fail metrics.

Parameters
----------
get_oxidation_states_passfail
Dictionary of pass fail per model.

Returns
-------
dict[str, dict]
Dictionary of pass fail per model.
"""
return get_oxidation_states_passfail


def test_oxidation_states_passfail_metrics(
oxidation_states_passfail_metrics: dict[str, dict],
) -> None:
"""
Run oxidation states test.

Parameters
----------
oxidation_states_passfail_metrics
All oxidation states pass fail.
"""
return
13 changes: 13 additions & 0 deletions ml_peg/analysis/physicality/oxidation_states/metrics.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
Fe-O RDF Peak Split:
Comment thread
PKourtis marked this conversation as resolved.
good: 1.0
bad: 0.0
unit: Yes(1)/No(0)
tooltip: Whether there is a split between Fe-O RDF peaks for different iron oxidation states
level_of_theory: Experimental
Peak Within Experimental Ref:
good: 1.0
bad: 0.0
unit: Yes(1)/No(0)
tooltip: Whether the RDF peak positions match experimental peaks
level_of_theory: Experimental
29 changes: 28 additions & 1 deletion ml_peg/analysis/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dash import dash_table
import numpy as np
import pandas as pd
import plotly.colors as pc
import plotly.graph_objects as go

from ml_peg.analysis.utils.utils import (
Expand Down Expand Up @@ -436,8 +437,10 @@ def plot_scatter(
x_label: str | None = None,
y_label: str | None = None,
show_line: bool = False,
show_markers: bool = True,
hoverdata: dict | None = None,
filename: str = "scatter.json",
highlight_range: dict = None,
) -> Callable:
"""
Plot scatter plot of MLIP results.
Expand All @@ -452,10 +455,14 @@ def plot_scatter(
Label for y-axis. Default is `None`.
show_line
Whether to show line between points. Default is False.
show_markers
Whether to show markers on the plot. Default is True.
hoverdata
Hover data dictionary. Default is `{}`.
filename
Filename to save plot as JSON. Default is "scatter.json".
highlight_range
Dictionary of rectangle title and x-axis endpoints.

Returns
-------
Expand Down Expand Up @@ -504,7 +511,13 @@ def plot_scatter_wrapper(*args, **kwargs) -> dict[str, Any]:
hovertemplate += f"<b>{key}: </b>%{{customdata[{i}]}}<br>"
customdata = list(zip(*hoverdata.values(), strict=True))

mode = "lines+markers" if show_line else "markers"
modes = []
if show_line:
modes.append("lines")
if show_markers:
modes.append("markers")

mode = "+".join(modes)

fig = go.Figure()
for mlip, value in results.items():
Expand All @@ -520,6 +533,20 @@ def plot_scatter_wrapper(*args, **kwargs) -> dict[str, Any]:
)
)

colors = pc.qualitative.Plotly

if highlight_range:
for i, (h_text, range) in enumerate(highlight_range.items()):
fig.add_vrect(
x0=range[0],
x1=range[1],
annotation_text=h_text,
annotation_position="top",
fillcolor=colors[i],
opacity=0.25,
line_width=0,
)

fig.update_layout(
title={"text": title},
xaxis={"title": {"text": x_label}},
Expand Down
90 changes: 90 additions & 0 deletions ml_peg/app/physicality/oxidation_states/app_oxidation_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Run oxidation states app."""

from __future__ import annotations

from dash import Dash
from dash.html import Div

from ml_peg.app import APP_ROOT
from ml_peg.app.base_app import BaseApp
from ml_peg.app.utils.build_callbacks import (
plot_from_table_cell,
)
from ml_peg.app.utils.load import read_plot
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

MODELS = get_model_names(current_models)

BENCHMARK_NAME = "Iron Oxidation States"
DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/physicality.html#oxidation-states"
DATA_PATH = APP_ROOT / "data" / "physicality" / "oxidation_states"
REF_PATH = CALCS_ROOT / "physicality" / "oxidation_states" / "data"


class FeOxidationStatesApp(BaseApp):
"""Fe Oxidation States benchmark app layout and callbacks."""

def register_callbacks(self) -> None:
"""Register callbacks to app."""
scatter_plots = {
model: {
"Fe-O RDF Peak Split": read_plot(
DATA_PATH / f"Fe-O_{model}_RDF_scatter.json",
id=f"{BENCHMARK_NAME}-{model}-figure-Fe-O-RDF",
),
"Peak Within Experimental Ref": read_plot(
DATA_PATH / f"Fe-O_{model}_RDF_scatter.json",
id=f"{BENCHMARK_NAME}-{model}-figure-Fe-O-RDF",
),
}
for model in MODELS
}

plot_from_table_cell(
table_id=self.table_id,
plot_id=f"{BENCHMARK_NAME}-figure-placeholder",
cell_to_plot=scatter_plots,
)


def get_app() -> FeOxidationStatesApp:
"""
Get Fe Oxidation States benchmark app layout and callback registration.

Returns
-------
FeOxidationStatesApp
Benchmark layout and callback registration.
"""
return FeOxidationStatesApp(
name=BENCHMARK_NAME,
description=(
"Evaluate model ability to capture different oxidation states of Fe"
"from aqueous Fe 2Cl and Fe 3Cl MD RDFs"
),
docs_url=DOCS_URL,
table_path=DATA_PATH / "oxidation_states_table.json",
extra_components=[
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
],
)


if __name__ == "__main__":
# Create Dash app
full_app = Dash(
__name__,
assets_folder=DATA_PATH.parent.parent,
suppress_callback_exceptions=True,
)

# Construct layout and register callbacks
FeOxidationStatesApp = get_app()
full_app.layout = FeOxidationStatesApp.layout
FeOxidationStatesApp.register_callbacks()

# Run app
full_app.run(port=8054, debug=True)
Loading
Loading