Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- name: Doing editable install
shell: bash -l {0}
run: |
test -f setup.py && pip install -e ".[dev]"
pip install -e ".[dev]"

- name: Check we are starting with clean git checkout
shell: bash -l {0}
Expand All @@ -37,27 +37,27 @@ jobs:
- name: Trying to strip out notebooks
shell: bash -l {0}
run: |
nbdev_clean
nbdev-clean
git status -s # display the status to see which nbs need cleaning up
if [[ `git status --porcelain -uno` ]]; then
git status -uno
echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_hooks"
echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev-install_hooks"
echo -e "This error can also happen if you are using an older version of nbdev relative to what is in CI. Please try to upgrade nbdev with the command `pip install -U nbdev`"
false
fi

- name: Run nbdev_export
- name: Run nbdev-export
shell: bash -l {0}
run: |
nbdev_export
nbdev-export
if [[ `git status --porcelain -uno` ]]; then
echo "::error::Notebooks and library are not in sync. Please run nbdev_export."
echo "::error::Notebooks and library are not in sync. Please run nbdev-export."
git status -uno
git diff
exit 1;
fi

- name: Run nbdev_test
- name: Run nbdev-test
shell: bash -l {0}
run: |
nbdev_test
nbdev-test
14 changes: 7 additions & 7 deletions diffdrr/data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/03_data.ipynb.

# %% ../notebooks/api/03_data.ipynb 3
# %% ../notebooks/api/03_data.ipynb #cec0581f
from __future__ import annotations

from pathlib import Path
Expand All @@ -12,10 +12,10 @@
from torchio import LabelMap, ScalarImage, Subject
from torchio.transforms import Resample

# %% auto 0
# %% auto #0
__all__ = ['load_example_ct', 'read']

# %% ../notebooks/api/03_data.ipynb 5
# %% ../notebooks/api/03_data.ipynb #126d05d3
def load_example_ct(
labels=None,
orientation="AP",
Expand All @@ -37,7 +37,7 @@ def load_example_ct(
**kwargs,
)

# %% ../notebooks/api/03_data.ipynb 6
# %% ../notebooks/api/03_data.ipynb #9eda997b-6e88-4bb1-bb9f-313302fe3c1c
from .pose import RigidTransform


Expand Down Expand Up @@ -96,7 +96,7 @@ def read(
dtype=torch.float32,
)
elif orientation == "PA":
# Rotates the C-arm about the x-axis by 90 degrees
# Rotates the C-arm about the x-axis by 90 degrees
# Reverses the direction of the y-axis
reorient = torch.tensor(
[
Expand Down Expand Up @@ -180,7 +180,7 @@ def read(

return subject

# %% ../notebooks/api/03_data.ipynb 7
# %% ../notebooks/api/03_data.ipynb #e20ad014-b2ae-41b9-8b65-e4ca865e19a2
from diffdrr.pose import RigidTransform


Expand Down Expand Up @@ -210,7 +210,7 @@ def canonicalize(subject):
subject.fiducials = affine(subject.fiducials)
return subject

# %% ../notebooks/api/03_data.ipynb 8
# %% ../notebooks/api/03_data.ipynb #ba2941e0-cb0d-44c7-9b00-4dad1ced447d
def transform_hu_to_density(volume, bone_attenuation_multiplier):
# volume can be loaded as int16, need to convert to float32 to use float bone_attenuation_multiplier
volume = volume.to(torch.float32)
Expand Down
18 changes: 9 additions & 9 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/02_detector.ipynb.

# %% ../notebooks/api/02_detector.ipynb 3
# %% ../notebooks/api/02_detector.ipynb #b758e4f3
from __future__ import annotations

import torch
from fastcore.basics import patch
from torch.nn.functional import normalize

# %% auto 0
# %% auto #0
__all__ = ['Detector', 'get_focal_length', 'get_principal_point', 'parse_intrinsic_matrix', 'make_intrinsic_matrix']

# %% ../notebooks/api/02_detector.ipynb 5
# %% ../notebooks/api/02_detector.ipynb #529b92a4-2f71-4d40-a25f-03cc4bc3eb6b
from .pose import RigidTransform


Expand Down Expand Up @@ -93,7 +93,7 @@ def intrinsic(self):
"""The 3x3 intrinsic matrix."""
return make_intrinsic_matrix(self).to(self.source)

# %% ../notebooks/api/02_detector.ipynb 6
# %% ../notebooks/api/02_detector.ipynb #b8ad63f4-0e38-4ea2-87b0-f298639dc9a9
@patch
def _initialize_carm(self: Detector):
"""Initialize the default position for the source and detector plane."""
Expand Down Expand Up @@ -137,7 +137,7 @@ def _initialize_carm(self: Detector):
self.subsamples.append(sample.tolist())
return source, target

# %% ../notebooks/api/02_detector.ipynb 7
# %% ../notebooks/api/02_detector.ipynb #063d06c3-2618-4282-accd-8fe0ab4d3faa
from .pose import RigidTransform


Expand All @@ -153,7 +153,7 @@ def forward(self: Detector, extrinsic: RigidTransform, calibration: RigidTransfo
target = pose(target)
return source, target

# %% ../notebooks/api/02_detector.ipynb 9
# %% ../notebooks/api/02_detector.ipynb #4c0f02f6-c27e-4bdc-a204-31ba5c9f73de
def get_focal_length(
intrinsic, # Intrinsic matrix (3 x 3 tensor)
delx: float, # X-direction spacing (in units length)
Expand All @@ -163,7 +163,7 @@ def get_focal_length(
fy = intrinsic[1, 1]
return abs((fx * delx) + (fy * dely)).item() / 2.0

# %% ../notebooks/api/02_detector.ipynb 10
# %% ../notebooks/api/02_detector.ipynb #a3535bdf-b819-4c42-9624-00d101b29ded
def get_principal_point(
intrinsic, # Intrinsic matrix (3 x 3 tensor)
height: int, # Y-direction length (in units pixels)
Expand All @@ -175,7 +175,7 @@ def get_principal_point(
y0 = dely * (intrinsic[1, 2] - height / 2)
return x0.item(), y0.item()

# %% ../notebooks/api/02_detector.ipynb 11
# %% ../notebooks/api/02_detector.ipynb #750cb0fe-c96a-4c76-a2cd-51a74fdc6b05
def parse_intrinsic_matrix(
intrinsic, # Intrinsic matrix (3 x 3 tensor)
height: int, # Y-direction length (in units pixels)
Expand All @@ -187,7 +187,7 @@ def parse_intrinsic_matrix(
x0, y0 = get_principal_point(intrinsic, height, width, delx, dely)
return focal_length, x0, y0

# %% ../notebooks/api/02_detector.ipynb 12
# %% ../notebooks/api/02_detector.ipynb #4e9f01cb-1dbc-4818-8521-e6785c101a82
def make_intrinsic_matrix(detector: Detector):
fx = detector.sdd / detector.delx
fy = detector.sdd / detector.dely
Expand Down
25 changes: 10 additions & 15 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/00_drr.ipynb.

# %% ../notebooks/api/00_drr.ipynb 3
# %% ../notebooks/api/00_drr.ipynb #ce95e1ac-413e-407b-87ad-0c7db3a945de
from __future__ import annotations

import numpy as np
Expand All @@ -11,10 +11,10 @@
from .detector import Detector
from .renderers import Siddon, Trilinear

# %% auto 0
# %% auto #0
__all__ = ['DRR']

# %% ../notebooks/api/00_drr.ipynb 7
# %% ../notebooks/api/00_drr.ipynb #97297d06-6772-4dc7-8af5-d1ea7b379d8d
from torchio import Subject

from .pose import RigidTransform
Expand Down Expand Up @@ -138,15 +138,15 @@ def device(self):
def dtype(self):
return self.density.dtype

# %% ../notebooks/api/00_drr.ipynb 8
# %% ../notebooks/api/00_drr.ipynb #6513c593-32b8-4676-83d4-e9e7dcf0630a
def reshape_subsampled_drr(img: torch.Tensor, detector: Detector, batch_size: int):
n_points = detector.height * detector.width
drr = torch.zeros(batch_size, n_points).to(img)
drr[:, detector.subsamples[-1]] = img
drr = drr.view(batch_size, 1, detector.height, detector.width)
return drr

# %% ../notebooks/api/00_drr.ipynb 10
# %% ../notebooks/api/00_drr.ipynb #27b19dfc-6a15-4896-9faa-20faee84dc1f
from torch.utils.checkpoint import checkpoint

from .pose import RigidTransform, convert
Expand All @@ -168,12 +168,7 @@ def forward(
if parameterization is None:
pose = args[0]
else:
pose = convert(
*args,
parameterization=parameterization,
convention=convention,
degrees=degrees,
)
pose = convert(*args, parameterization=parameterization, convention=convention, degrees=degrees)

# Create the source / target points and render the image
source, target = self.detector(pose, calibration)
Expand Down Expand Up @@ -231,7 +226,7 @@ def render(

return img

# %% ../notebooks/api/00_drr.ipynb 11
# %% ../notebooks/api/00_drr.ipynb #d17edb1b-d3b4-4f31-b110-f5811ec6c183
@patch
def set_intrinsics_(
self: DRR,
Expand Down Expand Up @@ -259,7 +254,7 @@ def set_intrinsics_(
reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,
).to(self.density)

# %% ../notebooks/api/00_drr.ipynb 12
# %% ../notebooks/api/00_drr.ipynb #5d96d0d0-d76f-406e-8bc8-9023cf0f3e01
@patch
def rescale_detector_(self: DRR, scale: float):
"""Rescale the detector plane (inplace)."""
Expand All @@ -270,7 +265,7 @@ def rescale_detector_(self: DRR, scale: float):
dely=float(self.detector.dely / scale),
)

# %% ../notebooks/api/00_drr.ipynb 13
# %% ../notebooks/api/00_drr.ipynb #93a94ef3-5449-45dc-aa62-9fcf6fad643d
@patch
def perspective_projection(
self: DRR,
Expand All @@ -294,7 +289,7 @@ def perspective_projection(

return x[..., :2]

# %% ../notebooks/api/00_drr.ipynb 14
# %% ../notebooks/api/00_drr.ipynb #802ba874-eef8-4524-be5c-bd250e5639d7
from torch.nn.functional import pad


Expand Down
26 changes: 12 additions & 14 deletions diffdrr/metrics.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/05_metrics.ipynb.

# %% ../notebooks/api/05_metrics.ipynb 3
# %% ../notebooks/api/05_metrics.ipynb #86ff7dae
from __future__ import annotations

import torch

# %% auto 0
# %% auto #0
__all__ = ['NormalizedCrossCorrelation2d', 'MultiscaleNormalizedCrossCorrelation2d', 'GradientNormalizedCrossCorrelation2d',
'MutualInformation', 'LogGeodesicSE3', 'DoubleGeodesicSE3']

# %% ../notebooks/api/05_metrics.ipynb 6
# %% ../notebooks/api/05_metrics.ipynb #a77b3608-8d2a-43b6-b902-9f905877dd40
from einops import rearrange


def to_patches(x, patch_size):
x = x.unfold(2, patch_size, step=1).unfold(3, patch_size, step=1).contiguous()
return rearrange(x, "b c p1 p2 h w -> b (c p1 p2) h w")

# %% ../notebooks/api/05_metrics.ipynb 7
# %% ../notebooks/api/05_metrics.ipynb #28930479-d8e6-4859-b5de-38a5350f510b
class NormalizedCrossCorrelation2d(torch.nn.Module):
"""Compute Normalized Cross Correlation between two batches of images."""

Expand All @@ -43,7 +43,7 @@ def norm(self, x):
std = var.sqrt()
return (x - mu) / std

# %% ../notebooks/api/05_metrics.ipynb 8
# %% ../notebooks/api/05_metrics.ipynb #8d06c00d-830c-48d9-b394-07cc83c1ed2b
class MultiscaleNormalizedCrossCorrelation2d(torch.nn.Module):
"""Compute Normalized Cross Correlation between two batches of images at multiple scales."""

Expand All @@ -62,7 +62,7 @@ def forward(self, x1, x2):
scores.append(weight * ncc(x1, x2))
return torch.stack(scores, dim=0).sum(dim=0)

# %% ../notebooks/api/05_metrics.ipynb 9
# %% ../notebooks/api/05_metrics.ipynb #a6317e99-8a0a-4dce-959f-904c21595d71
from torchvision.transforms.functional import gaussian_blur


Expand Down Expand Up @@ -92,7 +92,7 @@ def forward(self, img):
x = self.filter(x)
return x

# %% ../notebooks/api/05_metrics.ipynb 10
# %% ../notebooks/api/05_metrics.ipynb #fc39dd1d-ab40-4f7b-926d-dff305b9ab69
class GradientNormalizedCrossCorrelation2d(NormalizedCrossCorrelation2d):
"""Compute Normalized Cross Correlation between the image gradients of two batches of images."""

Expand All @@ -103,7 +103,7 @@ def __init__(self, patch_size=None, sigma=1.0, **kwargs):
def forward(self, x1, x2):
return super().forward(self.sobel(x1), self.sobel(x2))

# %% ../notebooks/api/05_metrics.ipynb 11
# %% ../notebooks/api/05_metrics.ipynb #6e422fc1-8100-4120-b226-f6f54602fe3c
from kornia.enhance.histogram import marginal_pdf, joint_pdf


Expand All @@ -118,7 +118,7 @@ def __init__(self, sigma=0.1, num_bins=256, epsilon=1e-10, normalize=True):
self.normalize = normalize

def forward(self, x1, x2):
assert x1.shape == x2.shape
assert(x1.shape == x2.shape)
B, C, H, W = x1.shape

x1 = x1.view(B, H * W, C)
Expand All @@ -138,7 +138,7 @@ def forward(self, x1, x2):

return mutual_information

# %% ../notebooks/api/05_metrics.ipynb 15
# %% ../notebooks/api/05_metrics.ipynb #b691875b-c136-4ea5-8551-fab45530e315
from .pose import RigidTransform, convert


Expand All @@ -157,7 +157,7 @@ def forward(
) -> Float[torch.Tensor, "b"]:
return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)

# %% ../notebooks/api/05_metrics.ipynb 18
# %% ../notebooks/api/05_metrics.ipynb #5ac6d6f4-462b-47ea-b25b-8a2518749e6f
from .pose import so3_log_map


Expand All @@ -175,9 +175,7 @@ def __init__(
self.sdr = sdd / 2
self.eps = eps

self.rot_geo = lambda r1, r2: self.sdr * so3_log_map(
r1.transpose(-1, -2) @ r2
).norm(dim=-1)
self.rot_geo = lambda r1, r2: self.sdr * so3_log_map(r1.transpose(-1, -2) @ r2).norm(dim=-1)
self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1)

def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):
Expand Down
Loading
Loading