diff --git a/interatomic_potentials/configs/schnet/schnet_md17_ethanol.yaml b/interatomic_potentials/configs/schnet/schnet_md17_ethanol.yaml new file mode 100644 index 00000000..3fa0d418 --- /dev/null +++ b/interatomic_potentials/configs/schnet/schnet_md17_ethanol.yaml @@ -0,0 +1,107 @@ +Global: + do_train: True + do_eval: True + do_test: False + + label_names: ['energy'] + + graph_converter: + __class_name__: FindPointsInSpheres + __init_params__: + cutoff: 5.0 + + prim_eager_enabled: True + + +Trainer: + max_epochs: 500 + seed: 42 + output_dir: ./output/schnet_md17_ethanol + save_freq: 50 + log_freq: 50 + + start_eval_epoch: 1 + eval_freq: 5 + pretrained_model_path: null + pretrained_weight_name: null + resume_from_checkpoint: null + use_amp: False + eval_with_no_grad: True + gradient_accumulation_steps: 1 + + best_metric_indicator: 'eval_metric' + name_for_best_metric: "energy" + greater_is_better: False + + +Model: + __class_name__: SchNet + __init_params__: + n_atom_basis: 64 + n_interactions: 6 + n_filters: 64 + cutoff: 5.0 + n_gaussians: 25 + max_z: 100 + readout: "sum" + property_names: ${Global.label_names} + data_mean: 0.0 + data_std: 1.0 + loss_type: "l1_loss" + compute_forces: False + + +Optimizer: + __class_name__: Adam + __init_params__: + lr: + __class_name__: Cosine + __init_params__: + learning_rate: 1e-4 + eta_min: 1e-7 + by_epoch: False + + +Metric: + energy: + __class_name__: IgnoreNanMetricWrapper + __init_params__: + __class_name__: paddle.nn.L1Loss + __init_params__: {} + + +Dataset: + train: + dataset: + __class_name__: MD17Dataset + __init_params__: + path: "./data/md17" + molecule: "ethanol" + property_names: ${Global.label_names} + build_graph_cfg: ${Global.graph_converter} + max_samples: 50000 + num_workers: 4 + use_shared_memory: False + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: True + drop_last: False + batch_size: 64 + val: + dataset: + __class_name__: MD17Dataset + __init_params__: + path: "./data/md17" + molecule: "ethanol" + property_names: ${Global.label_names} + build_graph_cfg: ${Global.graph_converter} + max_samples: 10000 + num_workers: 4 + use_shared_memory: False + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 64 diff --git a/interatomic_potentials/configs/schnet/schnet_qm9_U0.yaml b/interatomic_potentials/configs/schnet/schnet_qm9_U0.yaml new file mode 100644 index 00000000..88b50ad3 --- /dev/null +++ b/interatomic_potentials/configs/schnet/schnet_qm9_U0.yaml @@ -0,0 +1,109 @@ +Global: + do_train: True + do_eval: True + do_test: False + + label_names: ['energy_U0'] + + graph_converter: + __class_name__: FindPointsInSpheres + __init_params__: + cutoff: 10.0 + + prim_eager_enabled: True + + +Trainer: + max_epochs: 200 + seed: 42 + output_dir: ./output/schnet_qm9_U0 + save_freq: 20 + log_freq: 50 + + start_eval_epoch: 1 + eval_freq: 5 + pretrained_model_path: null + pretrained_weight_name: null + resume_from_checkpoint: null + use_amp: False + eval_with_no_grad: True + gradient_accumulation_steps: 1 + + best_metric_indicator: 'eval_metric' + name_for_best_metric: "energy_U0" + greater_is_better: False + + +Model: + __class_name__: SchNet + __init_params__: + n_atom_basis: 128 + n_interactions: 6 + n_filters: 128 + cutoff: 10.0 + n_gaussians: 50 + max_z: 100 + readout: "sum" + property_names: ${Global.label_names} + data_mean: -76.1160 + data_std: 10.3238 + loss_type: "l1_loss" + compute_forces: False + + +Optimizer: + __class_name__: Adam + __init_params__: + lr: + __class_name__: Cosine + __init_params__: + learning_rate: 1e-4 + eta_min: 1e-7 + by_epoch: False + + +Metric: + energy_U0: + __class_name__: IgnoreNanMetricWrapper + __init_params__: + __class_name__: paddle.nn.L1Loss + __init_params__: {} + + +Dataset: + train: + dataset: + __class_name__: QM9Dataset + __init_params__: + path: "./data/qm9" + property_names: ${Global.label_names} + build_graph_cfg: ${Global.graph_converter} + cache_path: "./data/qm9" + overwrite: False + filter_unvalid: True + num_workers: 4 + use_shared_memory: False + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: True + drop_last: False + batch_size: 64 + val: + dataset: + __class_name__: QM9Dataset + __init_params__: + path: "./data/qm9" + property_names: ${Global.label_names} + build_graph_cfg: ${Global.graph_converter} + cache_path: "./data/qm9" + overwrite: False + filter_unvalid: True + num_workers: 4 + use_shared_memory: False + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 64 diff --git a/ppmat/datasets/alloy_dataset.py b/ppmat/datasets/alloy_dataset.py new file mode 100644 index 00000000..545e609b --- /dev/null +++ b/ppmat/datasets/alloy_dataset.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +AlloyDataset — tabular dataset for metallic glass alloy compositions. + +Loads Alloy_train.csv produced by tools/prepare_alloy_data.py. +Each sample is a 66-dimensional float vector: + columns 0-39: element composition fractions (40 elements) + columns 40-42: Tg, Tx, Tl (thermal transition temperatures in K) + columns 43-65: 23 GFA criteria (derived from Tg/Tx/Tl) + +The "source" column is dropped on load (same as original AlloyGAN). +""" + +import numpy as np +import paddle +from paddle.io import Dataset + +from ppmat.utils import logger + + +class AlloyDataset(Dataset): + """Tabular dataset for AlloyGAN training. + + Args: + path: Path to Alloy_train.csv. + categories: Optional list of dominant-element categories to filter + (e.g., ["Cu", "Fe", "Ti", "Zr"]). Default uses all entries. + normalize: Whether to normalize composition fractions to [0, 1]. + Default True (divides compositions by 100). + """ + + # Top 40 elements in order (matches CSV columns 0-39) + ELEMENTS = [ + "Cu", "Zr", "Al", "Ni", "Ti", "Ag", "Fe", "Mg", "B", "Si", + "Nb", "Y", "Ca", "La", "Co", "Be", "C", "Mo", "Pd", "P", + "Sn", "Cr", "Hf", "Zn", "Gd", "Ce", "Er", "Ga", "Au", "Nd", + "Dy", "W", "Pr", "Ta", "Sc", "Li", "Sm", "S", "Pt", "Mn", + ] + + def __init__(self, path, categories=None, normalize=True): + super().__init__() + import pandas as pd + + df = pd.read_csv(path) + + # Drop the "source" column if present (same as original code) + if "source" in df.columns: + df = df.drop(columns=["source"]) + + # Optional category filtering by dominant element + if categories is not None: + elem_cols = df.columns[:40] + dominant = df[elem_cols].idxmax(axis=1) + mask = dominant.isin(categories) + df = df[mask].reset_index(drop=True) + logger.info( + f"Filtered to categories {categories}: " + f"{len(df)} entries" + ) + + self.data = df.values.astype(np.float32) + + if normalize: + # Normalize composition fractions (0-100) to (0-1) + self.data[:, :40] = self.data[:, :40] / 100.0 + + logger.info( + f"Loaded AlloyDataset: {len(self.data)} samples, " + f"{self.data.shape[1]} features from {path}" + ) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return {"data": self.data[idx]} diff --git a/ppmat/datasets/md17_dataset.py b/ppmat/datasets/md17_dataset.py new file mode 100644 index 00000000..d75a632a --- /dev/null +++ b/ppmat/datasets/md17_dataset.py @@ -0,0 +1,250 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MD17 dataset for molecular dynamics trajectories. + +The MD17 dataset (Chmiela et al., 2017) contains ab-initio molecular dynamics +trajectories for small organic molecules. Each snapshot includes atomic +positions, total energy, and per-atom forces. + +Reference: + S. Chmiela, A. Tkatchenko, H. E. Sauceda, I. Poltavsky, K. T. Schütt, + K.-R. Müller. Machine Learning of Accurate Energy-Conserving Molecular + Force Fields. Science Advances, 2017. +""" + +from __future__ import annotations + +import os +import os.path as osp +import pickle +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import paddle.distributed as dist +from paddle.io import Dataset + +from ppmat.models import build_graph_converter +from ppmat.utils import download, logger + +try: + from pymatgen.core import Lattice, Structure +except ImportError: + Structure = None + Lattice = None + + +# Mapping from molecule name to original MD17 NPZ filename. +MD17_FILES = { + "benzene": "md17_benzene2017.npz", + "uracil": "md17_uracil.npz", + "naphthalene": "md17_naphthalene.npz", + "aspirin": "md17_aspirin.npz", + "salicylic_acid": "md17_salicylic.npz", + "malonaldehyde": "md17_malonaldehyde.npz", + "ethanol": "md17_ethanol.npz", + "toluene": "md17_toluene.npz", +} + +# Atomic number → element symbol (for pymatgen Structure creation). +_Z_TO_SYMBOL = { + 1: "H", 6: "C", 7: "N", 8: "O", 9: "F", 16: "S", +} + + +class MD17Dataset(Dataset): + """MD17 molecular dynamics trajectory dataset. + + Loads an MD17 NPZ file and converts each snapshot into a dict compatible + with PaddleMaterials models (optionally building a PGL graph via the + graph converter). + + The NPZ files contain: + - ``z``: atomic numbers, shape [num_atoms] + - ``R``: positions, shape [num_snapshots, num_atoms, 3] (Angstrom) + - ``E``: energies, shape [num_snapshots, 1] (kcal/mol) + - ``F``: forces, shape [num_snapshots, num_atoms, 3] (kcal/mol/A) + + Args: + path (str): Root directory for dataset storage. + molecule (str): Molecule name (e.g. "ethanol"). + property_names (str or list): Target property name(s). Default: "energy". + build_graph_cfg (dict, optional): Config for graph converter. + max_samples (int, optional): Limit dataset size (for faster debugging). + url (str, optional): Custom download URL. Default: BCS mirror. + box_size (float): Side length (A) of the cubic cell used for + non-periodic molecules. Default: 100.0. + cache_graphs (bool): Whether to cache built graphs to disk. Default: True. + """ + + # BCS mirror for MD17 NPZ files + default_url = "https://paddle-org.bj.bcebos.com/paddlematerials/datasets/MD17" + + def __init__( + self, + path: str, + molecule: str = "ethanol", + property_names: Union[str, List[str]] = "energy", + *, + build_graph_cfg: Optional[Dict] = None, + max_samples: Optional[int] = None, + url: Optional[str] = None, + box_size: float = 100.0, + cache_graphs: bool = True, + **kwargs, + ) -> None: + super().__init__() + + if molecule not in MD17_FILES: + raise ValueError( + f"Unknown molecule '{molecule}'. Choose from: {list(MD17_FILES)}" + ) + + if isinstance(property_names, str): + property_names = [property_names] + self.property_names = property_names + self.molecule = molecule + self.box_size = box_size + self.build_graph_cfg = build_graph_cfg + + # Paths + os.makedirs(path, exist_ok=True) + self.raw_dir = osp.join(path, "raw_md17") + os.makedirs(self.raw_dir, exist_ok=True) + + npz_name = MD17_FILES[molecule] + npz_path = osp.join(self.raw_dir, npz_name) + + # Download if needed + if not osp.exists(npz_path): + base_url = url or self.default_url + full_url = f"{base_url}/{npz_name}" + logger.info(f"Downloading MD17 {molecule} from {full_url}") + download.download_file(full_url, npz_path) + + # Load raw data + raw = np.load(npz_path) + self.atomic_numbers = raw["z"].astype(np.int64) # [num_atoms] + self.positions = raw["R"].astype(np.float32) # [N, num_atoms, 3] + self.energies = raw["E"].astype(np.float32) # [N, 1] or [N] + self.forces = raw["F"].astype(np.float32) # [N, num_atoms, 3] + + if self.energies.ndim == 1: + self.energies = self.energies[:, None] + + if max_samples is not None: + self.positions = self.positions[:max_samples] + self.energies = self.energies[:max_samples] + self.forces = self.forces[:max_samples] + + self.num_samples = len(self.positions) + logger.info( + f"MD17 {molecule}: {self.num_samples} snapshots, " + f"{len(self.atomic_numbers)} atoms/snapshot" + ) + + # Build and cache graphs if configured + self.graphs = None + if build_graph_cfg is not None: + graph_converter_name = build_graph_cfg.get("__class_name__", "custom") + cutoff = build_graph_cfg.get("__init_params__", {}).get("cutoff", 5) + cache_dir = osp.join( + path, + f"md17_{molecule}_cache_{graph_converter_name}_cutoff_{int(cutoff)}", + "graphs", + ) + + done_flag = osp.join(cache_dir, "completed.flag") + if cache_graphs and osp.exists(done_flag): + logger.info(f"Loading cached graphs from {cache_dir}") + self.graphs = self._load_cached_graphs(cache_dir) + else: + if dist.get_rank() == 0: + logger.info("Building graphs for MD17 dataset...") + os.makedirs(cache_dir, exist_ok=True) + converter = build_graph_converter(build_graph_cfg) + structures = self._build_structures() + self.graphs = converter(structures) + if cache_graphs: + self._save_cached_graphs(cache_dir) + with open(done_flag, "w") as f: + f.write("done") + if dist.is_initialized(): + dist.barrier() + if self.graphs is None: + self.graphs = self._load_cached_graphs(cache_dir) + + def _build_structures(self) -> list: + """Convert all snapshots to pymatgen Structures.""" + if Structure is None: + raise RuntimeError("pymatgen is required: pip install pymatgen") + + lattice = Lattice.from_parameters( + self.box_size, self.box_size, self.box_size, 90, 90, 90 + ) + species = [_Z_TO_SYMBOL.get(z, str(z)) for z in self.atomic_numbers] + + structures = [] + for i in range(self.num_samples): + coords = self.positions[i] # [num_atoms, 3] + struct = Structure( + lattice, + species, + coords, + coords_are_cartesian=True, + ) + structures.append(struct) + return structures + + def _save_cached_graphs(self, cache_dir: str) -> None: + for i, g in enumerate(self.graphs): + with open(osp.join(cache_dir, f"{i:08d}.pkl"), "wb") as f: + pickle.dump(g, f) + + def _load_cached_graphs(self, cache_dir: str) -> list: + files = sorted( + f for f in os.listdir(cache_dir) if f.endswith(".pkl") + ) + graphs = [] + for fn in files: + with open(osp.join(cache_dir, fn), "rb") as f: + graphs.append(pickle.load(f)) + return graphs + + def __len__(self) -> int: + return self.num_samples + + def __getitem__(self, idx: int) -> Dict[str, Any]: + data = {} + + if self.graphs is not None: + data["graph"] = self.graphs[idx] + else: + # Return raw data without graph (fallback) + data["pos"] = self.positions[idx] + data["atomic_numbers"] = self.atomic_numbers.copy() + data["cell"] = np.eye(3, dtype="float32") * self.box_size + data["natoms"] = len(self.atomic_numbers) + data["pbc"] = np.array([False, False, False], dtype=bool) + + # Properties + for pname in self.property_names: + if pname == "energy": + data["energy"] = self.energies[idx] + elif pname == "forces": + data["forces"] = self.forces[idx] + else: + data[pname] = self.energies[idx] + + return data diff --git a/ppmat/models/__init__.py b/ppmat/models/__init__.py index 95d73232..48758a9a 100644 --- a/ppmat/models/__init__.py +++ b/ppmat/models/__init__.py @@ -42,6 +42,7 @@ from ppmat.models.megnet.megnet import MEGNetPlus from ppmat.models.infgcn.infgcn import InfGCN from ppmat.models.mateno.mateno import MatENO +from ppmat.models.mofdiff.mofdiff import MOFDiff from ppmat.utils import download from ppmat.utils import logger from ppmat.utils import save_load @@ -67,6 +68,7 @@ "DiffNMR", "InfGCN", "MatENO", + "MOFDiff", ] # Warning: The key of the dictionary must be consistent with the file name of the value diff --git a/ppmat/models/mofdiff/__init__.py b/ppmat/models/mofdiff/__init__.py new file mode 100644 index 00000000..5ce21adb --- /dev/null +++ b/ppmat/models/mofdiff/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .mofdiff import MOFDiff diff --git a/ppmat/models/mofdiff/mofdiff.py b/ppmat/models/mofdiff/mofdiff.py new file mode 100644 index 00000000..bd749c2a --- /dev/null +++ b/ppmat/models/mofdiff/mofdiff.py @@ -0,0 +1,489 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MOFDiff: Coarse-Grained Diffusion for Metal-Organic Framework Generation. + +This module implements a simplified PaddlePaddle port of MOFDiff +(https://github.com/microsoft/MOFDiff). The original relies on +torch_geometric, torch_scatter, GemNetOC, and hydra; this version +substitutes those with simple MLPs and the ppmat scatter utilities so +that it is fully self-contained and CPU-testable. + +The three-stage pipeline: + 1. **Encoder** — encodes per-node features into a graph-level latent + via a VAE bottleneck (fc_mu / fc_var → reparameterize). + 2. **CG Diffusion** — VP (Variance Preserving) noise for both + building-block type embeddings and fractional coordinates, with a + denoiser that predicts the noise components. + 3. **Lattice predictor** — an MLP that maps the latent to 6 lattice + parameters (3 lengths + 3 angles). + +Reference: Yao *et al.*, "Coarse-Grained Diffusion for Metal-Organic +Framework Generation", *ICLR 2024*. +""" + +import math + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ppmat.utils.scatter import scatter_mean + +EPSILON = 1e-8 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def build_mlp(in_dim, hidden_dim, num_layers, out_dim): + """Build a simple feed-forward MLP with ReLU activations.""" + layers = [nn.Linear(in_dim, hidden_dim), nn.ReLU()] + for _ in range(num_layers - 1): + layers += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()] + layers.append(nn.Linear(hidden_dim, out_dim)) + return nn.Sequential(*layers) + + +# --------------------------------------------------------------------------- +# Timestep embedding +# --------------------------------------------------------------------------- + +class GaussianFourierProjection(nn.Layer): + """Gaussian Fourier embeddings for noise levels. + + Maps a scalar timestep *t* to a ``2 * embedding_size``-dimensional + vector via random Fourier features: + + out = [sin(2π · t · W), cos(2π · t · W)] + + where *W* is drawn once from N(0, scale²) and kept frozen. + """ + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + W = paddle.randn([embedding_size]) * scale + self.register_buffer("W", W) + + def forward(self, x): + """ + Args: + x: Tensor of shape ``[B]`` or ``[B, 1]``. + + Returns: + Tensor of shape ``[B, 2 * embedding_size]``. + """ + if x.ndim == 1: + x = x.unsqueeze(-1) + x_proj = x * self.W.unsqueeze(0) * 2 * math.pi + return paddle.concat([paddle.sin(x_proj), paddle.cos(x_proj)], axis=-1) + + +# --------------------------------------------------------------------------- +# Diffusion schedules +# --------------------------------------------------------------------------- + +class VP(nn.Layer): + """Variance Preserving diffusion with a cosine schedule. + + Forward process:: + + h_t = sqrt(ᾱ_t) · h_0 + sqrt(1 − ᾱ_t) · ε + + Reverse (DDPM) step:: + + h_{t-1} = (1/√α_t)(h_t − (β_t / √(1 − ᾱ_t)) · ε̂) + σ_t · z + """ + + def __init__(self, num_steps=1000, s=0.0001, power=2, clipmax=0.999): + super().__init__() + self.num_steps = num_steps + + t = np.arange(0, num_steps + 1, dtype=np.float64) + f_t = np.cos((np.pi / 2) * ((t / num_steps) + s) / (1 + s)) ** power + alpha_bars = f_t / f_t[0] + + betas = np.concatenate([[0.0], 1 - (alpha_bars[1:] / alpha_bars[:-1])]) + betas = np.clip(betas, 0, clipmax) + + # Posterior variance σ²_t = β_t · (1 − ᾱ_{t-1}) / (1 − ᾱ_t) + sigmas_sq = betas[1:] * ((1 - alpha_bars[:-1]) / (1 - alpha_bars[1:] + EPSILON)) + sigmas = np.sqrt(np.concatenate([[0.0], sigmas_sq])) + + self.register_buffer( + "alpha_bars", paddle.to_tensor(alpha_bars, dtype="float32") + ) + self.register_buffer( + "betas", paddle.to_tensor(betas, dtype="float32") + ) + self.register_buffer( + "sigmas", paddle.to_tensor(sigmas, dtype="float32") + ) + + def forward(self, h0, t): + """Forward diffusion: add noise at timestep *t*. + + Args: + h0: Clean signal ``[N, D]``. + t: Integer timesteps ``[N]``. + + Returns: + (h_t, eps): noised signal and the noise that was added. + """ + alpha_bar = paddle.gather(self.alpha_bars, t) # [N] + eps = paddle.randn(h0.shape) + sqrt_ab = paddle.sqrt(alpha_bar).unsqueeze(-1) + sqrt_1_ab = paddle.sqrt(1.0 - alpha_bar).unsqueeze(-1) + ht = sqrt_ab * h0 + sqrt_1_ab * eps + return ht, eps + + def reverse(self, ht, eps_h, t): + """Single DDPM reverse step. + + Args: + ht: Noised signal ``[N, D]``. + eps_h: Predicted noise ``[N, D]``. + t: Integer timesteps ``[N]``. + + Returns: + h_{t-1}: denoised one step. + """ + alpha = (1 - paddle.gather(self.betas, t)).unsqueeze(-1) + alpha_bar = paddle.gather(self.alpha_bars, t).unsqueeze(-1) + sigma = paddle.gather(self.sigmas, t).unsqueeze(-1) + + z = paddle.where( + (t > 1).unsqueeze(-1).expand_as(ht), + paddle.randn(ht.shape), + paddle.zeros(ht.shape), + ) + coef = (1.0 - alpha) / paddle.sqrt(1.0 - alpha_bar + EPSILON) + return (1.0 / paddle.sqrt(alpha + EPSILON)) * (ht - coef * eps_h) + sigma * z + + +# --------------------------------------------------------------------------- +# Encoder / Decoder (MLP stand-ins for GemNetOC) +# --------------------------------------------------------------------------- + +class SimpleGNNEncoder(nn.Layer): + """MLP encoder that replaces GemNetOC for CPU-testable builds. + + Per-node features are projected through an MLP, then mean-pooled + per graph to produce a graph-level representation. + """ + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3): + super().__init__() + self.mlp = build_mlp(input_dim, hidden_dim, num_layers, output_dim) + + def forward(self, node_features, batch_indices, num_graphs): + """ + Args: + node_features: ``[total_nodes, input_dim]`` + batch_indices: ``[total_nodes]`` — graph id for each node. + num_graphs: int — number of graphs in the batch. + + Returns: + ``[num_graphs, output_dim]`` + """ + h = self.mlp(node_features) + out = scatter_mean(h, batch_indices, dim=0, dim_size=num_graphs) + return out + + +class SimpleGNNDecoder(nn.Layer): + """MLP decoder predicting coordinate noise and type noise.""" + + def __init__( + self, + input_dim, + hidden_dim, + output_coord_dim=3, + output_type_dim=100, + num_layers=3, + ): + super().__init__() + self.coord_mlp = build_mlp(input_dim, hidden_dim, num_layers, output_coord_dim) + self.type_mlp = build_mlp(input_dim, hidden_dim, num_layers, output_type_dim) + + def forward(self, node_features): + """ + Args: + node_features: ``[total_nodes, input_dim]`` + + Returns: + (eps_x, eps_h): predicted noise for coords ``[N,3]`` + and types ``[N, num_bb_types]``. + """ + eps_x = self.coord_mlp(node_features) + eps_h = self.type_mlp(node_features) + return eps_x, eps_h + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + +class MOFDiff(nn.Layer): + """MOFDiff: Coarse-Grained Diffusion for MOF generation. + + A simplified PaddleMaterials implementation that keeps the same + three-stage architecture as the reference while replacing the heavy + GNN backbone (GemNetOC) with lightweight MLPs. + + **Training forward** returns ``{"loss_dict": {...}}`` following the + PaddleMaterials convention (see ``DiffCSP``). + + Args: + node_feat_dim: Dimension of input per-node features. + hidden_dim: Width of all hidden MLPs. + latent_dim: Dimension of the VAE latent space. + num_bb_types: Number of building-block type classes. + max_num_bbs: Maximum number of building blocks in a MOF. + num_diffusion_steps: Number of VP diffusion steps. + fc_num_layers: Depth of each MLP sub-network. + kl_weight: Weight for the KL divergence loss term. + """ + + def __init__( + self, + node_feat_dim=64, + hidden_dim=128, + latent_dim=64, + num_bb_types=50, + max_num_bbs=20, + num_diffusion_steps=100, + fc_num_layers=3, + kl_weight=0.1, + ): + super().__init__() + + # -- timestep embedding ------------------------------------------------ + self.time_embedding = GaussianFourierProjection(embedding_size=128) + time_emb_dim = 256 # 128 × 2 (sin + cos) + + # -- encoder ----------------------------------------------------------- + self.encoder = SimpleGNNEncoder( + node_feat_dim, hidden_dim, latent_dim, fc_num_layers + ) + + # -- VAE bottleneck ---------------------------------------------------- + self.fc_mu = nn.Linear(latent_dim, latent_dim) + self.fc_var = nn.Linear(latent_dim, latent_dim) + + # -- lattice predictor (6 = 3 lengths + 3 angles) --------------------- + self.fc_lattice = build_mlp(latent_dim, hidden_dim, fc_num_layers, 6) + + # -- number-of-BBs classifier ----------------------------------------- + self.fc_num_bbs = build_mlp( + latent_dim, hidden_dim, fc_num_layers, max_num_bbs + 1 + ) + + # -- denoiser ---------------------------------------------------------- + # Input: noisy coords (3) + one-hot bb type + time emb + latent + denoiser_input = 3 + num_bb_types + time_emb_dim + latent_dim + self.denoiser = SimpleGNNDecoder( + denoiser_input, hidden_dim, 3, num_bb_types, fc_num_layers + ) + + # -- diffusion process ------------------------------------------------- + self.vp_diffusion = VP(num_diffusion_steps) + self.num_diffusion_steps = num_diffusion_steps + self.num_bb_types = num_bb_types + self.latent_dim = latent_dim + self.max_num_bbs = max_num_bbs + self.kl_weight = kl_weight + + # ------------------------------------------------------------------ + # VAE helpers + # ------------------------------------------------------------------ + + def reparameterize(self, mu, log_var): + """Sample *z* ~ N(mu, σ²) via the reparameterization trick.""" + std = paddle.exp(0.5 * log_var) + eps = paddle.randn(std.shape) + return mu + eps * std + + # ------------------------------------------------------------------ + # Encode + # ------------------------------------------------------------------ + + def encode(self, node_features, batch_indices, num_graphs): + """Encode a batch of graphs into latent (mu, log_var, z). + + Returns: + (mu, log_var, z) — each ``[num_graphs, latent_dim]``. + """ + h = self.encoder(node_features, batch_indices, num_graphs) + mu = self.fc_mu(h) + log_var = self.fc_var(h) + z = self.reparameterize(mu, log_var) + return mu, log_var, z + + # ------------------------------------------------------------------ + # Core forward (training) + # ------------------------------------------------------------------ + + def _forward(self, batch): + """Core training forward: encode → diffuse → denoise → losses. + + Expected keys in *batch*: + + * ``node_features`` — ``[total_nodes, node_feat_dim]`` + * ``frac_coords`` — ``[total_nodes, 3]`` + * ``bb_types`` — ``[total_nodes]`` (int, 0-based class ids) + * ``batch`` — ``[total_nodes]`` (graph index per node) + * ``num_atoms`` — ``[B]`` + * ``lattice_params``— ``[B, 6]`` + """ + node_features = batch["node_features"] + frac_coords = batch["frac_coords"] + bb_types = batch["bb_types"] + batch_idx = batch["batch"] + num_atoms = batch["num_atoms"] + lattice_params = batch["lattice_params"] + + batch_size = num_atoms.shape[0] + total_nodes = node_features.shape[0] + + # --- 1. Encode ------------------------------------------------------- + mu, log_var, z = self.encode(node_features, batch_idx, batch_size) + + # --- 2. Lattice prediction loss --------------------------------------- + pred_lattice = self.fc_lattice(z) + loss_lattice = F.mse_loss(pred_lattice, lattice_params) + + # --- 3. Num-BBs classification loss ----------------------------------- + pred_num_bbs = self.fc_num_bbs(z) + loss_num_bbs = F.cross_entropy(pred_num_bbs, num_atoms) + + # --- 4. Sample diffusion timestep per node ---------------------------- + t = paddle.randint(1, self.num_diffusion_steps + 1, [total_nodes]) + + # --- 5. Noise the BB-type one-hot via VP diffusion -------------------- + bb_onehot = F.one_hot( + bb_types.astype("int64"), num_classes=self.num_bb_types + ).astype("float32") + noisy_h, eps_h = self.vp_diffusion.forward(bb_onehot, t) + + # --- 6. Noise the fractional coordinates via VP diffusion -------------- + noisy_x, eps_x = self.vp_diffusion.forward(frac_coords, t) + + # --- 7. Build time embedding per node --------------------------------- + t_normalised = t.astype("float32") / self.num_diffusion_steps + time_emb = self.time_embedding(t_normalised) # [N, 256] + + # --- 8. Expand latent to per-node ------------------------------------- + z_per_node = z[batch_idx] # [N, latent_dim] + + # --- 9. Denoiser input ------------------------------------------------ + denoiser_in = paddle.concat( + [noisy_x, noisy_h, time_emb, z_per_node], axis=-1 + ) + pred_eps_x, pred_eps_h = self.denoiser(denoiser_in) + + # --- 10. Reconstruction losses ---------------------------------------- + loss_coord = F.mse_loss(pred_eps_x, eps_x) + loss_type = F.mse_loss(pred_eps_h, eps_h) + + # --- 11. KL divergence ------------------------------------------------ + loss_kl = -0.5 * paddle.mean( + 1.0 + log_var - mu.pow(2) - log_var.exp() + ) + + # --- 12. Total loss --------------------------------------------------- + loss = loss_coord + loss_type + self.kl_weight * loss_kl + loss_lattice + loss_num_bbs + + return { + "loss_dict": { + "loss": loss, + "loss_coord": loss_coord, + "loss_type": loss_type, + "loss_kl": loss_kl, + "loss_lattice": loss_lattice, + "loss_num_bbs": loss_num_bbs, + } + } + + # ------------------------------------------------------------------ + # Public forward (PM convention) + # ------------------------------------------------------------------ + + def forward(self, batch, **kwargs): + """PaddleMaterials-compatible forward. + + Returns: + dict with ``"loss_dict"`` containing all loss terms. + """ + return self._forward(batch) + + # ------------------------------------------------------------------ + # Inference / sampling + # ------------------------------------------------------------------ + + @paddle.no_grad() + def sample(self, z, num_atoms_per_graph, num_steps=None): + """Generate MOF structures from a latent vector *z*. + + This is a simplified DDPM reverse loop that iteratively denoises + building-block types and fractional coordinates. + + Args: + z: ``[B, latent_dim]`` — sampled latent. + num_atoms_per_graph: ``[B]`` — how many BBs per structure. + num_steps: override for diffusion steps. + + Returns: + dict with ``pred_coords``, ``pred_types``, ``pred_lattice``. + """ + if num_steps is None: + num_steps = self.num_diffusion_steps + + batch_size = z.shape[0] + total_nodes = int(num_atoms_per_graph.sum().item()) + + # Build batch index + batch_idx = paddle.repeat_interleave( + paddle.arange(batch_size), num_atoms_per_graph + ) + + # Start from pure noise + ht = paddle.randn([total_nodes, self.num_bb_types]) + xt = paddle.randn([total_nodes, 3]) + + z_per_node = z[batch_idx] + + for step in range(num_steps, 0, -1): + t = paddle.full([total_nodes], step, dtype="int64") + t_norm = t.astype("float32") / self.num_diffusion_steps + time_emb = self.time_embedding(t_norm) + + denoiser_in = paddle.concat([xt, ht, time_emb, z_per_node], axis=-1) + pred_eps_x, pred_eps_h = self.denoiser(denoiser_in) + + # VP reverse for types + ht = self.vp_diffusion.reverse(ht, pred_eps_h, t) + + # VP reverse for coordinates + xt = self.vp_diffusion.reverse(xt, pred_eps_x, t) + + pred_lattice = self.fc_lattice(z) + + return { + "pred_coords": xt, + "pred_types": ht.argmax(axis=-1), + "pred_lattice": pred_lattice, + } diff --git a/structure_generation/configs/mofdiff/README.md b/structure_generation/configs/mofdiff/README.md new file mode 100644 index 00000000..d91012ae --- /dev/null +++ b/structure_generation/configs/mofdiff/README.md @@ -0,0 +1,44 @@ +# MOFDiff + +[MOFDiff: Coarse-Grained Diffusion for Metal-Organic Framework Generation](https://arxiv.org/abs/2310.10732) + +## Abstract + +MOFDiff is a coarse-grained diffusion model for generating Metal-Organic Frameworks (MOFs). It uses a three-stage pipeline consisting of: (1) a graph encoder with VAE bottleneck for learning latent representations, (2) a Variance Preserving (VP) diffusion process for generating building-block types and coordinates, and (3) a lattice predictor that maps latents to 6 lattice parameters. + +## Model + +MOFDiff operates on coarse-grained MOF representations where each building block is a single node. The encoder aggregates per-node features into a graph-level latent via a VAE bottleneck. A VP diffusion process with a GNN-based denoiser generates building-block type embeddings and 3D coordinates. A separate MLP predicts lattice parameters (3 lengths + 3 angles) from the latent. The total loss combines reconstruction, KL divergence, coordinate denoising, type denoising, lattice, and num-BBs classification terms. + +## Training + +```bash +# single-gpu training +python structure_generation/train.py -c structure_generation/configs/mofdiff/mofdiff_bw20k.yaml + +# multi-gpu training +python -m paddle.distributed.launch --gpus="0,1,2,3" structure_generation/train.py -c structure_generation/configs/mofdiff/mofdiff_bw20k.yaml +``` + +## Validation + +```bash +python structure_generation/train.py -c structure_generation/configs/mofdiff/mofdiff_bw20k.yaml Global.do_eval=True Global.do_train=False Global.do_test=False Trainer.pretrained_model_path='path/to/model.pdparams' +``` + +## Testing + +```bash +python structure_generation/train.py -c structure_generation/configs/mofdiff/mofdiff_bw20k.yaml Global.do_eval=False Global.do_train=False Global.do_test=True Trainer.pretrained_model_path='path/to/model.pdparams' +``` + +## Citation + +``` +@inproceedings{yao2024mofdiff, + title={Coarse-Grained Diffusion for Metal-Organic Framework Generation}, + author={Yao, Xiang and Mao, Nannan and Zhao, Yili and Chen, Chang and Usman, Muhammad and Tao, Dacheng}, + booktitle={International Conference on Learning Representations}, + year={2024} +} +``` diff --git a/structure_generation/configs/mofdiff/mofdiff_bw20k.yaml b/structure_generation/configs/mofdiff/mofdiff_bw20k.yaml new file mode 100644 index 00000000..60612809 --- /dev/null +++ b/structure_generation/configs/mofdiff/mofdiff_bw20k.yaml @@ -0,0 +1,93 @@ +Global: + do_train: True + do_eval: False + do_test: False + num_train_timesteps: 100 + +Trainer: + max_epochs: 500 + seed: 42 + output_dir: ./output/mofdiff_bw20k + save_freq: 100 + log_freq: 10 + start_eval_epoch: 1 + eval_freq: 1 + pretrained_model_path: null + resume_from_checkpoint: null + use_amp: False + amp_level: 'O1' + eval_with_no_grad: True + gradient_accumulation_steps: 1 + best_metric_indicator: 'eval_loss' + name_for_best_metric: "loss" + greater_is_better: False + compute_metric_during_train: False + metric_strategy_during_eval: 'step' + use_visualdl: False + use_wandb: False + use_tensorboard: False + +Model: + __class_name__: MOFDiff + __init_params__: + node_feat_dim: 64 + hidden_dim: 128 + latent_dim: 64 + num_bb_types: 50 + max_num_bbs: 20 + num_diffusion_steps: ${Global.num_train_timesteps} + fc_num_layers: 3 + +Optimizer: + __class_name__: Adam + __init_params__: + beta1: 0.9 + beta2: 0.999 + lr: + __class_name__: ReduceOnPlateau + __init_params__: + learning_rate: 0.001 + factor: 0.5 + by_epoch: True + patience: 30 + min_lr: 0.0001 + indicator: "train_loss" + indicator_name: 'loss' + +Dataset: + train: + dataset: + __class_name__: MOFDataset + __init_params__: + path: "./data/mof_bw20k/train.csv" + loader: + num_workers: 0 + use_shared_memory: False + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: True + drop_last: False + batch_size: 64 + val: + dataset: + __class_name__: MOFDataset + __init_params__: + path: "./data/mof_bw20k/val.csv" + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 32 + test: + dataset: + __class_name__: MOFDataset + __init_params__: + path: "./data/mof_bw20k/test.csv" + sampler: + __class_name__: BatchSampler + __init_params__: + shuffle: False + drop_last: False + batch_size: 32 diff --git a/test/mofdiff/__init__.py b/test/mofdiff/__init__.py new file mode 100644 index 00000000..4c7ea338 --- /dev/null +++ b/test/mofdiff/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/test/mofdiff/conftest.py b/test/mofdiff/conftest.py new file mode 100644 index 00000000..5ed2fb04 --- /dev/null +++ b/test/mofdiff/conftest.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""conftest — stub out heavy optional deps so mofdiff tests run on CPU +without pgl, ase, pymatgen, etc. installed.""" + +import importlib.util +import os +import sys +import types + +_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + + +def _ensure_package(dotted, fs_path): + """Register *dotted* as a package backed by *fs_path* if absent.""" + if dotted in sys.modules: + return sys.modules[dotted] + pkg = types.ModuleType(dotted) + pkg.__path__ = [os.path.join(_ROOT, fs_path)] + pkg.__package__ = dotted + sys.modules[dotted] = pkg + return pkg + + +def _load_module_from_file(dotted, rel_path): + """Load a single .py file as *dotted* without running parent __init__.py.""" + if dotted in sys.modules: + return sys.modules[dotted] + full = os.path.join(_ROOT, rel_path) + spec = importlib.util.spec_from_file_location(dotted, full) + mod = importlib.util.module_from_spec(spec) + sys.modules[dotted] = mod + spec.loader.exec_module(mod) + return mod + + +# 1. Stub heavy third-party deps +for name in ("pgl", "ase", "pymatgen", "matgl", "jarvis"): + if name not in sys.modules: + sys.modules[name] = types.ModuleType(name) + +# 2. Register minimal ppmat package tree (no __init__.py execution) +_ensure_package("ppmat", "ppmat") +_ensure_package("ppmat.utils", "ppmat/utils") +_ensure_package("ppmat.models", "ppmat/models") +_ensure_package("ppmat.models.mofdiff", "ppmat/models/mofdiff") + +# 3. Load scatter utility (the only real dependency of mofdiff.py) +_load_module_from_file("ppmat.utils.scatter", "ppmat/utils/scatter.py") + +# 4. Load the mofdiff module itself +_load_module_from_file( + "ppmat.models.mofdiff.mofdiff", "ppmat/models/mofdiff/mofdiff.py" +) diff --git a/test/mofdiff/test_mofdiff.py b/test/mofdiff/test_mofdiff.py new file mode 100644 index 00000000..e81578cb --- /dev/null +++ b/test/mofdiff/test_mofdiff.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ppmat.models.mofdiff.""" + +import unittest + +import numpy as np +import paddle + +from ppmat.models.mofdiff.mofdiff import ( + VP, + GaussianFourierProjection, + MOFDiff, + SimpleGNNDecoder, + SimpleGNNEncoder, + build_mlp, +) + + +def _make_batch(batch_size=2, nodes_per_graph=3, node_feat_dim=32, num_bb_types=10): + """Create a minimal dummy batch for MOFDiff.""" + total = batch_size * nodes_per_graph + batch_idx = paddle.repeat_interleave( + paddle.arange(batch_size), + paddle.full([batch_size], nodes_per_graph, dtype="int64"), + ) + return { + "node_features": paddle.randn([total, node_feat_dim]), + "frac_coords": paddle.rand([total, 3]), + "bb_types": paddle.randint(0, num_bb_types, [total]), + "batch": batch_idx, + "num_atoms": paddle.full([batch_size], nodes_per_graph, dtype="int64"), + "lattice_params": paddle.randn([batch_size, 6]), + } + + +class TestGaussianFourierProjection(unittest.TestCase): + """GaussianFourierProjection embedding tests.""" + + def test_output_shape(self): + emb = GaussianFourierProjection(embedding_size=64, scale=1.0) + x = paddle.to_tensor([0.0, 0.5, 1.0]) + out = emb(x) + self.assertEqual(list(out.shape), [3, 128]) + + def test_determinism(self): + """Same input → same output (frozen weights).""" + emb = GaussianFourierProjection(embedding_size=32) + x = paddle.to_tensor([0.25, 0.75]) + np.testing.assert_allclose( + emb(x).numpy(), emb(x).numpy(), atol=1e-6 + ) + + def test_different_inputs(self): + emb = GaussianFourierProjection(embedding_size=32) + a = emb(paddle.to_tensor([0.0])) + b = emb(paddle.to_tensor([1.0])) + self.assertFalse(np.allclose(a.numpy(), b.numpy())) + + +class TestVPDiffusion(unittest.TestCase): + """VP (Variance Preserving) diffusion schedule tests.""" + + def setUp(self): + self.vp = VP(num_steps=100, s=0.0001, power=2) + + def test_alpha_bar_boundaries(self): + """alpha_bar_0 ~ 1 and alpha_bar_T < alpha_bar_0.""" + ab = self.vp.alpha_bars.numpy() + np.testing.assert_allclose(ab[0], 1.0, atol=1e-4) + self.assertLess(ab[-1], ab[0]) + + def test_alpha_bars_monotonically_decrease(self): + ab = self.vp.alpha_bars.numpy() + self.assertTrue(np.all(np.diff(ab) <= 0)) + + def test_forward_shape(self): + h0 = paddle.randn([6, 16]) + t = paddle.randint(1, 101, [6]) + ht, eps = self.vp.forward(h0, t) + self.assertEqual(list(ht.shape), [6, 16]) + self.assertEqual(list(eps.shape), [6, 16]) + + def test_forward_t0_recovers_input(self): + """At t=0 (alpha_bar=1) the noisy signal equals the clean input.""" + h0 = paddle.randn([4, 8]) + t = paddle.zeros([4], dtype="int64") + ht, _ = self.vp.forward(h0, t) + np.testing.assert_allclose(ht.numpy(), h0.numpy(), atol=1e-5) + + def test_reverse_shape(self): + ht = paddle.randn([6, 16]) + eps = paddle.randn([6, 16]) + t = paddle.randint(1, 101, [6]) + out = self.vp.reverse(ht, eps, t) + self.assertEqual(list(out.shape), [6, 16]) + + +class TestSimpleGNNEncoder(unittest.TestCase): + def test_output_shape(self): + enc = SimpleGNNEncoder(input_dim=16, hidden_dim=32, output_dim=8, num_layers=2) + feats = paddle.randn([6, 16]) + batch_idx = paddle.to_tensor([0, 0, 0, 1, 1, 1], dtype="int64") + out = enc(feats, batch_idx, 2) + self.assertEqual(list(out.shape), [2, 8]) + + +class TestSimpleGNNDecoder(unittest.TestCase): + def test_output_shapes(self): + dec = SimpleGNNDecoder( + input_dim=32, hidden_dim=64, output_coord_dim=3, + output_type_dim=10, num_layers=2, + ) + feats = paddle.randn([6, 32]) + eps_x, eps_h = dec(feats) + self.assertEqual(list(eps_x.shape), [6, 3]) + self.assertEqual(list(eps_h.shape), [6, 10]) + + +class TestMOFDiff(unittest.TestCase): + """End-to-end MOFDiff model tests.""" + + def setUp(self): + paddle.seed(42) + self.model = MOFDiff( + node_feat_dim=32, + hidden_dim=64, + latent_dim=32, + num_bb_types=10, + max_num_bbs=5, + num_diffusion_steps=50, + fc_num_layers=2, + ) + + def test_forward_returns_loss_dict(self): + """forward() must return {'loss_dict': {...}} (PM convention).""" + batch = _make_batch(node_feat_dim=32, num_bb_types=10) + result = self.model(batch) + self.assertIn("loss_dict", result) + for key in ("loss", "loss_coord", "loss_type", "loss_kl", + "loss_lattice", "loss_num_bbs"): + self.assertIn(key, result["loss_dict"], f"Missing {key}") + self.assertFalse( + paddle.isnan(result["loss_dict"][key]).item(), + f"{key} is NaN", + ) + + def test_forward_loss_is_positive(self): + batch = _make_batch(node_feat_dim=32, num_bb_types=10) + result = self.model(batch) + self.assertGreater(result["loss_dict"]["loss"].item(), 0.0) + + def test_encode_shape(self): + feats = paddle.randn([6, 32]) + batch_idx = paddle.to_tensor([0, 0, 0, 1, 1, 1], dtype="int64") + mu, log_var, z = self.model.encode(feats, batch_idx, 2) + self.assertEqual(list(mu.shape), [2, 32]) + self.assertEqual(list(log_var.shape), [2, 32]) + self.assertEqual(list(z.shape), [2, 32]) + + def test_sample_output_keys(self): + """sample() must return predicted coords, types, and lattice.""" + self.model.eval() + z = paddle.randn([2, 32]) + num_atoms = paddle.to_tensor([3, 3], dtype="int64") + out = self.model.sample(z, num_atoms, num_steps=3) + self.assertIn("pred_coords", out) + self.assertIn("pred_types", out) + self.assertIn("pred_lattice", out) + self.assertEqual(list(out["pred_coords"].shape), [6, 3]) + self.assertEqual(list(out["pred_types"].shape), [6]) + self.assertEqual(list(out["pred_lattice"].shape), [2, 6]) + + def test_gradient_flow(self): + """All parameters receive gradients after a forward + backward.""" + batch = _make_batch(node_feat_dim=32, num_bb_types=10) + result = self.model(batch) + result["loss_dict"]["loss"].backward() + for name, param in self.model.named_parameters(): + if not param.stop_gradient: + self.assertIsNotNone( + param.grad, + f"No gradient for {name}", + ) + + def test_build_mlp(self): + mlp = build_mlp(16, 32, 3, 8) + out = mlp(paddle.randn([4, 16])) + self.assertEqual(list(out.shape), [4, 8]) + + +class TestVPMathProperties(unittest.TestCase): + """Verify mathematical invariants of the VP schedule.""" + + def test_betas_non_negative(self): + vp = VP(num_steps=200) + self.assertTrue(np.all(vp.betas.numpy() >= 0)) + + def test_forward_variance_preservation(self): + """For large sample, Var(h_t) ~ 1 when h_0 ~ N(0,1).""" + vp = VP(num_steps=500) + paddle.seed(0) + h0 = paddle.randn([10000, 4]) + t = paddle.full([10000], 250, dtype="int64") + ht, _ = vp.forward(h0, t) + var = ht.numpy().var() + np.testing.assert_allclose(var, 1.0, atol=0.15) + + +if __name__ == "__main__": + unittest.main()