diff --git a/examples/register_config_test_elastic.yaml b/examples/register_config_test_elastic.yaml index 570a841..b145eaa 100644 --- a/examples/register_config_test_elastic.yaml +++ b/examples/register_config_test_elastic.yaml @@ -29,8 +29,8 @@ coherent_point_drift: ILP: min_neighbours: 10 - max_dist: 30 + max_dist: 15 -mobie_export: False +mobie_export: True semantic_seg: True mobie_dataset_name: "platy1_muscles_stardist" diff --git a/matchmaker/cpd_nonrigid_registration.py b/matchmaker/cpd_nonrigid_registration.py index 965169f..eff9305 100644 --- a/matchmaker/cpd_nonrigid_registration.py +++ b/matchmaker/cpd_nonrigid_registration.py @@ -11,11 +11,9 @@ get_attrs ) - from matchmaker.utils import overlay_pcds, visualize_displacement_field, extract_centroids, run_cpd, create_pcd - def cpd_from_images(fixed_img, fixed_resolution, moving_img, moving_resolution, output_dir, w, beta, lmd, maxiter): fixed_labels, fixed_center_coords = extract_centroids(fixed_img, fixed_resolution) moving_labels, moving_center_coord = extract_centroids(moving_img, moving_resolution) @@ -23,21 +21,57 @@ def cpd_from_images(fixed_img, fixed_resolution, moving_img, moving_resolution, fixed_pcd = create_pcd(fixed_center_coords, fixed_labels) moving_pcd = create_pcd(moving_center_coord, moving_labels) - overlay_pcds(fixed_pcd, moving_pcd, projection="xz", save_path = output_dir / "pcds_before_registration_xz.png") - overlay_pcds(fixed_pcd, moving_pcd, projection="yz", save_path = output_dir / "pcds_before_registration_yz.png") - overlay_pcds(fixed_pcd, moving_pcd, projection="xy", save_path = output_dir / "pcds_before_registration_xy.png") + overlay_pcds(fixed_pcd, moving_pcd, projection="xz", save_path=output_dir / "plots/pcds_before_registration_xz.png") + overlay_pcds(fixed_pcd, moving_pcd, projection="yz", save_path=output_dir / "plots/pcds_before_registration_yz.png") + overlay_pcds(fixed_pcd, moving_pcd, projection="xy", save_path=output_dir / "plots/pcds_before_registration_xy.png") logging.info(f"Point cloud registration with parameters w={w}, beta={beta}, lmd={lmd}, maxiter={maxiter}") registered_pcd = run_cpd(fixed_pcd, moving_pcd, w, beta, lmd, maxiter) - overlay_pcds(fixed_pcd, registered_pcd, projection="xz", save_path = output_dir / "pcds_after_registration_xz.png") - overlay_pcds(fixed_pcd, registered_pcd, projection="yz", save_path = output_dir / "pcds_after_registration_yz.png") - overlay_pcds(fixed_pcd, registered_pcd, projection="xy", save_path = output_dir / "pcds_after_registration_xy.png") + overlay_pcds( + fixed_pcd, + registered_pcd, + projection="xz", + save_path=output_dir / "plots/pcds_after_registration_xz.png", + ) + overlay_pcds( + fixed_pcd, + registered_pcd, + projection="yz", + save_path=output_dir / "plots/pcds_after_registration_yz.png", + ) + overlay_pcds( + fixed_pcd, + registered_pcd, + projection="xy", + save_path=output_dir / "plots/pcds_after_registration_xy.png", + ) - visualize_displacement_field(moving_pcd, registered_pcd, projection="xz", save_path = output_dir / "displacement_field_xz.png") - visualize_displacement_field(moving_pcd, registered_pcd, projection="yz", save_path = output_dir / "displacement_field_yz.png") - visualize_displacement_field(moving_pcd, registered_pcd, projection="xy", save_path = output_dir / "displacement_field_xy.png") + visualize_displacement_field( + moving_pcd, + registered_pcd, + projection="xz", + save_path=output_dir / "plots/displacement_field_xz.png", + ) + visualize_displacement_field( + moving_pcd, + registered_pcd, + projection="xz", + save_path=output_dir / "plots/displacement_field_xz.pdf", + ) + visualize_displacement_field( + moving_pcd, + registered_pcd, + projection="yz", + save_path=output_dir / "plots/displacement_field_yz.png", + ) + visualize_displacement_field( + moving_pcd, + registered_pcd, + projection="xy", + save_path=output_dir / "plots/displacement_field_xy.png", + ) o3d.t.io.write_point_cloud(str(output_dir / "fixed_pcd.pcd"), fixed_pcd, write_ascii=True) o3d.t.io.write_point_cloud(str(output_dir / "moving_pcd.pcd"), moving_pcd, write_ascii=True) @@ -90,8 +124,5 @@ def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, w, beta, lm ) - - - if __name__ == "__main__": main() \ No newline at end of file diff --git a/matchmaker/match_pointclouds.py b/matchmaker/match_pointclouds.py index 6c200c2..b99724a 100644 --- a/matchmaker/match_pointclouds.py +++ b/matchmaker/match_pointclouds.py @@ -10,8 +10,6 @@ from matchmaker.utils import sparse_ilp_matching, write_index_pairs, plot_matching_qc - - def match_points(fixed_pcd, registered_pcd, output_dir, max_dist, min_neighbours): logging.info(f"Number of points in fixed pcd: {len(fixed_pcd.point.positions)}") @@ -44,16 +42,14 @@ def match_points(fixed_pcd, registered_pcd, output_dir, max_dist, min_neighbours ] if swap_order: - plot_matching_qc(pos_2, pos_1, output_dir / "point_matching_xz.png", pairs=matched_idx_pairs, projection="xz") - plot_matching_qc(pos_2, pos_1, output_dir / "point_matching_yz.png", pairs=matched_idx_pairs, projection="yz") - plot_matching_qc(pos_2, pos_1, output_dir / "point_matching_xy.png", pairs=matched_idx_pairs, projection="xy") + plot_matching_qc(pos_2, pos_1, output_dir / "plots/point_matching_xz.png", pairs=matched_idx_pairs, projection="xz") + plot_matching_qc(pos_2, pos_1, output_dir / "plots/point_matching_yz.png", pairs=matched_idx_pairs, projection="yz") + plot_matching_qc(pos_2, pos_1, output_dir / "plots/point_matching_xy.png", pairs=matched_idx_pairs, projection="xy") else: - plot_matching_qc(pos_1, pos_2, output_dir / "point_matching_xz.png", pairs=matched_idx_pairs, projection="xz") - plot_matching_qc(pos_1, pos_2, output_dir / "point_matching_yz.png", pairs=matched_idx_pairs, projection="yz") - plot_matching_qc(pos_1, pos_2, output_dir / "point_matching_xy.png", pairs=matched_idx_pairs, projection="xy") - - + plot_matching_qc(pos_1, pos_2, output_dir / "plots/point_matching_xz.png", pairs=matched_idx_pairs, projection="xz") + plot_matching_qc(pos_1, pos_2, output_dir / "plots/point_matching_yz.png", pairs=matched_idx_pairs, projection="yz") + plot_matching_qc(pos_1, pos_2, output_dir / "plots/point_matching_xy.png", pairs=matched_idx_pairs, projection="xy") return matched_idx_pairs, matched_label_pairs diff --git a/matchmaker/prealignment.py b/matchmaker/prealignment.py index fd66066..ff496e3 100755 --- a/matchmaker/prealignment.py +++ b/matchmaker/prealignment.py @@ -1,4 +1,3 @@ -import sys import click import logging import numpy as np @@ -6,11 +5,21 @@ import matplotlib.pyplot as plt from matchmaker.data import create_point_cloud -from matchmaker.utils import (get_transformation_matrix, rotate_img, read_volume, get_attrs, write_volume, - write_transform_dict, plot_three_slices, plot_overlay, setup_logging, - get_axis_orient_matrix, resample_volume, transform_axes_vis) - -from matchmaker.utils.vis import CYAN_HEX, PINK, CYAN, PINK_HEX +from matchmaker.utils import ( + get_transformation_matrix, + rotate_img, + read_volume, + get_attrs, + write_volume, + write_transform_dict, + plot_three_slices, + plot_overlay, + setup_logging, + get_axis_orient_matrix, + transform_axes_vis, +) + +from matchmaker.utils.vis import CYAN_HEX, PINK, CYAN, PINK_HEX, LABEL, _savefig def get_SVD_transform(img, spacing, save_path=None): @@ -55,106 +64,138 @@ def get_SVD_transform(img, spacing, save_path=None): plt.subplot(1, 2, 2) plt.title("Rotated Vertices") plt.scatter(vr[:, 0], vr[:, 1], alpha=0.1) - plt.savefig(save_path, dpi=300) + _savefig(save_path) return gc, Vt def orient_axis(fixed_prealigned, moving_prealigned, output_dir): - int_prof_z = np.sum(moving_prealigned > 0, axis=(1, 2)) / np.sum(moving_prealigned > 0) - asymm_coeff_z = np.corrcoef(int_prof_z, int_prof_z)[0, 1] - np.corrcoef(int_prof_z, int_prof_z[::-1])[0, 1] - int_prof_y = np.sum(moving_prealigned > 0, axis=(0, 2)) / np.sum(moving_prealigned > 0) - asymm_coeff_y = np.corrcoef(int_prof_y, int_prof_y)[0, 1] - np.corrcoef(int_prof_y, int_prof_y[::-1])[0, 1] - int_prof_x = np.sum(moving_prealigned > 0, axis=(1, 0)) / np.sum(moving_prealigned > 0) - asymm_coeff_x = np.corrcoef(int_prof_x, int_prof_x)[0, 1] - np.corrcoef(int_prof_x, int_prof_x[::-1])[0, 1] - - logging.info(f"Asymmetry coefficients in moving image:") - logging.info(f"Z: {asymm_coeff_z}") - logging.info(f"Y: {asymm_coeff_y}") - logging.info(f"X: {asymm_coeff_x}") - - int_prof_z_fixed = np.sum(fixed_prealigned > 0, axis=(1, 2)) / np.sum(fixed_prealigned > 0) - int_prof_y_fixed = np.sum(fixed_prealigned > 0, axis=(0, 2)) / np.sum(fixed_prealigned > 0) - int_prof_x_fixed = np.sum(fixed_prealigned > 0, axis=(1, 0)) / np.sum(fixed_prealigned > 0) - - plt.figure() - plt.plot(int_prof_z_fixed, label="fixed", color=PINK_HEX) - plt.plot(int_prof_z, label="moving", color=CYAN_HEX) - plt.xlabel(f"Axis Z Coordinate") - plt.ylabel(f"Sum intensity along axis = Z") - plt.legend() - plt.savefig(f"{output_dir}/plots/axis_int_profile_Z.png", dpi=300) - - plt.figure() - plt.plot(int_prof_y_fixed, label="fixed", color=PINK_HEX) - plt.plot(int_prof_y, label="moving", color=CYAN_HEX) - plt.xlabel(f"Axis Y Coordinate") - plt.ylabel(f"Sum intensity along axis = Y") - plt.legend() - plt.savefig(f"{output_dir}/plots/axis_int_profile_Y.png", dpi=300) - - plt.figure() - plt.plot(int_prof_x_fixed, label="fixed", color=PINK_HEX) - plt.plot(int_prof_x, label="moving", color=CYAN_HEX) - plt.xlabel(f"Axis X Coordinate") - plt.ylabel(f"Sum intensity along axis = X") - plt.legend() - plt.savefig(f"{output_dir}/plots/axis_int_profile_X.png", dpi=300) - - if np.corrcoef(int_prof_z, int_prof_z_fixed)[0, 1] > np.corrcoef(int_prof_z[::-1], int_prof_z_fixed)[0, 1]: - z_correct = True - else: - z_correct = False + int_prof_z = np.sum(moving_prealigned > 0, axis=(1, 2)) / np.sum( + moving_prealigned > 0 + ) + asymm_coeff_z = ( + np.corrcoef(int_prof_z, int_prof_z)[0, 1] + - np.corrcoef(int_prof_z, int_prof_z[::-1])[0, 1] + ) + int_prof_y = np.sum(moving_prealigned > 0, axis=(0, 2)) / np.sum( + moving_prealigned > 0 + ) + asymm_coeff_y = ( + np.corrcoef(int_prof_y, int_prof_y)[0, 1] + - np.corrcoef(int_prof_y, int_prof_y[::-1])[0, 1] + ) + int_prof_x = np.sum(moving_prealigned > 0, axis=(1, 0)) / np.sum( + moving_prealigned > 0 + ) + asymm_coeff_x = ( + np.corrcoef(int_prof_x, int_prof_x)[0, 1] + - np.corrcoef(int_prof_x, int_prof_x[::-1])[0, 1] + ) - if np.corrcoef(int_prof_y, int_prof_y_fixed)[0, 1] > np.corrcoef(int_prof_y[::-1], int_prof_y_fixed)[0, 1]: - y_correct = True - else: - y_correct = False + logging.info("Asymmetry coefficients in moving image:") + logging.info(f"Z: {asymm_coeff_z}") + logging.info(f"Y: {asymm_coeff_y}") + logging.info(f"X: {asymm_coeff_x}") - if np.corrcoef(int_prof_x, int_prof_x_fixed)[0, 1] > np.corrcoef(int_prof_x[::-1], int_prof_x_fixed)[0, 1]: - x_correct = True - else: - x_correct = False - - logging.info(f"Orientations are correct: Z - {z_correct}, Y - {y_correct}, X - {x_correct}") - - if (asymm_coeff_x < asymm_coeff_y) and (asymm_coeff_x < asymm_coeff_z): - logging.info(f"Using axes Y and Z for determining orientation") - if z_correct and y_correct: - logging.info(f"Orientation along both axes is correct") - R = np.eye(4, 4) - elif z_correct and not y_correct: - logging.info(f"Rotate 180 degrees around Z axis") - R = get_axis_orient_matrix(moving_prealigned, "xyz") - else: - logging.info(f"Rotate 180 degrees around X axis") - R = get_axis_orient_matrix(moving_prealigned, "zyx") - - elif (asymm_coeff_y < asymm_coeff_x) and (asymm_coeff_y < asymm_coeff_z): - logging.info(f"Using axes X and Z for determining orientation") - if z_correct and x_correct: - logging.info(f"Orientation along both axes is correct") - R = np.eye(4, 4) - elif z_correct and not x_correct: - logging.info(f"Rotate 180 degrees around Z axis") - R = get_axis_orient_matrix(moving_prealigned, "xyz") - else: - logging.info(f"Rotate 180 degrees around Y axis") - R = get_axis_orient_matrix(moving_prealigned, "yzx") + int_prof_z_fixed = np.sum(fixed_prealigned > 0, axis=(1, 2)) / np.sum( + fixed_prealigned > 0 + ) + int_prof_y_fixed = np.sum(fixed_prealigned > 0, axis=(0, 2)) / np.sum( + fixed_prealigned > 0 + ) + int_prof_x_fixed = np.sum(fixed_prealigned > 0, axis=(1, 0)) / np.sum( + fixed_prealigned > 0 + ) + + plt.figure() + plt.plot(int_prof_z_fixed, label="fixed", color=PINK_HEX) + plt.plot(int_prof_z, label="moving", color=CYAN_HEX) + plt.xlabel("Axis Z Coordinate") + plt.ylabel("Sum intensity along axis = Z") + plt.legend() + _savefig(f"{output_dir}/plots/axis_int_profile_Z.png") + + plt.figure() + plt.plot(int_prof_y_fixed, label="fixed", color=PINK_HEX) + plt.plot(int_prof_y, label="moving", color=CYAN_HEX) + plt.xlabel("Axis Y Coordinate") + plt.ylabel("Sum intensity along axis = Y") + plt.legend() + _savefig(f"{output_dir}/plots/axis_int_profile_Y.png") + + plt.figure() + plt.plot(int_prof_x_fixed, label="fixed", color=PINK_HEX) + plt.plot(int_prof_x, label="moving", color=CYAN_HEX) + plt.xlabel("Axis X Coordinate") + plt.ylabel("Sum intensity along axis = X") + plt.legend() + _savefig(f"{output_dir}/plots/axis_int_profile_X.png") + + if ( + np.corrcoef(int_prof_z, int_prof_z_fixed)[0, 1] + > np.corrcoef(int_prof_z[::-1], int_prof_z_fixed)[0, 1] + ): + z_correct = True + else: + z_correct = False + + if ( + np.corrcoef(int_prof_y, int_prof_y_fixed)[0, 1] + > np.corrcoef(int_prof_y[::-1], int_prof_y_fixed)[0, 1] + ): + y_correct = True + else: + y_correct = False + + if ( + np.corrcoef(int_prof_x, int_prof_x_fixed)[0, 1] + > np.corrcoef(int_prof_x[::-1], int_prof_x_fixed)[0, 1] + ): + x_correct = True + else: + x_correct = False + + logging.info( + f"Orientations are correct: Z - {z_correct}, Y - {y_correct}, X - {x_correct}" + ) + if (asymm_coeff_x < asymm_coeff_y) and (asymm_coeff_x < asymm_coeff_z): + logging.info("Using axes Y and Z for determining orientation") + if z_correct and y_correct: + logging.info("Orientation along both axes is correct") + R = np.eye(4, 4) + elif z_correct and not y_correct: + logging.info("Rotate 180 degrees around Z axis") + R = get_axis_orient_matrix(moving_prealigned, "xyz") else: - logging.info(f"Using axes X and Y for determining orientation") - if x_correct and y_correct: - logging.info(f"Orientation along both axes is correct") - R = np.eye(4, 4) - elif x_correct and not y_correct: - logging.info(f"Rotate 180 degrees around X axis") - R = get_axis_orient_matrix(moving_prealigned, "zyx") - else: - logging.info(f"Rotate 180 degrees around Y axis") - R = get_axis_orient_matrix(moving_prealigned, "yzx") + logging.info("Rotate 180 degrees around X axis") + R = get_axis_orient_matrix(moving_prealigned, "zyx") + + elif (asymm_coeff_y < asymm_coeff_x) and (asymm_coeff_y < asymm_coeff_z): + logging.info("Using axes X and Z for determining orientation") + if z_correct and x_correct: + logging.info("Orientation along both axes is correct") + R = np.eye(4, 4) + elif z_correct and not x_correct: + logging.info("Rotate 180 degrees around Z axis") + R = get_axis_orient_matrix(moving_prealigned, "xyz") + else: + logging.info("Rotate 180 degrees around Y axis") + R = get_axis_orient_matrix(moving_prealigned, "yzx") + + else: + logging.info("Using axes X and Y for determining orientation") + if x_correct and y_correct: + logging.info("Orientation along both axes is correct") + R = np.eye(4, 4) + elif x_correct and not y_correct: + logging.info("Rotate 180 degrees around X axis") + R = get_axis_orient_matrix(moving_prealigned, "zyx") + else: + logging.info("Rotate 180 degrees around Y axis") + R = get_axis_orient_matrix(moving_prealigned, "yzx") - return R + return R def generate_rotation_overlays(fixed_prealigned, moving_prealigned, output_dir): @@ -206,14 +247,26 @@ def prealign_samples(fixed_img, moving_img, fixed_spacing, moving_spacing, new_s gc_fixed, Vt_fixed = get_SVD_transform(fixed_img, fixed_spacing) gc_moving, Vt_moving = get_SVD_transform(moving_img, moving_spacing) - T_fixed, fixed_shape = get_transformation_matrix(fixed_img, gc_fixed, Vt_fixed, fixed_spacing, - img_ref=moving_img, Vt_ref=Vt_moving, - spacing_ref=moving_spacing, - spacing_out=new_spacing) - T_moving, moving_shape = get_transformation_matrix(moving_img, gc_moving, Vt_moving, moving_spacing, - img_ref=fixed_img, Vt_ref=Vt_fixed, - spacing_ref=fixed_spacing, - spacing_out=new_spacing) + T_fixed, fixed_shape = get_transformation_matrix( + fixed_img, + gc_fixed, + Vt_fixed, + fixed_spacing, + img_ref=moving_img, + Vt_ref=Vt_moving, + spacing_ref=moving_spacing, + spacing_out=new_spacing, + ) + T_moving, moving_shape = get_transformation_matrix( + moving_img, + gc_moving, + Vt_moving, + moving_spacing, + img_ref=fixed_img, + Vt_ref=Vt_fixed, + spacing_ref=fixed_spacing, + spacing_out=new_spacing, + ) assert np.array_equal(fixed_shape, moving_shape) fixed_rot = rotate_img(fixed_img, T_fixed, output_shape=fixed_shape) @@ -318,14 +371,7 @@ def run_prealignment( save_path=f"{output_dir}/plots/fixed_input.pdf", gc=gc_fixed, Vt=Vt_fixed, - cmap="gnuplot2_r" - ) - plot_three_slices( - fixed_img, - save_path=f"{output_dir}/plots/fixed_input.png", - gc=gc_fixed, - Vt=Vt_fixed, - cmap="gnuplot2_r" + cmap=LABEL ) plot_three_slices( @@ -333,14 +379,7 @@ def run_prealignment( save_path=f"{output_dir}/plots/moving_input.pdf", gc=gc_moving, Vt=Vt_moving, - cmap="gnuplot2_r" - ) - plot_three_slices( - moving_img, - save_path=f"{output_dir}/plots/moving_input.png", - gc=gc_moving, - Vt=Vt_moving, - cmap="gnuplot2_r" + cmap=LABEL ) plot_three_slices( @@ -350,14 +389,6 @@ def run_prealignment( Vt=Vt_fixed, cmap=PINK, ) - plot_three_slices( - fixed_img, - save_path=f"{output_dir}/plots/fixed_input_semantic.png", - gc=gc_fixed, - Vt=Vt_fixed, - cmap=PINK, - - ) plot_three_slices( moving_img, @@ -366,13 +397,6 @@ def run_prealignment( Vt=Vt_moving, cmap=CYAN ) - plot_three_slices( - moving_img, - save_path=f"{output_dir}/plots/moving_input_semantic.png", - gc=gc_moving, - Vt=Vt_moving, - cmap=CYAN - ) plot_overlay( fixed_img, @@ -383,15 +407,6 @@ def run_prealignment( gc2=gc_moving, Vt2=Vt_moving, ) - plot_overlay( - fixed_img, - moving_img, - save_path=f"{output_dir}/plots/overlay_input.png", - gc1=gc_fixed, - Vt1=Vt_fixed, - gc2=gc_moving, - Vt2=Vt_moving, - ) plot_overlay( fixed_prealigned, @@ -405,7 +420,7 @@ def run_prealignment( # check orientation (if moving fits to fixed) if axis_orientation == "auto": - logging.info(f"Try to determine axis orientation based on intensity profile of the samples") + logging.info("Try to determine axis orientation based on intensity profile of the samples") R = orient_axis(fixed_prealigned, moving_prealigned, output_dir) # In case the automatic estimation is incorrect, generate possible rotations generate_rotation_overlays(fixed_prealigned, moving_prealigned, output_dir) @@ -443,7 +458,6 @@ def run_prealignment( moving_prealigned = rotate_img(moving_prealigned, R, output_shape=moving_prealigned.shape) T_moving = T_moving @ R - logging.info("Prealignment done.") prealignment_transform = { @@ -459,20 +473,18 @@ def run_prealignment( plot_three_slices( fixed_prealigned, - save_path=f"{output_dir}/plots/fixed_prealigned.png", - gc = (np.linalg.inv(T_fixed) @ np.append(gc_fixed, 1))[:3], - Vt = transform_axes_vis(Vt_fixed, T_fixed), + save_path=f"{output_dir}/plots/fixed_prealigned.pdf", + gc=(np.linalg.inv(T_fixed) @ np.append(gc_fixed, 1))[:3], + Vt=transform_axes_vis(Vt_fixed, T_fixed), cmap=PINK, - ) plot_three_slices( moving_prealigned, save_path=f"{output_dir}/plots/moving_prealigned.pdf", - gc = (np.linalg.inv(T_moving) @ np.append(gc_moving, 1))[:3], - Vt = transform_axes_vis(Vt_moving, T_moving), + gc=(np.linalg.inv(T_moving) @ np.append(gc_moving, 1))[:3], + Vt=transform_axes_vis(Vt_moving, T_moving), cmap=CYAN, - ) plot_overlay( @@ -501,7 +513,7 @@ def run_prealignment( @click.option("-axis_orientation", "--axis_orientation", required=True, help="How to find the correct orientation along the principal axes") @click.option("-tif", "--save_tif", is_flag=True, help="Whether to save tif or not") def main(fixed_path, fixed_key, fixed_spacing, moving_path, moving_key, moving_spacing, - output_dir, output_key, output_transform_path, axis_orientation, save_tif=False): + output_dir, output_key, output_transform_path, axis_orientation, save_tif=False): """ Perform prealignment of moving image to fixed image. @@ -515,7 +527,7 @@ def main(fixed_path, fixed_key, fixed_spacing, moving_path, moving_key, moving_s moving_path (str): Path to the moving input .n5 file. moving_key (str): Key to the moving image data in the .n5 file. output_dir (str): Directory where the results should be saved. - + output_key (str): Key to the output image data in the .n5 file. Returns: None """ diff --git a/matchmaker/raw_to_n5.py b/matchmaker/raw_to_n5.py index 2e737e7..d7fa5b2 100755 --- a/matchmaker/raw_to_n5.py +++ b/matchmaker/raw_to_n5.py @@ -1,11 +1,10 @@ import sys import click import logging -import numpy as np import tifffile as tif from pathlib import Path -from matchmaker.utils import (read_volume, write_volume, plot_three_slices, convert_to_int) +from matchmaker.utils import (read_volume, write_volume, plot_three_slices, convert_to_int, LABEL) def preprocess_tif_input(input_path, output_path, output_key, log_dir, x_res, y_res, z_res): @@ -31,7 +30,7 @@ def preprocess_tif_input(input_path, output_path, output_key, log_dir, x_res, y_ plot_three_slices( image[chan], log_dir / f"input_image_{Path(input_path).stem}_{chan}.png", - cmap="gnuplot2_r", + cmap=LABEL, ) else: @@ -39,7 +38,7 @@ def preprocess_tif_input(input_path, output_path, output_key, log_dir, x_res, y_ plot_three_slices( image, log_dir / f"input_image_{Path(input_path).stem}.png", - cmap="gnuplot2_r", + cmap=LABEL, ) write_volume(output_path, image, output_key, chunks=chunks, attrs=attrs) @@ -69,7 +68,7 @@ def preprocess_n5_input(input_path, input_key, output_path, output_key, log_dir, plot_three_slices( image[chan], log_dir / f"input_image_{Path(input_path).stem}_{chan}.png", - cmap="gnuplot2_r", + cmap=LABEL, ) else: @@ -77,7 +76,7 @@ def preprocess_n5_input(input_path, input_key, output_path, output_key, log_dir, plot_three_slices( image, save_path=log_dir / f"input_image_{Path(input_path).stem}.png", - cmap="gnuplot2_r", + cmap=LABEL, ) write_volume(output_path, image, output_key, chunks=chunks, attrs=attrs) diff --git a/matchmaker/utils/vis.py b/matchmaker/utils/vis.py index bd4f7ef..b8d244d 100644 --- a/matchmaker/utils/vis.py +++ b/matchmaker/utils/vis.py @@ -4,9 +4,20 @@ import seaborn as sns import matplotlib.colors as mcolors import logging +from pathlib import Path +from skimage.color import label2rgb from matchmaker.preprocessing import percentile_norm +# Change to 'png' or None (infer from path extension) to switch output format +PLOT_FORMAT = 'pdf' + + +def _savefig(save_path, dpi=300): + if PLOT_FORMAT is not None and save_path is not None: + save_path = Path(save_path).with_suffix(f'.{PLOT_FORMAT}') + plt.savefig(save_path, dpi=dpi) + PINK_HEX = '#FF3E96' CYAN_HEX = '#00CED1' @@ -49,6 +60,13 @@ def get_cyan_cmap(): CYAN = get_cyan_cmap() +class _LabelCmap: + pass + + +LABEL = _LabelCmap() + + def _slice_gc_coords(gc, axis): cz, cy, cx = gc if axis == 0: @@ -143,9 +161,13 @@ def plot_three_slices( for i, (axis, pos, s, a) in enumerate(slices, 1): plt.subplot(1, 3, i) plt.title(f"{'zyx'[axis]} slice at {pos}") - semantic = cmap is PINK or cmap is CYAN - display = (s > 0).astype(np.float32) if semantic else s - plt.imshow(display, cmap=cmap, alpha=a, vmin=0 if semantic else None, vmax=1 if semantic else None) + if cmap is LABEL: + display = label2rgb(s, bg_label=0, bg_color=(1, 1, 1)) + plt.imshow(display) + else: + semantic = cmap is PINK or cmap is CYAN + display = (s > 0).astype(np.float32) if semantic else s + plt.imshow(display, cmap=cmap, alpha=a, vmin=0 if semantic else None, vmax=1 if semantic else None) if gc is not None: px, py = _slice_gc_coords(gc, axis) @@ -155,7 +177,7 @@ def plot_three_slices( if save_path is None: plt.show() else: - plt.savefig(save_path, dpi=300) + _savefig(save_path) plt.close() @@ -205,7 +227,7 @@ def plot_overlay(img1, img2, save_path=None, x_pos=None, y_pos=None, z_pos=None, if save_path is None: plt.show() else: - plt.savefig(save_path, dpi=300) + _savefig(save_path) plt.close() @@ -300,7 +322,7 @@ def overlay_pcds( if save_path is None: plt.show() else: - plt.savefig(save_path, dpi=300) + _savefig(save_path) plt.close() @@ -335,7 +357,7 @@ def visualize_displacement_field( if save_path is None: plt.show() else: - plt.savefig(save_path, dpi=300) + _savefig(save_path) plt.close() @@ -390,7 +412,7 @@ def plot_matching_qc( plt.axis("equal") plt.gca().invert_yaxis() - plt.savefig(fig_name, dpi=300) + _savefig(fig_name) plt.close() diff --git a/matchmaker_vis.svg b/matchmaker_vis.svg index 9b48a81..c52d2c3 100644 --- a/matchmaker_vis.svg +++ b/matchmaker_vis.svg @@ -25,15 +25,16 @@ inkscape:deskcolor="#d1d1d1" inkscape:document-units="mm" labelstyle="default" - inkscape:zoom="1.3487996" - inkscape:cx="383.67449" - inkscape:cy="231.3168" - inkscape:window-width="1512" - inkscape:window-height="949" - inkscape:window-x="0" - inkscape:window-y="33" - inkscape:window-maximized="0" - inkscape:current-layer="layer1">TransformationID FixedID Moving12345...13475...roughly ovCompute final transformation based on matched instances as landmarksmovingmovingfinal overlayfixedfixedInputRigid alignmentPoint cloud registrationMatchingID FixedID Moving12345...13475...Prealignment - -- put together centers of mass -- rotate samples to have the - same orientationregister as close as possiblewithout deformations- create point clouds from center of instances- deform moving point cloud to match fixed point cloudmatch instances, multiple instances from bigger point cloud can match to one instance in smaller point cloudinstance segmentationsTransformationID FixedID Moving12345...13475...roughly ovCompute final transformation based on matched instances as landmarksfinal overlaymovingfixedmovingfixedInputRigid alignmentPoint cloud registrationMatchingID FixedID Moving12345...13475...Prealignment + +- put together centers of mass +- rotate samples to have the + same orientationregister as close as possiblewithout deformations- create point clouds from center of instances- deform moving point cloud to match fixed point cloudmatch instances, multiple instances from bigger point cloud can match to one instance in smaller point cloudinstance segmentations + transform="matrix(1.3333333,0,0,1.3333333,1083.3615,61.766815)" /> diff --git a/visualization.svg b/visualization.svg deleted file mode 100644 index 18441d6..0000000 --- a/visualization.svg +++ /dev/null @@ -1,70955 +0,0 @@ - - - -movingfixedinputinstance segmentationsrigid alignmentregister as close as possible without deformationPoint Cloud Registration- create point clouds from center of instances- deform moving point cloud to match fixed onematchingmatch instances, multiple instances from bigger point cloudcan match to one instance in the smaller point cloudID FixedID Moving12345...13475...roughly ovTRANSFORMATIONprealignment - -- put together centers of mass -- rotate samples to have the same orientation diff --git a/workflows/apply_transform.smk b/workflows/apply_transform.smk index eb36dda..9765046 100644 --- a/workflows/apply_transform.smk +++ b/workflows/apply_transform.smk @@ -1,8 +1,6 @@ import json import re -configfile: "examples/register_config_test_rigid_apply_transform.yaml" - moving_images = config["moving_images"] moving_paths = [item["input_path"] for item in moving_images] moving_keys = [item["input_key"] for item in moving_images] @@ -57,7 +55,6 @@ rule apply_transform_file: output_path = lambda w: output_paths[TARGET_OUTPUTS.index(w.out_file)], output_key = lambda w: output_keys[TARGET_OUTPUTS.index(w.out_file)], interpolation_order = lambda w: interpolation_orders[TARGET_OUTPUTS.index(w.out_file)], - conda: "matchmaker_env" shell: """ python matchmaker/apply_transform.py \ @@ -93,7 +90,6 @@ rule apply_transform_n5: output_path = lambda w: output_paths[TARGET_OUTPUTS.index(w.out_dir)], output_key = lambda w: output_keys[TARGET_OUTPUTS.index(w.out_dir)], interpolation_order = lambda w: interpolation_orders[TARGET_OUTPUTS.index(w.out_dir)], - conda: "matchmaker_env" shell: """ python matchmaker/apply_transform.py \ diff --git a/workflows/registration.smk b/workflows/registration.smk index 36ad138..867366e 100644 --- a/workflows/registration.smk +++ b/workflows/registration.smk @@ -4,7 +4,6 @@ from pathlib import Path root_dir = f"{Path(workflow.basedir).resolve().parent}/" print(f"working directory: {root_dir}") workdir: root_dir -configfile: "examples/register_config_test_rigid.yaml" print(config["fixed_image"]) print(config["moving_image"]) @@ -90,7 +89,6 @@ rule input_to_n5: params: output_key = raw_n5_key log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"rm -r {{output.fixed_image_n5}};" f"rm -r {{output.moving_image_n5}};" @@ -116,7 +114,6 @@ rule SVD_prealignment: output_key = prealignment_n5_key, axis_orientation = config["prealignment"]["axis_orientation"] log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/prealignment.py --fixed_path {{input.fixed_image_n5}} --fixed_key {{params.input_key}} --fixed_spacing {{fixed_spacing}} --moving_path {{input.moving_image_n5}} --moving_key {{params.input_key}} --moving_spacing {{moving_spacing}} --output_dir {log_dir}/{{params.output_key}} --output_key {{params.output_key}} --output_transform_path {{output.output_transform}} --axis_orientation {{params.axis_orientation}};" @@ -137,7 +134,6 @@ rule rigid_alignment: input_key = prealignment_n5_key, output_key = rigid_alignment_n5_key log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/rigid_alignment_elastix.py --fixed_path {{input.fixed_image_n5}} --fixed_key {{params.input_key}} --moving_path {{input.moving_image_n5}} --moving_key {{params.input_key}} --output_dir {log_dir}/{{params.output_key}} --output_key {{params.output_key}};" @@ -157,7 +153,6 @@ rule create_mobie_project: params: input_key = raw_n5_key log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/mobie_export.py --input_path {{input.fixed_image_n5}} --input_key {{params.input_key}} --input_type {fixed_type} {'--semantic_seg' if semantic_seg else ''} --dataset_name {dataset_name} --output_dir {log_dir};" f"touch {{output.fixed_check}};" @@ -186,7 +181,6 @@ rule add_prealignment_to_mobie: params: input_key = prealignment_n5_key log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/mobie_export.py --input_path {{input.fixed_image_n5}} --input_key {{params.input_key}} --input_type {fixed_type} {'--semantic_seg' if semantic_seg else ''} --dataset_name {dataset_name} --output_dir {log_dir};" f"touch {{output.fixed_check}};" @@ -212,7 +206,6 @@ rule add_rigid_alignment_to_mobie: params: input_key = rigid_alignment_n5_key log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/mobie_export.py --input_path {{input.moving_image_n5}} --input_key {{params.input_key}} --input_type {moving_type} {'--semantic_seg' if semantic_seg else ''} --dataset_name {dataset_name} --output_dir {log_dir};" f"touch {{output.moving_check}};" @@ -234,7 +227,6 @@ rule cpd_nonrigid_registration: fixed_key = prealignment_n5_key, moving_key = rigid_alignment_n5_key log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/cpd_nonrigid_registration.py --moving_path {{input.moving_image_n5}} --moving_key {{params.moving_key}} --fixed_path {{input.fixed_image_n5}} --fixed_key {{params.fixed_key}} -o {{params.log_dir}} --w {w} --beta {beta} --lmd {lmd} --maxiter {maxiter};" @@ -248,7 +240,6 @@ rule ilp_matching: params: log_dir = f"{log_dir}/match_pointclouds" log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/match_pointclouds.py --fixed_pcd {{input.fixed_pcd}} --moving_pcd {{input.moving_pcd}} -o {{params.log_dir}} --min_neighbours {min_neighbours} --max_dist {max_dist};" @@ -273,7 +264,6 @@ rule elastix_deformable_pointset: prealigned_output_key = pointset_alignment_n5_key, log_dir = f"{log_dir}/elastix_deformable_pointset_registration" log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/elastix_deformable_pointset_registration.py --fixed_path {{input.fixed_image_n5}} --fixed_key {{params.input_key}} --moving_path {{input.moving_image_n5}} --moving_key {{params.input_key}} --output_dir {{params.log_dir}} --output_key {{params.output_key}} --match_path {{input.match_path}} --prealigned_output_key {{params.prealigned_output_key}} --prealignment_transform {{input.prealignment_transform}};" @@ -288,7 +278,6 @@ rule add_elastix_deformable_pointset_to_mobie: params: input_key = pointset_alignment_input_space_n5_key log: f"{log_dir}/matchmaker.log" - conda: "matchmaker_env" shell: f"python matchmaker/mobie_export.py --input_path {{input.moving_image_n5}} --input_key {{params.input_key}} --input_type {moving_type} {'--semantic_seg' if semantic_seg else ''} --dataset_name {dataset_name} --output_dir {log_dir};" f"touch {{output.moving_check}};"