Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ppmat/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,13 @@
from ppmat.models.dimenetpp.dimenetpp import DimeNetPlusPlus
from ppmat.models.mattergen.mattergen import MatterGen
from ppmat.models.mattergen.mattergen import MatterGenWithCondition
from ppmat.models.matinvent.mattergen_compat import MatinventMatterGen
from ppmat.models.mattersim.m3gnet import M3GNet
from ppmat.models.mattersim.m3gnet_graph_converter import M3GNetGraphConvertor
from ppmat.models.megnet.megnet import MEGNetPlus
from ppmat.models.infgcn.infgcn import InfGCN
from ppmat.models.mateno.mateno import MatENO
import ppmat.models.matinvent.rl_wrapper
from ppmat.utils import download
from ppmat.utils import logger
from ppmat.utils import save_load
Expand All @@ -54,6 +56,7 @@
"MEGNetPlus",
"MatterGen",
"MatterGenWithCondition",
"MatinventMatterGen",
"DimeNetPlusPlus",
"CrystalNN",
"CHGNetGraphConverter",
Expand Down
165 changes: 165 additions & 0 deletions ppmat/models/diffcsp/diffcsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ppmat.schedulers import build_scheduler
from ppmat.utils import paddle_aux # noqa
from ppmat.utils.crystal import lattice_params_to_matrix_paddle
from ppmat.utils.scatter import scatter


def p_wrapped_normal(x, sigma, N=10, T=1.0):
Expand Down Expand Up @@ -531,3 +532,167 @@ def sample(self, batch_data, num_inference_steps=1000, **kwargs):
start_idx += structure_array["num_atoms"][i]

return {"result": result}

def add_noise(self, batch, timestep: int):
"""Add noise to batch for reinforcement learning fine-tuning.

Args:
batch: Input batch data structure
timestep: Diffusion timestep (0 to num_train_timesteps-1)

Returns:
Tuple of (noisy_batch, clean_batch, timesteps)
- noisy_batch: Dictionary with noisy frac_coords, lattice, atom_types
- clean_batch: Original clean batch
- t: Timestep tensor
"""
structure_array = batch["structure_array"]
num_atoms = structure_array["num_atoms"]
batch_size = num_atoms.shape[0]
batch_idx = paddle.repeat_interleave(
paddle.arange(batch_size), repeats=num_atoms
)

# Create timestep tensor
t = paddle.full([batch_size], timestep, dtype='int64')
times_per_atom = t.repeat_interleave(repeats=num_atoms)

# Add noise to lattice
if "lattice" in structure_array.keys():
lattices = structure_array["lattice"]
else:
lattices = lattice_params_to_matrix_paddle(
structure_array["lengths"], structure_array["angles"]
)
rand_l = paddle.randn(shape=lattices.shape, dtype=lattices.dtype)
input_lattice = self.lattice_scheduler.add_noise(
lattices, rand_l, timesteps=t
)

# Add noise to coordinates
frac_coords = structure_array["frac_coords"]
rand_x = paddle.randn(shape=frac_coords.shape, dtype=frac_coords.dtype)
input_frac_coords = self.coord_scheduler.add_noise(
frac_coords, rand_x, timesteps=times_per_atom
)
input_frac_coords = input_frac_coords % 1.0

# Create noisy batch structure
noisy_batch = {
"structure_array": {
"frac_coords": input_frac_coords,
"lattice": input_lattice,
"atom_types": structure_array["atom_types"],
"num_atoms": num_atoms,
},
"batch_idx": batch_idx,
"timesteps": t,
}

# Return tuple format expected by calc_sample_loss
return noisy_batch, batch, t

def calc_sample_loss(self, noised_input):
"""Calculate sample loss for reinforcement learning fine-tuning.
Args:
noised_input: Tuple of (noisy_batch, clean_batch, timesteps)
- noisy_batch: Dictionary with noisy frac_coords, lattice, atom_types
- clean_batch: Original clean batch
- t: Timestep tensor

Returns:
Tuple of (loss, prediction_dict)
- loss: Per-sample loss tensor of shape (batch_size,)
- prediction_dict: Dictionary with 'coords', 'lattice' keys
"""
noisy_batch, clean_batch, t = noised_input
structure_array = noisy_batch["structure_array"]
batch_idx = noisy_batch["batch_idx"]

# Get time embedding
time_emb = self.time_embedding(t)

# Run decoder to get predictions
pred_l, pred_x = self.decoder(
time_emb,
structure_array["atom_types"] - 1,
structure_array["frac_coords"],
structure_array["lattice"],
structure_array["num_atoms"],
batch_idx,
)

# Get clean targets
clean_structure = clean_batch["structure_array"]
if "lattice" in clean_structure.keys():
clean_lattice = clean_structure["lattice"]
else:
clean_lattice = lattice_params_to_matrix_paddle(
clean_structure["lengths"], clean_structure["angles"]
)
clean_frac_coords = clean_structure["frac_coords"]

# Calculate lattice loss (MSE)
lattice_loss = paddle.pow(pred_l - clean_lattice, 2).mean(axis=(1, 2))

# Calculate coordinate loss (MSE per atom, then aggregate)
coord_loss_per_atom = paddle.pow(pred_x - clean_frac_coords, 2).mean(axis=1)
coord_loss = scatter(coord_loss_per_atom, batch_idx, dim=0, reduce="mean")

# Combine losses with weights
lattice_weight = getattr(self, "lattice_loss_weight", 1.0)
coord_weight = getattr(self, "coord_loss_weight", 1.0)

# Total weighted loss per sample
total_loss = lattice_weight * lattice_loss + coord_weight * coord_loss

# Create prediction dict for KL calculation
prediction_dict = {
"coords": pred_x,
"lattice": pred_l,
}

return total_loss, prediction_dict

def calc_kl_reg(self, agent_pred, prior_pred, batch):
"""Calculate KL divergence regularization for reinforcement learning.

Uses paddle.scatter to replace torch_scatter.

Args:
agent_pred: Prediction dict from agent model with 'coords' and 'lattice' keys
prior_pred: Prediction dict from prior (frozen) model with same keys
batch: Input batch data

Returns:
KL divergence loss per sample (tensor of shape batch_size)
"""
# Extract predictions from agent and prior
pred_x, pred_l = agent_pred["coords"], agent_pred["lattice"]
pred_x_p, pred_l_p = prior_pred["coords"].detach(), prior_pred["lattice"].detach()

# Get batch index for aggregation
if "batch_idx" in batch:
batch_idx = batch["batch_idx"]
elif "structure_array" in batch:
structure_array = batch["structure_array"]
num_atoms = structure_array["num_atoms"]
batch_size = num_atoms.shape[0]
batch_idx = paddle.repeat_interleave(
paddle.arange(batch_size), repeats=num_atoms
)
else:
raise ValueError("Cannot find batch index in batch input")

# Compute KL divergence for lattice (per sample)
# Mean squared difference between agent and prior lattice predictions
kl_term_lattice = paddle.pow(pred_l - pred_l_p, 2).mean(axis=(1, 2))

# Compute KL divergence for coordinates (per atom, then aggregate to per sample)
x_ap = paddle.pow(pred_x - pred_x_p, 2).mean(axis=1)
kl_term_coord = scatter(x_ap, batch_idx, dim=0, reduce="mean")

# Total KL divergence is sum of both terms
kl_term = kl_term_lattice + kl_term_coord

return kl_term
72 changes: 72 additions & 0 deletions ppmat/models/matinvent/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# MatInvent

MatInvent 是一个结合扩散模型与强化学习的材料生成框架,用于定向设计具有特定性质的晶体材料。

## 功能特点

- **扩散模型集成**:支持 MatterGen 和 DiffCSP 作为生成器 backbone
- **强化学习优化**:通过奖励函数引导生成过程,实现定向材料设计
- **多属性评估**:支持多种材料性质计算器,如密度、价格、丰度等
- **内存管理**:包含回放缓冲区和长期记忆模块,提升训练效率

## 目录结构

```
matinvent/
├── rl/ # 强化学习核心代码
│ ├── models/ # 模型套件实现
│ ├── samplers.py # 采样器
│ └── mat_invent.py # MatInvent 主类
├── rewards/ # 奖励系统
│ ├── calculators/ # 性质计算器
│ └── reward.py # 奖励函数
├── memory/ # 内存管理
│ ├── replay_buffer.py # 回放缓冲区
│ └── ltm.py # 长期记忆
├── rl_train.py # 强化学习训练脚本
└── rl_wrapper.py # 包装器,用于集成到现有训练流程
```

## 核心组件

1. **MatInvent**:强化学习训练的核心类,实现了采样、评估、微调的完整循环
2. **奖励系统**:计算材料性质并转换为奖励信号
3. **模型套件**:封装了 MatterGen 和 DiffCSP 模型的加载和使用
4. **内存模块**:存储和管理生成的材料结构

## 使用方法

### 强化学习训练

```bash
# 使用统一入口执行 RL 训练
python structure_generation/train.py \
--config structure_generation/configs/matinvent/matinvent_mattergen.yaml
```

### 采样

```bash
# 使用训练好的模型进行采样
python structure_generation/sample.py \
--config_path structure_generation/configs/matinvent/matinvent_mattergen.yaml \
--checkpoint_path output/matinvent_mattergen/models/final/model.pdparams \
--save_path results/matinvent_mattergen \
--mode by_dataloader
```

## 配置文件

配置文件位于 `structure_generation/configs/matinvent/` 目录,包含以下主要部分:

- **Global**:全局设置
- **Model**:模型配置
- **RL**:强化学习参数
- **Sample**:采样设置

## 支持的性质计算器

- **PyMatGen**:密度、价格、丰度等基本性质
- **SynScore**:可合成性评分
- **FairChem**:基于 FairChem 的性质预测
- **DFT**:基于密度泛函理论的性质计算
50 changes: 50 additions & 0 deletions ppmat/models/matinvent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
MatInvent model module.

This module contains the MatInvent material generation model and its associated
components for reinforcement learning based material discovery.

Components:
- memory: Long-term memory and replay buffer for RL
- rewards: Reward system and property calculators
- common: Shared utilities that reuse ppmat built-in functionality

Note: RL compatibility for MatterGen is now provided through the
MatterGenRLAdapter class instead of monkey patching.
"""

from ppmat.models.matinvent.common import build_model_from_config
from ppmat.models.matinvent.common import load_config
from ppmat.models.matinvent.common import prepare_output_dir
from ppmat.models.matinvent.common import setup_matinvent_logging
from ppmat.models.matinvent.memory import LongTimeMem
from ppmat.models.matinvent.memory import ReplayBuffer
from ppmat.models.matinvent.rewards import Calculator
from ppmat.models.matinvent.rewards import Reward
from ppmat.models.matinvent.rl import MatInvent

__all__ = [
"MatInvent",
"setup_matinvent_logging",
"load_config",
"build_model_from_config",
"prepare_output_dir",
"ReplayBuffer",
"LongTimeMem",
"Reward",
"Calculator",
]
Loading