diff --git a/environment.yml b/environment.yml index 9a603c1..03ae45d 100644 --- a/environment.yml +++ b/environment.yml @@ -14,7 +14,7 @@ dependencies: - git-lfs - scikit-image - # - optuna + - optuna # - proxsuite - cvxpy - transforms3d diff --git a/examples/cpd_optimization_config.yaml b/examples/cpd_optimization_config.yaml new file mode 100644 index 0000000..95f4fab --- /dev/null +++ b/examples/cpd_optimization_config.yaml @@ -0,0 +1,41 @@ +fixed_image: + path: "data/test_elastic_registration/fixed_image.n5" + input_key: "input" + aligned_key: "svd_prealignment_with_lm" + x_res: 1.0 + y_res: 1.0 + z_res: 1.0 + +moving_image: + path: "data/test_elastic_registration/moving_image.n5" + input_key: "input" + aligned_key: "rigid_alignment_with_lm" + x_res: 1.0 + y_res: 1.0 + z_res: 1.0 + +landmarks: + fixed: "examples/data/fixed_landmarks.csv" + moving: "examples/data/moving_landmarks.csv" + +prealignment: + axis_orientation: "auto" + +log_dir: "data/test_elastic_registration/cpd_optimization" + +optuna: + study_name: "cpd_optimization" + # search_space controls which parameter ranges are used. Three options: + # + # "default" — load ranges from default_cpd_ranges.yaml + # "dataset-specific" — compute beta from extent of aligned fixed image via + # suggest_cpd_ranges.py (run automatically by Snakemake + # after pre- and rigid alignment); other params from defaults + # explicit dict — specify ranges directly, e.g.: + # + # search_space: + # w: [1.0e-5, 1.0e-4, 1.0e-3, 1.0e-2] + # beta: [50.0, 100.0, 200.0] + # lmd: [0.01, 0.1, 1.0] + # maxiter: [150] + search_space: "default" diff --git a/matchmaker/cpd_parameter_tuning/README.md b/matchmaker/cpd_parameter_tuning/README.md new file mode 100644 index 0000000..708e0c5 --- /dev/null +++ b/matchmaker/cpd_parameter_tuning/README.md @@ -0,0 +1,94 @@ +# CPD Hyperparameter Tuning + +Optuna-based grid search over CPD (Coherent Point Drift) nonrigid registration parameters, +evaluated on user-provided corresponding landmarks. + +## Input landmarks + +Pairs of corresponding landmarks needs to be provided, one CSV for the fixed image and +one for the moving image. Point to them in the `landmarks` section of +`examples/cpd_optimization_config.yaml`. Each CSV has the columns `name, x, y, z`, where +coordinates are in physical µm and `name` matches between the two files so corresponding +landmarks can be paired: + +``` +name,x,y,z +landmark_1,12.4,8.1,30.0 +landmark_2,40.2,15.7,28.5 +``` + +The repository does not ship example landmark CSVs; the example config paths are placeholders. + +## Parameters + +| Parameter | Role | +|-----------|------| +| `beta` | Gaussian kernel width: controls smoothness/locality of the displacement field. Larger = smoother, more global deformation. | +| `w` | Outlier weight: fraction of points treated as noise. Higher = more robust to outliers but less accurate. | +| `lmd` | Regularization strength: Larger = stiffer, penalizes deformation more strongly. | +| `maxiter` | Maximum EM iterations. | + +### Default CPD parameters + +saved in `default_cpd_params.yaml` + +``` +beta = 100 +w = 1e-5 +lmd = 0.1 +maxiter = 100 +``` + +### Default CPD grid search ranges + +saved in `default_cpd_ranges.yaml` + +``` +DEFAULT_SEARCH_SPACE = { + "w": [1e-5, 1e-4, 1e-3, 1e-2, 1e-1], + "beta": [10.0, 50.0, 100.0, 200.0], + "lmd": [0.01, 0.1, 1.0, 10.0], + "maxiter": [100], +} +``` + +### Dataset-specific CPD grid search ranges + +`suggest_cpd_ranges.py` derives a `beta` search range proportional to the data scale. +It reads a segmentation, computes basic point-cloud statistics (spatial extent, point +density), and prints a YAML block that can be pasted into the `optuna.search_space` +section of the optimization config. + +When the config sets `search_space: "dataset-specific"`, the Snakemake workflow runs this +step automatically on the aligned fixed segmentation. It can also run manually: + +``` +python matchmaker/cpd_parameter_tuning/suggest_cpd_ranges.py \ + --path .n5 \ + --key svd_prealignment \ + --x_res 0.4 --y_res 0.4 --z_res 0.4 +``` + + +## Snakemake workflow + +Snakemake workflow for CPD parameter optimization via Optuna. + +Separate from the main registration pipeline. Performs: + 1. Embed the corresponding landmarks into both input segmentations as labeled spheres. + 2. Apply SVD prealignment to the fixed image (carrying its landmarks). + 3. Apply SVD + rigid alignment to the moving image (carrying its landmarks). + 4. Optionally compute dataset-specific beta ranges from the aligned fixed segmentation + (when `search_space: "dataset-specific"`). + 5. Run an Optuna grid search over CPD parameters, evaluating each combination by the + mean Landmark Registration Error (LRE) between corresponding landmarks after CPD. + +Output: `best_cpd_params.yaml` (drop-in replacement for the `coherent_point_drift` section +of the main registration config). + +Usage: +``` +snakemake -s workflows/cpd_optimization.smk \ + --configfile examples/cpd_optimization_config.yaml \ + --cores 1 +``` \ No newline at end of file diff --git a/matchmaker/cpd_parameter_tuning/add_landmarks.py b/matchmaker/cpd_parameter_tuning/add_landmarks.py new file mode 100644 index 0000000..e39959d --- /dev/null +++ b/matchmaker/cpd_parameter_tuning/add_landmarks.py @@ -0,0 +1,148 @@ +import json +import click +import logging +import numpy as np +import pandas as pd +from pathlib import Path + +from matchmaker.utils import read_volume, write_volume, get_attrs, setup_logging, plot_landmark_qc, PINK, CYAN + + +def landmark_label_ids(sorted_names, dtype): + """Return dict {name: label_id} placing IDs at the top of the dtype range. + + IDs are assigned from (dtype_max - n_landmarks + 1) to dtype_max in sorted-name order, + so the mapping is fully determined by the landmark names and dtype — independent of + which segmentation (fixed or moving) is being processed, ensuring consistency. + """ + max_val = int(np.iinfo(dtype).max) + offset = max_val - len(sorted_names) + 1 + return {name: offset + i for i, name in enumerate(sorted_names)} + + +def add_landmarks_to_seg(seg, landmarks_df, resolution, id_map, radius=3): + """Embed landmark spheres into a segmentation at positions given by landmarks_df. + + Each landmark is written as a filled sphere of the given radius (in voxels). + This ensures the landmark label survives affine transforms and nearest-neighbor + interpolation, where a single voxel would often be missed. + + Args: + seg: integer ZYX numpy array + landmarks_df: DataFrame with columns [name, x, y, z] in physical µm + resolution: [z_res, y_res, x_res] in µm + id_map: {landmark_name: label_id} + radius: sphere radius in voxels + + Returns modified segmentation (copy). + """ + min_lm_id = min(id_map.values()) + if int(seg.max()) >= min_lm_id: + logging.warning( + f"Existing label max ({int(seg.max())}) >= landmark ID offset ({min_lm_id}). " + "Landmark labels may overwrite nucleus labels." + ) + seg_out = seg.copy() + shape = np.array(seg.shape) + offsets = [ + (dz, dy, dx) + for dz in range(-radius, radius + 1) + for dy in range(-radius, radius + 1) + for dx in range(-radius, radius + 1) + if dz**2 + dy**2 + dx**2 <= radius**2 + ] + for row in landmarks_df.itertuples(index=False): + if row.name not in id_map: + logging.warning(f"Landmark {row.name!r} not in id_map, skipping") + continue + lbl = id_map[row.name] + z_c = row.z / resolution[0] + y_c = row.y / resolution[1] + x_c = row.x / resolution[2] + for dz, dy, dx in offsets: + z_v = int(np.clip(round(z_c + dz), 0, shape[0] - 1)) + y_v = int(np.clip(round(y_c + dy), 0, shape[1] - 1)) + x_v = int(np.clip(round(x_c + dx), 0, shape[2] - 1)) + seg_out[z_v, y_v, x_v] = lbl + logging.info(f" {row.name}: label={lbl}, center=({z_c:.1f},{y_c:.1f},{x_c:.1f}), r={radius}") + return seg_out + + +@click.command() +@click.option("--fixed_path", required=True, help="Path to the fixed raw .n5 segmentation") +@click.option("--fixed_key", required=True, help="Dataset key of the raw fixed segmentation") +@click.option("--fixed_output_key", required=True, help="Dataset key to write the fixed segmentation with landmarks") +@click.option("--fixed_landmarks_csv", required=True, help="CSV with columns [name, x, y, z] in µm for the fixed image") +@click.option("--moving_path", required=True, help="Path to the moving raw .n5 segmentation") +@click.option("--moving_key", required=True, help="Dataset key of the raw moving segmentation") +@click.option("--moving_output_key", required=True, help="Dataset key to write the moving segmentation with landmarks") +@click.option( + "--moving_landmarks_csv", + required=True, + help="CSV with columns [name, x, y, z] in µm for the moving image") +@click.option("--log_dir", required=True, help="Output directory for logs and landmark_label_ids.json") +def main(fixed_path, fixed_key, fixed_output_key, fixed_landmarks_csv, + moving_path, moving_key, moving_output_key, moving_landmarks_csv, + log_dir): + + Path(log_dir).mkdir(parents=True, exist_ok=True) + setup_logging(log_dir, "add_landmarks.log") + + logging.info(f"Reading fixed segmentation from {fixed_path}/{fixed_key}") + fixed_seg = read_volume(fixed_path, fixed_key) + fixed_attrs = dict(get_attrs(fixed_path, fixed_key)) + fixed_resolution = fixed_attrs["resolution"] + logging.info(f"Fixed shape: {fixed_seg.shape}, dtype: {fixed_seg.dtype}, resolution: {fixed_resolution}") + + logging.info(f"Reading moving segmentation from {moving_path}/{moving_key}") + moving_seg = read_volume(moving_path, moving_key) + moving_attrs = dict(get_attrs(moving_path, moving_key)) + moving_resolution = moving_attrs["resolution"] + logging.info(f"Moving shape: {moving_seg.shape}, dtype: {moving_seg.dtype}, resolution: {moving_resolution}") + + fixed_lm_df = pd.read_csv(fixed_landmarks_csv) + moving_lm_df = pd.read_csv(moving_landmarks_csv) + shared_names = sorted(set(fixed_lm_df["name"].tolist()) & set(moving_lm_df["name"].tolist())) + fixed_only = set(fixed_lm_df["name"]) - set(shared_names) + moving_only = set(moving_lm_df["name"]) - set(shared_names) + if fixed_only: + logging.warning(f"Landmarks only in fixed CSV, skipping: {sorted(fixed_only)}") + if moving_only: + logging.warning(f"Landmarks only in moving CSV, skipping: {sorted(moving_only)}") + fixed_lm_df = fixed_lm_df[fixed_lm_df["name"].isin(shared_names)] + moving_lm_df = moving_lm_df[moving_lm_df["name"].isin(shared_names)] + + id_map = landmark_label_ids(shared_names, fixed_seg.dtype) + min_lm_id = min(id_map.values()) + logging.info(f"{len(shared_names)} corresponding landmarks, ids {min_lm_id}–{max(id_map.values())}") + + logging.info(f"Embedding {len(fixed_lm_df)} landmarks into fixed segmentation") + fixed_seg_lm = add_landmarks_to_seg(fixed_seg, fixed_lm_df, fixed_resolution, id_map) + + logging.info(f"Embedding {len(moving_lm_df)} landmarks into moving segmentation") + moving_seg_lm = add_landmarks_to_seg(moving_seg, moving_lm_df, moving_resolution, id_map) + + logging.info(f"Writing fixed with landmarks to {fixed_path}/{fixed_output_key}") + write_volume(fixed_path, fixed_seg_lm, fixed_output_key, attrs=fixed_attrs) + + logging.info(f"Writing moving with landmarks to {moving_path}/{moving_output_key}") + write_volume(moving_path, moving_seg_lm, moving_output_key, attrs=moving_attrs) + + label_id_path = Path(log_dir) / "landmark_label_ids.json" + with open(label_id_path, "w") as f: + json.dump(id_map, f, indent=2) + logging.info(f"Label ID mapping written to {label_id_path}") + + plots_dir = Path(log_dir) / "plots" + plots_dir.mkdir(exist_ok=True) + plot_landmark_qc(fixed_seg_lm, id_map, + save_path=plots_dir / "fixed_landmarks_qc.png", + cell_cmap=PINK, landmark_color="red") + plot_landmark_qc(moving_seg_lm, id_map, + save_path=plots_dir / "moving_landmarks_qc.png", + cell_cmap=CYAN, landmark_color="blue") + logging.info(f"Landmark QC plots written to {plots_dir}") + + +if __name__ == "__main__": + main() diff --git a/matchmaker/cpd_parameter_tuning/cpd_optimization.py b/matchmaker/cpd_parameter_tuning/cpd_optimization.py new file mode 100644 index 0000000..255a20c --- /dev/null +++ b/matchmaker/cpd_parameter_tuning/cpd_optimization.py @@ -0,0 +1,252 @@ +import json +import threading +import click +import logging +import numpy as np +import pandas as pd +import yaml +from functools import reduce +from pathlib import Path + +import optuna +from optuna.samplers import GridSampler + +from matchmaker.utils import ( + read_volume, get_attrs, extract_centroids, run_cpd, create_pcd, setup_logging, + visualize_displacement_field, +) +from matchmaker.cpd_parameter_tuning.suggest_cpd_ranges import suggest_beta_ranges + +import matplotlib +matplotlib.use("agg") # non-interactive backend required for worker threads + +_ranges_file = Path(__file__).parent / "default_cpd_ranges.yaml" +with open(_ranges_file) as _f: + DEFAULT_SEARCH_SPACE = yaml.safe_load(_f) + + +def compute_lre(fixed_pcd, registered_pcd, id_map): + """Compute mean and per-landmark LRE (µm) between fixed and CPD-registered point clouds. + + Args: + fixed_pcd: Open3D point cloud for the fixed image, with landmark labels embedded + registered_pcd: Open3D point cloud after CPD, with moving landmark labels at new positions + id_map: {landmark_name: label_id} + + Returns: + (mean_lre, per_landmark_dict) + """ + def pcd_to_label_pos(pcd): + labels = pcd.point.label.numpy()[:, 0].astype(int) + positions = pcd.point.positions.numpy() + return dict(zip(labels, positions)) + + fixed_pos = pcd_to_label_pos(fixed_pcd) + reg_pos = pcd_to_label_pos(registered_pcd) + + distances, per_lm = [], {} + for name, lbl in id_map.items(): + if lbl in fixed_pos and lbl in reg_pos: + d = float(np.linalg.norm(fixed_pos[lbl] - reg_pos[lbl])) + distances.append(d) + per_lm[name] = d + else: + logging.warning(f"Landmark {name!r} (id={lbl}) missing from point clouds") + + mean_lre = float(np.mean(distances)) if distances else float("inf") + return mean_lre, per_lm + + +def run_optimization(fixed_pcd, moving_pcd, fixed_labels, moving_labels, + id_map, search_space, n_trials, study_name, output_dir, n_jobs=1): + """Run Optuna grid search and save results.""" + + logging.info(f"Fixed point cloud: {len(fixed_labels)} points") + logging.info(f"Moving point cloud: {len(moving_labels)} points") + min_lm_id = min(id_map.values()) + n_landmark_fixed = sum(1 for lbl in fixed_labels if lbl >= min_lm_id) + n_landmark_moving = sum(1 for lbl in moving_labels if lbl >= min_lm_id) + logging.info(f"Landmarks in fixed pcd: {n_landmark_fixed}, in moving pcd: {n_landmark_moving}") + + trial_results = [] + results_lock = threading.Lock() + + def objective(trial): + w = trial.suggest_categorical("w", search_space["w"]) + beta = trial.suggest_categorical("beta", search_space["beta"]) + lmd = trial.suggest_categorical("lmd", search_space["lmd"]) + maxiter = trial.suggest_categorical("maxiter", search_space["maxiter"]) + + logging.info(f"Trial {trial.number}: w={w}, beta={beta}, lmd={lmd}, maxiter={maxiter}") + registered_pcd = run_cpd(fixed_pcd, moving_pcd, w, beta, lmd, maxiter) + mean_lre, per_lm = compute_lre(fixed_pcd, registered_pcd, id_map) + logging.info(f" → mean LRE = {mean_lre:.2f} µm") + + plot_title = f"Trial {trial.number:04d} | w={w}, β={beta}, λ={lmd}, maxiter={maxiter} | LRE={mean_lre:.2f} µm" + for proj in ("xy", "xz", "yz"): + visualize_displacement_field( + moving_pcd, + registered_pcd, + projection=proj, + title=plot_title, + save_path=plots_dir / f"trial_{trial.number:04d}_{proj}.png", + ) + + result = { + "trial": trial.number, + "w": w, + "beta": beta, + "lmd": lmd, + "maxiter": maxiter, + "mean_lre_um": mean_lre, + "per_landmark_um": per_lm, + } + with results_lock: + trial_results.append(result) + with open(output_dir / f"trial_{trial.number:04d}.json", "w") as f: + json.dump(result, f, indent=2) + + return mean_lre + + plots_dir = output_dir / "plots" + plots_dir.mkdir(exist_ok=True) + + storage = f"sqlite:///{output_dir}/optuna_study.db" + sampler = GridSampler(search_space) + study = optuna.create_study( + study_name=study_name, + direction="minimize", + sampler=sampler, + storage=storage, + load_if_exists=True, + ) + logging.info(f"Optuna study stored at {output_dir}/optuna_study.db") + logging.info(f"Running {n_trials} trials with n_jobs={n_jobs}") + study.optimize(objective, n_trials=n_trials, n_jobs=n_jobs) + + best = study.best_trial + logging.info( + f"Best trial {best.number}: w={best.params['w']}, beta={best.params['beta']}, " + f"lmd={best.params['lmd']}, maxiter={best.params['maxiter']} " + f"→ mean LRE = {best.value:.2f} µm" + ) + + best_params = { + "coherent_point_drift": { + "w": best.params["w"], + "beta": best.params["beta"], + "lmd": best.params["lmd"], + "maxiter": best.params["maxiter"], + }, + "optimization": { + "mean_lre_um": best.value, + "n_trials_run": len(study.trials), + }, + } + best_params_path = output_dir / "best_cpd_params.yaml" + with open(best_params_path, "w") as f: + yaml.dump(best_params, f, default_flow_style=False) + logging.info(f"Best parameters written to {best_params_path}") + + summary = pd.DataFrame([ + { + "trial": r["trial"], + "w": r["w"], + "beta": r["beta"], + "lmd": r["lmd"], + "maxiter": r["maxiter"], + "mean_lre_um": r["mean_lre_um"], + } + for r in trial_results + ]) + summary_path = output_dir / "study_results.csv" + summary.sort_values("mean_lre_um").to_csv(summary_path, index=False) + logging.info(f"Study summary written to {summary_path}") + + return best_params_path + + +@click.command() +@click.option("--config", required=True, help="Path to cpd_optimization_config.yaml") +@click.option( + "--landmark_ids_json", + default=None, + help="Path to landmark_label_ids.json (default: /landmark_label_ids.json)", +) +@click.option( + "--fixed_path", default=None, help="Override fixed image n5 path from config" +) +@click.option( + "--moving_path", default=None, help="Override moving image n5 path from config" +) +def main(config, landmark_ids_json, fixed_path, moving_path): + with open(config) as f: + cfg = yaml.safe_load(f) + + output_dir = Path(cfg["log_dir"]) + output_dir.mkdir(parents=True, exist_ok=True) + setup_logging(str(output_dir), "cpd_optimization.log") + + fixed_path = fixed_path or cfg["fixed_image"]["path"] + fixed_key = cfg["fixed_image"]["aligned_key"] + + moving_path = moving_path or cfg["moving_image"]["path"] + moving_key = cfg["moving_image"]["aligned_key"] + + landmark_ids_path = Path(landmark_ids_json) if landmark_ids_json else output_dir / "landmark_label_ids.json" + with open(landmark_ids_path) as f: + id_map = {name: int(lbl) for name, lbl in json.load(f).items()} + logging.info(f"Loaded {len(id_map)} landmark label IDs from {landmark_ids_path}") + + logging.info("Extracting point clouds (done once, reused across all trials)") + + logging.info("Reading fixed image") + fixed_img = read_volume(fixed_path, fixed_key) + fixed_resolution = get_attrs(fixed_path, fixed_key)["resolution"] + logging.info(f"Fixed image shape: {fixed_img.shape}, dtype {fixed_img.dtype}") + + logging.info("Reading moving image") + moving_img = read_volume(moving_path, moving_key) + moving_resolution = get_attrs(moving_path, moving_key)["resolution"] + logging.info(f"Moving image shape: {moving_img.shape}, dtype {moving_img.dtype}") + + fixed_labels, fixed_coords = extract_centroids(fixed_img, fixed_resolution) + moving_labels, moving_coords = extract_centroids(moving_img, moving_resolution) + + fixed_pcd = create_pcd(fixed_coords, fixed_labels) + moving_pcd = create_pcd(moving_coords, moving_labels) + + search_space_cfg = cfg.get("optuna", {}).get("search_space", "default") + if isinstance(search_space_cfg, dict): + search_space = {k: list(v) for k, v in search_space_cfg.items()} + logging.info("Using manually specified search space from config") + elif search_space_cfg == "dataset-specific": + search_space = dict(DEFAULT_SEARCH_SPACE) + search_space["beta"] = suggest_beta_ranges(fixed_coords, fallback_betas=DEFAULT_SEARCH_SPACE["beta"]) + logging.info(f"Dataset-specific beta ranges: {search_space['beta']}") + else: + search_space = DEFAULT_SEARCH_SPACE + logging.info("Using default search space from default_cpd_ranges.yaml") + n_combinations = reduce(lambda a, b: a * b, (len(v) for v in search_space.values()), 1) + logging.info(f"Grid search space: {search_space}") + logging.info(f"Total combinations: {n_combinations}") + + study_name = cfg.get("optuna", {}).get("study_name", "cpd_optimization") + n_jobs = cfg.get("optuna", {}).get("n_jobs", 1) + + run_optimization( + fixed_pcd=fixed_pcd, + moving_pcd=moving_pcd, + fixed_labels=fixed_labels, + moving_labels=moving_labels, + id_map=id_map, + search_space=search_space, + n_trials=n_combinations, + study_name=study_name, + output_dir=output_dir, + n_jobs=n_jobs, + ) + + +if __name__ == "__main__": + main() diff --git a/matchmaker/cpd_parameter_tuning/default_cpd_params.yaml b/matchmaker/cpd_parameter_tuning/default_cpd_params.yaml new file mode 100644 index 0000000..9bacb84 --- /dev/null +++ b/matchmaker/cpd_parameter_tuning/default_cpd_params.yaml @@ -0,0 +1,4 @@ +beta: 100 +w: 1.0e-5 +lmd: 0.1 +maxiter: 100 diff --git a/matchmaker/cpd_parameter_tuning/default_cpd_ranges.yaml b/matchmaker/cpd_parameter_tuning/default_cpd_ranges.yaml new file mode 100644 index 0000000..246cfbe --- /dev/null +++ b/matchmaker/cpd_parameter_tuning/default_cpd_ranges.yaml @@ -0,0 +1,4 @@ +w: [1.0e-5, 1.0e-4, 1.0e-3, 1.0e-2, 1.0e-1] +beta: [10.0, 50.0, 100.0, 200.0] +lmd: [0.01, 0.1, 1.0, 10.0] +maxiter: [100] diff --git a/matchmaker/cpd_parameter_tuning/suggest_cpd_ranges.py b/matchmaker/cpd_parameter_tuning/suggest_cpd_ranges.py new file mode 100644 index 0000000..dc1e71b --- /dev/null +++ b/matchmaker/cpd_parameter_tuning/suggest_cpd_ranges.py @@ -0,0 +1,93 @@ +import logging +from pathlib import Path + +import click +import numpy as np +import yaml +from scipy.spatial import cKDTree + +from matchmaker.utils import read_volume, get_attrs, extract_centroids + +_DEFAULT_RANGES_FILE = Path(__file__).parent / "default_cpd_ranges.yaml" + + +def compute_point_spacing(coords: np.ndarray) -> float: + """Mean nearest-neighbor distance as a proxy for point spacing (µm).""" + tree = cKDTree(coords) + distances, _ = tree.query(coords, k=2) # k=2: first hit is self (dist=0) + return float(np.mean(distances[:, 1])) + + +def suggest_beta_ranges(coords: np.ndarray, fallback_betas: list | None = None) -> list: + """4 beta values as max(n*h, f*D) with h=point spacing, D=mean extent. + + Values: max(2h,0.025D), max(4h,0.05D), max(8h,0.1D), max(16h,0.2D). + The h-based terms set a lower bound relative to point density; the + D-based terms prevent pathologically small betas when landmarks are + very densely packed relative to the object scale. + + Args: + coords: (N, 3) array of point positions in µm. + fallback_betas: returned as-is when coords has fewer than 2 points. + + Returns: + List of 4 beta values rounded to 2 decimal places. + """ + if len(coords) < 2: + logging.warning("Fewer than 2 points; returning fallback beta ranges") + return fallback_betas if fallback_betas is not None else [] + + h = compute_point_spacing(coords) + D = float(np.mean(coords.max(axis=0) - coords.min(axis=0))) + logging.info(f"Mean NN spacing h={h:.2f} µm, mean extent D={D:.1f} µm") + + betas = [ + max(2 * h, 0.025 * D), + max(4 * h, 0.05 * D), + max(8 * h, 0.1 * D), + max(16 * h, 0.2 * D), + ] + betas = np.round(betas, 2).tolist() + logging.info(f"Beta ranges: {betas}") + return betas + + +@click.command() +@click.option("--path", required=True, help="Path to .n5 segmentation file") +@click.option("--key", required=True, help="Dataset key to read from .n5") +@click.option("--log_dir", required=True, type=click.Path(), help="Directory to write dataset_cpd_ranges.yaml") +@click.option("--x_res", type=float, default=None, help="X resolution in µm (overrides .n5 attrs)") +@click.option("--y_res", type=float, default=None, help="Y resolution in µm (overrides .n5 attrs)") +@click.option("--z_res", type=float, default=None, help="Z resolution in µm (overrides .n5 attrs)") +def main(path, key, log_dir, x_res, y_res, z_res): + seg = read_volume(path, key) + + if x_res is None or y_res is None or z_res is None: + attrs = dict(get_attrs(path, key)) + resolution = attrs["resolution"] + else: + resolution = [z_res, y_res, x_res] + + _, positions = extract_centroids(seg, resolution) + n_points = len(positions) + + extents = positions.max(axis=0) - positions.min(axis=0) + betas = suggest_beta_ranges(positions) + + with open(_DEFAULT_RANGES_FILE) as f: + search_space = yaml.safe_load(f) + search_space["beta"] = betas + + output_path = Path(log_dir) / "dataset_cpd_ranges.yaml" + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + yaml.dump(search_space, f, default_flow_style=False) + + click.echo(f"# Point cloud: {n_points} points, extent (x,y,z) µm: " + f"[{extents[0]:.1f}, {extents[1]:.1f}, {extents[2]:.1f}]") + click.echo(f"Dataset-specific beta values: {betas}") + click.echo(f"# Dataset-specific CPD ranges written to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/matchmaker/utils/vis.py b/matchmaker/utils/vis.py index bd4f7ef..9422d96 100644 --- a/matchmaker/utils/vis.py +++ b/matchmaker/utils/vis.py @@ -160,6 +160,63 @@ def plot_three_slices( plt.close() +def plot_landmark_qc( + seg_with_lm, + id_map, + save_path=None, + cell_cmap=None, + landmark_color="red", +): + """Three-slice QC plot: cells in cell_cmap, each landmark centroid as a labeled scatter dot. + + Landmark positions are projected onto the mid-slice of each axis so all landmarks + are visible regardless of depth, making it easy to confirm placements visually. + + Args: + seg_with_lm: ZYX integer array with landmark spheres embedded at label IDs from id_map + id_map: {landmark_name: label_id} + save_path: output path; shows interactively if None + cell_cmap: colormap for regular cells (default: PINK) + landmark_color: scatter / text color for landmarks (default: "red") + """ + if cell_cmap is None: + cell_cmap = PINK + + min_lm_id = min(id_map.values()) + cells = (seg_with_lm > 0) & (seg_with_lm < min_lm_id) + + lm_centroids = {} + for name, lbl in id_map.items(): + voxels = np.argwhere(seg_with_lm == lbl) + if len(voxels) == 0: + logging.warning(f"Landmark {name!r} (id={lbl}) not found in segmentation") + continue + lm_centroids[name] = voxels.mean(axis=0) # [z, y, x] + + fig, axes = plt.subplots(1, 3, figsize=(18, 6)) + proj_specs = [ + (0, "xy projection", cells.max(axis=0)), + (1, "xz projection", cells.max(axis=1)), + (2, "yz projection", cells.max(axis=2)), + ] + + for ax, (axis, title, s) in zip(axes, proj_specs): + ax.set_title(title) + ax.imshow(s.astype(np.float32), cmap=cell_cmap, vmin=0, vmax=1, alpha=0.5) + for name, c in lm_centroids.items(): + col, row = _slice_gc_coords(c, axis) + ax.scatter(col, row, c=landmark_color, s=15, zorder=5, linewidths=0) + ax.text(col + 2, row, name, fontsize=4, color=landmark_color, zorder=6, va="center") + ax.invert_yaxis() + + plt.tight_layout() + if save_path is None: + plt.show() + else: + plt.savefig(save_path, dpi=300, bbox_inches="tight") + plt.close() + + def plot_overlay(img1, img2, save_path=None, x_pos=None, y_pos=None, z_pos=None, gc1=None, Vt1=None, gc2=None, Vt2=None): """ @@ -311,6 +368,7 @@ def visualize_displacement_field( projection="xy", center_slice=True, max_points=2000, + title=None, ): assert ( len(projection) == 2 @@ -332,10 +390,12 @@ def visualize_displacement_field( plt.axis("equal") plt.gca().invert_yaxis() + if title is not None: + plt.title(title, fontsize=8) if save_path is None: plt.show() else: - plt.savefig(save_path, dpi=300) + plt.savefig(save_path, dpi=300, bbox_inches="tight") plt.close() diff --git a/workflows/cpd_optimization.smk b/workflows/cpd_optimization.smk new file mode 100644 index 0000000..c994d3d --- /dev/null +++ b/workflows/cpd_optimization.smk @@ -0,0 +1,167 @@ +from pathlib import Path + +root_dir = f"{Path(workflow.basedir).resolve().parent}/" +workdir: root_dir + +fixed_input_key = config["fixed_image"]["input_key"] +fixed_aligned_key = config["fixed_image"]["aligned_key"] + +moving_input_key = config["moving_image"]["input_key"] +moving_aligned_key = config["moving_image"]["aligned_key"] + +fixed_lm_csv = config["landmarks"]["fixed"] +moving_lm_csv = config["landmarks"]["moving"] + +log_dir = config["log_dir"] +axis_orientation = config["prealignment"]["axis_orientation"] + +fixed_name = config.get("fixed_name", "fixed_image") +moving_name = config.get("moving_name", "moving_image") +fixed_n5_path = f"{log_dir}/{fixed_name}.n5" +moving_n5_path = f"{log_dir}/{moving_name}.n5" + +raw_n5_key = "input" +lm_input_key = "input_with_lm" +landmark_ids_json = f"{log_dir}/prepare_landmarks/landmark_label_ids.json" + +fixed_spacing = [config["fixed_image"]["z_res"], config["fixed_image"]["y_res"], config["fixed_image"]["x_res"]] +moving_spacing = [config["moving_image"]["z_res"], config["moving_image"]["y_res"], config["moving_image"]["x_res"]] + + +rule all: + input: + best_params = f"{log_dir}/best_cpd_params.yaml", + study_results = f"{log_dir}/study_results.csv", + + +rule input_to_n5: + input: + fixed_image_path = config["fixed_image"]["path"], + moving_image_path = config["moving_image"]["path"], + output: + directory(f"{fixed_n5_path}/{raw_n5_key}/"), + directory(f"{moving_n5_path}/{raw_n5_key}/"), + fixed_image_n5 = directory(fixed_n5_path), + moving_image_n5 = directory(moving_n5_path), + params: + output_key = raw_n5_key, + log: f"{log_dir}/matchmaker.log" + shell: + f"rm -rf {{output.fixed_image_n5}};" + f"rm -rf {{output.moving_image_n5}};" + f"python matchmaker/raw_to_n5.py --input_path {{input.fixed_image_path}} --input_key {fixed_input_key} --output_path {{output.fixed_image_n5}} --output_key {{params.output_key}} --log_dir {log_dir} --x_res {config['fixed_image']['x_res']} --y_res {config['fixed_image']['y_res']} --z_res {config['fixed_image']['z_res']};" + f"python matchmaker/raw_to_n5.py --input_path {{input.moving_image_path}} --input_key {moving_input_key} --output_path {{output.moving_image_n5}} --output_key {{params.output_key}} --log_dir {log_dir} --x_res {config['moving_image']['x_res']} --y_res {config['moving_image']['y_res']} --z_res {config['moving_image']['z_res']};" + + +rule add_landmarks: + """Embed landmarks into both raw segmentations.""" + input: + fixed_n5 = fixed_n5_path, + moving_n5 = moving_n5_path, + fixed_input_ds = f"{fixed_n5_path}/{raw_n5_key}", + moving_input_ds = f"{moving_n5_path}/{raw_n5_key}", + fixed_lm_csv = fixed_lm_csv, + moving_lm_csv = moving_lm_csv, + output: + directory(f"{fixed_n5_path}/{lm_input_key}"), + directory(f"{moving_n5_path}/{lm_input_key}"), + landmark_ids = landmark_ids_json, + params: + fixed_input_key = raw_n5_key, + moving_input_key = raw_n5_key, + lm_input_key = lm_input_key, + log_dir = f"{log_dir}/prepare_landmarks", + log: f"{log_dir}/prepare_landmarks/add_landmarks.log" + shell: + "python matchmaker/cpd_parameter_tuning/add_landmarks.py " + "--fixed_path {input.fixed_n5} " + "--fixed_key {params.fixed_input_key} " + "--fixed_output_key {params.lm_input_key} " + "--fixed_landmarks_csv {input.fixed_lm_csv} " + "--moving_path {input.moving_n5} " + "--moving_key {params.moving_input_key} " + "--moving_output_key {params.lm_input_key} " + "--moving_landmarks_csv {input.moving_lm_csv} " + "--log_dir {params.log_dir}" + + +rule prealignment_with_lm: + """Run SVD prealignment on the landmark-embedded segmentations.""" + input: + fixed_ds = f"{fixed_n5_path}/{lm_input_key}", + moving_ds = f"{moving_n5_path}/{lm_input_key}", + output: + directory(f"{fixed_n5_path}/{fixed_aligned_key}"), + directory(f"{moving_n5_path}/{fixed_aligned_key}"), + transform = f"{log_dir}/prealignment_with_lm/prealignment_transform.json", + params: + fixed_path = fixed_n5_path, + moving_path = moving_n5_path, + input_key = lm_input_key, + output_key = fixed_aligned_key, + fixed_spacing = f"{config['fixed_image']['z_res']} {config['fixed_image']['y_res']} {config['fixed_image']['x_res']}", + moving_spacing = f"{config['moving_image']['z_res']} {config['moving_image']['y_res']} {config['moving_image']['x_res']}", + axis_orientation = axis_orientation, + output_dir = f"{log_dir}/prealignment_with_lm", + log: f"{log_dir}/prealignment_with_lm/prealignment.log" + shell: + "python matchmaker/prealignment.py " + "--fixed_path {params.fixed_path} " + "--fixed_key {params.input_key} " + "--fixed_spacing {params.fixed_spacing} " + "--moving_path {params.moving_path} " + "--moving_key {params.input_key} " + "--moving_spacing {params.moving_spacing} " + "--output_dir {params.output_dir} " + "--output_key {params.output_key} " + "--output_transform_path {output.transform} " + "--axis_orientation {params.axis_orientation}" + + +rule rigid_alignment_with_lm: + """Run elastix rigid alignment on the prealigned landmark-embedded segmentations.""" + input: + fixed_ds = f"{fixed_n5_path}/{fixed_aligned_key}", + moving_ds = f"{moving_n5_path}/{fixed_aligned_key}", + output: + directory(f"{moving_n5_path}/{moving_aligned_key}"), + params: + fixed_path = fixed_n5_path, + moving_path = moving_n5_path, + input_key = fixed_aligned_key, + output_key = moving_aligned_key, + output_dir = f"{log_dir}/rigid_alignment_with_lm", + log: f"{log_dir}/rigid_alignment_with_lm/rigid_alignment.log" + shell: + "python matchmaker/rigid_alignment_elastix.py " + "--fixed_path {params.fixed_path} " + "--fixed_key {params.input_key} " + "--moving_path {params.moving_path} " + "--moving_key {params.input_key} " + "--output_dir {params.output_dir} " + "--output_key {params.output_key}" + + +rule optimize_cpd: + """Run Optuna grid search over CPD parameters, evaluated by mean LRE. + Beta ranges for dataset-specific mode are computed inside cpd_optimization.py + from the extracted point clouds. + """ + input: + fixed_ds = f"{fixed_n5_path}/{fixed_aligned_key}", + moving_ds = f"{moving_n5_path}/{moving_aligned_key}", + config = workflow.configfiles[0], + landmark_ids_json = landmark_ids_json, + output: + best_params = f"{log_dir}/best_cpd_params.yaml", + study_results = f"{log_dir}/study_results.csv", + params: + fixed_path = fixed_n5_path, + moving_path = moving_n5_path, + log: f"{log_dir}/cpd_optimization.log" + shell: + "python matchmaker/cpd_parameter_tuning/cpd_optimization.py " + "--config {input.config} " + "--landmark_ids_json {input.landmark_ids_json} " + "--fixed_path {params.fixed_path} " + "--moving_path {params.moving_path}"