Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/xvr/register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,16 @@ class RegisterBase(ABC):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
RegisterBase._registry[cls.__name__.replace("Register", "").lower()] = cls

base_sig = signature(RegisterBase.__init__)
cls.__init__.__signature__ = _merge_signatures(signature(cls.__init__), base_sig)

if "__call__" in cls.__dict__:
cls.__call__.__signature__ = _merge_signatures(
signature(cls.__call__),
signature(RegisterBase.__call__),
)

@classmethod
def create(cls, method: str, **kwargs):
if method not in cls._registry:
Expand Down
26 changes: 25 additions & 1 deletion src/xvr/register/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..model.inference import predict_pose
from ..model.modules.network import load_model
from .base import RegisterBase
from .logging import RegistrationResult


class RegisterFixed(RegisterBase):
Expand All @@ -15,6 +16,29 @@ class RegisterFixed(RegisterBase):
when initializing from a prior scan or a clinical estimate.
"""

def __call__(
self,
filename: str,
rot: tuple[float, float, float],
xyz: tuple[float, float, float],
orientation: str = "AP",
isocenter: bool = True,
**kwargs,
) -> RegistrationResult:
"""Run registration with a manually specified initial pose.

Args:
filename: Path to the X-ray image.
rot: Rotation angles (in degrees) as (rx, ry, rz).
xyz: Translation (in mm) as (x, y, z).
orientation: Starting orientation of the volume, e.g. "AP" or "lateral".
isocenter: If True, centers the pose at the subject isocenter.
**kwargs: See RegisterBase.__call__ for remaining arguments.
"""
return super().__call__(
filename, rot=rot, xyz=xyz, orientation=orientation, isocenter=isocenter, **kwargs
)

def get_initial_pose_estimate(
self,
_img,
Expand Down Expand Up @@ -68,7 +92,7 @@ def get_initial_pose_estimate(
"""Predict the initial pose from the X-ray using a neural network.

Args:
gt: Preprocessed ground truth X-ray image.
img: Preprocessed ground truth X-ray image.
intrinsics: Camera intrinsics for the X-ray.

Returns:
Expand Down