Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ppmat/models/liflow/ppmat/model/liflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .model import LiFlow
from .layers import FlowLayer, NormalizingFlow
from .utils import get_edge_vectors, atomic_number_to_index

__all__ = ['LiFlow', 'FlowLayer', 'NormalizingFlow', 'atomic_number_to_index']
60 changes: 60 additions & 0 deletions ppmat/models/liflow/ppmat/model/liflow/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .utils import radial_basis

class FlowLayer(nn.Layer):
"""Single flow layer"""
def __init__(self, hidden_dim, num_basis, r_max):
super().__init__()
self.hidden_dim = hidden_dim
self.num_basis = num_basis
self.r_max = r_max

# Radial basis function
self.rbf = nn.LayerList()
for _ in range(hidden_dim):
self.rbf.append(nn.Linear(num_basis, hidden_dim))

# Flow parameters
self.scale = nn.Linear(hidden_dim, hidden_dim)
self.shift = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x, edge_indices, edge_distances):
# Get radial basis features
rbf_features = radial_basis(edge_distances, self.r_max, self.num_basis)

# Get source and target nodes
src, dst = edge_indices

# Message passing
x_src = x[src]
message = paddle.concat([x_src, rbf_features], axis=-1)

# Aggregate messages
aggregated = paddle.zeros_like(x)
aggregated = paddle.scatter(aggregated, dst.unsqueeze(-1), message, axis=0, reduce='sum')

# Flow transformation
scale = F.sigmoid(self.scale(aggregated))
shift = self.shift(aggregated)
x = x * scale + shift

return x

class NormalizingFlow(nn.Layer):
"""Normalizing flow module"""
def __init__(self, hidden_dim, num_layers, num_basis, r_max):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_basis = num_basis
self.r_max = r_max

# Flow layers
self.layers = nn.LayerList()
for _ in range(num_layers):
self.layers.append(FlowLayer(hidden_dim, num_basis, r_max))
def forward(self, x, edge_indices, edge_distances):
for layer in self.layers:
x = layer(x, edge_indices, edge_distances)
return x
69 changes: 69 additions & 0 deletions ppmat/models/liflow/ppmat/model/liflow/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .layers import NormalizingFlow
from .utils import get_edge_vectors, atomic_number_to_index

class LiFlow(nn.Layer):
"""LiFlow model implementation"""
def __init__(self,
hidden_dim=128,
num_layers=4,
num_basis=8,
r_max=5.0,
num_elements=100):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.num_basis = num_basis
self.r_max = r_max
self.num_elements = num_elements

# Embedding layer for atomic numbers
self.embedding = nn.Embedding(num_elements, hidden_dim)

# Normalizing flow
self.flow = NormalizingFlow(hidden_dim, num_layers, num_basis, r_max)

# Output layer for energy prediction
self.output = nn.Linear(hidden_dim, 1)
def forward(self, atomic_numbers, positions, cell=None, pbc=False):
# Convert atomic numbers to indices
atom_indices = atomic_number_to_index(atomic_numbers)
atom_indices = paddle.to_tensor(atom_indices, dtype='int64')

# Get edge vectors and distances
edge_vectors, edge_distances = get_edge_vectors(positions, cell, pbc)

# Create edge indices
num_atoms = positions.shape[0]
edge_indices = paddle.meshgrid(paddle.arange(num_atoms), paddle.arange(num_atoms))
edge_indices = paddle.stack(edge_indices, axis=0).reshape((2, -1))

# Remove self-edges
mask = edge_indices[0] != edge_indices[1]
edge_indices = edge_indices[:, mask]
edge_distances = edge_distances.reshape((-1,))[mask]

# Filter edges by distance
distance_mask = edge_distances < self.r_max
edge_indices = edge_indices[:, distance_mask]
edge_distances = edge_distances[distance_mask]

# Initialize node features
x = self.embedding(atom_indices)

# Pass through normalizing flow
x = self.flow(x, edge_indices, edge_distances)

# Aggregate node features
atomic_energies = self.output(x)
total_energy = atomic_energies.sum()

# Compute forces by differentiating energy with respect to positions
forces = -paddle.grad(total_energy, positions)[0]

return total_energy, forces
def predict(self, atomic_numbers, positions, cell=None, pbc=False):
"""Predict energy and forces"""
return self.forward(atomic_numbers, positions, cell, pbc)
28 changes: 28 additions & 0 deletions ppmat/models/liflow/ppmat/model/liflow/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import paddle
import numpy as np

def atomic_number_to_index(atomic_numbers):
"""Convert atomic numbers to indices"""
unique_atoms = sorted(list(set(atomic_numbers)))
atom_to_idx = {atom: i for i, atom in enumerate(unique_atoms)}
return [atom_to_idx[atom] for atom in atomic_numbers]

def get_edge_vectors(positions, cell=None, pbc=False):
"""Compute edge vectors between atoms"""
num_atoms = positions.shape[0]
# Compute pairwise differences
pos_diff = positions.unsqueeze(0) - positions.unsqueeze(1)
# Reshape to (num_atoms, num_atoms, 3)
edge_vectors = pos_diff.reshape((num_atoms, num_atoms, 3))
# Compute distances
edge_distances = paddle.norm(edge_vectors, axis=-1)
return edge_vectors, edge_distances

def radial_basis(r, r_max, num_basis):
"""Radial basis function"""
r = paddle.clip(r, 0, r_max)
scaled_r = 2 * r / r_max - 1
basis = []
for n in range(num_basis):
basis.append(paddle.cos(n * np.pi * scaled_r))
return paddle.stack(basis, axis=-1)
35 changes: 35 additions & 0 deletions ppmat/models/liflow/ppmat/predictor/liflow_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import paddle
from ppmat.model.liflow import LiFlow

class LiFlowPredictor:
"""Predictor for LiFlow model"""
def __init__(self, config):
self.config = config
self.model = LiFlow(
hidden_dim=config['hidden_dim'],
num_layers=config['num_layers'],
num_basis=config['num_basis'],
r_max=config['r_max'],
num_elements=config['num_elements']
)
def load_model(self, path):
"""Load model"""
state_dict = paddle.load(path)
self.model.set_state_dict(state_dict)
self.model.eval()
def predict(self, atomic_numbers, positions, cell=None, pbc=False):
"""Predict energy and forces"""
with paddle.no_grad():
energy, forces = self.model(atomic_numbers, positions, cell, pbc)
return energy.numpy(), forces.numpy()
def batch_predict(self, data_loader):
"""Batch predict"""
energies = []
forces = []
with paddle.no_grad():
for batch in data_loader:
atomic_numbers, positions, _, _ = batch
energy, force = self.model(atomic_numbers, positions)
energies.append(energy.numpy())
forces.append(force.numpy())
return energies, forces
60 changes: 60 additions & 0 deletions ppmat/models/liflow/ppmat/trainer/liflow_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import paddle
import paddle.nn as nn
import paddle.optimizer as optim
from ppmat.model.liflow import LiFlow

class LiFlowTrainer:
"""Trainer for LiFlow model"""
def __init__(self, config):
self.config = config
self.model = LiFlow(
hidden_dim=config['hidden_dim'],
num_layers=config['num_layers'],
num_basis=config['num_basis'],
r_max=config['r_max'],
num_elements=config['num_elements']
)
self.optimizer = optim.Adam(
parameters=self.model.parameters(),
learning_rate=config['learning_rate']
)
self.loss_fn = nn.MSELoss()
def train_step(self, batch):
"""Single training step"""
atomic_numbers, positions, energies, forces = batch

# Forward pass
pred_energy, pred_forces = self.model(atomic_numbers, positions)

# Compute loss
energy_loss = self.loss_fn(pred_energy, energies)
force_loss = self.loss_fn(pred_forces, forces)
total_loss = energy_loss + self.config['force_weight'] * force_loss

# Backward pass
total_loss.backward()
self.optimizer.step()
self.optimizer.clear_grad()

return total_loss.item()
def evaluate(self, data_loader):
"""Evaluate model"""
self.model.eval()
total_loss = 0
with paddle.no_grad():
for batch in data_loader:
atomic_numbers, positions, energies, forces = batch
pred_energy, pred_forces = self.model(atomic_numbers, positions)
energy_loss = self.loss_fn(pred_energy, energies)
force_loss = self.loss_fn(pred_forces, forces)
batch_loss = energy_loss + self.config['force_weight'] * force_loss
total_loss += batch_loss.item()
self.model.train()
return total_loss / len(data_loader)
def save_model(self, path):
"""Save model"""
paddle.save(self.model.state_dict(), path)
def load_model(self, path):
"""Load model"""
state_dict = paddle.load(path)
self.model.set_state_dict(state_dict)