diff --git a/README.md b/README.md index b3929fa..bef4265 100644 --- a/README.md +++ b/README.md @@ -188,7 +188,31 @@ Type maps that starts with `m` (such as `mH`) or `OW` or `HW` will be recognized Two MM atoms will not build edges with each other. Such GNN+DPRc model can be directly used in AmberTools24. +## MACE-OFF Pretrained Models + +DeePMD-GNN supports loading MACE-OFF foundation models for use with DeePMD-kit and MD packages: + +```python +from deepmd_gnn import load_mace_off_model, convert_mace_off_to_deepmd + +# Load MACE-OFF model into DeePMD-GNN wrapper +model = load_mace_off_model("small") # or "medium", "large" + +# Convert to frozen format for MD simulations +convert_mace_off_to_deepmd("small", "frozen_model.pth") +# Now use with LAMMPS, AMBER/sander through DeePMD-kit +``` + +The frozen model can be used with: + +- **LAMMPS**: Set `DP_PLUGIN_PATH` and use `pair_style deepmd` +- **AMBER/sander**: Through DeePMD-kit's AMBER interface (supports QM/MM with DPRc) +- **Other MD packages**: Through DeePMD-kit's C++ interface + +For QM/MM simulations, use the DPRc mechanism: QM atoms use standard symbols (H, C, O), MM atoms use 'm' prefix (mH, mC) or HW/OW. + ## Examples -- [examples/water](examples/water) -- [examples/dprc](examples/dprc) +- [examples/water](examples/water) - Basic MACE and NequIP training +- [examples/dprc](examples/dprc) - DPRc (QM/MM) with mixed atom types +- [examples/mace_off](examples/mace_off) - Using MACE-OFF pretrained models with DeePMD-kit diff --git a/deepmd_gnn/__init__.py b/deepmd_gnn/__init__.py index 3a34b90..a5fee5c 100644 --- a/deepmd_gnn/__init__.py +++ b/deepmd_gnn/__init__.py @@ -4,11 +4,19 @@ from ._version import __version__ from .argcheck import mace_model_args +from .mace_off import ( + convert_mace_off_to_deepmd, + download_mace_off_model, + load_mace_off_model, +) __email__ = "jinzhe.zeng@ustc.edu.cn" __all__ = [ "__version__", + "convert_mace_off_to_deepmd", + "download_mace_off_model", + "load_mace_off_model", "mace_model_args", ] diff --git a/deepmd_gnn/__main__.py b/deepmd_gnn/__main__.py index cbfb7a3..8f2b063 100644 --- a/deepmd_gnn/__main__.py +++ b/deepmd_gnn/__main__.py @@ -1,5 +1,99 @@ """Main entry point for the command line interface.""" +import argparse +import logging +import sys +from pathlib import Path + +from deepmd_gnn.mace_off import convert_mace_off_to_deepmd, download_mace_off_model + +# Setup logging +logging.basicConfig( + level=logging.INFO, + format="%(message)s", +) +logger = logging.getLogger(__name__) + + +def main() -> int: + """Run the main CLI.""" + parser = argparse.ArgumentParser( + description="DeePMD-GNN utilities", + prog="deepmd-gnn", + ) + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + # MACE-OFF download command + download_parser = subparsers.add_parser( + "download-mace-off", + help="Download MACE-OFF pretrained models", + ) + download_parser.add_argument( + "model", + choices=["small", "medium", "large"], + help="MACE-OFF model size to download", + ) + download_parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Directory to cache the downloaded model (default: ~/.cache/deepmd-gnn/mace-off/)", + ) + download_parser.add_argument( + "--force", + action="store_true", + help="Force re-download even if file exists", + ) + + # MACE-OFF convert command + convert_parser = subparsers.add_parser( + "convert-mace-off", + help="Convert MACE-OFF model to DeePMD format", + ) + convert_parser.add_argument( + "model", + choices=["small", "medium", "large"], + help="MACE-OFF model size to convert", + ) + convert_parser.add_argument( + "-o", + "--output", + type=str, + default="frozen_model.pth", + help="Output file name for the frozen model (default: frozen_model.pth)", + ) + convert_parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Directory where MACE-OFF models are cached", + ) + + args = parser.parse_args() + + if args.command == "download-mace-off": + cache_dir = Path(args.cache_dir) if args.cache_dir else None + model_path = download_mace_off_model( + model_name=args.model, + cache_dir=cache_dir, + force_download=args.force, + ) + logger.info("Model downloaded to: %s", model_path) + return 0 + + if args.command == "convert-mace-off": + cache_dir = Path(args.cache_dir) if args.cache_dir else None + output_path = convert_mace_off_to_deepmd( + model_name=args.model, + output_file=args.output, + cache_dir=cache_dir, + ) + logger.info("Converted model saved to: %s", output_path) + return 0 + + parser.print_help() + return 1 + + if __name__ == "__main__": - msg = "This module is not meant to be executed directly." - raise NotImplementedError(msg) + sys.exit(main()) diff --git a/deepmd_gnn/mace_off.py b/deepmd_gnn/mace_off.py new file mode 100644 index 0000000..a723b50 --- /dev/null +++ b/deepmd_gnn/mace_off.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Support for loading pretrained MACE models into DeePMD-GNN.""" + +import logging +import os +from pathlib import Path +from typing import Optional +from urllib.request import urlretrieve + +import torch + +from deepmd_gnn.mace import ELEMENTS, MaceModel + +# Setup logger for this module +logger = logging.getLogger(__name__) + +# URLs for MACE-OFF pretrained models +MACE_OFF_MODELS = { + "small": "https://github.com/ACEsuit/mace-off/releases/download/mace_off_small/mace_off_small.model", + "medium": "https://github.com/ACEsuit/mace-off/releases/download/mace_off_medium/mace_off_medium.model", + "large": "https://github.com/ACEsuit/mace-off/releases/download/mace_off_large/mace_off_large.model", +} + + +def get_mace_off_cache_dir() -> Path: + """Get the cache directory for MACE-OFF models. + + Uses the XDG_CACHE_HOME environment variable if set, + otherwise uses ~/.cache/deepmd-gnn/mace-off/ + + Returns + ------- + cache_dir : Path + Path to cache directory + """ + if "XDG_CACHE_HOME" in os.environ: + cache_dir = Path(os.environ["XDG_CACHE_HOME"]) / "deepmd-gnn" / "mace-off" + else: + cache_dir = Path.home() / ".cache" / "deepmd-gnn" / "mace-off" + + cache_dir.mkdir(parents=True, exist_ok=True) + return cache_dir + + +def download_mace_off_model( + model_name: str = "small", + cache_dir: Optional[Path] = None, + force_download: bool = False, +) -> Path: + """Download a MACE-OFF pretrained model. + + Parameters + ---------- + model_name : str, optional + Name of the model to download: "small", "medium", or "large" + Default is "small" + cache_dir : Path, optional + Directory to cache the downloaded model + If None, uses the default cache directory + force_download : bool, optional + If True, download even if file exists + Default is False + + Returns + ------- + model_path : Path + Path to the downloaded model file + + Raises + ------ + ValueError + If model_name is not recognized + + Examples + -------- + >>> model_path = download_mace_off_model("small") + >>> print(f"Model downloaded to: {model_path}") + """ + if model_name not in MACE_OFF_MODELS: + msg = ( + f"Unknown MACE-OFF model: {model_name}. " + f"Available models: {list(MACE_OFF_MODELS.keys())}" + ) + raise ValueError(msg) + + if cache_dir is None: + cache_dir = get_mace_off_cache_dir() + else: + cache_dir = Path(cache_dir) + cache_dir.mkdir(parents=True, exist_ok=True) + + # Determine local file path + url = MACE_OFF_MODELS[model_name] + filename = f"mace_off_{model_name}.model" + model_path = cache_dir / filename + + # Download if needed + if not model_path.exists() or force_download: + logger.info("Downloading MACE-OFF %s model...", model_name) + logger.info("URL: %s", url) + logger.info("Destination: %s", model_path) + urlretrieve(url, model_path) # noqa: S310 + logger.info("Download complete!") + else: + logger.info("Using cached model: %s", model_path) + + return model_path + + +def load_mace_off_model( + model_name: str = "small", + cache_dir: Optional[Path] = None, + device: str = "cpu", +) -> MaceModel: + """Load a MACE-OFF pretrained model as a DeePMD-GNN MaceModel. + + This function downloads a MACE-OFF pretrained model and wraps it + in the DeePMD-GNN MaceModel interface, making it compatible with + DeePMD-kit and AMBER/sander through the existing integration. + + Parameters + ---------- + model_name : str, optional + Name of the model to load: "small", "medium", or "large" + Default is "small" + cache_dir : Path, optional + Directory where models are cached + If None, uses the default cache directory + device : str, optional + Device to load the model on ("cpu" or "cuda") + Default is "cpu" + + Returns + ------- + model : MaceModel + Loaded MACE model wrapped in DeePMD-GNN's MaceModel interface, + ready for use with DeePMD-kit and MD packages + + Examples + -------- + >>> from deepmd_gnn.mace_off import load_mace_off_model + >>> model = load_mace_off_model("small") + >>> # Now use with DeePMD-kit: dp freeze, then use in LAMMPS/Amber + """ + # Download model if necessary + model_path = download_mace_off_model( + model_name=model_name, + cache_dir=cache_dir, + force_download=False, + ) + + # Load the pretrained MACE model + logger.info("Loading MACE-OFF %s model from %s...", model_name, model_path) + # Note: weights_only=False is required for MACE models as they contain + # custom objects. Only use with trusted MACE-OFF models from official sources. + mace_model = torch.load(str(model_path), map_location=device, weights_only=False) + + # Extract configuration from the pretrained model + # MACE models have attributes we need to extract + if not hasattr(mace_model, "atomic_numbers"): + msg = "Loaded model does not appear to be a valid MACE model (missing atomic_numbers)" + raise ValueError(msg) + + atomic_numbers = mace_model.atomic_numbers.tolist() + + # Validate atomic numbers are in valid range + if any(z < 1 or z > len(ELEMENTS) for z in atomic_numbers): + msg = f"Invalid atomic numbers found: {atomic_numbers}. Must be between 1 and {len(ELEMENTS)}" + raise ValueError(msg) + + # Convert atomic numbers to element symbols + type_map = [ELEMENTS[z - 1] for z in atomic_numbers] + + # Extract model hyperparameters + # These are stored in the MACE model's configuration + if not hasattr(mace_model, "r_max") or not hasattr(mace_model, "num_interactions"): + msg = "Loaded model missing required attributes (r_max, num_interactions)" + raise ValueError(msg) + + r_max = float(mace_model.r_max) + num_interactions = int(mace_model.num_interactions) + + # Helper function to get attribute with default and warning + def get_attr_with_default( + obj: object, + attr: str, + default: object, + warn: bool = True, + ) -> object: + """Get attribute from object with default value and optional warning. + + Parameters + ---------- + obj : object + Object to get attribute from + attr : str + Attribute name + default : object + Default value if attribute not found + warn : bool, optional + Whether to print warning when using default + + Returns + ------- + object + Attribute value or default + """ + value = getattr(obj, attr, None) + if value is None: + if warn: + logger.warning( + "Using default %s=%s (not found in model)", + attr, + default, + ) + return default + return value + + # Get other parameters with defaults if not available + num_radial_basis = get_attr_with_default(mace_model, "num_bessel", 8) + num_cutoff_basis = get_attr_with_default(mace_model, "num_polynomial_cutoff", 5) + max_ell = get_attr_with_default(mace_model, "max_ell", 3) + correlation = get_attr_with_default(mace_model, "correlation", 3) + radial_mlp = get_attr_with_default(mace_model, "radial_MLP", [64, 64, 64]) + + # Get hidden_irreps with validation + if hasattr(mace_model, "hidden_irreps") and mace_model.hidden_irreps is not None: + hidden_irreps = str(mace_model.hidden_irreps) + else: + hidden_irreps = "128x0e + 128x1o" + logger.warning("Using default hidden_irreps (not found in model)") + + # Determine interaction class name + interaction_cls_name = ( + mace_model.interactions[0].__class__.__name__ + if hasattr(mace_model, "interactions") and len(mace_model.interactions) > 0 + else "RealAgnosticResidualInteractionBlock" + ) + + # Create MaceModel with the extracted configuration + logger.info("Creating DeePMD-GNN MaceModel wrapper...") + logger.debug("Type map: %s", type_map) + logger.debug("r_max: %s", r_max) + logger.debug("num_interactions: %s", num_interactions) + + deepmd_model = MaceModel( + type_map=type_map, + sel=100, # This will be auto-determined during training/usage + r_max=r_max, + num_radial_basis=num_radial_basis, + num_cutoff_basis=num_cutoff_basis, + max_ell=max_ell, + interaction=interaction_cls_name, + num_interactions=num_interactions, + hidden_irreps=hidden_irreps, + correlation=correlation, + radial_MLP=radial_mlp, + ) + + # Load the pretrained weights into the DeePMD model + # The MaceModel.model is a ScaleShiftMACE instance, same as MACE-OFF + logger.info("Loading pretrained weights...") + try: + deepmd_model.model.load_state_dict(mace_model.state_dict(), strict=True) + except RuntimeError as e: + msg = f"Failed to load pretrained weights: {e}. Model architectures may not match." + raise RuntimeError(msg) from e + + deepmd_model.eval() + + logger.info("MACE-OFF model successfully loaded into DeePMD-GNN wrapper!") + logger.info("You can now use this with DeePMD-kit (dp freeze) and MD packages.") + + return deepmd_model + + +def convert_mace_off_to_deepmd( + model_name: str = "small", + output_file: str = "frozen_model.pth", + cache_dir: Optional[Path] = None, +) -> Path: + """Convert a MACE-OFF model to frozen DeePMD format for use with MD packages. + + This function loads a MACE-OFF pretrained model, wraps it in the + DeePMD-GNN MaceModel interface, and saves it as a frozen model + that can be used directly with LAMMPS, AMBER/sander, and other + MD packages through the DeePMD-kit interface. + + Parameters + ---------- + model_name : str, optional + Name of the MACE-OFF model: "small", "medium", or "large" + Default is "small" + output_file : str, optional + Output file name for the frozen model + Default is "frozen_model.pth" + cache_dir : Path, optional + Directory where MACE-OFF models are cached + + Returns + ------- + output_path : Path + Path to the frozen model file + + Notes + ----- + The frozen model can be used with: + - LAMMPS: Set DP_PLUGIN_PATH and use pair_style deepmd + - AMBER/sander: Use through DeePMD-kit's AMBER interface + - Other MD packages: Through DeePMD-kit's C++ interface + + For QM/MM simulations with sander, use MM type prefixes ('m', 'HW', 'OW') + in your type_map to designate MM atoms. + + Examples + -------- + >>> from deepmd_gnn.mace_off import convert_mace_off_to_deepmd + >>> model_path = convert_mace_off_to_deepmd("small", "mace_small.pth") + >>> # Now use in LAMMPS, AMBER, etc. through DeePMD-kit + """ + # Load the MACE-OFF model as a MaceModel + model = load_mace_off_model(model_name, cache_dir=cache_dir) + + # Create output path + output_path = Path(output_file) + + # Try to freeze the model using TorchScript + logger.info("Freezing model to %s...", output_path) + try: + scripted_model = torch.jit.script(model) + torch.jit.save(scripted_model, str(output_path)) + logger.info("Model successfully frozen and saved to: %s", output_path) + logger.info("\nYou can now use this model with:") + logger.info(" - LAMMPS: Set DP_PLUGIN_PATH and use pair_style deepmd") + logger.info(" - AMBER/sander: Use DeePMD-kit's AMBER interface") + logger.info(" - For QM/MM: Use 'm' prefix or HW/OW for MM atom types") + except (RuntimeError, torch.jit.Error) as e: + logger.warning("TorchScript compilation failed (%s)", e) + logger.warning("Saving model in PyTorch format instead...") + torch.save(model, str(output_path)) + logger.info("Model saved to: %s", output_path) + + return output_path + + +__all__ = [ + "MACE_OFF_MODELS", + "convert_mace_off_to_deepmd", + "download_mace_off_model", + "get_mace_off_cache_dir", + "load_mace_off_model", +] diff --git a/examples/mace_off/README.md b/examples/mace_off/README.md new file mode 100644 index 0000000..bc06cd2 --- /dev/null +++ b/examples/mace_off/README.md @@ -0,0 +1,106 @@ +# Using MACE-OFF Pretrained Models with DeePMD-GNN + +This example demonstrates how to load MACE-OFF foundation models and use them with DeePMD-kit for MD simulations, including QM/MM simulations with AMBER/sander. + +## Overview + +MACE-OFF models are pretrained foundation models for molecular systems. This integration allows you to: + +1. Download MACE-OFF models (small, medium, large) +2. Load them into DeePMD-GNN's MaceModel wrapper +3. Use them with MD packages (LAMMPS, AMBER/sander) through DeePMD-kit + +## Architecture + +``` +MACE-OFF pretrained model + ↓ +DeePMD-GNN MaceModel wrapper + ↓ +DeePMD-kit interface + ↓ +MD packages (LAMMPS, AMBER/sander, etc.) +``` + +## Quick Start + +### 1. Download and Convert MACE-OFF Model + +```python +from deepmd_gnn import convert_mace_off_to_deepmd + +# Download and convert MACE-OFF small model to frozen DeePMD format +model_path = convert_mace_off_to_deepmd("small", "frozen_model.pth") +``` + +Available models: + +- `"small"`: Fast, good for screening and QM/MM +- `"medium"`: Balanced speed and accuracy +- `"large"`: Best accuracy, slower + +### 2. Use with LAMMPS + +```bash +# Set plugin path +export DP_PLUGIN_PATH=/path/to/libdeepmd_gnn.so + +# In LAMMPS input: +pair_style deepmd frozen_model.pth +pair_coeff * * +``` + +### 3. Use with AMBER/sander for QM/MM + +The MACE-OFF model integrates with AMBER through DeePMD-kit's existing interface. + +#### QM/MM Type Map Convention + +For QM/MM simulations, use the DPRc mechanism: + +- **QM atoms**: Standard element symbols (H, C, N, O, etc.) +- **MM atoms**: Prefix with 'm' (mH, mC, etc.) or use HW/OW for water + +#### Integration with sander + +The frozen_model.pth works with sander through DeePMD-kit's AMBER interface. See DeePMD-kit and AMBER documentation for setup details. + +## Programmatic Usage + +### Load Model Directly + +```python +from deepmd_gnn import load_mace_off_model + +# Load MACE-OFF model as DeePMD-GNN MaceModel +model = load_mace_off_model("small") + +# This is now a MaceModel instance that can be used with DeePMD-kit +print(f"Type map: {model.get_type_map()}") +print(f"Cutoff: {model.get_rcut()}") +``` + +## Model Details + +### MACE-OFF Models + +| Model | Parameters | Speed | Best For | +| ------ | ---------- | ------ | ----------------------------------- | +| small | ~1M | Fast | QM/MM, screening, quick simulations | +| medium | ~5M | Medium | Production MD runs | +| large | ~20M | Slow | High-accuracy calculations | + +### Energy and Force Units + +MACE models use: + +- **Energy**: eV +- **Forces**: eV/Angstrom +- **Coordinates**: Angstrom + +## References + +- MACE: https://github.com/ACEsuit/mace +- MACE-OFF: Foundation models for molecular simulation +- DeePMD-kit: https://github.com/deepmodeling/deepmd-kit +- DeePMD-GNN: https://gitlab.com/RutgersLBSR/deepmd-gnn diff --git a/examples/mace_off/load_mace_off.py b/examples/mace_off/load_mace_off.py new file mode 100644 index 0000000..be16b9b --- /dev/null +++ b/examples/mace_off/load_mace_off.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Example script to load and convert MACE-OFF models.""" + +import sys + +from deepmd_gnn import convert_mace_off_to_deepmd, load_mace_off_model + + +def main() -> int: + """Load and convert MACE-OFF model.""" + print("=" * 70) + print("MACE-OFF Model Loading Example") + print("=" * 70) + print() + + # Example 1: Load model programmatically + print("Example 1: Loading MACE-OFF model into DeePMD-GNN wrapper") + print("-" * 70) + try: + model = load_mace_off_model("small") + print("✓ Model loaded successfully!") + print(f" Type map: {model.get_type_map()}") + print(f" Cutoff radius: {model.get_rcut()} Å") + print() + except Exception as e: + print(f"✗ Failed to load model: {e}") + print() + + # Example 2: Convert to frozen format + print("Example 2: Converting to frozen DeePMD format") + print("-" * 70) + try: + frozen_path = convert_mace_off_to_deepmd( + "small", + "mace_off_frozen.pth", + ) + print(f"✓ Model converted and saved to: {frozen_path}") + print() + print("You can now use this model with:") + print(" - LAMMPS (with DP_PLUGIN_PATH set)") + print(" - AMBER/sander (through DeePMD-kit interface)") + print(" - Other MD packages supported by DeePMD-kit") + print() + except Exception as e: + print(f"✗ Failed to convert model: {e}") + print() + + print("=" * 70) + print("For QM/MM simulations:") + print(" - Use standard elements (H, C, N, O) for QM atoms") + print(" - Use 'm' prefix (mH, mC) or HW/OW for MM atoms") + print(" - The DPRc mechanism handles QM/MM separation automatically") + print("=" * 70) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/pyproject.toml b/pyproject.toml index 151a663..7e04897 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ keywords = [ ] [project.scripts] +deepmd-gnn = "deepmd_gnn.__main__:main" [project.entry-points."deepmd.pt"] mace = "deepmd_gnn.mace:MaceModel" @@ -103,6 +104,13 @@ convention = "numpy" "ANN", "D101", "D102", + "PT027", # unittest style assertRaises is fine +] +"examples/**/*.py" = [ + "T201", # print allowed in examples + "S603", # subprocess without shell allowed in examples + "EXE001", # shebang allowed in examples + "BLE001", # blind exception catching allowed in examples ] "docs/conf.py" = [ "ERA001", diff --git a/tests/__init__.py b/tests/__init__.py index d420712..44a4a59 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,3 @@ """Tests.""" + +import deepmd.pt.model # noqa: F401 diff --git a/tests/test_mace_off.py b/tests/test_mace_off.py new file mode 100644 index 0000000..40133aa --- /dev/null +++ b/tests/test_mace_off.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for MACE-OFF model loading.""" + +import tempfile +import unittest +from pathlib import Path + +import pytest +import torch +from torch import nn + +from deepmd_gnn.mace import ELEMENTS, MaceModel +from deepmd_gnn.mace_off import ( + MACE_OFF_MODELS, + convert_mace_off_to_deepmd, + download_mace_off_model, + get_mace_off_cache_dir, + load_mace_off_model, +) + + +class TestMaceOffDownload(unittest.TestCase): + """Test MACE-OFF download functionality.""" + + def test_get_cache_dir(self): + """Test cache directory creation.""" + cache_dir = get_mace_off_cache_dir() + assert isinstance(cache_dir, Path) + assert cache_dir.exists() + + def test_mace_off_models_defined(self): + """Test that MACE-OFF models are defined.""" + assert "small" in MACE_OFF_MODELS + assert "medium" in MACE_OFF_MODELS + assert "large" in MACE_OFF_MODELS + for url in MACE_OFF_MODELS.values(): + assert isinstance(url, str) + assert url.startswith("http") + + def test_invalid_model_name(self): + """Test that invalid model names raise errors.""" + with self.assertRaises(ValueError): + download_mace_off_model("invalid_model") + + @pytest.mark.slow + def test_download_real_model(self): + """Test downloading a real MACE-OFF model.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Download the small model (fastest) + model_path = download_mace_off_model("small", cache_dir=tmpdir) + + # Verify the file was downloaded + assert model_path.exists() + assert model_path.stat().st_size > 0 + + # Test that downloading again uses cache + model_path_2 = download_mace_off_model("small", cache_dir=tmpdir) + assert model_path == model_path_2 + + @pytest.mark.slow + def test_download_force(self): + """Test force download.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Download the model first + model_path = download_mace_off_model("small", cache_dir=tmpdir) + original_size = model_path.stat().st_size + + # Modify the file to simulate corruption + model_path.write_text("corrupted") + + # Force re-download + model_path_2 = download_mace_off_model( + "small", + cache_dir=tmpdir, + force_download=True, + ) + + # Verify it was re-downloaded (size should be restored) + assert model_path == model_path_2 + assert model_path_2.stat().st_size == original_size + + +class TestMaceOffLoading(unittest.TestCase): + """Test MACE-OFF model loading functionality.""" + + @pytest.mark.slow + def test_load_real_mace_off_model(self): + """Test loading a real MACE-OFF model.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Load the small model + model = load_mace_off_model("small", cache_dir=tmpdir) + + # Verify the model is a MaceModel instance + assert isinstance(model, MaceModel) + + # Verify it has the expected attributes + assert hasattr(model, "get_type_map") + assert hasattr(model, "get_rcut") + + # Verify type_map is properly set + type_map = model.get_type_map() + assert isinstance(type_map, list) + assert len(type_map) > 0 + + # Verify rcut is set + rcut = model.get_rcut() + assert isinstance(rcut, float) + assert rcut > 0 + + def test_load_model_invalid_atomic_numbers(self): + """Test that invalid atomic numbers are caught during loading.""" + # Create a minimal mock model file to test validation + with tempfile.TemporaryDirectory() as tmpdir: + mock_model_path = Path(tmpdir) / "invalid_model.pth" + + # Create a mock model with invalid atomic numbers + class MockModel(nn.Module): + def __init__(self): + super().__init__() + self.atomic_numbers = torch.tensor([0, 150]) # Invalid + self.r_max = 5.0 + self.num_interactions = 2 + + mock_model = MockModel() + torch.save(mock_model, mock_model_path) + + # Attempt to load and validate + # Note: This test validates the atomic number checking logic + loaded = torch.load(mock_model_path, weights_only=False) + # Check if atomic numbers would be validated + atomic_numbers = loaded.atomic_numbers.tolist() + # Should fail validation + invalid = any(z < 1 or z > len(ELEMENTS) for z in atomic_numbers) + assert invalid, "Invalid atomic numbers should be detected" + + +class TestMaceOffConversion(unittest.TestCase): + """Test MACE-OFF model conversion functionality.""" + + @pytest.mark.slow + def test_convert_real_model_to_deepmd(self): + """Test converting a real MACE-OFF model to DeePMD format.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_file = Path(tmpdir) / "frozen_model.pth" + + # Convert the small model + result = convert_mace_off_to_deepmd( + "small", + str(output_file), + cache_dir=tmpdir, + ) + + # Verify the output file was created + assert result == output_file + assert output_file.exists() + assert output_file.stat().st_size > 0 + + # Verify the frozen model can be loaded + frozen_model = torch.jit.load(str(output_file)) + assert frozen_model is not None + + +if __name__ == "__main__": + unittest.main()