diff --git a/.gitignore b/.gitignore index 54b7888..089f18d 100644 --- a/.gitignore +++ b/.gitignore @@ -145,3 +145,4 @@ wandb/ *outputs* *cache* +tests/cross_repo/ diff --git a/README.md b/README.md index 0d3afb0..5d93ac3 100644 --- a/README.md +++ b/README.md @@ -75,13 +75,19 @@ This code only supports sampling structures of monomers. You can try to sample m ## Steering to avoid chain breaks and clashes -BioEmu includes a [steering system](https://arxiv.org/abs/2501.06848) that uses [Sequential Monte Carlo (SMC)](https://www.stats.ox.ac.uk/~doucet/doucet_defreitas_gordon_smcbookintro.pdf) to guide the diffusion process toward more physically plausible protein structures. +BioEmu includes a [steering system](https://arxiv.org/abs/2501.06848) that guides the diffusion process toward more physically plausible protein structures. +Steering applies potential energy functions during denoising to favor conformations that satisfy physical constraints. +Two steering algorithms are available: + +- **SMC (Sequential Monte Carlo)**: Simulates multiple *candidate samples* (particles) per desired output sample and resamples between them according to the favorability of the provided potentials. This is the default for physical steering. +- **FKC (Feynman–Kac Control)**: Uses importance weighting and may perform ESS-based resampling between particles; useful when targeting a specific collective variable value (e.g., RMSD to a reference). + Empirically, using three (or up to 10) steering particles per output sample greatly reduces the number of unphysical samples (steric clashes or chain breaks) produced by the model. -Steering applies potential energy functions during denoising to favor conformations that satisfy physical constraints. -Algorithmically, steering simulates multiple *candidate samples* per desired output sample and resamples between these particles according to the favorability of the provided potentials. ### Quick start with steering +Steering is configured via a single YAML file passed as `denoiser_config`. This file specifies the denoiser, potentials, and steering parameters together. + Enable steering with physical constraints using the CLI: ```bash @@ -89,8 +95,7 @@ python -m bioemu.sample \ --sequence GYDPETGTWG \ --num_samples 100 \ --output_dir ~/steered-samples \ - --steering_config src/bioemu/config/steering/physical_steering.yaml \ - --denoiser_config src/bioemu/config/denoiser/stochastic_dpm.yaml + --denoiser_config src/bioemu/config/steering/physical_steering.yaml ``` Or using the Python API: @@ -102,26 +107,27 @@ sample( sequence='GYDPETGTWG', num_samples=100, output_dir='~/steered-samples', - denoiser_config="../src/bioemu/config/denoiser/stochastic_dpm.yaml", # Use stochastic DPM - steering_config="../src/bioemu/config/steering/physicality_steering.yaml", # Use physicality steering + denoiser_config="src/bioemu/config/steering/physical_steering.yaml", ) ``` ### Key steering parameters -- `num_steering_particles`: Number of particles per sample (1 = no steering, >1 enables steering) -- `steering_start_time`: When to start steering (0.0-1.0, default: 0.1) with reverse sampling 1 -> 0 -- `steering_end_time`: When to stop steering (0.0-1.0, default: 0.) with reverse sampling 1 -> 0 -- `resampling_interval`: How often to resample particles (default: 1) -- `steering_config`: Path to potentials configuration file (required for steering) +Inside the steering YAML config (e.g., [`physical_steering.yaml`](./src/bioemu/config/steering/physical_steering.yaml)): + +- `num_particles`: Number of particles per sample (higher = stronger steering, more compute) +- `ess_threshold`: Effective sample size threshold for resampling (0.0–1.0; SMC only) +- `start`: Diffusion time to start steering (0.0–1.0, default: 0.1; reverse process goes 1→0) +- `end`: Diffusion time to stop steering (0.0–1.0, default: 0.0) +- `fk_potentials`: List of potential energy functions to apply (Hydra-instantiated) ### Available potentials The [`physical_steering.yaml`](./src/bioemu/config/steering/physical_steering.yaml) configuration provides potentials for physical realism: -- **ChainBreak**: Prevents backbone discontinuities -- **ChainClash**: Avoids steric clashes between non-neighboring residues +- **CaCaDistance** + **UmbrellaPotential**: Prevents backbone discontinuities by penalizing large Cα–Cα distances +- **PairwiseClash** + **UmbrellaPotential**: Avoids steric clashes between non-neighboring residues -You can create a custom `steering_config.yaml` YAML file instantiating your own potential to steer the system with your own potentials. +For custom steering, you can write your own YAML config targeting any combination of potentials and collective variables. See [`cv_steer.yaml`](./src/bioemu/config/steering/cv_steer.yaml) for an example that steers toward a target RMSD value using FKC. ## Azure AI Foundry BioEmu is also available on [Azure AI Foundry](https://ai.azure.com/). See [How to run BioEmu on Azure AI Foundry](AZURE_AI_FOUNDRY.md) for more details. diff --git a/notebooks/fkc_steering.py b/notebooks/fkc_steering.py new file mode 100644 index 0000000..3cc661f --- /dev/null +++ b/notebooks/fkc_steering.py @@ -0,0 +1,245 @@ +"""FKC steering toy example using the dpm_solver_fkc sampler. + +Demonstrates Feynman-Kac Controlled (FKC) steering on a 1D Gaussian Mixture Model. +A quadratic potential biases the GMM, and we compare steered samples to the +analytically computed ground truth distribution. +""" + +import logging +import warnings + +# Suppress expected warnings for 1D toy setup +logging.basicConfig(level=logging.ERROR) +warnings.filterwarnings("ignore", message=".*Bio.pairwise2.*") + +import matplotlib.pyplot as plt +import torch +import torch.nn as nn +from torch_geometric.data import Batch + +from bioemu.chemgraph import ChemGraph +from bioemu.sde_lib import CosineVPSDE +from bioemu.so3_sde import DiGSO3SDE +from bioemu.steering.dpm_fkc import dpm_solver_fkc +from bioemu.toy_gmm import TimeDependentGMM1D + +# ============================================================ +# 1. Setup: GMM target distribution + SDE scheduler +# ============================================================ + +sde = CosineVPSDE() + +gmm = TimeDependentGMM1D( + mu1=torch.tensor([-1.0]), + mu2=torch.tensor([2.0]), + sigma1=1, + sigma2=0.5, + weight1=0.9, + scheduler=sde, +) + + +# ============================================================ +# 2. Score model wrapper (makes GMM score compatible with get_score) +# ============================================================ + + +class GMMScoreWrapper(nn.Module): + """Wraps TimeDependentGMM1D to conform to the bioemu score model interface. + + ``get_score()`` divides the model output by ``pos_std``, so this wrapper + returns ``analytical_score × std`` for pos and zeros for SO3. + """ + + def __init__(self, gmm: TimeDependentGMM1D, pos_sde: CosineVPSDE): + super().__init__() + self.gmm = gmm + self.pos_sde = pos_sde + + def forward(self, batch, t): + # Extract 1D positions (first coordinate of each node) + x = batch.pos[:, 0:1] # [total_nodes, 1] + t_per_node = t[batch.batch] + + # Analytical score from GMM. Pass create_graph=True when batch.pos + # requires grad so that downstream autograd through x0 (when + # use_x0_for_reward=True) sees the score's dependence on x_t (the + # Hessian of log p_t). Without this, d(x0)/d(x_t) is incorrectly + # treated as just 1/alpha_t, producing biased steered samples. + score_1d = self.gmm.score(x, t_per_node, create_graph=batch.pos.requires_grad) + + # get_score divides by std — pre-multiply to cancel + _, pos_std = self.pos_sde.marginal_prob( + x=torch.ones_like(batch.pos), t=t, batch_idx=batch.batch + ) + + # Embed 1D score into 3D: [score × std, 0, 0] + zero_yz = torch.zeros( + batch.pos.shape[0], 2, device=batch.pos.device, dtype=batch.pos.dtype + ) + pos_output = torch.cat([score_1d * pos_std[:, 0:1], zero_yz], dim=1) + + return { + "pos": pos_output, + "node_orientations": torch.zeros(batch.pos.shape[0], 3, device=batch.pos.device), + } + + +# ============================================================ +# 3. Quadratic steering potential +# ============================================================ + +POTENTIAL_K = 1.0 +POTENTIAL_CENTER = 2.0 + + +class QuadraticPotential: + """Quadratic potential U(x) = k/2 * (x - center)². + + Receives Cα positions in nm directly from the steering code. + """ + + def __init__(self, k: float = POTENTIAL_K, center: float = POTENTIAL_CENTER): + self.k = k + self.center = center + + def __call__(self, ca_pos_nm: torch.Tensor, *, t=None, sequence=None) -> torch.Tensor: + x = ca_pos_nm.reshape(ca_pos_nm.shape[0], -1)[:, 0] + return 0.5 * self.k * (x - self.center) ** 2 + + +# ============================================================ +# 4. Batch creation helper +# ============================================================ + + +def make_toy_batch(n_samples: int) -> Batch: + """Create a Batch of 1-residue ChemGraph objects for the 1D toy problem.""" + data_list = [] + for i in range(n_samples): + g = ChemGraph( + pos=torch.zeros(1, 3), + node_orientations=torch.eye(3).unsqueeze(0), + edge_index=torch.zeros(2, 0, dtype=torch.long), + single_embeds=torch.zeros(1, 1), + pair_embeds=torch.zeros(1, 1), + sequence="A", + system_id=f"toy_{i}", + node_labels=torch.zeros(1, dtype=torch.long), + ) + data_list.append(g) + return Batch.from_data_list(data_list) + + +# ============================================================ +# 5. Ground truth: biased distribution +# ============================================================ + +x_grid = torch.linspace(-6, 8, 1000) +t_zero = torch.zeros(len(x_grid)) + +with torch.no_grad(): + log_gmm = gmm.log_prob(x_grid.unsqueeze(-1), t_zero) + log_boltzmann = -0.5 * POTENTIAL_K * (x_grid - POTENTIAL_CENTER) ** 2 + + # Biased distribution: GMM(x) * exp(-U(x)) + log_biased = log_gmm + log_boltzmann + biased_unnorm = torch.exp(log_biased - log_biased.max()) + dx = x_grid[1] - x_grid[0] + biased_pdf = biased_unnorm / (biased_unnorm.sum() * dx) + + # Unbiased GMM for comparison + gmm_unnorm = torch.exp(log_gmm - log_gmm.max()) + gmm_pdf = gmm_unnorm / (gmm_unnorm.sum() * dx) + +assert torch.isclose( + biased_pdf.sum() * dx, torch.tensor(1.0), atol=1e-3 +), f"Biased PDF does not integrate to 1: {(biased_pdf.sum() * dx).item():.6f}" +assert torch.isclose( + gmm_pdf.sum() * dx, torch.tensor(1.0), atol=1e-3 +), f"GMM PDF does not integrate to 1: {(gmm_pdf.sum() * dx).item():.6f}" + +# ============================================================ +# 6. Run dpm_solver_fkc and plot +# ============================================================ + +N_SAMPLES = 20_000 +N_STEPS = 100 +NOISE_SCALE = 1.0 + +sdes = { + "pos": CosineVPSDE(), + "node_orientations": DiGSO3SDE(num_sigma=10, num_omega=10, l_max=10), +} + +score_model = GMMScoreWrapper(gmm, sde) +potential = QuadraticPotential(k=POTENTIAL_K, center=POTENTIAL_CENTER) +batch = make_toy_batch(N_SAMPLES) + +torch.manual_seed(42) +result_batch, log_weights = dpm_solver_fkc( + sdes=sdes, + batch=batch, + N=N_STEPS, + score_model=score_model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + fk_potentials=[potential], + steering_config={"num_particles": 100, "ess_threshold": 0.8, "start": 1.0, "end": 0.0}, + noise=NOISE_SCALE, + use_x0_for_reward=True, +) + +steered_samples = result_batch.pos[:, 0].detach().cpu() + +# --- Compute MAE between steered sample density and ground truth --- +import numpy as np + +# Build a density estimate on the same grid used for ground truth +bin_edges = np.linspace(x_grid[0].item(), x_grid[-1].item(), len(x_grid) + 1) +hist_counts, _ = np.histogram(steered_samples.numpy(), bins=bin_edges, density=True) +# hist_counts has len(x_grid) entries; align with grid centres +steered_density = torch.tensor(hist_counts, dtype=torch.float32) +mae = (steered_density - biased_pdf).abs().mean().item() +print(f"MAE between steered density and ground truth: {mae:.4f}") + +# --- Plot --- +fig, ax = plt.subplots(figsize=(10, 6)) +ax.hist( + steered_samples.numpy(), + bins=30, + density=True, + alpha=0.5, + color="steelblue", + label="FKC Steered Samples", +) +ax.plot(x_grid.numpy(), gmm_pdf.numpy(), "b--", linewidth=2, label="Unbiased GMM") +ax.plot( + x_grid.numpy(), + biased_pdf.numpy(), + "r-", + linewidth=2, + label=r"Ground Truth: GMM$(x) \times \exp(-U(x))$", +) + +# Potential on secondary y-axis +ax2 = ax.twinx() +U = 0.5 * POTENTIAL_K * (x_grid - POTENTIAL_CENTER) ** 2 +ax2.plot( + x_grid.numpy(), U.numpy(), "g-.", linewidth=1.5, alpha=0.7, label=r"$U(x)=\frac{k}{2}(x-c)^2$" +) +ax2.set_ylabel("Potential U(x)", color="green") +ax2.set_ylim(0, 50) +ax2.tick_params(axis="y", labelcolor="green") +ax2.legend(loc="upper left") +ax.set_xlabel("x") +ax.set_ylabel("Density") +ax.set_title( + f"FKC Steering with Quadratic Potential " + f"(k={POTENTIAL_K}, center={POTENTIAL_CENTER}, MAE={mae:.4f})" +) +ax.legend() +plt.tight_layout() +plt.savefig("fkc_steering_result.png", dpi=150) +plt.show() diff --git a/notebooks/gmm_umbrella_mbar.py b/notebooks/gmm_umbrella_mbar.py new file mode 100644 index 0000000..7dacc34 --- /dev/null +++ b/notebooks/gmm_umbrella_mbar.py @@ -0,0 +1,285 @@ +"""Multi-window FKC umbrella sampling + MBAR PMF reconstruction on a 1D GMM. + +Mirrors bioemu2/enhanced_sampling_paper/scripts/gmm/gmm_toy_example.py but +uses the current bioemu.steering stack directly. Runs a set of umbrella +windows with FKC steering, saves each window's samples, then combines them +with FastMBAR to recover the unbiased PMF. + +Run: + python notebooks/gmm_umbrella_mbar.py +Outputs saved next to this script: + gmm_mbar_windows.png — per-window histograms vs theoretical biased PDFs + gmm_mbar_pmf.png — MBAR PMF vs analytical PMF +""" + +import logging +import warnings +from pathlib import Path + +logging.basicConfig(level=logging.ERROR) +warnings.filterwarnings("ignore") + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +try: + from FastMBAR import FastMBAR +except ImportError: + raise ImportError( + "FastMBAR is required for this script. Install with `pip install FastMBAR`." + ) +from torch_geometric.data import Batch + +from bioemu.chemgraph import ChemGraph +from bioemu.sde_lib import CosineVPSDE +from bioemu.so3_sde import DiGSO3SDE +from bioemu.steering.dpm_fkc import dpm_solver_fkc +from bioemu.toy_gmm import TimeDependentGMM1D + + +# --------------------- GMM / score wrapper --------------------------- # + +MU1, MU2 = -1.0, 2.0 +SIGMA1, SIGMA2 = 1.0, 0.5 +WEIGHT1 = 0.9 + +sde = CosineVPSDE() +gmm = TimeDependentGMM1D( + mu1=torch.tensor([MU1]), + mu2=torch.tensor([MU2]), + sigma1=SIGMA1, + sigma2=SIGMA2, + weight1=WEIGHT1, + scheduler=sde, +) + + +class GMMScoreWrapper(nn.Module): + def __init__(self, gmm: TimeDependentGMM1D, pos_sde: CosineVPSDE): + super().__init__() + self.gmm = gmm + self.pos_sde = pos_sde + + def forward(self, batch, t): + x = batch.pos[:, 0:1] + t_per_node = t[batch.batch] + score_1d = self.gmm.score(x, t_per_node, create_graph=batch.pos.requires_grad) + _, pos_std = self.pos_sde.marginal_prob( + x=torch.ones_like(batch.pos), t=t, batch_idx=batch.batch + ) + zero_yz = torch.zeros( + batch.pos.shape[0], 2, device=batch.pos.device, dtype=batch.pos.dtype + ) + pos_output = torch.cat([score_1d * pos_std[:, 0:1], zero_yz], dim=1) + return { + "pos": pos_output, + "node_orientations": torch.zeros(batch.pos.shape[0], 3, device=batch.pos.device), + } + + +# --------------------- Umbrella potential ---------------------------- # + +SLOPE = 0.5 # matches bioemu2 toy +ORDER = 2.0 + + +class UmbrellaPotential1D: + """V(x) = (slope * (x - target)) ** order. Matches bioemu2 CVUmbrellaPotential.""" + + def __init__(self, target: float, slope: float = SLOPE, order: float = ORDER): + self.target = target + self.slope = slope + self.order = order + + def __call__(self, ca_pos_nm, *, t=None, sequence=None): + x = ca_pos_nm.reshape(ca_pos_nm.shape[0], -1)[:, 0] + return (self.slope * (x - self.target)).pow(self.order) + + def energy_np(self, x: np.ndarray) -> np.ndarray: + return (self.slope * (x - self.target)) ** self.order + + +# --------------------- Batch helper ---------------------------------- # + + +def make_toy_batch(n: int) -> Batch: + return Batch.from_data_list( + [ + ChemGraph( + pos=torch.zeros(1, 3), + node_orientations=torch.eye(3).unsqueeze(0), + edge_index=torch.zeros(2, 0, dtype=torch.long), + single_embeds=torch.zeros(1, 1), + pair_embeds=torch.zeros(1, 1), + sequence="A", + system_id=f"toy_{i}", + node_labels=torch.zeros(1, dtype=torch.long), + ) + for i in range(n) + ] + ) + + +# --------------------- Run a single window --------------------------- # + + +def run_window( + target: float, + n_samples: int = 2000, + num_particles: int = 100, + n_steps: int = 100, + noise: float = 1.0, + seed: int = 42, + use_x0_for_reward: bool = True, +) -> np.ndarray: + sdes = { + "pos": CosineVPSDE(), + "node_orientations": DiGSO3SDE(num_sigma=10, num_omega=10, l_max=10), + } + score_model = GMMScoreWrapper(gmm, sde) + potential = UmbrellaPotential1D(target=target) + batch = make_toy_batch(n_samples) + + torch.manual_seed(seed) + result_batch, _ = dpm_solver_fkc( + sdes=sdes, + batch=batch, + N=n_steps, + score_model=score_model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + fk_potentials=[potential], + steering_config={ + "num_particles": num_particles, + "ess_threshold": 0.5, + "start": 1.0, + "end": 0.0, + }, + noise=noise, + use_x0_for_reward=use_x0_for_reward, + ) + return result_batch.pos[:, 0].detach().cpu().numpy() + + +# --------------------- Main -------------------------------------------- # + + +def main(): + out_dir = Path(__file__).parent + centers = np.linspace(-5.0, 5.0, 10) + + window_samples: dict[float, np.ndarray] = {} + potentials: list[UmbrellaPotential1D] = [] + print(f"Running {len(centers)} umbrella windows...") + for c in centers: + print(f" window center = {c:+.2f}") + samples = run_window(target=float(c)) + window_samples[float(c)] = samples + potentials.append(UmbrellaPotential1D(target=float(c))) + + # ---------------- Plot per-window histograms vs theoretical ------- # + x_grid = np.linspace(-6, 8, 2000) + dx = x_grid[1] - x_grid[0] + + # Unbiased GMM PDF (analytical) + def gmm_pdf(x): + p1 = np.exp(-0.5 * ((x - MU1) / SIGMA1) ** 2) / (SIGMA1 * np.sqrt(2 * np.pi)) + p2 = np.exp(-0.5 * ((x - MU2) / SIGMA2) ** 2) / (SIGMA2 * np.sqrt(2 * np.pi)) + return WEIGHT1 * p1 + (1 - WEIGHT1) * p2 + + p_unb = gmm_pdf(x_grid) + + n_cols = 5 + n_rows = (len(centers) + n_cols - 1) // n_cols + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows)) + axes = axes.flatten() + for idx, c in enumerate(centers): + ax = axes[idx] + pot = UmbrellaPotential1D(target=float(c)) + V = pot.energy_np(x_grid) + biased = p_unb * np.exp(-V) + biased /= np.trapezoid(biased, x_grid) + + samples = window_samples[float(c)] + ax.hist(samples, bins=50, density=True, alpha=0.5, color="steelblue", label="FKC") + ax.plot(x_grid, biased, "r-", lw=2, label="theory") + ax.set_title(f"center = {c:+.2f}") + ax.set_xlim(-6, 8) + if idx == 0: + ax.legend() + for idx in range(len(centers), len(axes)): + axes[idx].axis("off") + fig.suptitle("Per-window FKC samples vs theoretical biased PDF") + fig.tight_layout() + fig.savefig(out_dir / "gmm_mbar_windows.png", dpi=150) + print(f"Saved {out_dir / 'gmm_mbar_windows.png'}") + + # ---------------- MBAR PMF reconstruction ------------------------- # + # Build reduced energy matrix u_kn = beta * V_k(x_n). beta = 1. + all_samples = np.concatenate([window_samples[float(c)] for c in centers]) + num_conf = np.array([len(window_samples[float(c)]) for c in centers], dtype=np.int64) + K = len(centers) + N_total = len(all_samples) + u_kn = np.zeros((K, N_total), dtype=np.float64) + for k, c in enumerate(centers): + u_kn[k] = UmbrellaPotential1D(target=float(c)).energy_np(all_samples) + + print("Solving MBAR...") + fastmbar = FastMBAR(energy=u_kn, num_conf=num_conf, cuda=False, verbose=False) + + # PMF via bin-wise dummy states (standard FastMBAR PMF pattern). + pmf_bins = np.linspace(-5.5, 5.5, 50) + bin_centers = 0.5 * (pmf_bins[:-1] + pmf_bins[1:]) + n_pmf_bins = len(bin_centers) + + # u_pmf[b, n] = 0 if x_n in bin b else +inf + u_pmf = np.full((n_pmf_bins, N_total), np.inf, dtype=np.float64) + for b in range(n_pmf_bins): + in_bin = (all_samples >= pmf_bins[b]) & (all_samples < pmf_bins[b + 1]) + u_pmf[b, in_bin] = 0.0 + + pmf_result = fastmbar.calculate_free_energies_of_perturbed_states(u_pmf) + F_pmf = np.asarray(pmf_result["F"]) + dF_pmf = np.asarray(pmf_result["F_std"]) + # Drop empty bins + finite = np.isfinite(F_pmf) + + # Analytical PMF from GMM: -log p(x) + analytical_pmf = -np.log(np.clip(p_unb, 1e-30, None)) + # Align reference: shift analytical PMF to zero at its min over the plotted range + mask_range = (x_grid >= bin_centers[finite].min()) & (x_grid <= bin_centers[finite].max()) + analytical_pmf_shifted = analytical_pmf - analytical_pmf[mask_range].min() + F_pmf_shifted = F_pmf - F_pmf[finite].min() + + fig2, ax2 = plt.subplots(figsize=(8, 5)) + ax2.errorbar( + bin_centers[finite], + F_pmf_shifted[finite], + yerr=dF_pmf[finite], + fmt="o-", + color="tab:red", + label="MBAR (FKC + umbrella windows)", + capsize=3, + ) + ax2.plot(x_grid, analytical_pmf_shifted, "k--", lw=2, label="Analytical: −log p(x)") + ax2.set_xlabel("x") + ax2.set_ylabel("Free energy (−log p)") + ax2.set_title(f"MBAR PMF from {K} FKC umbrella windows (GMM toy)") + ax2.set_xlim(-6, 6) + ax2.set_ylim(0, 10) + ax2.legend() + ax2.grid(True, alpha=0.3) + fig2.tight_layout() + fig2.savefig(out_dir / "gmm_mbar_pmf.png", dpi=150) + print(f"Saved {out_dir / 'gmm_mbar_pmf.png'}") + + # Report MAE on the valid bins + F_interp = np.interp(bin_centers[finite], x_grid, analytical_pmf_shifted) + mae = np.mean(np.abs(F_pmf_shifted[finite] - F_interp)) + print(f"MBAR PMF vs analytical MAE (on valid bins): {mae:.4f}") + + +if __name__ == "__main__": + main() diff --git a/src/bioemu/config/steering/cv_steer.yaml b/src/bioemu/config/steering/cv_steer.yaml new file mode 100644 index 0000000..8abfc7c --- /dev/null +++ b/src/bioemu/config/steering/cv_steer.yaml @@ -0,0 +1,26 @@ +# FKC (Feynman-Kac Control) Steering Configuration +# Self-contained denoiser config — pass as denoiser_config to sample(). +_target_: bioemu.steering.dpm_fkc.dpm_solver_fkc +_partial_: true +eps_t: 0.001 +max_t: 0.99 +N: 100 +noise: 1.0 +use_x0_for_reward: true +fk_potentials: + - _target_: bioemu.steering.LinearPotential + target: 0.5 + slope: -7.4 + order: 1 + weight: 1.0 + clip_min: -0.5 + clip_max: 0.7 + cv: + _target_: bioemu.steering.RMSD + # Path to the reference PDB file; must be provided by the user (override via CLI). + reference_pdb: ??? +steering_config: + num_particles: 100 + ess_threshold: 0.7 + start: 1.0 + end: 0.0 \ No newline at end of file diff --git a/src/bioemu/config/steering/physical_steering.yaml b/src/bioemu/config/steering/physical_steering.yaml index 4c67a8d..8aa89da 100644 --- a/src/bioemu/config/steering/physical_steering.yaml +++ b/src/bioemu/config/steering/physical_steering.yaml @@ -1,18 +1,32 @@ -num_particles: 5 -start: 0.1 -end: 0.0 -resampling_interval: 5 -potentials: - chainbreak: - _target_: bioemu.steering.ChainBreakPotential - flatbottom: 1. - slope: 1. +# Physical steering: chain break + chain clash potentials +# Self-contained denoiser config — pass as denoiser_config to sample(). +_target_: bioemu.steering.dpm_smc.dpm_solver_smc +_partial_: true +eps_t: 0.001 +max_t: 0.99 +N: 100 +noise: 0.5 +fk_potentials: + - _target_: bioemu.steering.UmbrellaPotential + cv: + _target_: bioemu.steering.CaCaDistance + target: 0.38 + flatbottom: 0.1 + slope: 10.0 order: 1 - linear_from: 1. + linear_from: 0.1 weight: 1.0 - chainclash: - _target_: bioemu.steering.ChainClashPotential - flatbottom: 0. - dist: 4.1 - slope: 3. + - _target_: bioemu.steering.UmbrellaPotential + cv: + _target_: bioemu.steering.PairwiseClash + min_dist: 0.41 + offset: 3 + target: 0.0 + flatbottom: 0.0 + slope: 30.0 weight: 1.0 +steering_config: + num_particles: 5 + ess_threshold: 0.5 + start: 0.1 + end: 0.0 diff --git a/src/bioemu/denoiser.py b/src/bioemu/denoiser.py index 2b63da6..31bd0db 100644 --- a/src/bioemu/denoiser.py +++ b/src/bioemu/denoiser.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import logging -from collections.abc import Callable +from dataclasses import dataclass from typing import cast import numpy as np @@ -12,7 +12,6 @@ from bioemu.chemgraph import ChemGraph from bioemu.sde_lib import SDE, CosineVPSDE from bioemu.so3_sde import SO3SDE, apply_rotvec_to_rotmat -from bioemu.steering import get_pos0_rot0, resample_batch logger = logging.getLogger(__name__) @@ -263,6 +262,274 @@ def _t_from_lambda(sde: CosineVPSDE, lambda_t: torch.Tensor) -> torch.Tensor: return t_lambda +# ============================================================================= +# DPM-Solver Helper Data Classes and Functions +# ============================================================================= + + +@dataclass +class DPMCoefficients: + """Coefficients for DPM-Solver++ step.""" + + alpha_t: torch.Tensor + sigma_t: torch.Tensor + alpha_t_next: torch.Tensor + sigma_t_next: torch.Tensor + alpha_t_lambda: torch.Tensor + sigma_t_lambda: torch.Tensor + h_t: torch.Tensor # lambda_t_next - lambda_t + t_lambda: torch.Tensor # midpoint time + + +@dataclass +class GuidedScore: + """Guided score with optional reward information.""" + + pos: torch.Tensor + so3: torch.Tensor + reward: torch.Tensor + reward_grad_t: torch.Tensor # gradient of reward w.r.t. time (for FKC weights) + x0: torch.Tensor + raw_score: dict[str, torch.Tensor] # original unguided scores + + +def _get_dpm_coefficients( + pos_sde: CosineVPSDE, + batch_pos: torch.Tensor, + t: torch.Tensor, + t_next: torch.Tensor, + batch_idx: torch.LongTensor, +) -> DPMCoefficients: + """Compute all DPM-Solver coefficients for a step from t to t_next.""" + alpha_t, sigma_t = pos_sde.mean_coeff_and_std(x=batch_pos, t=t, batch_idx=batch_idx) + lambda_t = torch.log(alpha_t / sigma_t) + + alpha_t_next, sigma_t_next = pos_sde.mean_coeff_and_std( + x=batch_pos, t=t_next, batch_idx=batch_idx + ) + lambda_t_next = torch.log(alpha_t_next / sigma_t_next) + + h_t = lambda_t_next - lambda_t + + # Compute midpoint time t_lambda + lambda_t_middle = (lambda_t + lambda_t_next) / 2 + t_lambda = _t_from_lambda(sde=pos_sde, lambda_t=lambda_t_middle) + t_lambda = torch.full_like(t, t_lambda[0][0]) + + alpha_t_lambda, sigma_t_lambda = pos_sde.mean_coeff_and_std( + x=batch_pos, t=t_lambda, batch_idx=batch_idx + ) + + return DPMCoefficients( + alpha_t=alpha_t, + sigma_t=sigma_t, + alpha_t_next=alpha_t_next, + sigma_t_next=sigma_t_next, + alpha_t_lambda=alpha_t_lambda, + sigma_t_lambda=sigma_t_lambda, + h_t=h_t, + t_lambda=t_lambda, + ) + + +def _so3_step( + so3_sde: SO3SDE, + node_orientations: torch.Tensor, + score: torch.Tensor, + t: torch.Tensor, + dt: torch.Tensor, + batch_idx: torch.LongTensor, + noise_weight: float, +) -> tuple[torch.Tensor, torch.Tensor]: + """Unified SO3 step combining drift computation and update. + + Args: + so3_sde: The SO3 SDE for orientations. + node_orientations: Current orientations (N, 3, 3). + score: Score at current time (N, 3). + t: Current time. + dt: Time step (negative for reverse). + batch_idx: Batch indices. + noise_weight: Scale for stochastic noise. + + Returns: + (sample, mean): Updated orientations and deterministic mean. + """ + predictor = EulerMaruyamaPredictor( + corruption=so3_sde, noise_weight=noise_weight, marginal_concentration_factor=1.0 + ) + + drift, diffusion = predictor.reverse_drift_and_diffusion( + x=node_orientations, score=score, t=t, batch_idx=batch_idx + ) + + sample, mean = predictor.update_given_drift_and_diffusion( + x=node_orientations, drift=drift, diffusion=diffusion, dt=dt + ) + + return sample, mean + + +def _predict_midpoint( + batch: Batch, + coeffs: DPMCoefficients, + score_pos: torch.Tensor, + score_so3: torch.Tensor, + so3_sde: SO3SDE, + t: torch.Tensor, + batch_idx: torch.LongTensor, +) -> Batch: + """First-order prediction step from t to t_lambda (midpoint). + + Updates both pos and node_orientations to the intermediate time. + No noise is applied (deterministic prediction to midpoint). + + Args: + score_pos: Position score to use (can be guided or scaled unsteered). + score_so3: SO3 score to use. + """ + # Position update (DPM-Solver first-order to midpoint) + u = ( + coeffs.alpha_t_lambda / coeffs.alpha_t * batch.pos + + coeffs.sigma_t_lambda * coeffs.sigma_t * (torch.exp(coeffs.h_t / 2) - 1) * score_pos + ) + + # SO3 update (drift only, no noise for midpoint prediction) + dt_lambda = coeffs.t_lambda[0] - t[0] + so3_sample, _ = _so3_step( + so3_sde=so3_sde, + node_orientations=batch.node_orientations, + score=score_so3, + t=t, + dt=dt_lambda, + batch_idx=batch_idx, + noise_weight=0.0, + ) + + return batch.replace(pos=u, node_orientations=so3_sample) + + +def second_order_step_dpmsolver_plusplus( + batch: Batch, + coeffs: DPMCoefficients, + scaled_score_pos_t: torch.Tensor, + scaled_score_pos_lambda: torch.Tensor, + score_so3_t: torch.Tensor, + score_so3_lambda: torch.Tensor, + so3_sde: SO3SDE, + t: torch.Tensor, + t_next: torch.Tensor, + batch_idx: torch.LongTensor, + noise_weight: float, + pos_noise: torch.Tensor | None = None, +) -> tuple[Batch, torch.Tensor]: + """Second-order DPM-Solver++ update from t to t_next. + + Uses scores at both t and t_lambda for higher accuracy. + Position scores should be PRE-SCALED by the caller: + - For FKC: use (1 + a^2) / 2 * score + (a^2) / 2 * reward_grad + + Args: + scaled_score_pos_t: Pre-scaled position score at time t. + scaled_score_pos_lambda: Pre-scaled position score at midpoint t_lambda. + score_so3_t: SO3 score at time t (no scaling). + score_so3_lambda: SO3 score at midpoint t_lambda (no scaling). + noise_weight: Scale for stochastic noise. + pos_noise: Optional pre-generated position noise. + + Returns: + (batch_next, pos_noise): Updated batch and position noise used. + """ + # Generate noise if not provided + if pos_noise is None: + pos_noise = torch.randn_like(batch.pos) + + # Position update (DPM-Solver++ 2nd order) + pos_next = ( + coeffs.alpha_t_next / coeffs.alpha_t * batch.pos + + 2 + * coeffs.sigma_t_next + * coeffs.sigma_t + * (torch.exp(coeffs.h_t) - 1) + * scaled_score_pos_t + + 2 + * coeffs.sigma_t_next + * (torch.exp(coeffs.h_t) - 1) + * torch.exp(-coeffs.h_t / 2) + * (coeffs.sigma_t_lambda * scaled_score_pos_lambda - coeffs.sigma_t * scaled_score_pos_t) + + noise_weight * coeffs.sigma_t_next * torch.sqrt(torch.exp(2 * coeffs.h_t) - 1) * pos_noise + ) + + # SO3 update with 2nd-order score extrapolation (deterministic — noise_weight=0) + dt_hat = t_next[0] - t[0] + dt_lambda = coeffs.t_lambda[0] - t[0] + score_correction = 0.5 * (score_so3_lambda - score_so3_t) / dt_lambda * dt_hat + + so3_sample, _ = _so3_step( + so3_sde=so3_sde, + node_orientations=batch.node_orientations, + score=score_so3_lambda + score_correction, + t=coeffs.t_lambda, + dt=dt_hat, + batch_idx=batch_idx, + noise_weight=0.0, + ) + + return batch.replace(pos=pos_next, node_orientations=so3_sample), pos_noise + + +def second_order_step_dpmsolver( + batch: Batch, + coeffs: DPMCoefficients, + score_pos_lambda: torch.Tensor, + score_so3_t: torch.Tensor, + score_so3_lambda: torch.Tensor, + so3_sde: SO3SDE, + t: torch.Tensor, + t_next: torch.Tensor, + batch_idx: torch.LongTensor, +) -> Batch: + """DPM-Solver 2nd-order ODE step using midpoint score only (no noise). + + This implements the DPM-Solver-2 update (Algorithm 1 in https://arxiv.org/abs/2206.00927), + which uses only the score evaluated at the midpoint t_lambda for the position update. + Used by the unsteered `dpm_solver` loop. + + For the DPM-Solver++ SDE variant (two scores + noise), see + `second_order_step_dpmsolver_plusplus`. + + Args: + score_pos_lambda: Position score at midpoint t_lambda (unscaled). + score_so3_t: SO3 score at time t. + score_so3_lambda: SO3 score at midpoint t_lambda. + """ + # Position: midpoint-only formula + pos_next = ( + coeffs.alpha_t_next / coeffs.alpha_t * batch.pos + + coeffs.sigma_t_next + * coeffs.sigma_t_lambda + * (torch.exp(coeffs.h_t) - 1) + * score_pos_lambda + ) + + # SO3: 2nd-order score extrapolation (same as dpmsolver_plusplus, deterministic) + dt_hat = t_next[0] - t[0] + dt_lambda = coeffs.t_lambda[0] - t[0] + score_correction = 0.5 * (score_so3_lambda - score_so3_t) / dt_lambda * dt_hat + + so3_sample, _ = _so3_step( + so3_sde=so3_sde, + node_orientations=batch.node_orientations, + score=score_so3_lambda + score_correction, + t=coeffs.t_lambda, + dt=dt_hat, + batch_idx=batch_idx, + noise_weight=0.0, + ) + + return batch.replace(pos=pos_next, node_orientations=so3_sample) + + def dpm_solver( sdes: dict[str, SDE], batch: Batch, @@ -273,29 +540,21 @@ def dpm_solver( device: torch.device, record_grad_steps: set[int] = set(), noise: float = 0.0, - fk_potentials: list[Callable] | None = None, - steering_config: dict | None = None, ) -> Batch: """ Implements the DPM solver for the VPSDE, with the Cosine noise schedule. Following this paper: https://arxiv.org/abs/2206.00927 Algorithm 1 DPM-Solver-2. - DPM solver is used only for positions, not node orientations. - Args: - steering_config: Configuration dictionary for steering. Can include: - - guidance_strength: Controls the strength of guidance steering (default: 3.0) - - Other steering parameters (start, end, num_particles, etc.) + This is the unsteered denoiser. For steered sampling (FKC/SMC), use + dpm_solver_fkc or dpm_solver_smc from the steering package. """ grad_is_enabled = torch.is_grad_enabled() assert isinstance(batch, Batch) assert max_t < 1.0 - if steering_config is not None: - assert noise > 0, "Steering requires noise > 0 for stochastic sampling" batch = batch.to(device) if isinstance(score_model, torch.nn.Module): - # permits unit-testing with dummy model score_model = score_model.to(device) pos_sde = sdes["pos"] assert isinstance(pos_sde, CosineVPSDE) @@ -306,13 +565,13 @@ def dpm_solver( batch.node_orientations.shape, device=device ), ) - batch = cast(ChemGraph, batch) # help out mypy/linter + batch = cast(ChemGraph, batch) so3_sde = sdes["node_orientations"] assert isinstance(so3_sde, SO3SDE) so3_sde.to(device) - timesteps = torch.linspace(max_t, eps_t, N, device=device) # 1 -> 0 + timesteps = torch.linspace(max_t, eps_t, N, device=device) dt = -torch.tensor((max_t - eps_t) / (N - 1)).to(device) ts_min = 0.0 ts_max = 1.0 @@ -323,16 +582,13 @@ def dpm_solver( ) for name, sde in sdes.items() } - previous_energy = None - - # Initialize log_weights for importance weight tracking (for gradient guidance) - log_weights = torch.zeros(batch.num_graphs, device=device) for i in tqdm(range(N - 1), position=1, desc="Denoising: ", ncols=0, leave=False): t = torch.full((batch.num_graphs,), timesteps[i], device=device) t_hat = t - noise * dt if (i > 0 and t[0] > ts_min and t[0] < ts_max) else t + t_next = t + dt - # Apply noise. + # Pre-step noise injection (Karras/Heun style) vals_hat = {} for field in fields: vals_hat[field] = noisers[field].forward_sde_step( @@ -340,141 +596,42 @@ def dpm_solver( )[0] batch_hat = batch.replace(**vals_hat) - # Evaluate score + # Evaluate score at (possibly noised) state with torch.set_grad_enabled(grad_is_enabled and (i in record_grad_steps)): score = get_score(batch=batch_hat, t=t_hat, score_model=score_model, sdes=sdes) - # t_{i-1} in the algorithm is the current t batch_idx = batch_hat.batch - alpha_t, sigma_t = pos_sde.mean_coeff_and_std(x=batch.pos, t=t_hat, batch_idx=batch_idx) - lambda_t = torch.log(alpha_t / sigma_t) - alpha_t_next, sigma_t_next = pos_sde.mean_coeff_and_std( - x=batch.pos, t=t + dt, batch_idx=batch_idx - ) - lambda_t_next = torch.log(alpha_t_next / sigma_t_next) - - # t+dt < t_hat, lambad_t_next > lambda_t - h_t = lambda_t_next - lambda_t - # For a given noise schedule (cosine is what we use), compute the intermediate t_lambda - lambda_t_middle = (lambda_t + lambda_t_next) / 2 - t_lambda = _t_from_lambda(sde=pos_sde, lambda_t=lambda_t_middle) + # Coefficients + midpoint prediction + coeffs = _get_dpm_coefficients(pos_sde, batch_hat.pos, t_hat, t_next, batch_idx) - # t_lambda has all the same components - t_lambda = torch.full((batch.num_graphs,), t_lambda[0][0], device=device) - - alpha_t_lambda, sigma_t_lambda = pos_sde.mean_coeff_and_std( - x=batch.pos, t=t_lambda, batch_idx=batch_idx - ) - # Note in the paper the algorithm uses noise instead of score, but we use score. - # So the formulation is slightly different in the prefactor. - u = ( - alpha_t_lambda / alpha_t * batch_hat.pos - + sigma_t_lambda * sigma_t * (torch.exp(h_t / 2) - 1) * score["pos"] - ) - - # Update positions to the intermediate timestep t_lambda - batch_u = batch.replace(pos=u) - # Denoise from t to t_lambda - assert score["node_orientations"].shape == (u.shape[0], 3) - assert batch.node_orientations.shape == (u.shape[0], 3, 3) - so3_predictor = EulerMaruyamaPredictor( - corruption=so3_sde, noise_weight=0.0, marginal_concentration_factor=1.0 - ) - drift, _ = so3_predictor.reverse_drift_and_diffusion( - x=batch_hat.node_orientations, - score=score["node_orientations"], + batch_lambda = _predict_midpoint( + batch=batch_hat, + coeffs=coeffs, + score_pos=score["pos"], + score_so3=score["node_orientations"], + so3_sde=so3_sde, t=t_hat, batch_idx=batch_idx, ) - sample, _ = so3_predictor.update_given_drift_and_diffusion( - x=batch_hat.node_orientations, - drift=drift, - diffusion=0.0, - dt=t_lambda[0] - t_hat[0], - ) # dt is negative, diffusion is 0 - assert sample.shape == (u.shape[0], 3, 3) - batch_u = batch_u.replace(node_orientations=sample) - - # Correction step - # Evaluate score at updated pos and node orientations + + # Correction step: evaluate score at midpoint with torch.set_grad_enabled(grad_is_enabled and (i in record_grad_steps)): - score_u = get_score(batch=batch_u, t=t_lambda, sdes=sdes, score_model=score_model) + score_lambda = get_score( + batch=batch_lambda, t=coeffs.t_lambda, sdes=sdes, score_model=score_model + ) - pos_next = ( - alpha_t_next / alpha_t * batch_hat.pos - + sigma_t_next * sigma_t_lambda * (torch.exp(h_t) - 1) * score_u["pos"] - ) - batch_next = batch.replace(pos=pos_next) - - assert score_u["node_orientations"].shape == (u.shape[0], 3) - - # Try a 2nd order correction - dt_hat = t + dt - t_hat - node_score = ( - score_u["node_orientations"] - + 0.5 - * (score_u["node_orientations"] - score["node_orientations"]) - / (t_lambda[0] - t_hat[0]) - * dt_hat[0] - ) - drift, diffusion = so3_predictor.reverse_drift_and_diffusion( - x=batch_u.node_orientations, - score=node_score, - t=t_lambda, + # DPM-Solver 2nd-order ODE step (midpoint-only formula) + batch = second_order_step_dpmsolver( + batch=batch_hat, + coeffs=coeffs, + score_pos_lambda=score_lambda["pos"], + score_so3_t=score["node_orientations"], + score_so3_lambda=score_lambda["node_orientations"], + so3_sde=so3_sde, + t=t_hat, + t_next=t_next, batch_idx=batch_idx, ) - sample, _ = so3_predictor.update_given_drift_and_diffusion( - x=batch_hat.node_orientations, - drift=drift, - diffusion=0.0, - dt=dt_hat[0], - ) # dt is negative, diffusion is 0 - batch = batch_next.replace(node_orientations=sample) - - if ( - steering_config is not None and fk_potentials is not None - ): # steering enabled when steering_config is provided - # Compute predicted x0 and R0 from current state and score - # x0_t: predicted positions, shape (batch_size, seq_length, 3), differs from batch.pos which is (batch_size * seq_length, 3) - # R0_t: predicted rotations, shape (batch_size, seq_length, 3, 3) - denoised_x0_t, denoised_R0_t = get_pos0_rot0( - sdes=sdes, batch=batch, t=t, score=score - ) # batch -> x0_t:(batch_size, seq_length, 3), R0_t:(batch_size, seq_length, 3, 3) - - energies = [] - for potential_ in fk_potentials: - energies += [potential_(10 * denoised_x0_t, i=i, N=N)] - total_energy = torch.stack(energies, dim=-1).sum(-1) # [BS] - - if steering_config["num_particles"] > 1: - # Resample between particles ... - if ( - steering_config["start"] >= timesteps[i] >= steering_config["end"] - and i % steering_config["resampling_interval"] == 0 - and i < N - 2 - ): - batch, total_energy, log_weights = resample_batch( - batch=batch, - num_particles=steering_config["num_particles"], - energy=total_energy, - previous_energy=previous_energy, - log_weights=log_weights, - ) - previous_energy = total_energy - - # ... or a single final sample - elif i >= N - 2: # The last step is N-2 - logger.info( - "Final Resampling [BS, FK_particles] back to BS, with real x0 instead of pred x0." - ) - batch, total_energy, log_weights = resample_batch( - batch=batch, - num_particles=steering_config["num_particles"], - energy=total_energy, - previous_energy=previous_energy, - log_weights=log_weights, - ) - previous_energy = total_energy return batch diff --git a/src/bioemu/sample.py b/src/bioemu/sample.py index c4283e9..f17e86f 100644 --- a/src/bioemu/sample.py +++ b/src/bioemu/sample.py @@ -13,7 +13,7 @@ import numpy as np import torch import yaml -from omegaconf import DictConfig, OmegaConf +from omegaconf import DictConfig from torch_geometric.data.batch import Batch from tqdm import tqdm @@ -23,7 +23,6 @@ from bioemu.model_utils import load_model, load_sdes, maybe_download_checkpoint from bioemu.sde_lib import SDE from bioemu.seq_io import check_protein_valid, parse_sequence, write_fasta -from bioemu.steering import log_physicality from bioemu.utils import ( count_samples_in_output_dir, format_npz_samples_filename, @@ -33,7 +32,6 @@ logger = logging.getLogger(__name__) DEFAULT_DENOISER_CONFIG_DIR = Path(__file__).parent / "config/denoiser/" -DEFAULT_STEERING_CONFIG_DIR = Path(__file__).parent / "config/steering/" SupportedDenoisersLiteral = Literal["dpm", "heun"] SUPPORTED_DENOISERS = list(typing.get_args(SupportedDenoisersLiteral)) @@ -83,7 +81,6 @@ def main( cache_so3_dir: str | Path | None = None, msa_host_url: str | None = None, filter_samples: bool = True, - steering_config: str | Path | dict | None = None, base_seed: int | None = None, ) -> None: """ @@ -104,19 +101,14 @@ def main( ckpt_path: Path to the model checkpoint. If this is set, `model_name` will be ignored. model_config_path: Path to the model config, defining score model architecture and the corruption process the model was trained with. Only required if `ckpt_path` is set. - denoiser_type: Denoiser to use for sampling, if `denoiser_config_path` not specified. Comes in with default parameter configuration. Must be one of ['dpm', 'heun'] - denoiser_config_path: Path to the denoiser config, defining the denoising process. + denoiser_type: Denoiser to use for sampling, if `denoiser_config` not specified. Comes in with default parameter configuration. Must be one of ['dpm', 'heun'] + denoiser_config: Path (str or :class:`os.PathLike`) to a denoiser config YAML, or a dict. For steered sampling (FKC/SMC), + pass a steering config (e.g., config/steering/physical_steering.yaml) which includes + the denoiser target, potentials, and steering parameters in one file. cache_embeds_dir: Directory to store MSA embeddings. If not set, this defaults to `COLABFOLD_DIR/embeds_cache`. cache_so3_dir: Directory to store SO3 precomputations. If not set, this defaults to `~/sampling_so3_cache`. msa_host_url: MSA server URL. If not set, this defaults to colabfold's remote server. If sequence is an a3m file, this is ignored. filter_samples: Filter out unphysical samples with e.g. long bond distances or steric clashes. - steering_config: Path to steering config YAML, or a dict containing steering parameters. - Can be None to disable steering (num_particles=1). The config should contain: - - num_particles: Number of particles per sample (>1 enables steering) - - start: Start time for steering (0.0-1.0) - - end: End time for steering (0.0-1.0) - - resampling_interval: Resampling interval - - potentials: Dict of potential configurations base_seed: Base random seed for sampling. If set, each batch's seed will be set to base_seed + (num samples already generated). """ @@ -127,58 +119,6 @@ def main( output_dir = Path(output_dir).expanduser().resolve() output_dir.mkdir(parents=True, exist_ok=True) # Fail fast if output_dir is non-writeable - # Steering config can be [None, [str/Path], [dict/DictConfig]] - if steering_config is None: - # No steering - will pass None to denoiser - steering_config_dict = None - potentials = None - elif isinstance(steering_config, str | Path): - # Path to steering config YAML - steering_config_path = Path(steering_config).expanduser().resolve() - if not steering_config_path.is_absolute(): - # Try relative to DEFAULT_STEERING_CONFIG_DIR - steering_config_path = DEFAULT_STEERING_CONFIG_DIR / steering_config - - assert ( - steering_config_path.is_file() - ), f"steering_config path '{steering_config_path}' does not exist or is not a file." - - with open(steering_config_path) as f: - steering_config_dict = yaml.safe_load(f) - elif isinstance(steering_config, dict | DictConfig): - # Already a dict/DictConfig - steering_config_dict = ( - OmegaConf.to_container(steering_config, resolve=True) - if isinstance(steering_config, DictConfig) - else steering_config - ) - else: - raise ValueError( - f"steering_config must be None, a path to a YAML file, or a dict, but got {type(steering_config)}" - ) - - if steering_config_dict is not None: - # If steering is enabled by defining a minimum of two particles, extract potentials and create config - - # Extract potentials configuration - potentials_config = steering_config_dict["potentials"] - - # Instantiate potentials - potentials = hydra.utils.instantiate(OmegaConf.create(potentials_config)) - potentials: list[Callable] = list(potentials.values()) # type: ignore - - # Create final steering config (without potentials, those are passed separately) - # Remove 'potentials' from steering_config_dict if present - steering_config_dict = dict(steering_config_dict) # ensure mutable copy - steering_config_dict.pop("potentials") - # Validate steering times for reverse diffusion start: t=1 to end: t=0 - assert ( - 0.0 <= steering_config_dict["end"] <= steering_config_dict["start"] <= 1.0 - ), f"Steering end ({steering_config_dict['end']}) must be between 0.0 and 1.0 and start ({steering_config_dict['start']}) must be between 0.0 and 1.0" - - else: - potentials = None - ckpt_path, model_config_path = maybe_download_checkpoint( model_name=model_name, ckpt_path=ckpt_path, model_config_path=model_config_path ) @@ -216,7 +156,7 @@ def main( denoiser_config = DEFAULT_DENOISER_CONFIG_DIR / f"{denoiser_type}.yaml" with open(denoiser_config) as f: denoiser_config = yaml.safe_load(f) - elif type(denoiser_config) is str: + elif isinstance(denoiser_config, str | Path): # path to denoiser config denoiser_config_path = Path(denoiser_config).expanduser().resolve() assert ( @@ -265,8 +205,6 @@ def main( cache_embeds_dir=cache_embeds_dir, msa_file=msa_file, msa_host_url=msa_host_url, - fk_potentials=potentials, - steering_config=steering_config_dict, ) batch = {k: v.cpu().numpy() for k, v in batch.items()} @@ -281,7 +219,6 @@ def main( node_orientations = torch.tensor( np.concatenate([np.load(f)["node_orientations"] for f in samples_files]) ) - log_physicality(positions, node_orientations, sequence) save_pdb_and_xtc( pos_nm=positions, node_orientations=node_orientations, @@ -350,8 +287,6 @@ def generate_batch( cache_embeds_dir: str | Path | None, msa_file: str | Path | None = None, msa_host_url: str | None = None, - fk_potentials: list[Callable] | None = None, - steering_config: dict | None = None, ) -> dict[str, torch.Tensor]: """Generate one batch of samples, using GPU if available. @@ -359,9 +294,11 @@ def generate_batch( score_model: Score model. sequence: Amino acid sequence. sdes: SDEs defining corruption process. Keys should be 'node_orientations' and 'pos'. - embeddings_file: Path to embeddings file. batch_size: Batch size. seed: Random seed. + denoiser: Denoiser callable (already configured via Hydra). For steered + sampling, this is a partial of dpm_solver_fkc or dpm_solver_smc with + potentials and steering_config already bound. msa_file: Optional path to an MSA A3M file. msa_host_url: MSA server URL for colabfold. """ @@ -378,14 +315,18 @@ def generate_batch( context_batch = Batch.from_data_list([context_chemgraph] * batch_size) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - sampled_chemgraph_batch = denoiser( + result = denoiser( sdes=sdes, device=device, batch=context_batch, score_model=score_model, - fk_potentials=fk_potentials, - steering_config=steering_config, ) + + # Steered denoisers (FKC/SMC) return (batch, log_weights); unsteered returns batch + if isinstance(result, tuple): + sampled_chemgraph_batch, _ = result + else: + sampled_chemgraph_batch = result assert isinstance(sampled_chemgraph_batch, Batch) sampled_chemgraphs = sampled_chemgraph_batch.to_data_list() pos = torch.stack([x.pos for x in sampled_chemgraphs]).to("cpu") # [BS, L, 3] diff --git a/src/bioemu/steering.py b/src/bioemu/steering.py deleted file mode 100644 index 2084b2f..0000000 --- a/src/bioemu/steering.py +++ /dev/null @@ -1,358 +0,0 @@ -""" -Steering potentials for BioEmu sampling. - -This module provides steering potentials to guide protein structure generation -towards physically realistic conformations by penalizing chain breaks and clashes. -""" -import logging - -import torch -from openfold.np.residue_constants import ca_ca -from torch_geometric.data import Batch - -from bioemu.sde_lib import SDE - -from .so3_sde import apply_rotvec_to_rotmat - -logger = logging.getLogger(__name__) - - -def _get_x0_given_xt_and_score( - sde: SDE, - x: torch.Tensor, - t: torch.Tensor, - batch_idx: torch.LongTensor, - score: torch.Tensor, -) -> torch.Tensor: - """ - Compute expected value of x_0 using x_t and score. - """ - alpha_t, sigma_t = sde.mean_coeff_and_std(x=x, t=t, batch_idx=batch_idx) - return (x + sigma_t**2 * score) / alpha_t - - -def _get_R0_given_xt_and_score( - sde: SDE, - R: torch.Tensor, - t: torch.Tensor, - batch_idx: torch.LongTensor, - score: torch.Tensor, -) -> torch.Tensor: - """ - Compute R_0 given R_t and score. - """ - alpha_t, sigma_t = sde.mean_coeff_and_std(x=R, t=t, batch_idx=batch_idx) - return apply_rotvec_to_rotmat(R, -(sigma_t**2) * score) - - -def stratified_resample(weights: torch.Tensor) -> torch.Tensor: - """ - Stratified resampling along the last dimension of a batched tensor. - - Args: - weights: (B, N), normalized along dim=-1 - - Returns: - (B, N) indices of chosen particles - """ - B, N = weights.shape - - # 1. Compute cumulative sums (CDF) for each batch - cdf = torch.cumsum(weights, dim=-1) # (B, N) - - # 2. Stratified positions: one per interval - # shape (B, N): each row gets N stratified uniforms - u = (torch.rand(B, N, device=weights.device) + torch.arange(N, device=weights.device)) / N - - # 3. Inverse-CDF search: for each u, find smallest j s.t. cdf[b, j] >= u[b, i] - idx = torch.searchsorted(cdf, u, right=True) - - return idx # shape (B, N) - - -def get_pos0_rot0(sdes, batch, t, score): - """Get predicted x0 and R0 from current state and score.""" - x0_t = _get_x0_given_xt_and_score( - sde=sdes["pos"], - x=batch.pos, - t=t, - batch_idx=batch.batch, - score=score["pos"], - ) - R0_t = _get_R0_given_xt_and_score( - sde=sdes["node_orientations"], - R=batch.node_orientations, - t=t, - batch_idx=batch.batch, - score=score["node_orientations"], - ) - seq_length = len(batch.sequence[0]) - x0_t = x0_t.reshape(batch.batch_size, seq_length, 3).detach() - R0_t = R0_t.reshape(batch.batch_size, seq_length, 3, 3).detach() - return x0_t, R0_t - - -def log_physicality(pos: torch.Tensor, rot: torch.Tensor, sequence: str): - """ - Log physicality metrics for the generated structures. - - Args: - pos: Position tensor in nanometers - rot: Rotation tensor (unused, kept for API compatibility) - sequence: Amino acid sequence string (unused, kept for API compatibility) - """ - pos = 10 * pos # convert to Angstrom - n_residues = pos.shape[1] - - # Ca-Ca distances - ca_ca_dist = (pos[..., :-1, :] - pos[..., 1:, :]).pow(2).sum(dim=-1).pow(0.5) - - # Clash distances - clash_distances = torch.cdist(pos, pos) # shape: (batch, L, L) - mask = torch.ones(n_residues, n_residues, dtype=torch.bool, device=pos.device) - mask = mask.triu(diagonal=4) - clash_distances = clash_distances[:, mask] - - # Compute physicality violations - ca_break = (ca_ca_dist > 4.5).float() - ca_clash = (clash_distances < 3.4).float() - - # Print physicality metrics - logger.info(f"physicality/ca_break_mean: {ca_break.sum().item()}") - logger.info(f"physicality/ca_clash_mean: {ca_clash.sum().item()}") - logger.info(f"physicality/ca_ca_dist_mean: {ca_ca_dist.mean().item()}") - logger.info(f"physicality/clash_distances_mean: {clash_distances.mean().item()}") - - -def potential_loss_fn( - x: torch.Tensor, - target: torch.Tensor, - flatbottom: float, - slope: float, - order: float, - linear_from: float, -) -> torch.Tensor: - """ - Flat-bottom loss for continuous variables. - - Args: - x: Input tensor - target: Target value - flatbottom: Flat region width around target (zero penalty within this range) - slope: Slope outside flatbottom region - order: Power law exponent for penalty function - linear_from: Distance threshold where penalty switches from power law to linear - - Returns: - Loss values tensor - """ - diff = torch.abs(x - target) - diff_tol = torch.relu(diff - flatbottom) - - # Power law region - power_loss = (slope * diff_tol) ** order - - # Linear region (simple linear continuation from linear_from) - linear_loss = (slope * linear_from) ** order + slope * (diff_tol - linear_from) - - # Piecewise function - loss = torch.where(diff_tol <= linear_from, power_loss, linear_loss) - return loss - - -class Potential: - """Base class for steering potentials.""" - - def __call__( - self, - Ca_pos: torch.Tensor, - i: int, - N: int, - ) -> torch.Tensor: - raise NotImplementedError("Subclasses should implement this method.") - - def __repr__(self): - attrs = [ - f"{k}={getattr(self, k)!r}" - for k in getattr(self, "__dataclass_fields__", {}) or self.__dict__ - ] - sig = f"({', '.join(attrs)})" if attrs else "" - return f"{self.__class__.__name__}{sig}" - - -class ChainBreakPotential(Potential): - """ - Enforces realistic Ca-Ca distances (3.8Å) using flat-bottom loss. - - Penalizes deviations from the expected Ca-Ca distance between neighboring residues. - """ - - def __init__( - self, - flatbottom: float = 0.0, - slope: float = 1.0, - order: float = 1, - linear_from: float = 1.0, - weight: float = 1.0, - guidance_steering: bool = False, - ): - """ - Args: - flatbottom: Zero penalty within this range around target distance (Å). - slope: Steepness of penalty outside flatbottom region. - order: Exponent for power law region. - linear_from: Distance from target where penalty transitions to linear. - weight: Overall weight of this potential in total potential calculation. - guidance_steering: Enable gradient guidance for this potential. - """ - self.ca_ca = ca_ca - self.flatbottom = flatbottom - self.slope = slope - self.order = order - self.linear_from = linear_from - self.weight = weight - self.guidance_steering = guidance_steering - - def __call__( - self, - Ca_pos: torch.Tensor, - i: int, - N: int, - ): - """ - Compute the potential energy based on neighboring Ca-Ca distances. - - Args: - N_pos, Ca_pos, C_pos, O_pos: Backbone atom positions - i: Denoising step index - N: Number of residues - - Returns: - Tensor of shape (batch_size,) with chain break energies - """ - ca_ca_dist = (Ca_pos[..., :-1, :] - Ca_pos[..., 1:, :]).pow(2).sum(dim=-1).pow(0.5) - target_distance = self.ca_ca - dist_diff = potential_loss_fn( - ca_ca_dist, target_distance, self.flatbottom, self.slope, self.order, self.linear_from - ) - return self.weight * dist_diff.sum(dim=-1) - - -class ChainClashPotential(Potential): - """ - Prevents steric clashes between non-neighboring Ca atoms. - - Penalizes Ca-Ca distances below a minimum threshold for residues - separated by more than `offset` positions in sequence. - """ - - def __init__( - self, - flatbottom: float = 0.0, - dist: float = 4.2, - slope: float = 1.0, - weight: float = 1.0, - offset: int = 3, - guidance_steering: bool = False, - ): - """ - Args: - flatbottom: Additional buffer distance added to dist (Å). - dist: Minimum acceptable distance between non-neighboring Ca atoms (Å). - slope: Steepness of penalty outside flatbottom region. - weight: Overall weight of this potential in total potential calculation. - offset: Minimum residue separation to consider (excludes nearby residues). - guidance_steering: Enable gradient guidance for this potential. - """ - self.flatbottom = flatbottom - self.dist = dist - self.slope = slope - self.weight = weight - self.offset = offset - self.guidance_steering = guidance_steering - - def __call__( - self, - Ca_pos: torch.Tensor, - i: int, - N: int, - ): - """ - Calculate clash potential for Ca atoms. - - Args: - N_pos, Ca_pos, C_pos, O_pos: Backbone atom positions - i: Denoising step index - N: Number of residues - - Returns: - Tensor of shape (batch_size,) with clash energies - """ - # Calculate all pairwise distances - pairwise_distances = torch.cdist(Ca_pos, Ca_pos) # (batch_size, n_residues, n_residues) - - # Use triu mask with offset to select relevant pairs - n_residues = Ca_pos.shape[1] - mask = torch.ones(n_residues, n_residues, dtype=torch.bool, device=Ca_pos.device) - mask = mask.triu(diagonal=self.offset) - relevant_distances = pairwise_distances[:, mask] # (batch_size, n_pairs) - - potential_energy = torch.relu( - self.slope * (self.dist - self.flatbottom - relevant_distances) - ) - return self.weight * potential_energy.sum(dim=-1) - - -def resample_batch(batch, num_particles, energy, previous_energy=None, log_weights=None): - """ - Resample the batch based on the energy. - - Args: - batch: PyG batch of samples - num_particles: Number of particles per sample - energy: Current energy values - previous_energy: Previous energy values (for computing resampling probability) - log_weights: Log importance weights from gradient guidance - - Returns: - Tuple of (resampled_batch, resampled_energy, resampled_log_weights) - """ - BS = energy.shape[0] - - if previous_energy is not None: - # Compute the resampling probability based on the energy difference - # If previous_energy > energy, high probability to resample since new energy is lower - resample_logprob = previous_energy - energy - else: - # If no previous energy is provided, use the energy directly - resample_logprob = -energy - - # Add importance weights from gradient guidance (if provided) - if log_weights is not None: - resample_logprob = resample_logprob + log_weights - - # Sample indices per sample in mini batch [BS, Replica] - chunks = torch.split(resample_logprob, split_size_or_sections=num_particles) - chunk_size = chunks[0].shape[0] - indices = [] - for chunk_idx, chunk in enumerate(chunks): - chunk_prob = torch.exp(torch.nn.functional.log_softmax(chunk, dim=-1)) - indices_ = torch.multinomial(chunk_prob, num_samples=chunk.numel(), replacement=True) - indices_ = indices_ + chunk_size * chunk_idx - indices.append(indices_) - indices = torch.cat(indices, dim=0) - - # Resample samples - data_list = batch.to_data_list() - resampled_data_list = [data_list[i] for i in indices] - batch = Batch.from_data_list(resampled_data_list) - - resampled_energy = energy.flatten()[indices] - - # Reset log_weights after resampling - if log_weights is not None: - resampled_log_weights = torch.log(torch.ones(BS, device=batch.pos.device)) - else: - resampled_log_weights = None - - return batch, resampled_energy, resampled_log_weights diff --git a/src/bioemu/steering/__init__.py b/src/bioemu/steering/__init__.py new file mode 100644 index 0000000..bc2cefa --- /dev/null +++ b/src/bioemu/steering/__init__.py @@ -0,0 +1,34 @@ +"""Steering potentials and collective variables for guided BioEmu sampling. + +This package provides: +- Potentials for steering protein structure generation +- Collective variable (CV) framework for defining reaction coordinates +- Utility functions for resampling and x0 prediction +- Steering denoisers (FKC, SMC) +""" + +# fmt: off +# ruff: noqa: F401 + +from .collective_variables import ( + RMSD, + CaCaDistance, + CollectiveVariable, + FractionNativeContacts, + PairwiseClash, + load_reference_traj, +) +from .potentials import LinearPotential, Potential, UmbrellaPotential +from .utils import ( + _get_R0_given_xt_and_score, + _get_x0_given_xt_and_score, + compute_ess_from_log_weights, + compute_reward_and_grad, + compute_sequence_alignment, + get_pos0_rot0, + kabsch_align, + resample_based_on_log_weights, + reward_grad_rotmat_to_rotvec, + stratified_resample, + validate_steering_config, +) diff --git a/src/bioemu/steering/collective_variables.py b/src/bioemu/steering/collective_variables.py new file mode 100644 index 0000000..58955ec --- /dev/null +++ b/src/bioemu/steering/collective_variables.py @@ -0,0 +1,272 @@ +"""Collective Variable classes for steering/guided sampling.""" + +from abc import ABC, abstractmethod + +import mdtraj +import numpy as np +import torch + +from ..chemgraph import ChemGraph +from openfold.np.residue_constants import restype_3to1 +from ..training.foldedness import ( + _compute_contact_score, + compute_contacts, +) +from .utils import compute_sequence_alignment, kabsch_align + + +def load_reference_traj(reference_pdb: str) -> mdtraj.Trajectory: + reference_traj = mdtraj.load(reference_pdb) + reference_traj = reference_traj.atom_slice(reference_traj.topology.select("name CA")) + return reference_traj + + +class CollectiveVariable(ABC): + """Base class for all collective variables. + + All CVs receive Cα positions in **nanometres (nm)** (matching the units + used throughout the steering stack: ``potential_(x0, ...)``). + """ + + def __init__(self, **params): + self.params = params + + @abstractmethod + def compute_batch(self, ca_pos_nm: torch.Tensor, sequence: str | None = None) -> torch.Tensor: + """Compute CV for a batch of structures. + + Args: + ca_pos_nm: Cα positions in nm, shape ``(batch, n_residues, 3)``. + sequence: Amino acid sequence string (optional for some CVs). + + Returns: + CV values. Shape is ``(batch,)`` for scalar CVs, + ``(batch, ...)`` for per-element CVs. + """ + + def compute(self, chemgraph_list: list[ChemGraph], **kwargs) -> torch.Tensor: + """Convenience wrapper that extracts positions from *ChemGraph* objects. + + Positions in *ChemGraph* are stored in nanometres, which is the same + unit expected by :meth:`compute_batch`. + """ + all_positions = torch.stack([cg.pos for cg in chemgraph_list], dim=0) + sequence = chemgraph_list[0].sequence + return self.compute_batch(all_positions, sequence) + + +class FractionNativeContacts(CollectiveVariable): + """Fraction of Native Contacts CV.""" + + def __init__(self, reference_pdb: str, **kwargs): + # Accept but ignore extra kwargs (e.g., alignment_resid_ranges, metric_resid_ranges) + # to allow using this CV with configs designed for other CVs + assert reference_pdb is not None, "reference_pdb is required for FractionNativeContacts" + reference_traj = load_reference_traj(reference_pdb) + self.reference_pdb = reference_pdb + self.reference_info = compute_contacts(traj=reference_traj) + # Cache for aligned indices (computed on first call) + self._cached_sample_contact_indices: torch.Tensor | None = None + self._cached_ref_contact_distances: torch.Tensor | None = None + self._cached_sequence: str | None = None + + def _setup_alignment_cache(self, sequence: str, device: torch.device) -> None: + """Compute and cache aligned contact indices for the given sequence.""" + # Use shared sequence alignment utility + ref_to_sample = compute_sequence_alignment(self.reference_info.sequence, sequence) + + # Get aligned reference indices (those that have a match in sample) + aligned_indices_ref = sorted(ref_to_sample.keys()) + + # Get the reference contact distances that align with the batch sequence + mask_i = np.isin(self.reference_info.contact_indices[0, :], aligned_indices_ref) + mask_j = np.isin(self.reference_info.contact_indices[1, :], aligned_indices_ref) + mask = np.logical_and(mask_i, mask_j) + aligned_ref_contact_indices = self.reference_info.contact_indices[:, mask] + self._cached_ref_contact_distances = ( + torch.from_numpy(self.reference_info.contact_distances_angstrom[mask]) + .float() + .to(device) + ) + + # Map reference contact indices to sample contact indices + self._cached_sample_contact_indices = torch.tensor( + [ + [ref_to_sample[int(idx)] for idx in aligned_ref_contact_indices[0, :]], + [ref_to_sample[int(idx)] for idx in aligned_ref_contact_indices[1, :]], + ], + dtype=torch.long, + device=device, + ) + self._cached_sequence = sequence + + def compute_batch(self, ca_pos_nm: torch.Tensor, sequence: str | None = None) -> torch.Tensor: + """Compute FNC for a batch of structures. + + Args: + ca_pos_nm: Cα positions in nm, shape ``(batch, n_residues, 3)``. + sequence: Amino acid sequence string. + + Returns: + FNC values, shape ``(batch,)``. + """ + assert sequence is not None, "FractionNativeContacts requires a sequence" + device = ca_pos_nm.device + + # Setup cache on first call or if sequence changed + if self._cached_sequence != sequence or self._cached_sample_contact_indices is None: + self._setup_alignment_cache(sequence, device) + + # Move cached tensors to correct device if needed + assert self._cached_sample_contact_indices is not None + assert self._cached_ref_contact_distances is not None + if self._cached_sample_contact_indices.device != device: + self._cached_sample_contact_indices = self._cached_sample_contact_indices.to(device) + self._cached_ref_contact_distances = self._cached_ref_contact_distances.to(device) + + # Compute contact distances for all samples in batch + # ca_pos_nm: (batch_size, n_residues, 3) + # indices: (2, n_contacts) + contact_distances_nm = torch.linalg.norm( + ca_pos_nm[:, self._cached_sample_contact_indices[0, :], :] + - ca_pos_nm[:, self._cached_sample_contact_indices[1, :], :], + dim=2, + ) # (batch_size, n_contacts) + contact_distances_angstrom = contact_distances_nm * 10.0 + + # Compute contact scores using the shared foldedness helper + contact_scores = _compute_contact_score( + sample_contact_distances=contact_distances_angstrom, + reference_contact_distances=self._cached_ref_contact_distances, + ) # (batch_size, n_contacts) + + # Average contact scores over contacts for each sample + fnc = torch.mean(contact_scores, dim=1) + return fnc + + +class RMSD(CollectiveVariable): + """RMSD CV with differentiable Kabsch alignment. + + Handles sequence mismatches between sample and reference via pairwise + sequence alignment — only aligned (matched) residues are used for both + the Kabsch superposition and the RMSD metric. + """ + + def __init__(self, reference_pdb: str, **kwargs): + # Accept but ignore extra kwargs for config compatibility + assert reference_pdb is not None, "reference_pdb is required for RMSD" + reference_traj = load_reference_traj(reference_pdb) + self.reference_pdb = reference_pdb + + # Store reference positions as torch tensor (in nm) + self.ref_pos = torch.tensor( + reference_traj.xyz[0, :, :], dtype=torch.float32 + ) # shape: (n_ref_residues, 3) + + # Extract reference sequence for alignment + self.ref_sequence = "" + for r in reference_traj.topology.residues: + resname = r.name + if resname in restype_3to1: + self.ref_sequence += restype_3to1[resname] + else: + self.ref_sequence += "X" + + # Cache for sequence alignment results + self._cached_sequence: str | None = None + self._cached_ref_indices: torch.Tensor | None = None + self._cached_sample_indices: torch.Tensor | None = None + + def _setup_sequence_alignment(self, sample_sequence: str, device: torch.device) -> None: + """Align sample sequence to reference and cache matched index pairs.""" + n_ref = len(self.ref_sequence) + n_sample = len(sample_sequence) + + if n_ref == n_sample and self.ref_sequence == sample_sequence: + # Exact match — use all residues + ref_indices = list(range(n_ref)) + sample_indices = list(range(n_sample)) + else: + ref_to_sample = compute_sequence_alignment(self.ref_sequence, sample_sequence) + ref_indices = sorted(ref_to_sample.keys()) + sample_indices = [ref_to_sample[i] for i in ref_indices] + + self._cached_sequence = sample_sequence + self._cached_ref_indices = torch.tensor(ref_indices, dtype=torch.long, device=device) + self._cached_sample_indices = torch.tensor(sample_indices, dtype=torch.long, device=device) + + def compute_batch(self, ca_pos_nm: torch.Tensor, sequence: str | None = None) -> torch.Tensor: + """Compute RMSD for a batch of structures using differentiable Kabsch alignment. + + Args: + ca_pos_nm: Cα positions in nm, shape ``(batch, n_residues, 3)``. + sequence: Amino acid sequence string. + + Returns: + RMSD values in nm, shape ``(batch,)``. + """ + device = ca_pos_nm.device + + # Setup / refresh alignment cache + assert sequence is not None, "RMSD requires a sequence" + if self._cached_sequence != sequence or self._cached_ref_indices is None: + self._setup_sequence_alignment(sequence, device) + assert self._cached_ref_indices is not None and self._cached_sample_indices is not None + if self._cached_ref_indices.device != device: + self._cached_ref_indices = self._cached_ref_indices.to(device) + self._cached_sample_indices = self._cached_sample_indices.to(device) + + ref_pos = self.ref_pos.to(device) + + # Select aligned residues + ref_aligned = ref_pos[self._cached_ref_indices] # (n_aligned, 3) + samples_aligned = ca_pos_nm[:, self._cached_sample_indices, :] # (batch, n_aligned, 3) + + # Center both + ref_centered = ref_aligned - ref_aligned.mean(dim=0, keepdim=True) + samples_centered = samples_aligned - samples_aligned.mean(dim=1, keepdim=True) + + # Kabsch alignment: rotate samples onto reference + samples_rotated = kabsch_align(samples_centered, ref_centered) + + # RMSD over aligned atoms + diff = samples_rotated - ref_centered.unsqueeze(0) + squared_distances = (diff**2).sum(dim=2) # (batch_size, n_aligned) + rmsd = torch.sqrt(squared_distances.mean(dim=1)) # (batch_size,) + + return rmsd + + +class CaCaDistance(CollectiveVariable): + """Consecutive Cα–Cα distances. + + Returns per-bond distances in nm, shape ``(batch, L-1)``. + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def compute_batch(self, ca_pos_nm: torch.Tensor, sequence: str | None = None) -> torch.Tensor: + return (ca_pos_nm[..., :-1, :] - ca_pos_nm[..., 1:, :]).pow(2).sum(dim=-1).pow(0.5) + + +class PairwiseClash(CollectiveVariable): + """Pairwise clash distances: ``relu(min_dist - dist)`` for residue pairs + separated by at least *offset* positions. + + Returns per-pair clash values in nm (0 when no clash), shape ``(batch, n_pairs)``. + """ + + def __init__(self, min_dist: float = 0.42, offset: int = 3, **kwargs): + super().__init__(**kwargs) + self.min_dist = min_dist + self.offset = offset + + def compute_batch(self, ca_pos_nm: torch.Tensor, sequence: str | None = None) -> torch.Tensor: + pairwise_distances = torch.cdist(ca_pos_nm, ca_pos_nm) + n_residues = ca_pos_nm.shape[1] + mask = torch.ones(n_residues, n_residues, dtype=torch.bool, device=ca_pos_nm.device) + mask = mask.triu(diagonal=self.offset) + relevant_distances = pairwise_distances[:, mask] + return torch.relu(self.min_dist - relevant_distances) diff --git a/src/bioemu/steering/dpm_fkc.py b/src/bioemu/steering/dpm_fkc.py new file mode 100644 index 0000000..017ff16 --- /dev/null +++ b/src/bioemu/steering/dpm_fkc.py @@ -0,0 +1,384 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import logging +from collections.abc import Callable +from typing import cast + +import torch +from torch_geometric.data.batch import Batch + +from ..chemgraph import ChemGraph +from ..denoiser import ( + GuidedScore, + _get_dpm_coefficients, + _predict_midpoint, + second_order_step_dpmsolver_plusplus, +) +from ..sde_lib import SDE, CosineVPSDE +from ..so3_sde import SO3SDE +from .utils import compute_reward_and_grad, resample_based_on_log_weights, validate_steering_config + +logger = logging.getLogger(__name__) + +# ============================================================================= +# FKC-Specific Helper Functions +# ============================================================================= + + +def _get_fkc_guided_score( + *, + sdes: dict[str, SDE], + batch: Batch, + t: torch.Tensor, + score_model: torch.nn.Module, + potentials: list[Callable], + use_x0_for_reward: bool, + enable_grad: bool, + noise_scale: float, +) -> GuidedScore: + """Compute FKC-style guided score with noise scaling baked in. + + FKC guided score formula (Eq.29 of FKC paper): + - Steered: (1 + a^2) / 2 * score + (a^2) / 2 * reward_grad + - Unsteered: (1 + a^2) / 2 * score + + SO3 scores are NOT scaled (uses raw score, equivalent to noise_scale=0 for SO3). + + Args: + noise_scale: The 'a' parameter controlling stochasticity. + + Returns: + GuidedScore with pre-scaled position scores ready for second_order_step_dpmsolver_plusplus. + """ + reward, grad_x, grad_so3, grad_t, x0, score = compute_reward_and_grad( + sdes=sdes, + batch=batch, + t=t, + score_model=score_model, + potentials=potentials, + use_x0_for_reward=use_x0_for_reward, + eval_score=True, + enable_grad=enable_grad, + ) + + a = noise_scale + noise_factor = (1 + a**2) / 2 + + if len(potentials) > 0 and enable_grad: + # Eq.29 of FKC paper with pre-scaling + guided_pos = noise_factor * score["pos"] + (a**2) / 2 * grad_x + else: + # Unsteered case with noise scaling + guided_pos = noise_factor * score["pos"] + + # SO3 uses raw score (no scaling, equivalent to noise_scale=0) + guided_so3 = score["node_orientations"] + + return GuidedScore( + pos=guided_pos, + so3=guided_so3, + reward=reward, + reward_grad_t=grad_t, + x0=x0, + raw_score=score, + ) + + +def _compute_fkc_weights( + batch: Batch, + pos_sde: SDE, + raw_score_pos: torch.Tensor, + reward_grad_x: torch.Tensor, + reward_grad_t: torch.Tensor, + t: torch.Tensor, + t_next: torch.Tensor, + batch_idx: torch.LongTensor, +) -> torch.Tensor: + """Compute FKC log weight update using analytical Girsanov formula (Eq.30). + + dlog_weights/dt = d_t reward - reward_grad_x · (-f + 0.5 * g^2 * score) + + where f is the forward drift and g is the diffusion coefficient. + + Args: + raw_score_pos: Raw (unscaled) position score. + reward_grad_x: Gradient of reward w.r.t. position. + reward_grad_t: Gradient of reward w.r.t. time. + + Returns: + dlog_weights for this step (shape: batch_size). + """ + batch_size = batch.num_graphs + seq_length = batch.pos.shape[0] // batch_size + x_dim = batch.pos.shape[-1] + + # Get forward SDE coefficients + fw_drift, diffusion = pos_sde.sde(x=batch.pos, t=t, batch_idx=batch_idx) + + # Eq.30 of FKC paper + reward_inner_product = reward_grad_x * (-fw_drift + 0.5 * diffusion**2 * raw_score_pos) + dlog_weights_dt = reward_grad_t - torch.sum( + reward_inner_product.view(batch_size, seq_length * x_dim), dim=-1 + ) + dlog_weights = dlog_weights_dt * (t_next - t) + + return dlog_weights + + +# ============================================================================= +# Main FKC Step Function +# ============================================================================= + + +def dpm_solver_sde_fkc_step( + batch: Batch, + t: torch.Tensor, + t_next: torch.Tensor, + sdes: dict[str, SDE], + score_model: torch.nn.Module, + max_t: float, + potentials: list[Callable], + is_last_step: bool, + use_x0_for_reward: bool = False, + enable_grad: bool = True, + noise_scale: float = 0.0, +) -> tuple[Batch, dict[str, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]: + """DPM-Solver++ step with FKC-style steering and analytical weight computation. + + Args: + batch: Current batch at time t. + t: Current diffusion time. + t_next: Next diffusion time (t_next < t for reverse). + sdes: Dictionary with "pos" and "node_orientations" SDEs. + score_model: Score network. + max_t: Maximum diffusion time (unused, kept for API compatibility). + potentials: List of FK potentials for steering. + is_last_step: Whether this is the final denoising step (unused). + use_x0_for_reward: Evaluate potentials on x0 estimate vs x_t. + enable_grad: Enable gradient computation for steering. + noise_scale: Scale for stochastic noise (parameter 'a'). + + Returns: + (batch_next, score, dlog_weights, x0, reward) + """ + pos_sde = sdes["pos"] + so3_sde = sdes["node_orientations"] + assert isinstance(pos_sde, CosineVPSDE) + assert isinstance(so3_sde, SO3SDE) + batch_idx = batch.batch + + # 1. COMPUTE COEFFICIENTS + coeffs = _get_dpm_coefficients(pos_sde, batch.pos, t, t_next, batch_idx) + + # 2. COMPUTE GUIDED SCORE AT t (with FKC formula, pre-scaled) + guided_score_t = _get_fkc_guided_score( + sdes=sdes, + batch=batch, + t=t, + score_model=score_model, + potentials=potentials, + use_x0_for_reward=use_x0_for_reward, + enable_grad=enable_grad, + noise_scale=noise_scale, + ) + + # 3. COMPUTE FKC WEIGHTS (analytical formula, Eq.30) + if len(potentials) > 0 and enable_grad: + # Need reward_grad_x for weight computation - extract from raw score computation + _, reward_grad_x, _, _, _, _ = compute_reward_and_grad( + sdes=sdes, + batch=batch, + t=t, + score_model=score_model, + potentials=potentials, + use_x0_for_reward=use_x0_for_reward, + eval_score=False, # Don't need score again + enable_grad=True, + ) + dlog_weights = _compute_fkc_weights( + batch=batch, + pos_sde=pos_sde, + raw_score_pos=guided_score_t.raw_score["pos"], + reward_grad_x=reward_grad_x, + reward_grad_t=guided_score_t.reward_grad_t, + t=t, + t_next=t_next, + batch_idx=batch_idx, + ) + else: + dlog_weights = torch.zeros(batch.num_graphs, device=batch.pos.device) + + # 4. PREDICTION STEP -> t_lambda (midpoint) + batch_lambda = _predict_midpoint( + batch=batch, + coeffs=coeffs, + score_pos=guided_score_t.pos, + score_so3=guided_score_t.so3, + so3_sde=so3_sde, + t=t, + batch_idx=batch_idx, + ) + + # 5. COMPUTE GUIDED SCORE AT t_lambda (with FKC formula, pre-scaled) + guided_score_lambda = _get_fkc_guided_score( + sdes=sdes, + batch=batch_lambda, + t=coeffs.t_lambda, + score_model=score_model, + potentials=potentials, + use_x0_for_reward=use_x0_for_reward, + enable_grad=enable_grad, + noise_scale=noise_scale, + ) + + # 6. UPDATE STEP (Second Order) -> t_next + # Scores are already pre-scaled by _get_fkc_guided_score + batch_next, _ = second_order_step_dpmsolver_plusplus( + batch=batch, + coeffs=coeffs, + scaled_score_pos_t=guided_score_t.pos, + scaled_score_pos_lambda=guided_score_lambda.pos, + score_so3_t=guided_score_t.so3, + score_so3_lambda=guided_score_lambda.so3, + so3_sde=so3_sde, + t=t, + t_next=t_next, + batch_idx=batch_idx, + noise_weight=noise_scale, + ) + + return ( + batch_next, + guided_score_t.raw_score, + dlog_weights, + guided_score_t.x0, + guided_score_t.reward, + ) + + +# ============================================================================= +# FKC Denoiser Loop +# ============================================================================= + + +def dpm_solver_fkc( + sdes: dict[str, SDE], + batch: Batch, + N: int, + score_model: torch.nn.Module, + max_t: float, + eps_t: float, + device: torch.device, + record_grad_steps: set[int] = set(), + noise: float = 0.0, + fk_potentials: list[Callable] | None = None, + steering_config: dict | None = None, + output_dir: str | None = None, + batch_seed: int = 0, + use_x0_for_reward: bool = False, +) -> tuple[ChemGraph, torch.Tensor]: + """FKC denoiser loop using DPM-Solver 2nd order integrator. + + Runs the reverse diffusion process with optional FKC steering. + At each step, calls dpm_solver_sde_fkc_step to compute guided scores + and analytical importance weights, then resamples if steering is enabled. + + When called with fk_potentials=[] and steering_config=None, this is + equivalent to an unsteered DPM solver. + + Args: + sdes: Dictionary of SDEs with keys "pos" and "node_orientations". + batch: Initial batch (prior sampling done internally). + N: Number of denoising steps. + score_model: Score network. + max_t: Maximum diffusion time (must be < 1.0). + eps_t: Minimum diffusion time. + device: Torch device. + noise: Stochastic noise scale (parameter 'a'). 0 = ODE, 1 = full SDE. + fk_potentials: List of FK potentials for steering. + steering_config: Configuration dict with keys num_particles, ess_threshold. + Optional keys: start (max diffusion time for resampling, default max_t), + end (min diffusion time for resampling, default 0.0). + use_x0_for_reward: Evaluate potentials on x0 estimate vs x_t. + + Returns: + (batch, batch_log_weights) + """ + logger.info("Running DPM-Solver FKC for %s steps", N) + assert isinstance(batch, Batch) + assert max_t < 1.0 + validate_steering_config(steering_config) + + batch = batch.to(device) + + if isinstance(score_model, torch.nn.Module): + score_model = score_model.to(device) + + assert isinstance(sdes["node_orientations"], SO3SDE) + sdes["node_orientations"] = sdes["node_orientations"].to(device) + + # Initialize batch from prior + batch = batch.replace( + pos=sdes["pos"].prior_sampling(batch.pos.shape, device=device), + node_orientations=sdes["node_orientations"].prior_sampling( + batch.node_orientations.shape, device=device + ), + ) + batch = cast(ChemGraph, batch) + + # Uniform timestep grid + timesteps = torch.linspace(max_t, eps_t, N, device=device) + num_steps = timesteps.shape[0] + + enable_steering = ( + (steering_config is not None) + and (fk_potentials is not None) + and (steering_config["num_particles"] > 1) + ) + log_weights = torch.zeros(batch.num_graphs, device=device) + batch_log_weights = torch.zeros(batch.num_graphs, device=device) + + from tqdm.auto import tqdm + + for i in tqdm(range(num_steps - 1), position=1, desc="Denoising: ", ncols=0, leave=False): + t = torch.full((batch.num_graphs,), timesteps[i], device=device) + t_next = torch.full((batch.num_graphs,), timesteps[i + 1], device=device) + + batch, _, dlog_weights, _, _ = dpm_solver_sde_fkc_step( + batch=batch, + t=t, + t_next=t_next, + sdes=sdes, + score_model=score_model, + max_t=max_t, + potentials=fk_potentials or [], + is_last_step=(i == num_steps - 2), + enable_grad=len(fk_potentials or []) > 0, + noise_scale=noise, + use_x0_for_reward=use_x0_for_reward, + ) + + log_weights = log_weights + dlog_weights + batch_log_weights = batch_log_weights + dlog_weights + + if enable_steering: + assert steering_config is not None + current_t = timesteps[i].item() + steer_start = steering_config["start"] + steer_end = steering_config["end"] + in_window = steer_start >= current_t >= steer_end + + if in_window: + batch, log_weights, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_weights, + n_particles=min(batch.num_graphs, steering_config["num_particles"]), + is_last_step=(i == num_steps - 2), + ess_threshold=steering_config["ess_threshold"], + step=i, + t=t[0], + ) + if indices is not None: + batch_log_weights = batch_log_weights[indices] + + return batch, batch_log_weights diff --git a/src/bioemu/steering/dpm_smc.py b/src/bioemu/steering/dpm_smc.py new file mode 100644 index 0000000..5889eb0 --- /dev/null +++ b/src/bioemu/steering/dpm_smc.py @@ -0,0 +1,286 @@ +import logging +from collections.abc import Callable +from typing import cast + +import torch +from torch_geometric.data import Batch +from tqdm import tqdm + +from ..chemgraph import ChemGraph +from ..denoiser import ( + _get_dpm_coefficients, + _predict_midpoint, + get_score, + second_order_step_dpmsolver_plusplus, +) +from ..sde_lib import SDE, CosineVPSDE +from ..so3_sde import SO3SDE +from .utils import compute_reward_and_grad, resample_based_on_log_weights, validate_steering_config + +logger = logging.getLogger(__name__) + + +def dpm_solver_sde_smc_step( + batch, + t, + t_next, + sdes, + score_model, + max_t, + potentials, + step_idx, + use_x0_for_reward: bool = True, + previous_reward: torch.Tensor | None = None, + log_weights: torch.Tensor | None = None, + steering_config: dict | None = None, + noise_scale: float = 0.5, +): + """SMC step using DPM-Solver 2nd order integrator. + + Args: + batch: ChemGraph batch at time t. + t: Current diffusion time tensor [batch_size]. + t_next: Next diffusion time tensor [batch_size]. + sdes: Dictionary of SDEs with keys "pos" and "node_orientations". + score_model: Score network. + max_t: Maximum diffusion time. + potentials: List of potential functions for steering. + step_idx: Current step index. + use_x0_for_reward: Whether to evaluate potentials on the x0 estimate + (``t=0`` denoised prediction) rather than on x_t. SMC is strictly + defined at ``t=0``; set ``use_x0_for_reward=False`` only for debug + or toy use cases. + previous_reward: Reward from the previous step [batch_size], used for TDS weight. + log_weights: Current log importance weights [batch_size]. + steering_config: Steering configuration dictionary. + noise_scale: Scale for stochastic noise (parameter 'a'). + + Returns: + batch_next: Updated ChemGraph batch at time t_next. + score: Dict of scores from the current step. + log_weights: Updated log importance weights [batch_size]. + x0: Estimated clean positions. + reward: Current reward for use as previous_reward in next step. + """ + pos_sde = sdes["pos"] + so3_sde = sdes["node_orientations"] + assert isinstance(pos_sde, CosineVPSDE) + assert isinstance(so3_sde, SO3SDE) + batch_idx = batch.batch + + reward, _, _, _, x0, score = compute_reward_and_grad( + sdes=sdes, + batch=batch, + t=t, + score_model=score_model, + potentials=potentials, + use_x0_for_reward=use_x0_for_reward, + eval_score=True, + enable_grad=False, + ) + + # Resampling based on TDS weights (reward difference) + if len(potentials) > 0: + assert steering_config is not None + assert log_weights is not None + assert previous_reward is not None + log_weights = log_weights + reward - previous_reward + + indices = None + original_batch_size = batch.num_graphs + seq_length = batch.pos.shape[0] // batch.num_graphs + x_dim = batch.pos.shape[-1] + + batch, log_weights, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_weights, + n_particles=min(batch.num_graphs, steering_config["num_particles"]), + is_last_step=False, + ess_threshold=steering_config["ess_threshold"], + step=step_idx, + t=t[0], + ) + + if indices is not None: + seq_length = batch.pos.shape[0] // batch.num_graphs + score = { + "pos": score["pos"] + .view(original_batch_size, seq_length, x_dim)[indices] + .reshape(-1, x_dim), + "node_orientations": score["node_orientations"] + .view(original_batch_size, seq_length, 3)[indices] + .reshape(-1, 3), + } + reward = reward[indices] + t = t[indices] + t_next = t_next[indices] + batch_idx = batch.batch + + # Compute DPM-Solver coefficients + coeffs = _get_dpm_coefficients(pos_sde, batch.pos, t, t_next, batch_idx) + + # Scale scores with noise factor: (1 + a²) / 2 + a = noise_scale + noise_factor = (1 + a**2) / 2 + scaled_score_t = noise_factor * score["pos"] + + # Midpoint prediction (position + SO3) + batch_lambda = _predict_midpoint( + batch=batch, + coeffs=coeffs, + score_pos=scaled_score_t, + score_so3=score["node_orientations"], + so3_sde=so3_sde, + t=t, + batch_idx=batch_idx, + ) + + # Correction step: evaluate score at midpoint + score_lambda = get_score( + batch=batch_lambda, + t=coeffs.t_lambda, + score_model=score_model, + sdes=sdes, + ) + + scaled_score_lambda = noise_factor * score_lambda["pos"] + + # Second-order update (position + SO3) + batch_next, _ = second_order_step_dpmsolver_plusplus( + batch=batch, + coeffs=coeffs, + scaled_score_pos_t=scaled_score_t, + scaled_score_pos_lambda=scaled_score_lambda, + score_so3_t=score["node_orientations"], + score_so3_lambda=score_lambda["node_orientations"], + so3_sde=so3_sde, + t=t, + t_next=t_next, + batch_idx=batch_idx, + noise_weight=a, + ) + + return batch_next, score, log_weights, x0, reward + + +def dpm_solver_smc( + sdes: dict[str, SDE], + batch: Batch, + N: int, + score_model: torch.nn.Module, + max_t: float, + eps_t: float, + device: torch.device, + record_grad_steps: set[int] = set(), + noise: float = 0.0, + fk_potentials: list[Callable] | None = None, + steering_config: dict | None = None, + output_dir: str | None = None, +) -> tuple[ChemGraph, torch.Tensor]: + """ + SMC denoiser loop using DPM-Solver 2nd order integrator. + + Args: + steering_config: Configuration dictionary for steering. Can include: + - num_particles: Number of particles per group + - ess_threshold: ESS threshold for resampling + - start: Max diffusion time for resampling (default: max_t) + - end: Min diffusion time for resampling (default: 0.0) + """ + logger.info("Using DPMSolver SDE SMC %s steps", N) + assert isinstance(batch, Batch) + assert max_t < 1.0 + validate_steering_config(steering_config) + + batch = batch.to(device) + + if isinstance(score_model, torch.nn.Module): + score_model = score_model.to(device) + + so3_sde = sdes["node_orientations"] + assert isinstance(so3_sde, SO3SDE) + sdes["node_orientations"] = so3_sde.to(device) + + # Initialize batch from prior + batch = batch.replace( + pos=sdes["pos"].prior_sampling(batch.pos.shape, device=device), + node_orientations=sdes["node_orientations"].prior_sampling( + batch.node_orientations.shape, device=device + ), + ) + batch = cast(ChemGraph, batch) + + # Uniform timestep grid + timesteps = torch.linspace(max_t, eps_t, N, device=device) + num_steps = timesteps.shape[0] + + # Initialize lists to store stats + enable_steering = ( + (steering_config is not None) + and (fk_potentials is not None) + and (steering_config["num_particles"] > 1) + ) + log_weights = torch.zeros(batch.num_graphs, device=device) + previous_reward = torch.zeros(batch.num_graphs, device=device) + + for i in tqdm(range(num_steps - 1), position=1, desc="Denoising: ", ncols=0, leave=False): + t = torch.full((batch.num_graphs,), timesteps[i], device=device) + t_next = torch.full((batch.num_graphs,), timesteps[i + 1], device=device) + + # Check time window for resampling + current_t = timesteps[i].item() + steer_start = steering_config["start"] if steering_config else 1.0 + steer_end = steering_config["end"] if steering_config else 0.0 + in_window = steer_start >= current_t >= steer_end + step_steering = enable_steering and in_window + + batch, _, step_log_weights, _, reward = dpm_solver_sde_smc_step( + batch=batch, + t=t, + t_next=t_next, + sdes=sdes, + score_model=score_model, + max_t=max_t, + potentials=(fk_potentials or []) if step_steering else [], + step_idx=i, + use_x0_for_reward=True, + previous_reward=previous_reward if step_steering else None, + log_weights=log_weights if step_steering else None, + steering_config=steering_config if step_steering else None, + noise_scale=noise, + ) + + if step_steering: + log_weights = step_log_weights + previous_reward = reward.detach() + + if enable_steering: + assert steering_config is not None + assert fk_potentials is not None + # Evaluate reward on final clean x0 for final resampling + reward, _, _, _, _, _ = compute_reward_and_grad( + sdes=sdes, + batch=batch, + t=t_next, + score_model=score_model, + potentials=fk_potentials, + use_x0_for_reward=True, + eval_score=False, + enable_grad=False, + ) + + # Update weights + log_weights = log_weights + reward - previous_reward + + # Resample + batch, log_weights, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_weights, + n_particles=min(batch.num_graphs, steering_config["num_particles"]), + is_last_step=True, + ess_threshold=steering_config["ess_threshold"], + step=i + 1, + t=t_next[0], + ) + + return batch, log_weights diff --git a/src/bioemu/steering/potentials.py b/src/bioemu/steering/potentials.py new file mode 100644 index 0000000..eaf44f3 --- /dev/null +++ b/src/bioemu/steering/potentials.py @@ -0,0 +1,136 @@ +"""Potential classes for steering/guided sampling.""" + +import logging +from abc import ABC, abstractmethod +from typing import Any + +import torch + +from .collective_variables import CollectiveVariable + +logger = logging.getLogger(__name__) + + +class Potential(ABC): + """Base class for steering potentials. + + Subclasses must implement :meth:`energy_from_cv` and :meth:`loss_fn`. + """ + + @abstractmethod + def __call__(self, ca_pos_nm: torch.Tensor, *, t=None, sequence=None) -> torch.Tensor: + """Compute potential energy from Cα positions (in nm).""" + + @abstractmethod + def loss_fn(self, x: torch.Tensor) -> torch.Tensor: + """Per-element loss function using instance attributes.""" + + @abstractmethod + def energy_from_cv(self, cv_values: torch.Tensor, t: float | None = None) -> torch.Tensor: + """Compute energy from precomputed CV values.""" + + def __repr__(self): + attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items()) + return f"{self.__class__.__name__}({attrs})" + + +class UmbrellaPotential(Potential): + """Flat-bottom umbrella potential applied to a collective variable. + + Energy = weight × Σ potential_loss(cv_values, target, flatbottom, slope, order, linear_from) + + The loss per element is a flat-bottom region around *target* (zero within + ±flatbottom), a power-law ramp (slope·Δ)^order, and a linear tail beyond + *linear_from*. + """ + + def __init__( + self, + target: float = 1.0, + flatbottom: float = 0.0, + slope: float = 1.0, + order: float = 1, + linear_from: float = 1.0, + weight: float = 1.0, + guidance_steering: bool = False, + cv: CollectiveVariable | None = None, + **_: Any, + ) -> None: + self.target = target + self.flatbottom = flatbottom + self.slope = slope + self.order = order + self.linear_from = linear_from + self.weight = weight + self.guidance_steering = guidance_steering + self.cv = cv + + def loss_fn(self, x: torch.Tensor) -> torch.Tensor: + """Flat-bottom + piecewise-linear umbrella loss. + + Returns the per-element loss (same shape as *x*). + """ + diff = torch.abs(x - self.target) + diff_tol = torch.relu(diff - self.flatbottom) + power_loss = (self.slope * diff_tol) ** self.order + linear_loss = (self.slope * self.linear_from) ** self.order + self.slope * ( + diff_tol - self.linear_from + ) + return torch.where(diff_tol <= self.linear_from, power_loss, linear_loss) + + def energy_from_cv(self, cv_values: torch.Tensor, t: float | None = None) -> torch.Tensor: + """Compute energy from precomputed CV values.""" + base = self.loss_fn(cv_values) + # Sum over all non-batch dims (handles both scalar and per-element CVs) + if base.ndim > 1: + base = base.sum(dim=tuple(range(1, base.ndim))) + return self.weight * base + + def __call__(self, ca_pos_nm: torch.Tensor, *, t=None, sequence=None): + assert self.cv is not None, "UmbrellaPotential requires a cv to be set." + cv_values = self.cv.compute_batch(ca_pos_nm, sequence) + return self.energy_from_cv(cv_values, t=t) + + +class LinearPotential(Potential): + """(Optionally clamped) linear potential applied to a collective variable. + + Energy = weight × slope × clamp(cv - target, clip_min, clip_max) + """ + + def __init__( + self, + target: float = 1.0, + slope: float = 1.0, + weight: float = 1.0, + clip_min: float | None = None, + clip_max: float | None = None, + guidance_steering: bool = False, + cv: CollectiveVariable | None = None, + **_: Any, + ) -> None: + self.target = target + self.slope = slope + self.weight = weight + self.clip_min = clip_min + self.clip_max = clip_max + self.guidance_steering = guidance_steering + self.cv = cv + + def loss_fn(self, x: torch.Tensor) -> torch.Tensor: + """(Optionally clamped) linear loss. Returns per-element values.""" + diff = x - self.target + if self.clip_min is not None or self.clip_max is not None: + diff = torch.clamp(diff, min=self.clip_min, max=self.clip_max) + return self.slope * diff + + def energy_from_cv(self, cv_values: torch.Tensor, t: float | None = None) -> torch.Tensor: + base = self.loss_fn(cv_values) + if base.ndim > 1: + base = base.sum(dim=tuple(range(1, base.ndim))) + return self.weight * base + + def __call__(self, ca_pos_nm: torch.Tensor, *, t=None, sequence=None): + assert self.cv is not None, "LinearPotential requires a cv to be set." + cv_values = self.cv.compute_batch(ca_pos_nm, sequence) + return self.energy_from_cv(cv_values, t=t) diff --git a/src/bioemu/steering/utils.py b/src/bioemu/steering/utils.py new file mode 100644 index 0000000..bfb9454 --- /dev/null +++ b/src/bioemu/steering/utils.py @@ -0,0 +1,453 @@ +"""Utility functions for steering/guided sampling.""" + +import logging +from collections.abc import Callable + +import torch +from torch_geometric.data import Batch +from torch_geometric.data.batch import Batch as BatchType + +from ..chemgraph import ChemGraph +from ..sde_lib import SDE +from ..so3_sde import SO3SDE, apply_rotvec_to_rotmat, skew_matrix_to_vector + +logger = logging.getLogger(__name__) + + +def validate_steering_config(steering_config: dict | None) -> None: + """Validate steering config parameters. + + Args: + steering_config: Steering configuration dict. Must contain (when not None): + - num_particles: Number of particles (>1 for steering) + - ess_threshold: ESS threshold for resampling + - start: Start time for steering (0.0-1.0) + - end: End time for steering (0.0-1.0) + + Raises: + ValueError: If required keys are missing or start/end times are invalid. + """ + if steering_config is None: + return + for key in ("start", "end", "num_particles", "ess_threshold"): + if key not in steering_config: + raise ValueError( + f"steering_config is missing required key '{key}'. " + "All of 'start', 'end', 'num_particles', 'ess_threshold' must be specified." + ) + start = steering_config["start"] + end = steering_config["end"] + if not (0.0 <= end <= start <= 1.0): + raise ValueError( + f"Steering time window invalid: need 0.0 <= end ({end}) <= start ({start}) <= 1.0" + ) + + +def _get_x0_given_xt_and_score( + sde: SDE, + x: torch.Tensor, + t: torch.Tensor, + batch_idx: torch.LongTensor, + score: torch.Tensor, +) -> torch.Tensor: + """ + Compute x_0 given x_t and score. + """ + + alpha_t, sigma_t = sde.mean_coeff_and_std(x=x, t=t, batch_idx=batch_idx) + + return (x + sigma_t**2 * score) / alpha_t + + +def _get_R0_given_xt_and_score( + sde: SO3SDE, + R: torch.Tensor, + t: torch.Tensor, + batch_idx: torch.LongTensor, + score: torch.Tensor, +) -> torch.Tensor: + """ + Compute x_0 given x_t and score. + """ + + alpha_t, sigma_t = sde.mean_coeff_and_std(x=R, t=t, batch_idx=batch_idx) + + return apply_rotvec_to_rotmat(R, -(sigma_t**2) * score, tol=sde.tol) + + +def stratified_resample(weights: torch.Tensor) -> torch.Tensor: + """ + Stratified resampling along the last dimension of a batched tensor. + + Args: + weights: (B, N), normalized along dim=-1 + + Returns: + (B, N) indices of chosen particles + """ + B, N = weights.shape + + # 1. Compute cumulative sums (CDF) for each batch + cdf = torch.cumsum(weights, dim=-1) # (B, N) + # Normalize to ensure cdf[..., -1] == 1.0 exactly (guards against FP error) + cdf = cdf / cdf[..., -1:].clamp(min=1e-12) + + # 2. Stratified positions: one per interval + # shape (B, N): each row gets N stratified uniforms + u = (torch.rand(B, N, device=weights.device) + torch.arange(N, device=weights.device)) / N + + # 3. Inverse-CDF search: for each u, find smallest j s.t. cdf[b, j] >= u[b, i] + idx = torch.searchsorted(cdf, u, right=True) + idx.clamp_(0, N - 1) # Guard against FP edge case where u > cdf[..., -1] + + return idx # shape (B, N) + + +def get_pos0_rot0(sdes, batch, t, score): + x0_t = _get_x0_given_xt_and_score( + sde=sdes["pos"], + x=batch.pos, + t=t, + batch_idx=batch.batch, + score=score["pos"], + ) + R0_t = _get_R0_given_xt_and_score( + sde=sdes["node_orientations"], + R=batch.node_orientations, + t=t, + batch_idx=batch.batch, + score=score["node_orientations"], + ) + seq_length = len(batch.sequence[0]) + x0_t = x0_t.reshape(batch.batch_size, seq_length, 3).detach() + R0_t = R0_t.reshape(batch.batch_size, seq_length, 3, 3).detach() + return x0_t, R0_t + + +def compute_sequence_alignment(ref_sequence: str, sample_sequence: str) -> dict[int, int]: + """Compute sequence alignment and return mapping from reference to sample indices. + + Uses Bio.Align.PairwiseAligner for global sequence alignment. + + Args: + ref_sequence: Reference amino acid sequence. + sample_sequence: Sample amino acid sequence. + + Returns: + Dictionary mapping reference 0-indexed positions to sample 0-indexed positions. + Only positions that align (no gaps) are included in the mapping. + """ + from Bio import Align + + aligner = Align.PairwiseAligner(mode="global", open_gap_score=-0.5) + alignments = aligner.align(ref_sequence, sample_sequence) + alignment = alignments[0] # Take best alignment + + # Build mapping from reference indices to sample indices + ref_ranges, sample_ranges = alignment.aligned + ref_to_sample: dict[int, int] = {} + for (ref_start, ref_end), (sample_start, sample_end) in zip(ref_ranges, sample_ranges): + for ref_idx, sample_idx in zip(range(ref_start, ref_end), range(sample_start, sample_end)): + ref_to_sample[ref_idx] = sample_idx + + return ref_to_sample + + +def resample_based_on_log_weights( + batch: ChemGraph, + log_weight: torch.Tensor, + n_particles: int, + is_last_step: bool, + ess_threshold: float, + step: int, + t: float, +) -> tuple[ChemGraph, torch.Tensor, torch.Tensor, torch.Tensor]: + """Resample particles based on importance weights. + + When batch_size < n_particles (due to memory constraints), the entire batch is + treated as a single resampling group. Each batch operates independently with + its own resampling. ESS is computed over the actual batch size in this case. + + Args: + batch: Current batch of samples. + log_weight: Log importance weights, shape (n_samples,). + n_particles: Target number of particles per group. If n_samples < n_particles, + all samples are treated as one group. + is_last_step: Whether this is the last denoising step. + ess_threshold: ESS threshold for triggering resampling. + step: Current step index (for logging). + t: Current diffusion time (for logging). + + Returns: + Tuple of (resampled_batch, reset_log_weights, indices, ess) where + both ``indices`` (LongTensor of selected particle indices) and + ``ess`` (normalized effective sample size, scalar tensor) are tensors. + """ + # Compute ESS from log_weights for particles in a group + n_samples = log_weight.shape[0] + + # Handle case where batch_size < n_particles: treat entire batch as one group + if n_samples < n_particles: + logger.warning( + "n_samples (%s) < n_particles (%s); treating entire batch as one " + "resampling group with effective_n_particles=%s.", + n_samples, + n_particles, + n_samples, + ) + effective_n_particles = n_samples + n_groups = 1 + else: + assert ( + n_samples % n_particles == 0 + ), f"n_samples ({n_samples}) is not multiple of n_particles ({n_particles})" + effective_n_particles = n_particles + n_groups = n_samples // n_particles + unnormalized_weight = torch.exp( + torch.nn.functional.log_softmax(log_weight.view(n_groups, effective_n_particles), dim=-1) + ) + normalized_weight = unnormalized_weight / ( + unnormalized_weight.sum(dim=-1, keepdim=True) + 1e-12 + ) + ess = 1.0 / (normalized_weight**2).sum(dim=-1) + ess = (ess / effective_n_particles).mean() # average over groups + logger.info( + "Step %s, t %.4f, ESS=%.2f (n_samples=%s, effective_n_particles=%s)", + step, + t, + ess.item(), + n_samples, + effective_n_particles, + ) + + # Resample particles based on log weights + if (ess < ess_threshold) or is_last_step: + logger.info( + "Resampling step %s: ESS=%.2f < %s or is_last_step=%s", + step, + ess.item(), + ess_threshold, + is_last_step, + ) + indices = stratified_resample( + weights=normalized_weight + ) # [n_groups, effective_n_particles] + + BS_offset = torch.arange(n_groups).unsqueeze(-1) * effective_n_particles # [n_groups, 1] + indices = (indices + BS_offset.to(indices.device)).flatten() # [n_groups, n_particles] + + # Resample samples + data_list = batch.to_data_list() + resampled_data_list = [data_list[i] for i in indices] + batch = Batch.from_data_list( + resampled_data_list + ) # TODO: there should be a more efficient way + + log_weight = torch.zeros(n_samples, device=batch.pos.device) + else: + indices = torch.arange(n_samples, device=batch.pos.device) + return batch, log_weight, indices, ess + + +# ============================================================================= +# Denoiser utility functions (moved from denoisers/utils.py) +# ============================================================================= + + +def kabsch_align(samples_centered: torch.Tensor, ref_centered: torch.Tensor) -> torch.Tensor: + """Optimal rigid alignment of ``samples_centered`` onto ``ref_centered`` (Kabsch algorithm). + + Both inputs must already be mean-centred. Gradients do not flow through the + SVD (``R`` is detached) for numerical stability. + + Args: + samples_centered: ``(batch, n_atoms, 3)`` — mean-centred sample coordinates. + ref_centered: ``(n_atoms, 3)`` — mean-centred reference coordinates. + + Returns: + Rotated samples with the same shape as ``samples_centered``. + """ + batch_size = samples_centered.shape[0] + device = samples_centered.device + + # Covariance matrix H = P^T * Q + H = torch.einsum("bni,nj->bij", samples_centered, ref_centered) + + # SVD decomposition + U, _S, Vh = torch.linalg.svd(H) + + # Optimal rotation (handle improper rotation / reflection) + d = torch.det(torch.bmm(Vh.transpose(-2, -1), U.transpose(-2, -1))) + sign_matrix = torch.ones(batch_size, 3, device=device) + sign_matrix[:, -1] = d.sign() + R = torch.bmm(Vh.transpose(-2, -1) * sign_matrix.unsqueeze(-1), U.transpose(-2, -1)) + + # Detach R so gradients don't flow through SVD (numerically unstable) + R = R.detach() + + return torch.einsum("bij,bnj->bni", R, samples_centered) + + +def compute_ess_from_log_weights( + log_weight: torch.Tensor, n_particles: int +) -> tuple[torch.Tensor, torch.Tensor]: + # Compute ESS from log_weights for particles in a group + n_samples = log_weight.shape[0] + assert n_samples % n_particles == 0, "n_samples must be multiple of n_particles" + n_groups = n_samples // n_particles + unnormalized_weight = torch.exp( + torch.nn.functional.log_softmax(log_weight.view(n_groups, n_particles), dim=-1) + ) + normalized_weight = unnormalized_weight / ( + unnormalized_weight.sum(dim=-1, keepdim=True) + 1e-12 + ) + ess = 1.0 / (normalized_weight**2).sum(dim=-1) + ess = (ess / n_particles).mean() # average over groups + return ess, normalized_weight + + +def reward_grad_rotmat_to_rotvec(R: torch.Tensor, dJ_dR: torch.Tensor) -> torch.Tensor: + """ + Map ambient gradient dJ/dR (..,3,3) to a right-trivialized tangent vector (..,3) + consistent with updates R <- R @ Exp(omega^). + + The factor of 2.0 arises from the relationship between the Frobenius inner product + and the R^3 inner product on so(3). For skew-symmetric matrices a_hat, b_hat with + vee-vectors a, b in R^3, each off-diagonal entry a_k of a_hat appears once as +a_k + at position (i,j) and once as -a_k at position (j,i), so: + + _F = tr(a_hat^T b_hat) = 2 (a_1 b_1 + a_2 b_2 + a_3 b_3) + = 2 _R3. + + The skew-symmetric projection A = 0.5*(R^T G - G^T R) already introduces a 0.5 + factor, so multiplying by 2.0 recovers the correct vee-vector gradient. + """ + RtG = R.transpose(-2, -1) @ dJ_dR # (...,3,3) + A = 0.5 * (RtG - RtG.transpose(-2, -1)) # skew(...) in so(3) + return 2.0 * skew_matrix_to_vector(A) # (...,3) vee-map + + +def compute_reward_and_grad( + *, + sdes: dict[str, SDE], + batch: BatchType, + t: torch.Tensor, + score_model: torch.nn.Module, + potentials: list[Callable], + use_x0_for_reward: bool, + eval_score: bool, + enable_grad: bool = True, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor] +]: + """Compute reward and its gradients w.r.t. x_t (batch.pos) and t. + + This helper lets you consistently evaluate FK-style potentials either on x_t or + on the x0 estimate obtained from (x_t, t, score), and obtain d(reward)/d x_t and + d(reward)/d t (including all dependencies through score, alpha_t, sigma_t, etc.). + + Args: + sdes: Dictionary of SDEs (expects keys "pos" and "node_orientations"). + batch: ChemGraph batch at time t. + t: Diffusion time tensor of shape [batch_size,]. + score_model: Score network. + potentials: List of FK potentials. Each is called as + ``potential(coords, t=t_var[0], sequence=batch.sequence[0])`` + where ``coords`` is in nanometres (no factor-of-10 scaling is applied). + use_x0_for_reward: If True, evaluate potentials on estimated x0; otherwise on x_t. + enable_grad: Whether to enable gradient computation. + + Returns: + reward: Tensor of shape [batch_size,]. + grad_x: Gradient d(reward)/d x_t with shape like batch.pos. + grad_so3: Gradient d(reward)/d R_t (node_orientations) with shape like score['node_orientations']. + grad_t: Gradient d(reward)/d t with shape like t. + x0: Estimated clean positions with shape like batch.pos. + score: Dict of scores ("pos" and "node_orientations"). + """ + + pos_sde = sdes["pos"] + batch_size = batch.num_graphs + device = batch.pos.device + batch_idx = batch.batch + + # Default return values + x0 = batch.pos.detach() + # NOTE: node_orientations score has the same shape as pos + score = {"pos": torch.zeros_like(batch.pos), "node_orientations": torch.zeros_like(batch.pos)} + + with torch.enable_grad() if enable_grad else torch.no_grad(): + batch_pos = batch.pos.clone().detach().requires_grad_(enable_grad) + batch_so3 = batch.node_orientations.clone().detach().requires_grad_(enable_grad) + t_var = t.clone().detach().requires_grad_(enable_grad) + + batch_for_grad = batch.replace(pos=batch_pos, node_orientations=batch_so3) + + if use_x0_for_reward or eval_score: + # Lazy import to avoid circular dependency (denoiser.py imports from steering) + from bioemu.denoiser import get_score + + # Score at (x_t, t) + score = get_score(batch=batch_for_grad, t=t_var, score_model=score_model, sdes=sdes) + + if use_x0_for_reward: + # x0 estimate from (x_t, t, score) + x0 = _get_x0_given_xt_and_score( + sde=pos_sde, + x=batch_pos, + t=t_var, + batch_idx=batch_idx, + score=score["pos"], + ) + coords = x0 + else: + coords = batch_pos + + # Choose coordinates for potentials: x_t or x0 + seq_length = batch_pos.shape[0] // batch_size + assert batch_pos.shape[0] == batch_size * seq_length + coords = coords.view(batch_size, seq_length, -1) + + reward = torch.zeros(batch_size, device=device) + if len(potentials) > 0: + for potential in potentials: + if hasattr(batch, "sequence"): + sequence = batch.sequence[0] + else: + sequence = None # for 1D toy example + reward = reward - potential( + coords, + t=t_var[0], + sequence=sequence, + ) + + assert reward.shape == ( + batch_size, + ), f"reward shape {reward.shape}, batch_size {batch_size}" + + if enable_grad and len(potentials) > 0: + grad_x, grad_so3_3x3, grad_t = torch.autograd.grad( + reward.sum(), (batch_pos, batch_so3, t_var), create_graph=False, allow_unused=True + ) + if grad_t is None: + logger.warning("grad t is None, setting to zero") + grad_t = torch.zeros_like(t_var) + + if grad_so3_3x3 is None: # GMM or reward not depending on so3 + logger.warning("grad so3 is None, setting to zero") + grad_so3 = torch.zeros_like(score["node_orientations"]) + else: + grad_so3 = reward_grad_rotmat_to_rotvec(batch_so3, grad_so3_3x3) + else: + grad_x = torch.zeros_like(batch_pos) + grad_so3 = torch.zeros_like(score["node_orientations"]) + grad_t = torch.zeros_like(t_var) + + return ( + reward.detach(), + grad_x.detach(), + grad_so3.detach(), + grad_t.detach(), + x0.detach(), + {k: v.detach() for k, v in score.items()}, + ) diff --git a/src/bioemu/toy_gmm.py b/src/bioemu/toy_gmm.py new file mode 100644 index 0000000..07fa458 --- /dev/null +++ b/src/bioemu/toy_gmm.py @@ -0,0 +1,213 @@ +""" +Gaussian Mixture Model utilities for 1D toy example. + +Adapted from enhancedsampling repo for use with bioemu framework. +""" + +import torch +import torch.nn as nn +from torch.distributions import Normal + +from bioemu.steering import CollectiveVariable + +MU1 = -2.0 +MU2 = 3.0 +SIGMA1 = 0.5 +SIGMA2 = 0.5 + + +class TimeDependentGMM1D(nn.Module): + """ + Time-dependent 1D Gaussian Mixture Model. + + The data distribution at t=0 is a GMM with two modes. + At time t, the distribution evolves according to the forward SDE. + """ + + def __init__( + self, + mu1: torch.Tensor, + mu2: torch.Tensor, + sigma1: float = SIGMA1, + sigma2: float = SIGMA2, + weight1: float = 0.7, + scheduler=None, # SDE scheduler (e.g., CosineVPSDE) + ): + super().__init__() + self.register_buffer("mu1", mu1) # [1] + self.register_buffer("mu2", mu2) # [1] + self.sigma1 = sigma1 + self.sigma2 = sigma2 + self.weight1 = weight1 + self.weight2 = 1.0 - weight1 + + # Use provided scheduler (should be same as diffusion process) + self.scheduler = scheduler + + # Log weights for numerical stability + self.register_buffer("log_w1", torch.log(torch.tensor(weight1))) + self.register_buffer("log_w2", torch.log(torch.tensor(1.0 - weight1))) + + def log_prob(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Compute log probability at time t. + + Args: + x: [batch_size, 1] - positions + t: [batch_size] or scalar - time + + Returns: + log_prob: [batch_size] - log probabilities + """ + if x.dim() == 1: + x = x.unsqueeze(-1) + + batch_size = x.shape[0] + + # Handle time broadcasting + # if not isinstance(t, torch.Tensor): + # t = torch.tensor(t, device=x.device) + if t.dim() == 0: + t = t.unsqueeze(0).expand(batch_size) + + # Get marginal parameters at time t + if hasattr(self.scheduler, "marginal_prob"): + # CosineVPSDE or other SDE with marginal_prob method + seq_length = 1 + dim = 1 + ones = torch.ones(batch_size, seq_length, dim, device=x.device) + alpha_t, sigma_t = self.scheduler.marginal_prob(x=ones, t=t) + alpha_t = alpha_t.squeeze() # [batch_size] + sigma_t = sigma_t.squeeze() # [batch_size] + else: + # BetaSchedule fallback + alpha_t, sigma_t = self.scheduler.get_alpha_t_sigma_t(t) # [batch_size] + + # Compute marginal means and stds for each component + # mu_t = alpha_t * mu_0 + mu1_t = alpha_t.unsqueeze(-1) * self.mu1 # [batch_size, 1] + mu2_t = alpha_t.unsqueeze(-1) * self.mu2 # [batch_size, 1] + + # sigma_t^2 = sigma_0^2 * alpha_t^2 + sigma_noise^2 + sigma1_t = torch.sqrt(self.sigma1**2 * alpha_t**2 + sigma_t**2).unsqueeze( + -1 + ) # [batch_size, 1] + sigma2_t = torch.sqrt(self.sigma2**2 * alpha_t**2 + sigma_t**2).unsqueeze( + -1 + ) # [batch_size, 1] + + # Compute log probabilities for each component + log_prob1 = Normal(mu1_t, sigma1_t).log_prob(x).sum(dim=-1) # [batch_size] + log_prob2 = Normal(mu2_t, sigma2_t).log_prob(x).sum(dim=-1) # [batch_size] + + # Use log-sum-exp trick for mixture + log_probs = torch.stack( + [log_prob1 + self.log_w1, log_prob2 + self.log_w2], dim=-1 + ) # [batch_size, 2] + + return torch.logsumexp(log_probs, dim=-1) # [batch_size] + + @torch.enable_grad() + def score( + self, + x: torch.Tensor, + t: torch.Tensor | None = None, + *, + create_graph: bool = False, + ) -> torch.Tensor: + """Compute score function using autograd: ∇_x log q(x_t). + + By default this is a first-order quantity used for sampling only, so + the computation graph is not kept. When ``create_graph=True``, the + returned score retains its computation graph so that gradients of any + downstream quantity that depends on the score with respect to ``x`` or + ``t`` can be computed (i.e. enables second-order derivatives). + """ + + if create_graph: + # Use x directly (or a cloned version) without detaching so that + # the graph from log_prob to the original x is preserved. + x_var = x + + else: + # Default: no higher-order derivatives required. Detach x to keep + # this computation inexpensive and side-effect free for callers that + # only need the score value. + x_var = x.clone().requires_grad_(True) + + # Compute log probability of marginal distribution at time t + log_prob = self.log_prob(x_var, t) + + # Sum log probabilities to get scalar for autograd (needed for gradient computation) + if log_prob.dim() > 0: + log_prob_sum = log_prob.sum() + else: + log_prob_sum = log_prob + + # Compute gradient of log probability with respect to x + score = torch.autograd.grad( + outputs=log_prob_sum, + inputs=x_var, + create_graph=create_graph, + retain_graph=create_graph, # TODO: make this an option + )[0] + return score + + def sample(self, n_samples: int, t: float = 0.0) -> torch.Tensor: + """ + Sample from the GMM at time t. + + Args: + n_samples: number of samples + t: time (default 0 = data distribution) + + Returns: + samples: [n_samples, 1] + """ + device = self.mu1.device + t_tensor = torch.tensor(t, device=device) + + # Sample component assignments + z = torch.rand(n_samples, device=device) < self.weight1 + + # Get marginal parameters + if hasattr(self.scheduler, "marginal_prob"): + # CosineVPSDE or other SDE with marginal_prob method + ones = torch.ones(1, 1, device=device) + t_batch = t_tensor.unsqueeze(0) + alpha_t, sigma_t = self.scheduler.marginal_prob(x=ones, t=t_batch) + alpha_t = alpha_t.squeeze() + sigma_t = sigma_t.squeeze() + else: + # BetaSchedule fallback + alpha_t, sigma_t = self.scheduler.get_alpha_t_sigma_t(t_tensor) + + mu1_t = alpha_t * self.mu1 + mu2_t = alpha_t * self.mu2 + sigma1_t = torch.sqrt(self.sigma1**2 * alpha_t**2 + sigma_t**2) + sigma2_t = torch.sqrt(self.sigma2**2 * alpha_t**2 + sigma_t**2) + + # Sample from each component + samples1 = mu1_t + sigma1_t * torch.randn(n_samples, 1, device=device) + samples2 = mu2_t + sigma2_t * torch.randn(n_samples, 1, device=device) + + # Mix samples based on component assignments + samples = torch.where(z.unsqueeze(-1), samples1, samples2) + return samples + + def energy(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Return negative log probability.""" + return -self.log_prob(x, t) + + +class ToyPosCV(CollectiveVariable): + """Minimal CV that extracts the first coordinate from a ChemGraph. + + The bioemu steering stack passes Cα positions in nm directly, so no + unit conversion is needed. + """ + + def compute_batch(self, ca_pos_nm: torch.Tensor, sequence: str) -> torch.Tensor: + # ca_pos_nm shape: [batch_size, L, D]; toy uses L=1, D=1 + vals = ca_pos_nm.reshape(ca_pos_nm.shape[0], -1)[:, 0] + return vals \ No newline at end of file diff --git a/tests/steering/__init__.py b/tests/steering/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/steering/test_chignolin_e2e.py b/tests/steering/test_chignolin_e2e.py new file mode 100644 index 0000000..af7d1bf --- /dev/null +++ b/tests/steering/test_chignolin_e2e.py @@ -0,0 +1,96 @@ +"""End-to-end integration tests for steering with chignolin (GYDPETGTWG). + +These tests call the full sample() pipeline and require model weights +(downloaded from HuggingFace). They are slow and intended for manual or +CI-with-model-access runs. + +Adapted from the original tests/test_steering.py on main. +""" + +import os + +import pytest +import yaml + +from bioemu.sample import main as sample + +PHYSICAL_STEERING_CONFIG_PATH = os.path.join( + os.path.dirname(__file__), "../../src/bioemu/config/steering/physical_steering.yaml" +) + + +@pytest.fixture +def chignolin_sequence(): + return "GYDPETGTWG" + + +@pytest.fixture +def base_test_config(): + return {"batch_size_100": 100, "num_samples": 10} + + +def load_steering_config(): + with open(PHYSICAL_STEERING_CONFIG_PATH) as f: + return yaml.safe_load(f) + + +def test_steering_with_config_path(chignolin_sequence, base_test_config, tmp_path): + """Test steering by passing the steering config file as denoiser_config.""" + sample( + sequence=chignolin_sequence, + num_samples=base_test_config["num_samples"], + batch_size_100=base_test_config["batch_size_100"], + output_dir=str(tmp_path / "config_path"), + denoiser_config=PHYSICAL_STEERING_CONFIG_PATH, + ) + + +def test_steering_with_config_dict(chignolin_sequence, base_test_config, tmp_path): + """Test steering by passing the config as a dict.""" + sample( + sequence=chignolin_sequence, + num_samples=base_test_config["num_samples"], + batch_size_100=base_test_config["batch_size_100"], + output_dir=str(tmp_path / "config_dict"), + denoiser_config=load_steering_config(), + ) + + +def test_steering_modified_num_particles(chignolin_sequence, base_test_config, tmp_path): + """Test steering with modified number of particles.""" + config = load_steering_config() + config["steering_config"]["num_particles"] = 5 + + sample( + sequence=chignolin_sequence, + num_samples=base_test_config["num_samples"], + batch_size_100=base_test_config["batch_size_100"], + output_dir=str(tmp_path / "modified_particles"), + denoiser_config=config, + ) + + +def test_steering_modified_time_window(chignolin_sequence, base_test_config, tmp_path): + """Test steering with modified start/end time window.""" + config = load_steering_config() + config["steering_config"]["start"] = 0.7 + config["steering_config"]["end"] = 0.3 + + sample( + sequence=chignolin_sequence, + num_samples=base_test_config["num_samples"], + batch_size_100=base_test_config["batch_size_100"], + output_dir=str(tmp_path / "modified_time"), + denoiser_config=config, + ) + + +def test_no_steering(chignolin_sequence, base_test_config, tmp_path): + """Test sampling without steering (default dpm denoiser).""" + sample( + sequence=chignolin_sequence, + num_samples=base_test_config["num_samples"], + batch_size_100=base_test_config["batch_size_100"], + output_dir=str(tmp_path / "no_steering"), + denoiser_type="dpm", + ) diff --git a/tests/steering/test_collective_variables.py b/tests/steering/test_collective_variables.py new file mode 100644 index 0000000..4377497 --- /dev/null +++ b/tests/steering/test_collective_variables.py @@ -0,0 +1,84 @@ +"""Tests for bioemu.steering.collective_variables — collective variable classes.""" + +import torch + + +class TestCaCaDistance: + """Tests for CaCaDistance.""" + + def test_shape(self): + from bioemu.steering.collective_variables import CaCaDistance + + cv = CaCaDistance() + ca_pos = torch.randn(2, 10, 3) + result = cv.compute_batch(ca_pos) + assert result.shape == (2, 9) + + def test_known_distances(self): + from bioemu.steering.collective_variables import CaCaDistance + + cv = CaCaDistance() + ca_pos = torch.zeros(1, 4, 3) + for i in range(4): + ca_pos[0, i, 0] = i * 0.38 + result = cv.compute_batch(ca_pos) + torch.testing.assert_close(result, torch.full((1, 3), 0.38), atol=1e-5, rtol=1e-5) + + def test_differentiable(self): + """CaCaDistance should support autograd for steering gradients.""" + from bioemu.steering.collective_variables import CaCaDistance + + cv = CaCaDistance() + ca_pos = torch.randn(2, 5, 3, requires_grad=True) + result = cv.compute_batch(ca_pos) + result.sum().backward() + assert ca_pos.grad is not None + assert ca_pos.grad.shape == ca_pos.shape + + +class TestPairwiseClash: + """Tests for PairwiseClash.""" + + def test_no_clash(self): + from bioemu.steering.collective_variables import PairwiseClash + + cv = PairwiseClash(min_dist=0.4, offset=3) + ca_pos = torch.zeros(1, 10, 3) + for i in range(10): + ca_pos[0, i, 0] = float(i) # well separated + result = cv.compute_batch(ca_pos) + assert result.shape[0] == 1 + torch.testing.assert_close(result.sum(), torch.tensor(0.0), atol=1e-6, rtol=1e-6) + + def test_clash_detected(self): + from bioemu.steering.collective_variables import PairwiseClash + + cv = PairwiseClash(min_dist=0.4, offset=3) + ca_pos = torch.zeros(1, 10, 3) # all at origin + result = cv.compute_batch(ca_pos) + assert result.sum() > 0 + + def test_offset_respects_separation(self): + from bioemu.steering.collective_variables import PairwiseClash + + cv = PairwiseClash(min_dist=0.4, offset=3) + ca_pos = torch.zeros(1, 3, 3) # only 3 residues, offset=3 means no pairs + result = cv.compute_batch(ca_pos) + torch.testing.assert_close(result.sum(), torch.tensor(0.0), atol=1e-6, rtol=1e-6) + + def test_monotonic_with_distance(self): + """Clash energy should decrease monotonically as atoms move apart.""" + from bioemu.steering.collective_variables import PairwiseClash + + cv = PairwiseClash(min_dist=0.4, offset=1) + energies = [] + for spacing in [0.05, 0.1, 0.2, 0.5, 1.0]: + ca_pos = torch.zeros(1, 5, 3) + for i in range(5): + ca_pos[0, i, 0] = i * spacing + energies.append(cv.compute_batch(ca_pos).sum().item()) + # Energy should be monotonically non-increasing + for i in range(len(energies) - 1): + assert ( + energies[i] >= energies[i + 1] + ), f"Energy not monotonic: {energies[i]} < {energies[i + 1]}" diff --git a/tests/steering/test_denoisers.py b/tests/steering/test_denoisers.py new file mode 100644 index 0000000..9b4f916 --- /dev/null +++ b/tests/steering/test_denoisers.py @@ -0,0 +1,71 @@ +"""Tests for bioemu.steering denoiser utilities.""" + +import torch + + +class TestComputeEssFromLogWeights: + """Tests for compute_ess_from_log_weights.""" + + def test_uniform_weights(self): + """Uniform log weights → ESS should be close to 1.0.""" + from bioemu.steering.utils import compute_ess_from_log_weights + + n = 32 + log_w = torch.zeros(n) + ess, _ = compute_ess_from_log_weights(log_w, n_particles=n) + assert abs(ess.item() - 1.0) < 0.01 + + def test_delta_weights(self): + """One dominant weight → ESS should be close to 1/N.""" + from bioemu.steering.utils import compute_ess_from_log_weights + + n = 32 + log_w = torch.full((n,), -1000.0) + log_w[0] = 0.0 + ess, _ = compute_ess_from_log_weights(log_w, n_particles=n) + assert abs(ess.item() - 1.0 / n) < 0.05 + + def test_two_equal_groups(self): + """Two groups with uniform weights → ESS ≈ 1.0 for each group.""" + from bioemu.steering.utils import compute_ess_from_log_weights + + n_particles = 16 + log_w = torch.zeros(32) + ess, _ = compute_ess_from_log_weights(log_w, n_particles=n_particles) + assert abs(ess.item() - 1.0) < 0.01 + + def test_returns_normalized_weights(self): + from bioemu.steering.utils import compute_ess_from_log_weights + + n = 8 + log_w = torch.randn(n) + _, norm_w = compute_ess_from_log_weights(log_w, n_particles=n) + torch.testing.assert_close(norm_w.sum(), torch.tensor(1.0), atol=1e-5, rtol=1e-5) + + +class TestRewardGradRotmatToRotvec: + """Tests for reward_grad_rotmat_to_rotvec.""" + + def test_zero_gradient(self): + """Zero gradient → zero tangent vector.""" + from bioemu.steering.utils import reward_grad_rotmat_to_rotvec + + R = torch.eye(3).unsqueeze(0) + dJ_dR = torch.zeros(1, 3, 3) + result = reward_grad_rotmat_to_rotvec(R, dJ_dR) + torch.testing.assert_close(result, torch.zeros(1, 3)) + + def test_antisymmetric_projection(self): + """The rotation tangent vector should reflect only the antisymmetric part of R^T dJ/dR.""" + from bioemu.steering.utils import reward_grad_rotmat_to_rotvec + + B = 4 + R = torch.eye(3).unsqueeze(0).expand(B, -1, -1).clone() + dJ_dR = torch.randn(B, 3, 3) + result = reward_grad_rotmat_to_rotvec(R, dJ_dR) + assert result.shape == (B, 3) + # Symmetric dJ_dR should give zero tangent (symmetric part is projected out) + sym = torch.randn(B, 3, 3) + sym = (sym + sym.transpose(-1, -2)) / 2 + result_sym = reward_grad_rotmat_to_rotvec(R, sym) + torch.testing.assert_close(result_sym, torch.zeros(B, 3), atol=1e-5, rtol=1e-5) diff --git a/tests/steering/test_integration.py b/tests/steering/test_integration.py new file mode 100644 index 0000000..d81a615 --- /dev/null +++ b/tests/steering/test_integration.py @@ -0,0 +1,885 @@ +"""Lightweight integration tests for steering configuration and wiring. + +These tests verify that yaml configs parse correctly and produce valid +potential objects, without requiring model weights or GPU. +""" + +from contextlib import contextmanager +from pathlib import Path +from unittest.mock import MagicMock, patch + +import torch +import yaml + +STEERING_CONFIG_DIR = Path(__file__).parent.parent.parent / "src" / "bioemu" / "config" / "steering" + + +@contextmanager +def _mock_colabfold_embeds(seq: str): + """Patch the colabfold_inline stack so get_colabfold_embeds returns mocked arrays.""" + import numpy as np + + from tests.test_embeds import _make_mock_run_model + + L = len(seq) + with ( + patch( + "bioemu.colabfold_inline.model_runner._run_model", + side_effect=_make_mock_run_model(L), + ), + patch( + "bioemu.colabfold_inline.model_runner._load_model_and_params", + return_value=(MagicMock(), {}), + ), + patch("bioemu.colabfold_inline.model_runner.download_alphafold_params"), + patch("bioemu.get_embeds._get_a3m_string", return_value=f">q\n{seq}\n"), + ): + # Touch np to keep import side-effect ordering stable across reloads. + _ = np.zeros(1) + yield + + +class TestPhysicalSteeringConfig: + """Verify physical_steering.yaml loads and instantiates potentials.""" + + def test_config_loads(self): + with open(STEERING_CONFIG_DIR / "physical_steering.yaml") as f: + cfg = yaml.safe_load(f) + assert "_target_" in cfg + assert "fk_potentials" in cfg + assert "steering_config" in cfg + assert len(cfg["fk_potentials"]) == 2 + + def test_hydra_instantiate(self): + """Hydra can instantiate the full config as a partial denoiser.""" + import hydra + + with open(STEERING_CONFIG_DIR / "physical_steering.yaml") as f: + cfg = yaml.safe_load(f) + + denoiser = hydra.utils.instantiate(cfg) + assert callable(denoiser) + + +class TestCvSteerConfig: + """Verify cv_steer.yaml loads correctly.""" + + def test_config_loads_and_has_denoiser(self): + with open(STEERING_CONFIG_DIR / "cv_steer.yaml") as f: + cfg = yaml.safe_load(f) + assert "_target_" in cfg + assert "fk_potentials" in cfg + assert "steering_config" in cfg + + +class TestPotentialForwardBackward: + """Test that potentials support autograd (gradient-based steering).""" + + def test_umbrella_with_caca_cv_gradients(self): + from bioemu.steering.collective_variables import CaCaDistance + from bioemu.steering.potentials import UmbrellaPotential + + pot = UmbrellaPotential(cv=CaCaDistance(), target=0.38, slope=10.0, weight=1.0) + ca_pos = torch.randn(2, 10, 3, requires_grad=True) + energy = pot(ca_pos) + energy.sum().backward() + assert ca_pos.grad is not None + assert ca_pos.grad.shape == ca_pos.shape + + def test_umbrella_with_pairwise_clash_cv_gradients(self): + from bioemu.steering.collective_variables import PairwiseClash + from bioemu.steering.potentials import UmbrellaPotential + + pot = UmbrellaPotential( + cv=PairwiseClash(min_dist=0.4, offset=3), target=0.0, slope=10.0, weight=1.0 + ) + ca_pos = torch.randn(2, 10, 3, requires_grad=True) + energy = pot(ca_pos) + energy.sum().backward() + assert ca_pos.grad is not None + + +class TestFkcSteeringIntegration: + """Integration test: run compute_reward_and_grad → FKC weights pipeline. + + Verifies the full FKC steering pipeline with a deterministic mock score model. + """ + + @staticmethod + def _make_sdes(): + from bioemu.sde_lib import CosineVPSDE + from bioemu.so3_sde import DiGSO3SDE + + return { + "pos": CosineVPSDE(), + "node_orientations": DiGSO3SDE(num_sigma=10, num_omega=10, l_max=10), + } + + @staticmethod + def _make_batch(n_residues: int = 8, batch_size: int = 4): + from bioemu.chemgraph import ChemGraph + + torch.manual_seed(42) + data_list = [] + for i in range(batch_size): + g = ChemGraph( + pos=torch.randn(n_residues, 3) * 0.1, + node_orientations=torch.eye(3).unsqueeze(0).expand(n_residues, -1, -1).clone(), + edge_index=torch.zeros(2, 0, dtype=torch.long), + single_embeds=torch.zeros(n_residues, 1), + pair_embeds=torch.zeros(n_residues**2, 1), + sequence="A" * n_residues, + system_id=f"test_{i}", + node_labels=torch.zeros(n_residues, dtype=torch.long), + ) + data_list.append(g) + from torch_geometric.data import Batch + + return Batch.from_data_list(data_list) + + @staticmethod + def _make_score_model(total_nodes: int): + class MockScoreModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dummy = torch.nn.Parameter(torch.zeros(1)) + torch.manual_seed(99) + self._pos = torch.randn(total_nodes, 3) * 0.01 + self._rot = torch.randn(total_nodes, 3) * 0.01 + + def forward(self, batch, t): + return { + "pos": self._pos + self.dummy * 0, + "node_orientations": self._rot + self.dummy * 0, + } + + return MockScoreModel() + + def test_reward_and_grad_pipeline(self): + """Full pipeline: potentials → compute_reward_and_grad → finite grad values.""" + from bioemu.steering.collective_variables import CaCaDistance + from bioemu.steering.potentials import UmbrellaPotential + from bioemu.steering.utils import compute_reward_and_grad + + n_res, bs = 8, 4 + sdes = self._make_sdes() + batch = self._make_batch(n_res, bs) + model = self._make_score_model(n_res * bs) + t = torch.full((bs,), 0.5) + + pot = UmbrellaPotential(cv=CaCaDistance(), target=0.38, slope=10.0, weight=1.0) + + reward, grad_x, grad_rot, grad_t, raw_score, info = compute_reward_and_grad( + sdes=sdes, + batch=batch, + t=t, + score_model=model, + potentials=[pot], + use_x0_for_reward=True, + eval_score=True, + ) + + assert reward.shape == (bs,) + assert grad_x.shape == (n_res * bs, 3) + assert grad_t.shape == (bs,) + assert raw_score.shape == (n_res * bs, 3) + assert torch.isfinite(reward).all() + assert torch.isfinite(grad_x).all() + + def test_fkc_weights_pipeline(self): + """Full pipeline: compute_reward_and_grad → FKC weights → finite log weights.""" + from bioemu.steering.collective_variables import CaCaDistance + from bioemu.steering.dpm_fkc import _compute_fkc_weights, _get_fkc_guided_score + from bioemu.steering.potentials import UmbrellaPotential + + n_res, bs = 8, 4 + sdes = self._make_sdes() + batch = self._make_batch(n_res, bs) + model = self._make_score_model(n_res * bs) + t = torch.full((bs,), 0.5) + + pot = UmbrellaPotential(cv=CaCaDistance(), target=0.38, slope=10.0, weight=1.0) + + guided = _get_fkc_guided_score( + sdes=sdes, + batch=batch, + t=t, + score_model=model, + potentials=[pot], + use_x0_for_reward=True, + enable_grad=True, + noise_scale=1.0, + ) + + assert guided.pos.shape == (n_res * bs, 3) + assert guided.reward.shape == (bs,) + assert torch.isfinite(guided.pos).all() + + # Compute FKC weights using a separate reward call (as done in dpm_fkc_step) + from bioemu.steering.utils import compute_reward_and_grad + + batch2 = self._make_batch(n_res, bs) + t2 = torch.full((bs,), 0.5) + _, reward_grad_x, _, reward_grad_t, _, _ = compute_reward_and_grad( + sdes=sdes, + batch=batch2, + t=t2, + score_model=model, + potentials=[pot], + use_x0_for_reward=True, + eval_score=False, + ) + + batch_idx = batch2.batch + t_next = torch.full((bs,), 0.4) + log_weights = _compute_fkc_weights( + batch2, + sdes["pos"], + guided.raw_score["pos"], + reward_grad_x, + reward_grad_t, + t2, + t_next, + batch_idx, + ) + assert log_weights.shape == (bs,) + assert torch.isfinite(log_weights).all() + + def test_unsteered_loop_via_dpm_solver_fkc(self): + """dpm_solver_fkc with no potentials = unsteered DPM solver loop.""" + from bioemu.steering.dpm_fkc import dpm_solver_fkc + + n_res, bs = 8, 2 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + batch = self._make_batch(n_res, bs) + + torch.manual_seed(42) + result_batch, batch_log_weights = dpm_solver_fkc( + sdes=sdes, + batch=batch, + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + fk_potentials=[], + steering_config=None, + noise=0.3, + ) + + assert result_batch.pos.shape == (n_res * bs, 3) + assert result_batch.node_orientations.shape == (n_res * bs, 3, 3) + assert torch.isfinite(result_batch.pos).all() + assert torch.isfinite(result_batch.node_orientations).all() + # No steering → all log weights should be zero + assert torch.allclose(batch_log_weights, torch.zeros_like(batch_log_weights)) + + def test_steered_loop_via_dpm_solver_fkc(self): + """dpm_solver_fkc with potentials = steered DPM solver loop.""" + from bioemu.steering.collective_variables import CaCaDistance + from bioemu.steering.dpm_fkc import dpm_solver_fkc + from bioemu.steering.potentials import UmbrellaPotential + + n_res, bs = 8, 4 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + batch = self._make_batch(n_res, bs) + pot = UmbrellaPotential(cv=CaCaDistance(), target=0.38, slope=10.0, weight=1.0) + + torch.manual_seed(42) + result_batch, batch_log_weights = dpm_solver_fkc( + sdes=sdes, + batch=batch, + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + fk_potentials=[pot], + steering_config={"num_particles": 4, "ess_threshold": 0.5, "start": 1.0, "end": 0.0}, + noise=0.5, + ) + + assert result_batch.pos.shape[1] == 3 + assert torch.isfinite(result_batch.pos).all() + # With potentials, log weights should be non-zero + assert not torch.allclose(batch_log_weights, torch.zeros_like(batch_log_weights)) + + +class TestSmcSteeringIntegration: + """Integration test: run dpm_solver_smc loop with and without steering.""" + + @staticmethod + def _make_sdes(): + from bioemu.sde_lib import CosineVPSDE + from bioemu.so3_sde import DiGSO3SDE + + return { + "pos": CosineVPSDE(), + "node_orientations": DiGSO3SDE(num_sigma=10, num_omega=10, l_max=10), + } + + @staticmethod + def _make_batch(n_residues: int = 8, batch_size: int = 4): + from bioemu.chemgraph import ChemGraph + + torch.manual_seed(42) + data_list = [] + for i in range(batch_size): + g = ChemGraph( + pos=torch.randn(n_residues, 3) * 0.1, + node_orientations=torch.eye(3).unsqueeze(0).expand(n_residues, -1, -1).clone(), + edge_index=torch.zeros(2, 0, dtype=torch.long), + single_embeds=torch.zeros(n_residues, 1), + pair_embeds=torch.zeros(n_residues**2, 1), + sequence="A" * n_residues, + system_id=f"test_{i}", + node_labels=torch.zeros(n_residues, dtype=torch.long), + ) + data_list.append(g) + from torch_geometric.data import Batch + + return Batch.from_data_list(data_list) + + @staticmethod + def _make_score_model(total_nodes: int): + class MockScoreModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dummy = torch.nn.Parameter(torch.zeros(1)) + torch.manual_seed(99) + self._pos = torch.randn(total_nodes, 3) * 0.01 + self._rot = torch.randn(total_nodes, 3) * 0.01 + + def forward(self, batch, t): + return { + "pos": self._pos + self.dummy * 0, + "node_orientations": self._rot + self.dummy * 0, + } + + return MockScoreModel() + + def test_smc_loop_unsteered(self): + """SMC loop with no potentials should produce finite output with zero log weights.""" + from bioemu.steering.dpm_smc import dpm_solver_smc + + n_res, bs = 8, 2 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + batch = self._make_batch(n_res, bs) + + torch.manual_seed(42) + result_batch, log_weights = dpm_solver_smc( + sdes=sdes, + batch=batch, + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + noise=0.3, + fk_potentials=[], + steering_config={"num_particles": 2, "ess_threshold": 0.5, "start": 1.0, "end": 0.0}, + ) + + assert result_batch.pos.shape == (n_res * bs, 3) + assert torch.isfinite(result_batch.pos).all() + assert torch.allclose(log_weights, torch.zeros_like(log_weights)) + + def test_smc_loop_steered(self): + """SMC loop with potentials should produce finite output with non-zero log weights.""" + from bioemu.steering.collective_variables import CaCaDistance + from bioemu.steering.dpm_smc import dpm_solver_smc + from bioemu.steering.potentials import UmbrellaPotential + + n_res, bs = 8, 4 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + batch = self._make_batch(n_res, bs) + pot = UmbrellaPotential(cv=CaCaDistance(), target=0.38, slope=10.0, weight=1.0) + + torch.manual_seed(42) + result_batch, log_weights = dpm_solver_smc( + sdes=sdes, + batch=batch, + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + noise=0.5, + fk_potentials=[pot], + steering_config={"num_particles": 4, "ess_threshold": 0.5, "start": 1.0, "end": 0.0}, + ) + + assert result_batch.pos.shape[1] == 3 + assert torch.isfinite(result_batch.pos).all() + + +class TestResampleCorrectness: + """Verify resampling actually copies the dominant-weight sample.""" + + def test_dominant_weight_copies_positions(self): + """With one extreme weight, resampled batch should have all positions from that sample.""" + from torch_geometric.data import Batch, Data + + from bioemu.steering.utils import resample_based_on_log_weights + + n_res = 5 + n_samples = 8 + # Create batch with distinct positions per sample + data_list = [] + for i in range(n_samples): + d = Data( + pos=torch.full((n_res, 3), float(i)), + node_orientations=torch.eye(3).unsqueeze(0).expand(n_res, -1, -1).clone(), + batch=torch.zeros(n_res, dtype=torch.long), + sequence=["ACDEF"], + ) + data_list.append(d) + batch = Batch.from_data_list(data_list) + + # Put all weight on sample 3 + log_w = torch.full((n_samples,), -1000.0) + log_w[3] = 0.0 + + new_batch, new_lw, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_w, + n_particles=n_samples, + is_last_step=False, + ess_threshold=0.5, + step=0, + t=0.5, + ) + + # All resampled positions should equal sample 3's value (3.0) + assert (indices == 3).all(), f"Expected all indices=3, got {indices}" + expected_pos = torch.full((n_samples * n_res, 3), 3.0) + torch.testing.assert_close(new_batch.pos, expected_pos) + + +class TestOdeConsistency: + """Verify consistency between dpm_solver, dpm_solver_fkc, and dpm_solver_smc in ODE mode. + + All three solvers with noise=0 should produce finite, reasonable outputs. + FKC and SMC (both using DPM-Solver++ with 0.5*score) should be identical. + dpm_solver (DPM-Solver with 1.0*score and midpoint-only formula) differs but should be close. + """ + + @staticmethod + def _make_sdes(): + from bioemu.sde_lib import CosineVPSDE + from bioemu.so3_sde import DiGSO3SDE + + return { + "pos": CosineVPSDE(), + "node_orientations": DiGSO3SDE(num_sigma=10, num_omega=10, l_max=10), + } + + @staticmethod + def _make_batch(n_residues: int, batch_size: int): + from bioemu.chemgraph import ChemGraph + + data_list = [] + for i in range(batch_size): + g = ChemGraph( + pos=torch.randn(n_residues, 3) * 0.1, + node_orientations=torch.eye(3).unsqueeze(0).expand(n_residues, -1, -1).clone(), + edge_index=torch.zeros(2, 0, dtype=torch.long), + single_embeds=torch.zeros(n_residues, 1), + pair_embeds=torch.zeros(n_residues**2, 1), + sequence="A" * n_residues, + system_id=f"test_{i}", + node_labels=torch.zeros(n_residues, dtype=torch.long), + ) + data_list.append(g) + from torch_geometric.data import Batch + + return Batch.from_data_list(data_list) + + @staticmethod + def _make_score_model(total_nodes: int, seed: int = 99): + class MockScoreModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dummy = torch.nn.Parameter(torch.zeros(1)) + torch.manual_seed(seed) + self._pos = torch.randn(total_nodes, 3) * 0.05 + self._rot = torch.randn(total_nodes, 3) * 0.05 + + def forward(self, batch, t): + return { + "pos": self._pos + self.dummy * 0, + "node_orientations": self._rot + self.dummy * 0, + } + + return MockScoreModel() + + def test_fkc_ode_vs_smc_ode_identical(self): + """FKC(noise=0) and SMC(noise_scale=0) should produce identical ODE steps.""" + from bioemu.steering.dpm_fkc import dpm_solver_sde_fkc_step + from bioemu.steering.dpm_smc import dpm_solver_sde_smc_step + + n_res, bs = 8, 2 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + + torch.manual_seed(42) + batch = self._make_batch(n_res, bs) + + t = torch.full((bs,), 0.8) + t_next = torch.full((bs,), 0.6) + + torch.manual_seed(100) + fkc_result = dpm_solver_sde_fkc_step( + batch=batch.clone(), + t=t, + t_next=t_next, + sdes=sdes, + score_model=model, + max_t=0.99, + potentials=[], + is_last_step=False, + enable_grad=False, + noise_scale=0.0, + ) + + torch.manual_seed(100) + smc_result = dpm_solver_sde_smc_step( + batch=batch.clone(), + t=t, + t_next=t_next, + sdes=sdes, + score_model=model, + max_t=0.99, + potentials=[], + step_idx=0, + noise_scale=0.0, + ) + + # Both use (1+0²)/2 = 0.5 * score + DPM-Solver++ formula → must be identical + torch.testing.assert_close(fkc_result[0].pos, smc_result[0].pos, msg="pos mismatch") + torch.testing.assert_close( + fkc_result[0].node_orientations, + smc_result[0].node_orientations, + msg="SO3 mismatch", + ) + + def test_dpm_solver_ode_step_vs_fkc_ode_step_close(self): + """Single-step: dpm_solver ODE helper vs FKC ODE step produce different but close results. + + dpm_solver uses 1.0*score + midpoint-only formula. + FKC uses 0.5*score + two-score DPM-Solver++ formula. + These differ but should be in the same ballpark. + """ + from bioemu.denoiser import ( + _get_dpm_coefficients, + _predict_midpoint, + get_score, + second_order_step_dpmsolver, + ) + from bioemu.sde_lib import CosineVPSDE + from bioemu.so3_sde import SO3SDE + from bioemu.steering.dpm_fkc import dpm_solver_sde_fkc_step + + n_res, bs = 8, 2 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + pos_sde = sdes["pos"] + so3_sde = sdes["node_orientations"] + assert isinstance(pos_sde, CosineVPSDE) + assert isinstance(so3_sde, SO3SDE) + + torch.manual_seed(42) + batch = self._make_batch(n_res, bs) + t = torch.full((bs,), 0.8) + t_next = torch.full((bs,), 0.6) + batch_idx = batch.batch + + # dpm_solver ODE step (manual, using helpers) + torch.manual_seed(100) + score = get_score(batch=batch, t=t, score_model=model, sdes=sdes) + coeffs = _get_dpm_coefficients(pos_sde, batch.pos, t, t_next, batch_idx) + batch_lambda = _predict_midpoint( + batch=batch, + coeffs=coeffs, + score_pos=score["pos"], # unscaled (1.0 × score) + score_so3=score["node_orientations"], + so3_sde=so3_sde, + t=t, + batch_idx=batch_idx, + ) + score_lambda = get_score( + batch=batch_lambda, t=coeffs.t_lambda, score_model=model, sdes=sdes + ) + dpm_batch = second_order_step_dpmsolver( + batch=batch, + coeffs=coeffs, + score_pos_lambda=score_lambda["pos"], + score_so3_t=score["node_orientations"], + score_so3_lambda=score_lambda["node_orientations"], + so3_sde=so3_sde, + t=t, + t_next=t_next, + batch_idx=batch_idx, + ) + + # FKC ODE step (noise_scale=0) + torch.manual_seed(100) + fkc_result = dpm_solver_sde_fkc_step( + batch=batch.clone(), + t=t, + t_next=t_next, + sdes=sdes, + score_model=model, + max_t=0.99, + potentials=[], + is_last_step=False, + enable_grad=False, + noise_scale=0.0, + ) + + # Both should be finite + assert torch.isfinite(dpm_batch.pos).all() + assert torch.isfinite(fkc_result[0].pos).all() + + # They use different formula + score scaling but may be very close for small scores. + # Verify both produce finite, reasonable results. + pos_diff = (dpm_batch.pos - fkc_result[0].pos).abs().max().item() + assert pos_diff < 10.0, f"Results too far apart: diff={pos_diff}" + + def test_dpm_solver_loop_ode_vs_fkc_loop_ode_close(self): + """Multi-step loop: dpm_solver(noise=0) vs dpm_solver_fkc(noise=0) close but not identical.""" + from bioemu.denoiser import dpm_solver + from bioemu.steering.dpm_fkc import dpm_solver_fkc + + n_res, bs = 8, 2 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + + torch.manual_seed(42) + batch = self._make_batch(n_res, bs) + + torch.manual_seed(100) + dpm_result = dpm_solver( + sdes=sdes, + batch=batch.clone(), + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + noise=0.0, + ) + + torch.manual_seed(100) + fkc_result, _ = dpm_solver_fkc( + sdes=sdes, + batch=batch.clone(), + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + noise=0.0, + fk_potentials=[], + steering_config=None, + ) + + # Both should produce finite outputs + assert torch.isfinite(dpm_result.pos).all() + assert torch.isfinite(fkc_result.pos).all() + + # Close but not necessarily identical (different formula + score scaling) + pos_diff = (dpm_result.pos - fkc_result.pos).abs().max().item() + assert pos_diff < 10.0, f"Results too far apart: diff={pos_diff}" + + def test_dpm_solver_ode_regression(self): + """dpm_solver(noise=0) must produce deterministic, reproducible output. + + This captures the current dpm_solver behavior as a regression baseline. + After refactoring dpm_solver to use helpers, re-run to verify identical output. + """ + from bioemu.denoiser import dpm_solver + + n_res, bs = 8, 2 + sdes = self._make_sdes() + model = self._make_score_model(n_res * bs) + + torch.manual_seed(42) + batch = self._make_batch(n_res, bs) + + torch.manual_seed(100) + result1 = dpm_solver( + sdes=sdes, + batch=batch.clone(), + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + noise=0.0, + ) + + torch.manual_seed(42) + batch2 = self._make_batch(n_res, bs) + + torch.manual_seed(100) + result2 = dpm_solver( + sdes=sdes, + batch=batch2.clone(), + N=5, + score_model=model, + max_t=0.99, + eps_t=0.01, + device=torch.device("cpu"), + noise=0.0, + ) + + torch.testing.assert_close(result1.pos, result2.pos, msg="ODE not deterministic") + torch.testing.assert_close( + result1.node_orientations, + result2.node_orientations, + msg="ODE orientations not deterministic", + ) + # Regression snapshot: verify specific numerical values don't change after refactor + assert ( + abs(result1.pos.sum().item() - (-373.535)) < 0.01 + ), f"pos sum changed: {result1.pos.sum().item()}" + assert ( + abs(result1.node_orientations.sum().item() - (-7.475)) < 0.01 + ), f"orient sum changed: {result1.node_orientations.sum().item()}" + + +class TestGenerateBatchWithSteering: + """Test generate_batch with steered denoisers (FKC/SMC). + + Uses mock colabfold to avoid network calls, and a mock score model to avoid + model weight downloads. Tests the full pipeline: generate_batch → denoiser. + """ + + @staticmethod + def _mock_score_model(batch, t): + device = batch["pos"].device + return { + "pos": torch.rand(batch["pos"].shape, device=device), + "node_orientations": torch.rand(batch["node_orientations"].shape[0], 3, device=device), + } + + def test_generate_batch_with_fkc_denoiser(self): + """generate_batch with dpm_solver_fkc denoiser produces valid output.""" + import os + + import hydra + + from bioemu.sample import generate_batch + from bioemu.shortcuts import CosineVPSDE, DiGSO3SDE + from tests.test_embeds import TEST_SEQ + + sdes = {"node_orientations": DiGSO3SDE(), "pos": CosineVPSDE()} + + # Build FKC denoiser config with physical potentials + config_path = os.path.join( + os.path.dirname(__file__), + "../../src/bioemu/config/steering/physical_steering.yaml", + ) + with open(config_path) as f: + import yaml + + denoiser_config = yaml.safe_load(f) + + denoiser = hydra.utils.instantiate(denoiser_config) + + with _mock_colabfold_embeds(TEST_SEQ): + batch = generate_batch( + score_model=self._mock_score_model, + sequence=TEST_SEQ, + sdes=sdes, + batch_size=2, + seed=42, + denoiser=denoiser, + cache_embeds_dir=None, + ) + + assert "pos" in batch + assert "node_orientations" in batch + assert batch["pos"].shape == (2, len(TEST_SEQ), 3) + assert batch["node_orientations"].shape == (2, len(TEST_SEQ), 3, 3) + + def test_generate_batch_with_smc_denoiser(self): + """generate_batch with dpm_solver_smc denoiser produces valid output.""" + import hydra + + from bioemu.sample import generate_batch + from bioemu.shortcuts import CosineVPSDE, DiGSO3SDE + from bioemu.steering.collective_variables import CaCaDistance + from bioemu.steering.potentials import UmbrellaPotential + from tests.test_embeds import TEST_SEQ + + sdes = {"node_orientations": DiGSO3SDE(), "pos": CosineVPSDE()} + + # Build SMC denoiser config as a dict (instead of YAML) + pot = UmbrellaPotential(cv=CaCaDistance(), target=0.38, slope=10.0, weight=1.0) + smc_config = { + "_target_": "bioemu.steering.dpm_smc.dpm_solver_smc", + "_partial_": True, + "eps_t": 0.001, + "max_t": 0.99, + "N": 5, + "noise": 0.5, + "fk_potentials": [pot], + "steering_config": {"num_particles": 2, "ess_threshold": 0.5, "start": 1.0, "end": 0.0}, + } + denoiser = hydra.utils.instantiate(smc_config) + + with _mock_colabfold_embeds(TEST_SEQ): + batch = generate_batch( + score_model=self._mock_score_model, + sequence=TEST_SEQ, + sdes=sdes, + batch_size=2, + seed=42, + denoiser=denoiser, + cache_embeds_dir=None, + ) + + assert "pos" in batch + assert "node_orientations" in batch + assert batch["pos"].shape == (2, len(TEST_SEQ), 3) + + def test_generate_batch_unsteered_dpm(self): + """generate_batch with standard dpm denoiser (no steering) still works.""" + import os + + import hydra + + from bioemu.sample import generate_batch + from bioemu.shortcuts import CosineVPSDE, DiGSO3SDE + from tests.test_embeds import TEST_SEQ + + sdes = {"node_orientations": DiGSO3SDE(), "pos": CosineVPSDE()} + + config_path = os.path.join( + os.path.dirname(__file__), "../../src/bioemu/config/denoiser/dpm.yaml" + ) + with open(config_path) as f: + import yaml + + denoiser_config = yaml.safe_load(f) + denoiser = hydra.utils.instantiate(denoiser_config) + + with _mock_colabfold_embeds(TEST_SEQ): + batch = generate_batch( + score_model=self._mock_score_model, + sequence=TEST_SEQ, + sdes=sdes, + batch_size=2, + seed=42, + denoiser=denoiser, + cache_embeds_dir=None, + ) + + assert "pos" in batch + assert batch["pos"].shape == (2, len(TEST_SEQ), 3) diff --git a/tests/steering/test_potentials.py b/tests/steering/test_potentials.py new file mode 100644 index 0000000..16546d4 --- /dev/null +++ b/tests/steering/test_potentials.py @@ -0,0 +1,218 @@ +"""Tests for bioemu.steering.potentials — potential energy functions.""" + +import torch + +from bioemu.steering.collective_variables import CaCaDistance, PairwiseClash +from bioemu.steering.potentials import LinearPotential, UmbrellaPotential + + +class TestUmbrellaPotentialLossFn: + """Tests for UmbrellaPotential.loss_fn (instance method).""" + + def test_at_target_zero(self): + x = torch.tensor([5.0]) + pot = UmbrellaPotential(target=5.0, flatbottom=0.5, slope=2.0, order=2, linear_from=1.0) + loss = pot.loss_fn(x) + torch.testing.assert_close(loss, torch.tensor([0.0])) + + def test_power_law(self): + x = torch.tensor([7.0]) + pot = UmbrellaPotential(target=5.0, flatbottom=0.0, slope=2.0, order=2, linear_from=10.0) + loss = pot.loss_fn(x) + torch.testing.assert_close(loss, torch.tensor([16.0])) + + def test_flatbottom_zero_inside(self): + x = torch.tensor([5.3]) + pot = UmbrellaPotential(target=5.0, flatbottom=0.5, slope=2.0, order=2, linear_from=1.0) + loss = pot.loss_fn(x) + torch.testing.assert_close(loss, torch.tensor([0.0])) + + def test_flatbottom_nonzero_outside(self): + x = torch.tensor([6.0]) + pot = UmbrellaPotential(target=5.0, flatbottom=0.2, slope=1.0, order=2, linear_from=10.0) + loss = pot.loss_fn(x) + torch.testing.assert_close(loss, torch.tensor([0.64])) + + def test_linear_from_transition(self): + pot = UmbrellaPotential(target=5.0, flatbottom=0.0, slope=1.0, order=2, linear_from=2.0) + loss_at = pot.loss_fn(torch.tensor([7.0])) + torch.testing.assert_close(loss_at, torch.tensor([4.0])) + + loss_beyond = pot.loss_fn(torch.tensor([8.0])) + torch.testing.assert_close(loss_beyond, torch.tensor([5.0])) + + def test_symmetric(self): + pot = UmbrellaPotential(target=5.0, flatbottom=0.0, slope=1.0, order=2, linear_from=10.0) + loss_above = pot.loss_fn(torch.tensor([6.0])) + loss_below = pot.loss_fn(torch.tensor([4.0])) + torch.testing.assert_close(loss_above, loss_below) + + +class TestLinearPotentialLossFn: + """Tests for LinearPotential.loss_fn (instance method).""" + + def test_basic(self): + x = torch.tensor([3.0]) + pot = LinearPotential(target=1.0, slope=2.0) + result = pot.loss_fn(x) + torch.testing.assert_close(result, torch.tensor([4.0])) + + def test_no_clip(self): + x = torch.tensor([10.0]) + pot = LinearPotential(target=0.0, slope=1.0) + result = pot.loss_fn(x) + torch.testing.assert_close(result, torch.tensor([10.0])) + + def test_clip_max(self): + x = torch.tensor([10.0]) + pot = LinearPotential(target=0.0, slope=1.0, clip_max=5.0) + result = pot.loss_fn(x) + torch.testing.assert_close(result, torch.tensor([5.0])) + + def test_clip_min(self): + x = torch.tensor([-10.0]) + pot = LinearPotential(target=0.0, slope=1.0, clip_min=-3.0) + result = pot.loss_fn(x) + torch.testing.assert_close(result, torch.tensor([-3.0])) + + +class TestChainBreakAsUmbrella: + """ChainBreakPotential replaced by UmbrellaPotential + CaCaDistance.""" + + @staticmethod + def _make_pot(**kwargs): + defaults = dict( + target=0.380209737096, flatbottom=0.0, slope=10.0, order=2, linear_from=0.1, weight=1.0 + ) + defaults.update(kwargs) + return UmbrellaPotential(cv=CaCaDistance(), **defaults) + + def test_ideal_spacing_low_energy(self): + pot = self._make_pot() + ca_ca_dist = 0.380209737096 + n = 10 + Ca_pos = torch.zeros(2, n, 3) + for i in range(n): + Ca_pos[:, i, 0] = i * ca_ca_dist + energy = pot(Ca_pos) + assert energy.shape == (2,) + torch.testing.assert_close(energy, torch.zeros(2), atol=1e-4, rtol=1e-4) + + def test_large_spacing_high_energy(self): + pot = self._make_pot() + Ca_pos = torch.zeros(1, 5, 3) + for i in range(5): + Ca_pos[0, i, 0] = i * 2.0 + energy = pot(Ca_pos) + assert energy.item() > 0 + + def test_flatbottom_zero_in_range(self): + pot = self._make_pot(flatbottom=0.05) + ca_ca_dist = 0.380209737096 + n = 10 + Ca_pos = torch.zeros(1, n, 3) + for i in range(n): + Ca_pos[0, i, 0] = i * (ca_ca_dist + 0.03) + energy = pot(Ca_pos) + torch.testing.assert_close(energy, torch.zeros(1), atol=1e-4, rtol=1e-4) + + +class TestChainClashAsUmbrella: + """ChainClashPotential replaced by UmbrellaPotential + PairwiseClash.""" + + @staticmethod + def _make_pot(min_dist=0.42, offset=3, slope=10.0, weight=1.0): + return UmbrellaPotential( + cv=PairwiseClash(min_dist=min_dist, offset=offset), + target=0.0, + flatbottom=0.0, + slope=slope, + order=1, + linear_from=1e6, + weight=weight, + ) + + def test_well_separated_zero_energy(self): + pot = self._make_pot() + n = 10 + Ca_pos = torch.zeros(2, n, 3) + for i in range(n): + Ca_pos[:, i, 0] = float(i) + energy = pot(Ca_pos) + assert energy.shape == (2,) + torch.testing.assert_close(energy, torch.zeros(2), atol=1e-6, rtol=1e-6) + + def test_overlapping_positive_energy(self): + pot = self._make_pot() + Ca_pos = torch.zeros(1, 10, 3) + energy = pot(Ca_pos) + assert energy.item() > 0 + + def test_offset_excludes_neighbors(self): + pot = self._make_pot() + Ca_pos = torch.zeros(1, 3, 3) + energy = pot(Ca_pos) + torch.testing.assert_close(energy, torch.zeros(1), atol=1e-6, rtol=1e-6) + + +class TestLinearPotentialWithCV: + """Tests for LinearPotential with a CV.""" + + def test_energy_with_mock_cv(self): + class MockCV: + def compute_batch(self, ca_pos, sequence=None): + return ca_pos.pow(2).sum(-1).mean(-1).sqrt() + + cv = MockCV() + pot = LinearPotential(target=1.0, slope=2.0, weight=1.0, cv=cv) + + Ca_pos = torch.zeros(1, 10, 3) + Ca_pos[0, :, 0] = torch.arange(10, dtype=torch.float32) + energy = pot(Ca_pos, t=0.5, sequence="A" * 10) + cv_val = cv.compute_batch(Ca_pos) + expected = 2.0 * (cv_val - 1.0) + torch.testing.assert_close(energy, expected, atol=1e-5, rtol=1e-5) + + def test_clipping(self): + class MockCV: + def compute_batch(self, ca_pos, sequence=None): + return torch.tensor([100.0]) + + cv = MockCV() + pot = LinearPotential(target=0.0, slope=1.0, weight=1.0, clip_max=0.5, cv=cv) + Ca_pos = torch.randn(1, 10, 3) + energy = pot(Ca_pos, t=0.5, sequence="A" * 10) + torch.testing.assert_close(energy, torch.tensor([0.5]), atol=1e-5, rtol=1e-5) + + def test_differentiable(self): + class DiffCV: + def compute_batch(self, ca_pos, sequence=None): + return ca_pos.pow(2).sum(dim=(1, 2)) + + cv = DiffCV() + pot = LinearPotential(target=0.0, slope=1.0, weight=1.0, cv=cv) + Ca_pos = torch.randn(2, 5, 3, requires_grad=True) + energy = pot(Ca_pos, t=0.5, sequence="A" * 5) + loss = energy.sum() + loss.backward() + assert Ca_pos.grad is not None + assert Ca_pos.grad.shape == Ca_pos.shape + + +class TestUmbrellaPotentialWithCV: + """Tests for UmbrellaPotential with a CV.""" + + def test_energy_follows_loss_fn(self): + class MockCV: + def compute_batch(self, ca_pos, sequence=None): + return torch.tensor([3.0]) + + cv = MockCV() + pot = UmbrellaPotential( + target=1.0, flatbottom=0.0, slope=2.0, order=2, linear_from=10.0, weight=3.0, cv=cv + ) + Ca_pos = torch.randn(1, 10, 3) + energy = pot(Ca_pos, t=0.5, sequence="A" * 10) + cv_val = torch.tensor([3.0]) + expected = 3.0 * pot.loss_fn(cv_val) + torch.testing.assert_close(energy, expected, atol=1e-5, rtol=1e-5) diff --git a/tests/steering/test_steering_numerically.py b/tests/steering/test_steering_numerically.py new file mode 100644 index 0000000..5d9123f --- /dev/null +++ b/tests/steering/test_steering_numerically.py @@ -0,0 +1,231 @@ +"""Numerical tests verifying that dpm_fkc and dpm_smc steer toward the correct distribution. + +Both samplers should produce samples from the biased distribution +biased(x) ∝ GMM(x) × p(x), where p(x) = 1/Z exp(-U(x)) is the Boltzmann +distribution induced by a quadratic potential U(x) = k/2 (x - center)². +""" + +import logging + +import numpy as np +import torch +import torch.nn as nn +from torch_geometric.data import Batch + +from bioemu.chemgraph import ChemGraph +from bioemu.sde_lib import CosineVPSDE +from bioemu.so3_sde import DiGSO3SDE +from bioemu.toy_gmm import TimeDependentGMM1D + +# Suppress noisy warnings from 1D toy setup +logging.disable(logging.WARNING) + +# ============================================================ +# Shared constants +# ============================================================ + +POTENTIAL_K = 1.0 +POTENTIAL_CENTER = 2.0 +N_SAMPLES = 4096 +N_STEPS = 100 +NOISE_SCALE = 1.0 +MAX_T = 0.99 +EPS_T = 0.01 +MAE_THRESHOLD = 0.05 + + +# ============================================================ +# Shared helpers +# ============================================================ + + +def make_gmm_and_sde(): + """Create the GMM target distribution and CosineVPSDE scheduler.""" + sde = CosineVPSDE() + gmm = TimeDependentGMM1D( + mu1=torch.tensor([-2.0]), + mu2=torch.tensor([3.0]), + sigma1=0.5, + sigma2=0.5, + weight1=0.7, + scheduler=sde, + ) + return gmm, sde + + +def make_sdes(): + """Create the SDEs dict required by the samplers.""" + return { + "pos": CosineVPSDE(), + "node_orientations": DiGSO3SDE(num_sigma=10, num_omega=10, l_max=10), + } + + +class GMMScoreWrapper(nn.Module): + """Wraps TimeDependentGMM1D to conform to the bioemu score model interface. + + ``get_score()`` divides the model output by ``pos_std``, so this wrapper + returns ``analytical_score × std`` for pos and zeros for SO3. + """ + + def __init__(self, gmm: TimeDependentGMM1D, pos_sde: CosineVPSDE): + super().__init__() + self.gmm = gmm + self.pos_sde = pos_sde + self.dummy = nn.Parameter(torch.zeros(1)) + + def forward(self, batch, t): + x = batch.pos[:, 0:1] + t_per_node = t[batch.batch] + score_1d = self.gmm.score(x, t_per_node) + + _, pos_std = self.pos_sde.marginal_prob( + x=torch.ones_like(batch.pos), t=t, batch_idx=batch.batch + ) + + pos_output = torch.zeros_like(batch.pos) + self.dummy * 0 + pos_output[:, 0:1] = score_1d * pos_std[:, 0:1] + + return { + "pos": pos_output, + "node_orientations": torch.zeros(batch.pos.shape[0], 3, device=batch.pos.device) + + self.dummy * 0, + } + + +class QuadraticPotential: + """Quadratic potential U(x) = k/2 * (x - center)².""" + + def __init__(self, k: float = POTENTIAL_K, center: float = POTENTIAL_CENTER): + self.k = k + self.center = center + + def __call__(self, ca_pos_nm: torch.Tensor, *, t=None, sequence=None) -> torch.Tensor: + x = ca_pos_nm.reshape(ca_pos_nm.shape[0], -1)[:, 0] + return 0.5 * self.k * (x - self.center) ** 2 + + +def make_toy_batch(n_samples: int) -> Batch: + """Create a Batch of 1-residue ChemGraph objects for the 1D toy problem.""" + data_list = [ + ChemGraph( + pos=torch.zeros(1, 3), + node_orientations=torch.eye(3).unsqueeze(0), + edge_index=torch.zeros(2, 0, dtype=torch.long), + single_embeds=torch.zeros(1, 1), + pair_embeds=torch.zeros(1, 1), + sequence="A", + system_id=f"toy_{i}", + node_labels=torch.zeros(1, dtype=torch.long), + ) + for i in range(n_samples) + ] + return Batch.from_data_list(data_list) + + +def compute_ground_truth_pdf( + gmm: TimeDependentGMM1D, + k: float = POTENTIAL_K, + center: float = POTENTIAL_CENTER, + x_min: float = -6.0, + x_max: float = 8.0, + n_points: int = 1000, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute the normalized biased PDF: biased(x) ∝ GMM(x) × p(x). + + Returns (x_grid, biased_pdf). + """ + x_grid = torch.linspace(x_min, x_max, n_points) + t_zero = torch.zeros(n_points) + dx = x_grid[1] - x_grid[0] + + with torch.no_grad(): + log_gmm = gmm.log_prob(x_grid.unsqueeze(-1), t_zero) + + # Normalized Boltzmann p(x) = 1/Z exp(-U(x)) + boltzmann_unnorm = torch.exp(-0.5 * k * (x_grid - center) ** 2) + boltzmann_pdf = boltzmann_unnorm / (boltzmann_unnorm.sum() * dx) + + # Normalized GMM + gmm_unnorm = torch.exp(log_gmm - log_gmm.max()) + gmm_pdf = gmm_unnorm / (gmm_unnorm.sum() * dx) + + # Biased = GMM × p(x), renormalized + biased_unnorm = gmm_pdf * boltzmann_pdf + biased_pdf = biased_unnorm / (biased_unnorm.sum() * dx) + + return x_grid, biased_pdf + + +def compute_mae(samples: torch.Tensor, x_grid: torch.Tensor, target_pdf: torch.Tensor) -> float: + """Compute MAE between a histogram of samples and a target PDF on the same grid.""" + bin_edges = np.linspace(x_grid[0].item(), x_grid[-1].item(), len(x_grid) + 1) + hist_counts, _ = np.histogram(samples.numpy(), bins=bin_edges, density=True) + sample_density = torch.tensor(hist_counts, dtype=torch.float32) + return (sample_density - target_pdf).abs().mean().item() + + +# ============================================================ +# Tests +# ============================================================ + + +def test_dpm_fkc(): + """dpm_solver_fkc steers samples toward the biased distribution.""" + from bioemu.steering.dpm_fkc import dpm_solver_fkc + + gmm, sde = make_gmm_and_sde() + sdes = make_sdes() + score_model = GMMScoreWrapper(gmm, sde) + potential = QuadraticPotential() + batch = make_toy_batch(N_SAMPLES) + x_grid, biased_pdf = compute_ground_truth_pdf(gmm) + + torch.manual_seed(42) + result_batch, _ = dpm_solver_fkc( + sdes=sdes, + batch=batch, + N=N_STEPS, + score_model=score_model, + max_t=MAX_T, + eps_t=EPS_T, + device=torch.device("cpu"), + fk_potentials=[potential], + steering_config={"num_particles": 128, "ess_threshold": 0.5, "start": 1.0, "end": 0.0}, + noise=NOISE_SCALE, + use_x0_for_reward=False, + ) + + samples = result_batch.pos[:, 0].detach().cpu() + mae = compute_mae(samples, x_grid, biased_pdf) + assert mae < MAE_THRESHOLD, f"FKC MAE {mae:.4f} exceeds threshold {MAE_THRESHOLD}" + + +def test_dpm_smc(): + """dpm_solver_smc steers samples toward the biased distribution.""" + from bioemu.steering.dpm_smc import dpm_solver_smc + + gmm, sde = make_gmm_and_sde() + sdes = make_sdes() + score_model = GMMScoreWrapper(gmm, sde) + potential = QuadraticPotential() + batch = make_toy_batch(N_SAMPLES) + x_grid, biased_pdf = compute_ground_truth_pdf(gmm) + + torch.manual_seed(42) + result_batch, _ = dpm_solver_smc( + sdes=sdes, + batch=batch, + N=N_STEPS, + score_model=score_model, + max_t=MAX_T, + eps_t=EPS_T, + device=torch.device("cpu"), + fk_potentials=[potential], + steering_config={"num_particles": 128, "ess_threshold": 0.5, "start": 1.0, "end": 0.0}, + noise=NOISE_SCALE, + ) + + samples = result_batch.pos[:, 0].detach().cpu() + mae = compute_mae(samples, x_grid, biased_pdf) + assert mae < MAE_THRESHOLD, f"SMC MAE {mae:.4f} exceeds threshold {MAE_THRESHOLD}" diff --git a/tests/steering/test_utils.py b/tests/steering/test_utils.py new file mode 100644 index 0000000..8548b71 --- /dev/null +++ b/tests/steering/test_utils.py @@ -0,0 +1,168 @@ +"""Tests for bioemu.steering.utils — resampling, loss functions, and helpers.""" + +import torch +from torch_geometric.data import Batch, Data + + +class TestStratifiedResample: + """Tests for stratified_resample.""" + + def test_uniform_weights_spread(self): + """Uniform weights should produce indices that cover most of the range.""" + from bioemu.steering.utils import stratified_resample + + torch.manual_seed(0) + B, N = 1, 100 + weights = torch.ones(B, N) / N + idx = stratified_resample(weights) + assert idx.shape == (B, N) + unique = idx[0].unique() + assert unique.numel() >= N - 5 + + def test_delta_distribution(self): + """All weight on particle k → all indices should be k.""" + from bioemu.steering.utils import stratified_resample + + B, N, k = 2, 20, 7 + weights = torch.zeros(B, N) + weights[:, k] = 1.0 + idx = stratified_resample(weights) + assert (idx == k).all() + + +class TestResampleBasedOnLogWeights: + """Tests for resample_based_on_log_weights.""" + + @staticmethod + def _make_batch(n_samples: int, n_residues: int = 5) -> Batch: + data_list = [] + for i in range(n_samples): + d = Data( + pos=torch.randn(n_residues, 3), + node_orientations=torch.eye(3).unsqueeze(0).expand(n_residues, -1, -1).clone(), + batch=torch.zeros(n_residues, dtype=torch.long), + sequence=["ACDEF"], + ) + data_list.append(d) + return Batch.from_data_list(data_list) + + def test_equal_weights_no_resample(self): + """When all weights are equal, ESS ≈ 1 → no resampling if threshold < 1.""" + from bioemu.steering.utils import resample_based_on_log_weights + + n = 16 + batch = self._make_batch(n) + log_w = torch.zeros(n) + new_batch, new_lw, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_w, + n_particles=n, + is_last_step=False, + ess_threshold=0.5, + step=0, + t=0.5, + ) + assert ess > 0.9 + torch.testing.assert_close(indices, torch.arange(n)) + + def test_dominant_weight_resamples(self): + """When one weight dominates and ESS is low, resampling should happen.""" + from bioemu.steering.utils import resample_based_on_log_weights + + n = 16 + batch = self._make_batch(n) + log_w = torch.full((n,), -100.0) + log_w[3] = 0.0 + new_batch, new_lw, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_w, + n_particles=n, + is_last_step=False, + ess_threshold=0.5, + step=0, + t=0.5, + ) + assert ess < 0.2 + torch.testing.assert_close(new_lw, torch.zeros(n)) + + def test_last_step_always_resamples(self): + """On the last step, resampling should always happen regardless of ESS.""" + from bioemu.steering.utils import resample_based_on_log_weights + + n = 8 + batch = self._make_batch(n) + log_w = torch.zeros(n) + new_batch, new_lw, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_w, + n_particles=n, + is_last_step=True, + ess_threshold=0.5, + step=99, + t=0.001, + ) + torch.testing.assert_close(new_lw, torch.zeros(n)) + + def test_small_batch_treated_as_single_group(self): + """When batch_size < n_particles, entire batch is treated as one group.""" + from bioemu.steering.utils import resample_based_on_log_weights + + n = 4 + batch = self._make_batch(n) + log_w = torch.zeros(n) + new_batch, new_lw, indices, ess = resample_based_on_log_weights( + batch=batch, + log_weight=log_w, + n_particles=100, + is_last_step=False, + ess_threshold=0.5, + step=0, + t=0.5, + ) + assert ess > 0.9 + + +class TestComputeSequenceAlignment: + """Tests for compute_sequence_alignment.""" + + def test_identical_sequences(self): + from bioemu.steering.utils import compute_sequence_alignment + + mapping = compute_sequence_alignment("ACDEF", "ACDEF") + assert mapping == {0: 0, 1: 1, 2: 2, 3: 3, 4: 4} + + def test_subsequence(self): + from bioemu.steering.utils import compute_sequence_alignment + + mapping = compute_sequence_alignment("ACE", "ACDEF") + assert 0 in mapping # A + assert 1 in mapping # C + assert 2 in mapping # E + + def test_different_sequences(self): + """Completely different sequences should still return a dict.""" + from bioemu.steering.utils import compute_sequence_alignment + + mapping = compute_sequence_alignment("AAAA", "CCCC") + assert isinstance(mapping, dict) + + +class TestGetPos0Rot0: + """Tests for get_pos0_rot0 helper.""" + + def test_identity_score_returns_input(self): + """With zero score, x0 prediction should roughly match x_t / alpha_t.""" + from bioemu.steering.utils import _get_x0_given_xt_and_score + + # Simple test: x0 = (x + sigma^2 * score) / alpha + x = torch.randn(10, 3) + score = torch.zeros_like(x) + batch_idx = torch.zeros(10, dtype=torch.long) + + # Mock SDE that returns alpha=1, sigma=0 + class MockSDE: + def mean_coeff_and_std(self, x, t, batch_idx): + return torch.ones_like(x), torch.zeros_like(x) + + x0 = _get_x0_given_xt_and_score(MockSDE(), x, torch.tensor([0.5]), batch_idx, score) + torch.testing.assert_close(x0, x) diff --git a/tests/test_steering.py b/tests/test_steering.py deleted file mode 100644 index 150ed68..0000000 --- a/tests/test_steering.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Tests for steering features in BioEMU. - -Tests the steering capabilities including: -- ChainBreakPotential and ChainClashPotential - -All tests use the chignolin sequence (GYDPETGTWG) for consistency. -""" - -import os -import random -import shutil -from pathlib import Path - -import numpy as np -import pytest -import torch -import yaml - -from bioemu.sample import main as sample - -# Path to the physical steering config file (ground truth) -PHYSICAL_STEERING_CONFIG_PATH = ( - Path(__file__).parent.parent - / "src" - / "bioemu" - / "config" - / "steering" - / "physical_steering.yaml" -) - -# Set fixed seeds for reproducibility -SEED = 42 -random.seed(SEED) -np.random.seed(SEED) -torch.manual_seed(SEED) -if torch.cuda.is_available(): - torch.cuda.manual_seed_all(SEED) - - -@pytest.fixture -def chignolin_sequence(): - """Chignolin sequence for consistent testing across all steering tests.""" - return "GYDPETGTWG" - - -@pytest.fixture -def base_test_config(): - """Base configuration for steering tests.""" - return { - "batch_size_100": 100, # Small for fast testing - "num_samples": 10, # Small for fast testing - } - - -def load_steering_config(): - """Load the physical steering config from YAML file.""" - with open(PHYSICAL_STEERING_CONFIG_PATH) as f: - return yaml.safe_load(f) - - -def test_steering_with_config_path(chignolin_sequence, base_test_config): - """Test steering by passing the config file path directly.""" - output_dir = "./test_outputs/steering_config_path" - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - - sample( - sequence=chignolin_sequence, - num_samples=base_test_config["num_samples"], - batch_size_100=base_test_config["batch_size_100"], - output_dir=output_dir, - denoiser_type="dpm", - denoiser_config="src/bioemu/config/denoiser/stochastic_dpm.yaml", - steering_config=PHYSICAL_STEERING_CONFIG_PATH, - ) - - -def test_steering_with_config_dict(chignolin_sequence, base_test_config): - """Test steering by passing the config as a dict.""" - output_dir = "./test_outputs/steering_config_dict" - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - - steering_config = load_steering_config() - - sample( - sequence=chignolin_sequence, - num_samples=base_test_config["num_samples"], - batch_size_100=base_test_config["batch_size_100"], - output_dir=output_dir, - denoiser_type="dpm", - denoiser_config="src/bioemu/config/denoiser/stochastic_dpm.yaml", - steering_config=steering_config, - ) - - -def test_steering_modified_num_particles(chignolin_sequence, base_test_config): - """Test steering with modified number of particles.""" - output_dir = "./test_outputs/steering_modified_particles" - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - - steering_config = load_steering_config() - steering_config["num_particles"] = 5 # Modify from default - - sample( - sequence=chignolin_sequence, - num_samples=base_test_config["num_samples"], - batch_size_100=base_test_config["batch_size_100"], - output_dir=output_dir, - denoiser_type="dpm", - denoiser_config="src/bioemu/config/denoiser/stochastic_dpm.yaml", - steering_config=steering_config, - ) - - -def test_steering_modified_time_window(chignolin_sequence, base_test_config): - """Test steering with modified start/end time window.""" - output_dir = "./test_outputs/steering_modified_time" - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - - steering_config = load_steering_config() - steering_config["start"] = 0.7 # Modify time window - steering_config["end"] = 0.3 - - sample( - sequence=chignolin_sequence, - num_samples=base_test_config["num_samples"], - batch_size_100=base_test_config["batch_size_100"], - output_dir=output_dir, - denoiser_type="dpm", - denoiser_config="src/bioemu/config/denoiser/stochastic_dpm.yaml", - steering_config=steering_config, - ) - - -def test_no_steering(chignolin_sequence, base_test_config): - """Test sampling without steering (steering_config=None).""" - output_dir = "./test_outputs/no_steering" - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - - sample( - sequence=chignolin_sequence, - num_samples=base_test_config["num_samples"], - batch_size_100=base_test_config["batch_size_100"], - output_dir=output_dir, - denoiser_type="dpm", - denoiser_config="src/bioemu/config/denoiser/stochastic_dpm.yaml", - steering_config=None, - )