diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d29073c..f0f179c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v4 diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..f9672bc --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,17 @@ +version: 2 + +build: + os: ubuntu-24.04 + tools: + python: "3.12" + jobs: + # Install only the docs toolchain. The SPINEPS package itself is intentionally NOT installed: + # its heavy runtime dependencies (torch, nnunetv2, antspyx, monai, ...) can exceed Read the Docs + # build time/memory limits. mkdocstrings' griffe backend reads the package source statically + # (see the `paths` option in mkdocs.yml), so those dependencies are not needed to render the docs. + post_install: + # ruff is only used by mkdocstrings to format rendered signatures (lightweight, no build deps). + - pip install "mkdocs>=1.6" "mkdocs-material>=9.5" "mkdocstrings[python]>=0.25" ruff + +mkdocs: + configuration: mkdocs.yml diff --git a/README.md b/README.md index 28eb0b6..02d092f 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ [![codecov](https://codecov.io/gh/Hendrik-code/spineps/graph/badge.svg?token=A7FWUKO9Y4)](https://codecov.io/gh/Hendrik-code/spineps) [![tests](https://github.com/Hendrik-code/spineps/actions/workflows/tests.yml/badge.svg)](https://github.com/Hendrik-code/spineps/actions/workflows/tests.yml) [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) +[![Documentation Status](https://readthedocs.org/projects/spineps/badge/?version=latest)](https://spineps.readthedocs.io) # SPINEPS – Automatic Whole Spine Segmentation of T2w MR images using a Two-Phase Approach to Multi-class Semantic and Instance Segmentation. # and @@ -12,8 +13,26 @@ This is a segmentation pipeline to automatically, and robustly, segment the whole spine in T2w sagittal images. +## NOW SUPPORTS BOTH CT AND T2W! +There is a new release that finally supports both CT and T2W with completely independent, modality specific models. We are already working on completely modality/sequence robust version that works on everything. Stay tuned for that. + + ![pipeline_process](spineps/example/figures/pipeline_processflow.png?raw=true) +## Documentation + +πŸ“– **Online documentation: [spineps.readthedocs.io](https://spineps.readthedocs.io)** + +The documentation source lives in the [`docs/`](docs/) folder and is built with [MkDocs](https://www.mkdocs.org/) +(Material theme + [mkdocstrings](https://mkdocstrings.github.io/)). To build and preview it locally: + +```bash +pip install mkdocs mkdocs-material "mkdocstrings[python]" +mkdocs serve # then open http://127.0.0.1:8000 +``` + +Start with [`docs/index.md`](docs/index.md) and the [Getting Started](docs/getting-started.md) guide. + ## Citation If you are using SPINEPS, please cite the following: diff --git a/docs/api/architectures.md b/docs/api/architectures.md new file mode 100644 index 0000000..f6cf9f4 --- /dev/null +++ b/docs/api/architectures.md @@ -0,0 +1,19 @@ +# Architectures + +Network architectures and the vertebra label definitions used by the models. + +## spineps.architectures.read_labels + +::: spineps.architectures.read_labels + +## spineps.architectures.pl_densenet + +::: spineps.architectures.pl_densenet + +## spineps.architectures.pl_unet + +::: spineps.architectures.pl_unet + +## spineps.architectures.unet3D + +::: spineps.architectures.unet3D diff --git a/docs/api/enums.md b/docs/api/enums.md new file mode 100644 index 0000000..a009dbc --- /dev/null +++ b/docs/api/enums.md @@ -0,0 +1,11 @@ +# Enums & Config + +Enumerations and the inference-configuration model used throughout SPINEPS. + +## spineps.seg_enums + +::: spineps.seg_enums + +## spineps.utils.seg_modelconfig + +::: spineps.utils.seg_modelconfig diff --git a/docs/api/models.md b/docs/api/models.md new file mode 100644 index 0000000..9096088 --- /dev/null +++ b/docs/api/models.md @@ -0,0 +1,15 @@ +# Models + +Model discovery/loading and the segmentation/labeling model classes. + +## spineps.get_models + +::: spineps.get_models + +## spineps.seg_model + +::: spineps.seg_model + +## spineps.lab_model + +::: spineps.lab_model diff --git a/docs/api/phases.md b/docs/api/phases.md new file mode 100644 index 0000000..4fe7a62 --- /dev/null +++ b/docs/api/phases.md @@ -0,0 +1,23 @@ +# Processing Phases + +The per-phase processing functions that make up the pipeline. + +## spineps.phase_pre + +::: spineps.phase_pre + +## spineps.phase_semantic + +::: spineps.phase_semantic + +## spineps.phase_instance + +::: spineps.phase_instance + +## spineps.phase_labeling + +::: spineps.phase_labeling + +## spineps.phase_post + +::: spineps.phase_post diff --git a/docs/api/pipeline.md b/docs/api/pipeline.md new file mode 100644 index 0000000..e6832fa --- /dev/null +++ b/docs/api/pipeline.md @@ -0,0 +1,15 @@ +# Pipeline & Run + +Top-level orchestration: process a single image or a whole dataset, and shared pipeline helpers. + +## spineps.seg_run + +::: spineps.seg_run + +## spineps.seg_pipeline + +::: spineps.seg_pipeline + +## spineps.seg_utils + +::: spineps.seg_utils diff --git a/docs/api/utils.md b/docs/api/utils.md new file mode 100644 index 0000000..866b5de --- /dev/null +++ b/docs/api/utils.md @@ -0,0 +1,31 @@ +# Utilities + +Image processing, the vertebra-labeling path solver, disc labeling and other helpers. + +## spineps.utils.proc_functions + +::: spineps.utils.proc_functions + +## spineps.utils.find_min_cost_path + +::: spineps.utils.find_min_cost_path + +## spineps.utils.generate_disc_labels + +::: spineps.utils.generate_disc_labels + +## spineps.utils.filepaths + +::: spineps.utils.filepaths + +## spineps.utils.auto_download + +::: spineps.utils.auto_download + +## spineps.utils.citation_reminder + +::: spineps.utils.citation_reminder + +## spineps.utils.compat + +::: spineps.utils.compat diff --git a/docs/getting-started.md b/docs/getting-started.md new file mode 100644 index 0000000..c65c133 --- /dev/null +++ b/docs/getting-started.md @@ -0,0 +1,128 @@ +# Getting Started + +This guide walks you through installing SPINEPS and running your first segmentation. + +## Installation (Ubuntu) + +This installation assumes you are comfortable with conda and virtual environments. **The order of the +following steps matters.** + +### 1. Create a virtual environment + +```bash +conda create --name spineps python=3.11 +conda activate spineps +conda install pip +``` + +### 2. Install PyTorch + +Go to [pytorch.org/get-started/locally](https://pytorch.org/get-started/locally/) and install a PyTorch +build that matches your machine. Then confirm the install works: + +```bash +nvidia-smi # should show your GPU +python -c "import torch; print(torch.cuda.is_available())" # should print True +``` + +### 3. Install SPINEPS + +From PyPI: + +```bash +pip install spineps +``` + +Or, for local development, clone the repository, `cd` into it and run: + +```bash +pip install -e . +``` + +## Model weights + +SPINEPS automatically downloads the latest model weights on first use, so no manual setup is required. + +If you prefer to manage weights manually, download them from the GitHub release page and extract them into +a models folder with the following structure: + +```text + +β”œβ”€β”€ +β”‚ β”œβ”€β”€ inference_config.json +β”‚ └── +β”œβ”€β”€ +β”‚ β”œβ”€β”€ inference_config.json +β”‚ └── ... +``` + +Point SPINEPS at that folder via an environment variable (otherwise it defaults to `spineps/spineps/models/`): + +```bash +export SPINEPS_SEGMENTOR_MODELS= +echo ${SPINEPS_SEGMENTOR_MODELS} # verify it is set +``` + +## Usage + +### Command line + +```bash +spineps -h # top-level help +spineps sample -h # options for a single file +spineps dataset -h # options for a whole dataset +``` + +Segment a single scan: + +```bash +# T2w sagittal +spineps sample -ignore_bids_filter -ignore_inference_compatibility \ + -i /path/sub-testsample_T2w.nii.gz -model_semantic t2w -model_instance instance + +# T1w sagittal +spineps sample -ignore_bids_filter -ignore_inference_compatibility \ + -i /path/sub-testsample_T1w.nii.gz -model_semantic t1w -model_instance instance +``` + +Process a whole [BIDS](https://bids-specification.readthedocs.io/en/stable/) dataset: + +```bash +spineps dataset -i /path/to/dataset -model_semantic t2w -model_instance instance +``` + +### Adding vertebra labels (VERIDAH) + +To assign anatomical vertebra labels after segmentation, additionally pass a labeling model: + +```bash +spineps sample -i /path/sub-test_T2w.nii.gz \ + -model_semantic t2w -model_instance instance -model_labeling labeling +``` + +### From Python + +```python +from TPTBox import BIDS_FILE +from spineps import get_semantic_model, get_instance_model, process_img_nii + +semantic = get_semantic_model("t2w") +instance = get_instance_model("instance") + +# img_ref is a TPTBox BIDS_FILE pointing at the input scan +img_ref = BIDS_FILE("sub-test_T2w.nii.gz", dataset="/path/to/dataset") + +process_img_nii( + img_ref, + model_semantic=semantic, + model_instance=instance, + derivative_name="derivatives_seg", +) +``` + +See the [Pipeline](modules/pipeline.md) page for more detail on the Python entry points. + +## Troubleshooting + +- **Import issues**: re-run the install; sometimes not every dependency installs the first time. +- **PyTorch / CUDA issues**: make sure the PyTorch build matches your CUDA version (step 2 above). diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..858c541 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,63 @@ +# SPINEPS + +**SPINEPS** is a framework for out-of-the-box **whole-spine segmentation of MR images**. It segments the +spine in sagittal MR images (T2w, T1w and others) using a two-phase approach to multi-class **semantic** +and **instance** segmentation, and can additionally assign anatomical **vertebra labels** via the +**VERIDAH** labeling model. + +[![Paper](https://img.shields.io/badge/Paper-10.1007-blue)](https://link.springer.com/article/10.1007/s00330-024-11155-y) +[![PyPI version](https://badge.fury.io/py/spineps.svg)](https://pypi.python.org/pypi/spineps/) +[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) + +![Pipeline process flow](https://github.com/Hendrik-code/spineps/raw/main/spineps/example/figures/pipeline_processflow.png) + +## What it does + +Given a sagittal MR scan, the pipeline: + +1. **Semantically** segments 14 spinal structures (nine vertebra subregions, spinal cord, spinal canal, + intervertebral discs, endplate and sacrum). +2. Derives a per-vertebra **instance** mask from the vertebra subregions. +3. Optionally assigns each instance an anatomical **vertebra label** (C1–L6, sacrum) with VERIDAH. +4. Computes **centroids** (points of interest) for each vertebra, endplate and disc. +5. Renders a **snapshot** visualizing the result. + +## Quick links + +- [Getting Started](getting-started.md) β€” installation and first run. +- [Pipeline](modules/pipeline.md) β€” how the two-phase pipeline is structured. +- [Processing Phases](modules/phases.md) β€” pre-processing, semantic, instance, labeling and post-processing. +- [Models & Labeling](modules/models.md) β€” model loading and the VERIDAH labeling model. +- [API Reference](api/pipeline.md) β€” full auto-generated API documentation. + +## Quick start + +```bash +# Install +pip install spineps + +# Segment a single T2w sagittal scan +spineps sample -i /path/sub-test_T2w.nii.gz -model_semantic t2w -model_instance instance +``` + +See [Getting Started](getting-started.md) for the full installation guide (including PyTorch setup and +model weights), and the [Pipeline](modules/pipeline.md) page for calling SPINEPS from Python. + +## Citation + +If you use SPINEPS, please cite: + +```bibtex +@article{moller_spinepsautomatic_2024, + title = {{SPINEPS}β€”automatic whole spine segmentation of T2-weighted {MR} images using a two-phase approach to multi-class semantic and instance segmentation}, + doi = {10.1007/s00330-024-11155-y}, + journal = {European Radiology}, + author = {MΓΆller, Hendrik and Graf, Robert and Schmitt, Joachim and Keinert, Benjamin and SchΓΆn, Hanna and Atad, Matan and Sekuboyina, Anjany and Streckenbach, Felix and Kofler, Florian and Kroencke, Thomas and Bette, Stefanie and Willich, Stefan N. and Keil, Thomas and Niendorf, Thoralf and Pischon, Tobias and Endemann, Beate and Menze, Bjoern and Rueckert, Daniel and Kirschke, Jan S.}, + date = {2024-10-29}, +} +``` + +## License + +SPINEPS is released under the [Apache License 2.0](https://opensource.org/licenses/Apache-2.0). +Copyright 2023 Hendrik MΓΆller. diff --git a/docs/modules/models.md b/docs/modules/models.md new file mode 100644 index 0000000..7916769 --- /dev/null +++ b/docs/modules/models.md @@ -0,0 +1,46 @@ +# Models & Labeling + +## Loading models + +Models are referenced by name and resolved against the configured models directory (or downloaded +automatically). The convenience loaders live in [`spineps.get_models`](../api/models.md): + +```python +from spineps import get_semantic_model, get_instance_model, get_labeling_model + +semantic = get_semantic_model("t2w") +instance = get_instance_model("instance") +labeling = get_labeling_model("labeling") # optional, VERIDAH +``` + +Each loader looks the name up in a model-id-to-folder map, resolves remote (HTTP) entries by downloading +the weights if needed, reads the model's `inference_config.json` and instantiates the right model class. + +## Model types + +The concrete model classes are defined in [`spineps.seg_model`](../api/models.md) and +[`spineps.lab_model`](../api/models.md): + +- **`Segmentation_Model`** β€” abstract base wrapping a network plus its inference configuration. It handles + input preparation (reorientation, rescaling to the recommended zoom, padding), running the model and + mapping outputs back into the input space. +- **`Segmentation_Model_NNunet`** β€” an nnU-Net backend. +- **`Segmentation_Model_Unet3D`** β€” a 3D U-Net backend. +- **`VertLabelingClassifier`** β€” the VERIDAH vertebra-labeling classifier. + +The model type is selected from the `inference_config.json` via `ModelType` (see +[Enums & Config](../api/enums.md)). + +## Inference configuration + +Every model ships an `inference_config.json` describing its expected inputs, modality and acquisition, +recommended resolution range, label mapping and processing thresholds. It is parsed into a +`Segmentation_Inference_Config` (see [`spineps.utils.seg_modelconfig`](../api/enums.md)). + +## VERIDAH labeling + +VERIDAH ("solving Enumeration Anomaly Aware Vertebra Labeling across Imaging Sequences") assigns +anatomical labels to the detected vertebra instances. It combines a per-vertebra classifier with a +min-cost path solver that enforces a plausible ordering and is aware of enumeration anomalies such as a +13th thoracic vertebra (T13) or a sixth lumbar vertebra (L6). Enable it by passing a labeling model to +the pipeline (`-model_labeling` on the CLI, or `model_labeling=` in Python). diff --git a/docs/modules/phases.md b/docs/modules/phases.md new file mode 100644 index 0000000..2bc5fa7 --- /dev/null +++ b/docs/modules/phases.md @@ -0,0 +1,69 @@ +# Processing Phases + +Each scan flows through the following phases. The API for every phase is documented under +[Processing Phases (API)](../api/phases.md). + +## Pre-processing β€” [`spineps.phase_pre`](../api/phases.md) + +Prepares the input image: intensity normalization into a fixed range, optional N4 bias-field +correction, cropping to the non-zero region (optionally auto-cropping to the spine using VIBESeg), and +edge padding. + +## Semantic phase β€” [`spineps.phase_semantic`](../api/phases.md) + +Runs the semantic model to produce the subregion mask, then post-processes it: removing small +connected-component artifacts, optionally restricting to structures near the spinal canal, keeping only +components inside the largest spine bounding box, and filling 3D holes. + +The semantic mask uses these labels: + +| Label | Structure | +| :---: | ------------------------ | +| 41 | Arcus_Vertebrae | +| 42 | Spinosus_Process | +| 43 | Costal_Process_Left | +| 44 | Costal_Process_Right | +| 45 | Superior_Articular_Left | +| 46 | Superior_Articular_Right | +| 47 | Inferior_Articular_Left | +| 48 | Inferior_Articular_Right | +| 49 | Vertebra_Corpus_border | +| 60 | Spinal_Cord | +| 61 | Spinal_Canal | +| 62 | Endplate | +| 100 | Vertebra_Disc | +| 26 | Sacrum | + +## Instance phase β€” [`spineps.phase_instance`](../api/phases.md) + +Derives a per-vertebra instance mask from the vertebra subregions. For each corpus center of mass it +extracts a cutout, runs the instance model, and merges the overlapping per-vertebra predictions into a +single label map. It also detects and splits merged vertebral bodies. + +In the instance mask, each label `X` in `[1, 25]` is a unique vertebra; `100 + X` is that vertebra's +intervertebral disc and `200 + X` its endplate. + +## Labeling phase (VERIDAH) β€” [`spineps.phase_labeling`](../api/phases.md) + +Optional. Assigns each detected vertebra an anatomical label using a classifier and a min-cost path +solver ([`find_most_probably_sequence`](../api/utils.md)) that enforces a plausible cranio-caudal +ordering and handles transitional-vertebra anomalies (e.g. T13, L6). + +| Label | Structure | +| :-----: | ---------- | +| 1 | C1 | +| 2 – 7 | C2 – C7 | +| 8 – 19 | T1 – T12 | +| 28 | T13 | +| 20 | L1 | +| 21 – 25 | L2 – L6 | +| 26 | Sacrum | + +As in the instance mask, `100 + X` is a vertebra's IVD and `200 + X` its endplate (e.g. label 119 is the +IVD below T12). + +## Post-processing β€” [`spineps.phase_post`](../api/phases.md) + +Combines the semantic and instance masks: cleans them against each other, assigns intervertebral discs +and endplates to their parent vertebra, resolves merged vertebrae, fixes mislabeled posterior elements, +and labels the instances top-to-bottom (or via VERIDAH when a labeling model is supplied). diff --git a/docs/modules/pipeline.md b/docs/modules/pipeline.md new file mode 100644 index 0000000..96477f2 --- /dev/null +++ b/docs/modules/pipeline.md @@ -0,0 +1,69 @@ +# Pipeline + +SPINEPS processes a scan in a sequence of phases orchestrated by the functions in +[`spineps.seg_run`](../api/pipeline.md). There are two top-level entry points: + +- **`process_img_nii`** β€” process a single image. +- **`process_dataset`** β€” discover and process every suitable scan in a [BIDS](https://bids-specification.readthedocs.io/en/stable/) dataset. + +## Two-phase approach + +The core idea is to split spine segmentation into two complementary tasks: + +1. **Semantic phase** β€” a multi-class network labels every voxel with its anatomical *subregion* + (vertebra subregions, spinal cord, spinal canal, discs, endplate, sacrum). +2. **Instance phase** β€” using the vertebra subregions from the semantic mask, individual vertebrae are + separated into a per-vertebra *instance* mask. + +A combined **post-processing** step then cleans both masks against each other, assigns intervertebral +discs and endplates to their parent vertebra, and (optionally) runs the **VERIDAH labeling** model to give +each instance an anatomical vertebra label. + +```text +input scan + β”‚ pre-processing (normalize, optional N4 bias correction, crop, pad) + β–Ό +semantic phase ──► subregion (semantic) mask + β”‚ + β–Ό +instance phase ──► per-vertebra instance mask + β”‚ + β–Ό +post-processing (clean, assign IVD/endplate, optional VERIDAH labeling) + β”‚ + β–Ό +seg-spine mask Β· seg-vert mask Β· centroids (.json) Β· snapshot (.png) +``` + +## Outputs + +For each processed scan SPINEPS writes a derivatives folder next to the input containing: + +- a `seg-spine` mask (semantic / subregion segmentation), +- a `seg-vert` mask (vertebra instance segmentation), +- a centroid file (`.json`) with points of interest for each vertebra, endplate and disc, +- a snapshot `.png` visualizing the result, +- optionally: an uncertainty image, the model-resolution masks, softmax logits and debug data. + +## Calling from Python + +```python +from TPTBox import BIDS_FILE +from spineps import get_semantic_model, get_instance_model, process_img_nii + +semantic = get_semantic_model("t2w") +instance = get_instance_model("instance") + +process_img_nii( + BIDS_FILE("sub-test_T2w.nii.gz", dataset="/path/to/dataset"), + model_semantic=semantic, + model_instance=instance, +) +``` + +`process_img_nii` exposes many `proc_*` flags to toggle individual processing steps (pre-processing, +semantic/instance cleaning, hole filling, labeling, …). See the +[Pipeline & Run API reference](../api/pipeline.md) for the full signature. + +For batch processing, `process_dataset` accepts the same processing flags and applies them to every +matching scan it finds. diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..624070c --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,77 @@ +site_name: SPINEPS +site_description: SPINEPS β€” automatic whole-spine segmentation of T2-weighted MR images via a two-phase semantic and instance segmentation approach. +site_url: https://spineps.readthedocs.io +repo_url: https://github.com/Hendrik-code/spineps +repo_name: Hendrik-code/spineps +edit_uri: blob/main/ + +theme: + name: material + features: + - navigation.tabs + - navigation.sections + - navigation.top + - toc.integrate + - content.code.copy + - content.code.annotate + palette: + - scheme: default + primary: indigo + accent: indigo + toggle: + icon: material/brightness-7 + name: Switch to dark mode + - scheme: slate + primary: indigo + accent: indigo + toggle: + icon: material/brightness-4 + name: Switch to light mode + +plugins: + - search + - mkdocstrings: + handlers: + python: + # Find the package in the source tree so griffe can read it statically without the + # package (and its heavy runtime dependencies) being installed. + paths: ["."] + options: + docstring_style: google + show_source: true + show_root_heading: true + show_symbol_type_heading: true + show_symbol_type_toc: true + members_order: source + separate_signature: true + show_signature_annotations: true + unwrap_annotated: true + +markdown_extensions: + - admonition + - pymdownx.details + - pymdownx.superfences + - pymdownx.highlight: + anchor_linenums: true + - pymdownx.inlinehilite + - pymdownx.snippets: + base_path: ["."] + - attr_list + - md_in_html + - toc: + permalink: true + +nav: + - Home: index.md + - Getting Started: getting-started.md + - Modules: + - Pipeline: modules/pipeline.md + - Processing Phases: modules/phases.md + - Models & Labeling: modules/models.md + - API Reference: + - Pipeline & Run: api/pipeline.md + - Models: api/models.md + - Processing Phases: api/phases.md + - Enums & Config: api/enums.md + - Utilities: api/utils.md + - Architectures: api/architectures.md diff --git a/pyproject.toml b/pyproject.toml index a53cbac..112465f 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,6 +9,7 @@ description = "Framework for out-of-the box whole spine MRI segmentation." authors = ["Hendrik MΓΆller "] repository = "https://github.com/Hendrik-code/spineps" homepage = "https://github.com/Hendrik-code/spineps" +documentation = "https://spineps.readthedocs.io" license = "Apache License Version 2.0, January 2004" readme = "README.md" exclude = ["models", "examples"] @@ -49,6 +50,15 @@ flake8 = ">=4.0.1" tqdm = ">=4.62.3" +[tool.poetry.group.docs] +optional = true + +[tool.poetry.group.docs.dependencies] +mkdocs = ">=1.6" +mkdocs-material = ">=9.5" +mkdocstrings = { extras = ["python"], version = ">=0.25" } + + [tool.poetry-dynamic-versioning] enable = true @@ -85,6 +95,7 @@ exclude = [ "spineps/utils/plans_handler.py", "spineps/utils/predictor.py", "spineps/utils/sliding_window_prediction.py", + "spineps/utils/image.py", # vendored from spinalcordtoolbox ".toml", ] line-length = 140 @@ -136,6 +147,7 @@ select = [ "PERF", "FURB", "RUF", + "RUF059", # preview rule: unused unpacked variable (enabled via explicit-preview-rules below) ] @@ -180,6 +192,11 @@ extend-safe-fixes = ["RUF015", "C419", "C408", "B006"] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +# Enable preview mode for the linter, but only enforce preview rules that are selected by their +# explicit code (e.g. RUF059 above) instead of pulling in the entire preview rule set. +preview = true +explicit-preview-rules = true + [tool.ruff.lint.mccabe] # Flag errors (`C901`) whenever the complexity level exceeds 5. max-complexity = 20 diff --git a/spineps/__init__.py b/spineps/__init__.py index d97acf6..cb36017 100755 --- a/spineps/__init__.py +++ b/spineps/__init__.py @@ -1,3 +1,5 @@ +"""SPINEPS: spine MRI segmentation package exposing the model loaders, pipeline phases and run entry points.""" + from spineps.entrypoint import entry_point from spineps.get_models import get_instance_model, get_labeling_model, get_semantic_model from spineps.phase_instance import predict_instance_mask diff --git a/spineps/architectures/__init__.py b/spineps/architectures/__init__.py index e69de29..3f8a5ea 100644 --- a/spineps/architectures/__init__.py +++ b/spineps/architectures/__init__.py @@ -0,0 +1 @@ +"""Neural network architectures and label definitions used by SPINEPS for spine segmentation and vertebra labeling.""" diff --git a/spineps/architectures/pl_densenet.py b/spineps/architectures/pl_densenet.py index c219e36..0f3993a 100644 --- a/spineps/architectures/pl_densenet.py +++ b/spineps/architectures/pl_densenet.py @@ -1,3 +1,5 @@ +"""DenseNet/ResNet-based classifier (PLClassifier) and model configuration for vertebra labeling.""" + from __future__ import annotations import os @@ -28,13 +30,24 @@ def resnet2( layers: list[int] | None = None, **kwargs, -): +) -> ResNet: + """Build a very small 2-stage MONAI ResNet variant ("resnet2"). + + Args: + layers (list[int] | None): Number of blocks per stage; defaults to ``[1, 1]``. + **kwargs: Additional keyword arguments forwarded to the MONAI ``_resnet`` factory. + + Returns: + ResNet: The constructed ResNet model. + """ if layers is None: layers = [1, 1] return _resnet("resnet2", ResNetBlock, layers, get_inplanes(), False, False, **kwargs) class MODEL(Enum): + """Selectable backbone architectures (DenseNet and ResNet variants) for the vertebra classifier.""" + DENSENET169 = DenseNet169 DENSENET121 = DenseNet121 RESNET10 = 10 # resnet10 @@ -49,7 +62,20 @@ def __call__( self, opt: ARGS_MODEL, remove_classification_head: bool = True, - ): + ) -> tuple[nn.Module, int]: + """Instantiate the selected backbone network. + + Args: + opt (ARGS_MODEL): Model configuration providing channels, class count and pretraining flag. + remove_classification_head (bool): If True, strip the backbone's final classification layer so it acts as a + feature extractor. + + Returns: + tuple: ``(model, linear_in_features)`` where ``linear_in_features`` is the input feature size of the removed head. + + Raises: + ValueError: If the enum member is neither a DenseNet nor a ResNet variant. + """ if "DENSENET" in self.name: return get_densenet_architecture( self.value, @@ -78,6 +104,8 @@ def __call__( @dataclass class ARGS_MODEL(Class_to_ArgParse): + """Configuration (and argparse schema) for the vertebra labeling classifier, covering backbone, heads and training options.""" + backbone: MODEL = MODEL.DENSENET169.name classification_conv: bool = False classification_linear: bool = True @@ -99,7 +127,22 @@ class ARGS_MODEL(Class_to_ArgParse): class PLClassifier(pl.LightningModule): + """LightningModule that classifies vertebrae using a shared backbone with one classification head per target group. + + The configured backbone (DenseNet/ResNet) acts as a feature extractor, and a separate head is built for each entry in + ``group_2_n_channel`` to produce that group's class logits. + """ + def __init__(self, opt: ARGS_MODEL, group_2_n_channel: dict[str, int]): + """Build the backbone, classification heads and loss/activation modules. + + Args: + opt (ARGS_MODEL): Model configuration; ``opt.num_classes`` must be an int. + group_2_n_channel (dict[str, int]): Mapping from each target group name to its number of output channels. + + Raises: + AssertionError: If ``opt.num_classes`` is not an int. + """ super().__init__() self.opt = opt assert isinstance(opt.num_classes, int), opt.num_classes @@ -125,12 +168,39 @@ def __init__(self, opt: ARGS_MODEL, group_2_n_channel: dict[str, int]): self.mse = nn.MSELoss(reduction="none") self.l2_reg_w = opt.l2_regularization_w - def forward(self, x): + def forward(self, x) -> dict[str, torch.Tensor]: + """Extract features with the backbone and apply every classification head. + + Args: + x (torch.Tensor): Input image batch fed to the backbone. + + Returns: + dict[str, torch.Tensor]: Mapping from each group name to that head's output logits. + """ features = self.net(x) return {k: v(features) for k, v in self.classification_heads.items()} - def build_classification_heads(self, linear_in: int, convolution_first: bool, fully_connected: bool): + def build_classification_heads(self, linear_in: int, convolution_first: bool, fully_connected: bool) -> nn.ModuleDict: + """Build one classification head per target group as a :class:`~torch.nn.ModuleDict`. + + Args: + linear_in (int): Number of input features coming from the backbone. + convolution_first (bool): If True, prepend a 3x3x3 convolution that halves the channels before the linear layers. + fully_connected (bool): If True, insert a hidden linear+ReLU layer (halving channels) before the output layer. + + Returns: + nn.ModuleDict: Mapping from each group name to its head, each ending in a linear layer with that group's class count. + """ + def construct_one_head(output_classes: int): + """Build a single classification head producing ``output_classes`` logits. + + Args: + output_classes (int): Number of output classes for this head. + + Returns: + nn.Sequential: The assembled head modules. + """ modules = [] n_channel = linear_in n_channel_next = linear_in @@ -150,16 +220,33 @@ def construct_one_head(output_classes: int): return nn.ModuleDict({k: construct_one_head(v) for k, v in self.group_2_n_channel.items()}) def __str__(self) -> str: + """Return the model name. + + Returns: + str: The fixed name ``"VertebraLabelingModel"``. + """ return "VertebraLabelingModel" def get_densenet_architecture( - model, + model: object, in_channel: int = 1, out_channel: int = 1, pretrained: bool = True, remove_classification_head: bool = True, -): +) -> tuple[nn.Module, int]: + """Instantiate a 3D MONAI DenseNet and optionally remove its final classification layer. + + Args: + model: A MONAI DenseNet constructor (e.g. ``DenseNet121`` or ``DenseNet169``). + in_channel (int): Number of input channels. + out_channel (int): Number of output channels for the original classification layer. + pretrained (bool): Whether to load pretrained weights. + remove_classification_head (bool): If True, drop the final classification layer to use the model as a feature extractor. + + Returns: + tuple: ``(model, linear_infeatures)`` where ``linear_infeatures`` is the input feature size of the removed head. + """ model = model( spatial_dims=3, in_channels=in_channel, @@ -173,9 +260,19 @@ def get_densenet_architecture( def get_resnet_architecture( - model, + model: object, remove_classification_head: bool = True, -): +) -> tuple[nn.Module, int]: + """Instantiate a 3D MONAI ResNet and optionally remove its fully connected head. + + Args: + model: A MONAI ResNet constructor (e.g. ``resnet18`` or ``resnet50``). + remove_classification_head (bool): If True, set the final fully connected layer to None to use the model as a + feature extractor. + + Returns: + tuple: ``(model, linear_infeatures)`` where ``linear_infeatures`` is the input feature size of the removed head. + """ model = model( spatial_dims=3, n_input_channels=1, diff --git a/spineps/architectures/pl_unet.py b/spineps/architectures/pl_unet.py index 598da4c..7868c0e 100755 --- a/spineps/architectures/pl_unet.py +++ b/spineps/architectures/pl_unet.py @@ -1,3 +1,5 @@ +"""PyTorch Lightning wrapper around the 3D U-Net used for spine segmentation training and inference.""" + from __future__ import annotations from typing import Any @@ -12,7 +14,21 @@ class PLNet(pl.LightningModule): - def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + """LightningModule wrapping a :class:`~spineps.architectures.unet3D.Unet3D` for multi-class segmentation. + + Configures a 4-class 3D U-Net with 10 input channels and provides shared loss/metric helpers (Dice scores) and softmax-based + class prediction. + """ + + def __init__(self, opt: Any = None, do2D: bool = False, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + """Build the wrapped U-Net and store training hyperparameters. + + Args: + opt: Options object; ``opt.n_epoch`` sets the number of epochs when provided. + do2D (bool): Whether the model operates in 2D mode (affects only the string representation). + *args (Any): Ignored extra positional arguments. + **kwargs (Any): Ignored extra keyword arguments. + """ super().__init__() self.save_hyperparameters() @@ -45,10 +61,28 @@ def __init__(self, opt=None, do2D: bool = False, *args: Any, **kwargs: Any) -> N self.softmax = nn.Softmax(dim=1) - def forward(self, x): + def forward(self, x) -> torch.Tensor: + """Run the wrapped U-Net on an input batch. + + Args: + x (torch.Tensor): Input tensor of shape ``(B, 10, D, H, W)``. + + Returns: + torch.Tensor: Raw class logits of shape ``(B, 4, D, H, W)``. + """ return self.network(x) def _shared_step(self, target, gt, detach2cpu: bool = False): + """Run the forward pass and compute loss plus predicted class map for a batch. + + Args: + target (torch.Tensor): Input batch fed to the network. + gt (torch.Tensor): Ground-truth class labels. + detach2cpu (bool): If True, detach ``gt``, ``logits`` and ``pred_cls`` and move them to CPU. + + Returns: + tuple: ``(loss, logits, gt, pred_cls)`` where ``pred_cls`` is the argmax over the softmax of the logits. + """ logits = self.forward(target) loss = self.loss(logits, gt) @@ -65,29 +99,67 @@ def _shared_step(self, target, gt, detach2cpu: bool = False): return loss, logits, gt, pred_cls def _shared_metric_step(self, loss, _, gt, pred_cls): + """Compute segmentation metrics (overall, foreground and per-class Dice) for a batch. + + Args: + loss (torch.Tensor): The batch loss to record. + _: Unused logits placeholder. + gt (torch.Tensor): Ground-truth class labels. + pred_cls (torch.Tensor): Predicted class labels. + + Returns: + dict: Metrics with keys ``loss``, ``dice``, ``diceFG`` (Dice ignoring the background class) and ``dice_p_cls``. + """ dice = mF.dice(pred_cls, gt, num_classes=self.n_classes) diceFG = mF.dice(pred_cls, gt, num_classes=self.n_classes, ignore_index=0) dice_p_cls = mF.dice(pred_cls, gt, average=None, num_classes=self.n_classes) return {"loss": loss.detach().cpu(), "dice": dice, "diceFG": diceFG, "dice_p_cls": dice_p_cls} def _shared_metric_append(self, metrics, outputs): + """Append each metric value to the per-key list of accumulated outputs (in place). + + Args: + metrics (dict): Metric name to value mapping for one step. + outputs (dict): Accumulator mapping each metric name to a list of values. + """ for k, v in metrics.items(): if k not in outputs: outputs[k] = [] outputs[k].append(v) def _shared_cat_metrics(self, outputs): + """Aggregate accumulated per-step metrics into mean values. + + Args: + outputs (dict): Mapping of metric name to a list of per-step tensors. + + Returns: + dict: Mean of each metric; ``dice_p_cls`` is averaged along the step dimension to keep per-class values. + """ results = {} for m, v in outputs.items(): stacked = torch.stack(v) results[m] = torch.mean(stacked) if m != "dice_p_cls" else torch.mean(stacked, dim=0) return results - def __str__(self): + def __str__(self) -> str: + """Return a short model name including the spatial mode. + + Returns: + str: ``"Unet_2D"`` or ``"Unet_3D"`` depending on ``do2D``. + """ text = "Unet" dim = "2D" if self.do2D else "3D" return text + "_" + dim def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor: + """Apply softmax along dimension 1 (the channel/class dimension). + + Args: + x (torch.Tensor): Input tensor with classes on dimension 1. + + Returns: + torch.Tensor: Tensor of the same shape with a softmax applied over dimension 1. + """ return torch.softmax(x, 1) diff --git a/spineps/architectures/read_labels.py b/spineps/architectures/read_labels.py index 96bc4c6..b10be91 100644 --- a/spineps/architectures/read_labels.py +++ b/spineps/architectures/read_labels.py @@ -1,6 +1,9 @@ +"""Vertebra label definitions, enums and mappings used for vertebra classification targets.""" + from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Iterator from dataclasses import dataclass, field from enum import Enum, auto @@ -11,12 +14,16 @@ class VertRegion(Enum): + """Spinal region of a vertebra: cervical (HWS), thoracic (BWS) or lumbar (LWS).""" + HWS = 0 BWS = 1 LWS = 2 class VertRel(Enum): + """Relative position of a vertebra at a region boundary (e.g. last cervical, first thoracic).""" + NOTHING = 0 LAST_HWK = 1 # @@ -28,6 +35,8 @@ class VertRel(Enum): class VertExact(Enum): + """Exact vertebra identity from C1 to L5 (T12 absorbs a potential T13), with 24 classes (0-23).""" + C1 = 0 C2 = 1 C3 = 2 @@ -58,6 +67,8 @@ class VertExact(Enum): class VertExactClass(Enum): + """Exact vertebra identity including an explicit T13 and L6, with 26 classes (0-25).""" + C1 = 0 C2 = 1 C3 = 2 @@ -88,11 +99,15 @@ class VertExactClass(Enum): class VertT13(Enum): + """Whether a vertebra is a (rare) supernumerary T13 or a normal vertebra.""" + Normal = 0 T13 = 1 class VertGroup(Enum): + """Coarse vertebra grouping that buckets neighbouring vertebrae into shared classes (12 groups).""" + C12 = 0 C345 = 1 C67 = 2 @@ -109,12 +124,27 @@ class VertGroup(Enum): def vert_label_to_vertrel( vertlabel: int, - last_bwk, - last_lwk, + last_bwk: int | None, + last_lwk: int | None, last_hwk=7, first_bwk=8, first_lwk=20, ) -> VertRel: + """Map a numeric vertebra label to its region-boundary relation. + + Note that ``last_hwk``, ``first_bwk`` and ``first_lwk`` are reset to their fixed defaults (7, 8, 20) inside the function. + + Args: + vertlabel (int): Numeric vertebra label to classify. + last_bwk: Label of the last thoracic vertebra, or None if unknown. + last_lwk: Label of the last lumbar vertebra, or None if unknown. + last_hwk (int): Label of the last cervical vertebra (overwritten to 7). + first_bwk (int): Label of the first thoracic vertebra (overwritten to 8). + first_lwk (int): Label of the first lumbar vertebra (overwritten to 20). + + Returns: + VertRel: The boundary relation of the given label (NOTHING if it is not a boundary vertebra). + """ last_hwk = 7 first_bwk = 8 first_lwk = 20 @@ -134,14 +164,43 @@ def vert_label_to_vertrel( def vert_class_to_region(vert_exact: VertExact) -> VertRegion: + """Map an exact vertebra class to its spinal region. + + Args: + vert_exact (VertExact): Exact vertebra identity. + + Returns: + VertRegion: HWS for C1-C7, BWS for T1-T12 and LWS for the lumbar vertebrae. + """ return VertRegion.HWS if vert_exact.value < 7 else VertRegion.BWS if 7 <= vert_exact.value < 19 else VertRegion.LWS def vert_label_to_class(vertlabel: int) -> VertExact: + """Map a numeric vertebra label to a :class:`VertExact` class. + + Label 28 (T13) is folded into T12; all other labels map to ``vertlabel - 1`` capped at 23 (L5). + + Args: + vertlabel (int): Numeric vertebra label. + + Returns: + VertExact: The corresponding exact vertebra class. + """ return VertExact.T12 if vertlabel == 28 else VertExact(min(23, vertlabel - 1)) def vert_label_to_exactclass(vertlabel: int) -> VertExactClass: + """Map a numeric vertebra label to a :class:`VertExactClass` class. + + Label 28 maps to the explicit T13 class; labels up to 19 map to ``vertlabel - 1`` (capped at 24) and higher labels to + ``vertlabel`` (capped at 25), accounting for the extra T13/L6 slots. + + Args: + vertlabel (int): Numeric vertebra label. + + Returns: + VertExactClass: The corresponding exact vertebra class. + """ return ( VertExactClass.T13 if vertlabel == 28 @@ -188,10 +247,32 @@ def vert_label_to_exactclass(vertlabel: int) -> VertExactClass: def vert_class_to_group(vert_exact: VertExact) -> VertGroup: + """Map an exact vertebra class to its coarse :class:`VertGroup`. + + Args: + vert_exact (VertExact): Exact vertebra identity. + + Returns: + VertGroup: The group that contains the given vertebra. + """ return vert_exact_to_group_dict[vert_exact] def vertgrp_sequence_to_class(vertgrp: list[VertGroup]) -> list[VertExact]: + """Resolve a top-to-bottom sequence of vertebra groups into exact vertebra classes. + + For each group, if every member of the group is present the assignment is trivial; otherwise the neighbouring group before or + after the partial run determines whether the members align from the top or from the bottom of the group. + + Args: + vertgrp (list[VertGroup]): Vertebra groups ordered from top (cranial) to bottom (caudal). + + Returns: + list[VertExact]: Exact vertebra classes for each position in the input sequence. + + Raises: + AssertionError: If a partial group has neighbours on both sides, which cannot be resolved unambiguously. + """ # input must be sorted from top to bottom already! vert_exact_list: list[VertExact] = [None] * len(vertgrp) # type: ignore @@ -221,16 +302,40 @@ def vertgrp_sequence_to_class(vertgrp: list[VertGroup]) -> list[VertExact]: class LabelType(ABC): + """Abstract base for converting one or more columns of a data entry into a model target label.""" + def __init__(self, column_name: str | list[str], *args, **kwargs) -> None: # noqa: ARG002 + """Initialize the label type with the source column name(s). + + Args: + column_name (str | list[str]): Single column name or list of column names to read from each entry; a single + string is wrapped into a one-element list. + """ if not isinstance(column_name, list): column_name = [column_name] self.column_name = column_name - def __call__(self, entry_dict: dict): + def __call__(self, entry_dict: dict) -> object: + """Read the configured columns from ``entry_dict`` and convert them into a label. + + Args: + entry_dict (dict): Mapping of column names to values for a single sample. + + Returns: + The label produced by :meth:`convert_to_label`. + """ entry = self.get_entry(entry_dict) return self.convert_to_label(entry) def get_entry(self, entry: dict) -> str | int | list[str | int]: + """Extract the configured column value(s) from a data entry. + + Args: + entry (dict): Mapping of column names to values. + + Returns: + str | int | list[str | int]: The single value if only one column is configured, otherwise a list of values. + """ entries = [entry[c] for c in self.column_name] if len(entries) == 1: return entries[0] @@ -239,24 +344,45 @@ def get_entry(self, entry: dict) -> str | int | list[str | int]: @property @abstractmethod def number_of_channel(self) -> str | int | list[str | int]: - pass + """Number of output channels (label vector length) produced by this label type.""" @abstractmethod def convert_to_label(self, entry: str): - pass + """Convert an extracted entry value into the label representation for this label type. + + Args: + entry (str): The value extracted from the data entry. + """ class EnumLabelType(LabelType): + """Label type that one-hot encodes an :class:`~enum.Enum` value into a multi-class target vector.""" + def __init__(self, enum: Enum, column_name: str, *args, **kwargs) -> None: # noqa: ARG002 + """Initialize the enum label type. + + Args: + enum (Enum): Enum class whose members define the classes; its length sets the number of channels. + column_name (str): Column name holding the enum value for each entry. + """ super().__init__(column_name) self.enum = enum self.n_channel = len(enum) @property def number_of_channel(self) -> int: + """Number of channels, equal to the number of members in the configured enum.""" return self.n_channel - def convert_to_label(self, entry: Enum): + def convert_to_label(self, entry: Enum) -> list[int]: + """One-hot encode an enum member into a label vector. + + Args: + entry (Enum): Enum member whose ``value`` indexes the hot position. + + Returns: + list[int]: A list of zeros with a single 1 at the index given by ``entry.value``. + """ label = list(np.zeros(self.number_of_channel, dtype=int)) idx = entry.value label[idx] = 1 @@ -264,14 +390,33 @@ def convert_to_label(self, entry: Enum): class BinaryLabelTypeDummy(LabelType): + """Label type for a binary attribute, one-hot encoded into two channels (true/false).""" + def __init__(self, column_name: str | list[str], *args, **kwargs) -> None: + """Initialize the binary label type. + + Args: + column_name (str | list[str]): Column name(s) holding the binary value. + """ super().__init__(column_name, *args, **kwargs) @property def number_of_channel(self) -> int: + """Number of channels, always 2 (true and false).""" return 2 def convert_to_label(self, entry: str | int) -> int: + """Convert a truthy/falsy entry into a two-channel one-hot label. + + Args: + entry (str | int): A value contained in ``TRUE_KEYS`` or ``FALSE_KEYS``. + + Returns: + list[int]: ``[1, 0]`` for true values and ``[0, 1]`` for false values. + + Raises: + AssertionError: If ``entry`` is a list, or is not recognised as a true or false value. + """ assert not isinstance(entry, list), entry if entry in TRUE_KEYS: return [1, 0] @@ -281,6 +426,8 @@ def convert_to_label(self, entry: str | int) -> int: class Target(Enum): + """Available classification targets, each mapping to a (label-type, enum/column, column-name) configuration tuple.""" + REGION = EnumLabelType, VertRegion, "vert_region" # HWS, BWS, LWS VERT = EnumLabelType, VertExact, "vert_exact" # exakt WK VERTEX = EnumLabelType, VertExactClass, "vert_exact2" # exakt WK @@ -296,11 +443,19 @@ class Target(Enum): class Objectives: + """Bundle of classification targets that builds and combines their label vectors for a single data entry.""" + def __init__( self, objectives: list[Target], as_group: bool = True, ) -> None: + """Initialize the objectives and instantiate the label type for each target. + + Args: + objectives (list[Target]): Targets to compute labels for, in order. + as_group (bool): If True, ``__call__`` returns labels grouped per target name; if False, a flat concatenated list. + """ self.__as_group = as_group self.targets: list[Target] = objectives self.__objective_labels: list[LabelType] = [] @@ -316,24 +471,40 @@ def __init__( @property def n_channel_p_group(self): + """List of channel counts, one per target objective.""" return self.__n_channel_p_group @property def n_channel(self): + """Total number of channels across all target objectives.""" return self.__n_channel @property def group_2_n_channel(self) -> dict[str, int]: + """Mapping from each target name to its number of channels.""" return {self.targets[idx].name: self.n_channel_p_group[idx] for idx in range(len(self.targets))} @property def required_dict_keys(self): + """Unique set of data-entry column names required to compute all objectives.""" return self.__required_dict_keys def __call__( self, entry: dict, ) -> list[int]: + """Compute the label(s) for all objectives from a single data entry. + + Args: + entry (dict): Data entry containing at least every key in :attr:`required_dict_keys`. + + Returns: + list[int] | dict | None: A flat concatenated label list when ``as_group`` is False, a per-target-name dict of label + lists when ``as_group`` is True, or None if a label could not be produced (e.g. a NaN binary/pathology value). + + Raises: + AssertionError: If a required key is missing from ``entry``. + """ entry_keys = entry.keys() for r in self.required_dict_keys: assert r in entry_keys, f"required {r} not in entry_keys, got {entry_keys}" @@ -357,7 +528,15 @@ def __call__( return labels if not self.__as_group else {self.targets[idx].name: labels_grouped[idx] for idx in range(len(self.targets))} -def flatten(a: list[str | int | list[str] | list[int]]): +def flatten(a: list[str | int | list[str] | list[int]]) -> Iterator[str | int]: + """Recursively flatten an arbitrarily nested list of strings and integers. + + Args: + a (list[str | int | list[str] | list[int]]): A value or (nested) list of strings and integers. + + Yields: + str | int: Each leaf string or integer in depth-first order. + """ # a = itertools.chain(*a) if isinstance(a, (str, int)): yield a @@ -369,6 +548,8 @@ def flatten(a: list[str | int | list[str] | list[int]]): ### @dataclass class SubjectInfo: + """Per-subject vertebra labelling metadata, including anomalies, the resolved label map and region boundaries.""" + subject_name: int has_anomaly_entry: bool anomaly_entry: dict @@ -385,12 +566,22 @@ class SubjectInfo: @property def has_tea(self) -> bool: + """Whether the subject has a transitional anomaly (a T11 or T13 anomaly entry). + + Returns: + bool | None: True/False based on the T11/T13 anomaly flags, or None if the subject has no anomaly entry. + """ if not self.has_anomaly_entry: return None return self.anomaly_entry["T11"] or self.anomaly_entry["T13"] @property def block(self) -> int: + """Dataset block identifier, taken from the first three digits of the subject name. + + Returns: + int: The integer formed by the first three characters of ``subject_name``. + """ return int(str(self.subject_name)[:3]) @@ -400,7 +591,24 @@ def get_subject_info( anomaly_dict: dict, vert_subfolders_int: list[int], subject_name_int: bool = True, -): +) -> SubjectInfo: + """Build a :class:`SubjectInfo` from a subject's raw vertebra labels and any anomaly overrides. + + Applies anomaly handling (label deletion, removal flags, T11/T13 remapping and explicit label overrides), derives the actual + labels, the expected double-entry labels and the last thoracic/lumbar vertebra labels. + + Args: + subject_name (str | int): Subject identifier. + anomaly_dict (dict): Mapping of subject names to anomaly entries; empty if no anomalies are known. + vert_subfolders_int (list[int]): Raw numeric vertebra labels present for the subject. + subject_name_int (bool): If True, cast ``subject_name`` to int before lookup. + + Returns: + SubjectInfo: The assembled per-subject labelling metadata. + + Raises: + AssertionError: If a ``LabelOverride`` length does not match the number of vertebra labels. + """ if subject_name_int: subject_name = int(subject_name) double_entries = [] @@ -469,6 +677,18 @@ def get_subject_info( def get_vert_entry(v: int, subject_info: SubjectInfo) -> tuple[int, dict]: + """Build the per-vertebra target entry dict for a single vertebra label. + + Applies the subject's label map to ``v`` and fills in all derived targets (relative position, exact class, exact2 class, + group, region and T13 flag). + + Args: + v (int): Raw numeric vertebra label. + subject_info (SubjectInfo): Subject metadata providing the label map and region boundaries. + + Returns: + tuple[int, dict]: The remapped actual label and a dict of target values keyed by their column names. + """ entry: dict = {} v_actual = subject_info.labelmap.get(v, v) diff --git a/spineps/architectures/unet3D.py b/spineps/architectures/unet3D.py index 41a096f..203f90a 100755 --- a/spineps/architectures/unet3D.py +++ b/spineps/architectures/unet3D.py @@ -1,3 +1,5 @@ +"""3D U-Net architecture with residual blocks used for volumetric spine segmentation.""" + from __future__ import annotations from functools import partial @@ -9,6 +11,13 @@ class Unet3D(nn.Module): + """A 3D U-Net with residual (ResNet) blocks, a symmetric encoder/decoder and skip connections. + + The encoder repeatedly applies two residual blocks followed by a strided convolution that halves each spatial dimension; a + bottleneck of two residual blocks follows; the decoder mirrors the encoder with transposed convolutions and averages in the + matching encoder skip features before a final residual block and 1x1x1 output convolution. + """ + def __init__( self, dim, @@ -21,6 +30,19 @@ def __init__( learned_variance=False, conditional_label_size=0, ): + """Build the 3D U-Net layers. + + Args: + dim (int): Base feature dimension used to derive per-level channel counts. + init_dim (int | None): Channels after the initial convolution; defaults to ``dim``. + out_dim (int | None): Number of output channels; defaults to ``channels`` (doubled if ``learned_variance``). + dim_mults (tuple[int, ...]): Per-resolution multipliers of ``dim`` defining encoder/decoder depth and widths. + channels (int): Number of input image channels. + conditional_dimensions (int): Extra conditioning channels concatenated to the input at the first convolution. + resnet_block_groups (int): Number of groups for the GroupNorm inside each residual block. + learned_variance (bool): If True, doubles the default output channels to also predict variance. + conditional_label_size (int): Size of an optional conditional label vector (stored but unused in ``forward``). + """ super().__init__() self.learned_variance = learned_variance @@ -80,7 +102,28 @@ def __init__( self.final_conv = nn.Conv3d(dim, self.out_dim, 1) self.first_forward = False - def forward(self, x, time=None, label=None, embedding=None) -> torch.Tensor: # time # noqa: ARG002 + def forward( + self, + x, + time: torch.Tensor | None = None, + label: torch.Tensor | None = None, # noqa: ARG002 + embedding: torch.Tensor | None = None, # noqa: ARG002 + ) -> torch.Tensor: # time + """Run the U-Net forward pass on a 5D input volume. + + Args: + x (torch.Tensor): Input tensor of shape ``(B, channels, D, H, W)``; each spatial dimension must be divisible by + ``2 ** (num_downsampling_levels)``. + time: Unused timestep input; replaced by a constant if None. + label: Unused optional conditioning label. + embedding: Unused optional conditioning embedding. + + Returns: + torch.Tensor: Output tensor of shape ``(B, out_dim, D, H, W)`` with the same spatial size as the input. + + Raises: + AssertionError: If any spatial dimension of ``x`` is not divisible by the total downsampling factor. + """ down_factor = 2 ** (len(self.downs) - 1) shape = x.shape assert shape[-1] % down_factor == 0, f"dimensions are not dividable by {down_factor}, {shape}, {shape[-1]}" @@ -147,13 +190,32 @@ def forward(self, x, time=None, label=None, embedding=None) -> torch.Tensor: # class Block3D(nn.Module): + """Basic 3D conv block: 3x3x3 convolution, group normalization, optional FiLM-style scale/shift and LeakyReLU.""" + def __init__(self, dim, dim_out, groups=8): + """Build the conv block. + + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + groups (int): Number of groups for the GroupNorm. + """ super().__init__() self.proj = nn.Conv3d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.LeakyReLU() - def forward(self, x, scale_shift=None): + def forward(self, x, scale_shift=None) -> torch.Tensor: + """Apply convolution, normalization, optional scale/shift modulation and activation. + + Args: + x (torch.Tensor): Input tensor of shape ``(B, dim, D, H, W)``. + scale_shift (tuple[torch.Tensor, torch.Tensor] | None): Optional ``(scale, shift)`` tensors applied as + ``x * (scale + 1) + shift`` after normalization. + + Returns: + torch.Tensor: Output tensor of shape ``(B, dim_out, D, H, W)``. + """ x = self.proj(x) x = self.norm(x) @@ -166,7 +228,17 @@ def forward(self, x, scale_shift=None): class ResnetBlock3D(nn.Module): + """Residual block of two :class:`Block3D` layers with a skip connection and optional time-embedding modulation.""" + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + """Build the residual block. + + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + time_emb_dim (int | None): If given, size of a time embedding mapped to per-channel scale and shift parameters. + groups (int): Number of groups for the GroupNorm in each inner block. + """ super().__init__() self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if time_emb_dim is not None else None @@ -174,7 +246,17 @@ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): self.block2 = Block3D(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() - def forward(self, x, time_emb=None): + def forward(self, x, time_emb=None) -> torch.Tensor: + """Apply the two conv blocks plus residual connection, optionally modulated by a time embedding. + + Args: + x (torch.Tensor): Input tensor of shape ``(B, dim, D, H, W)``. + time_emb (torch.Tensor | None): Optional time embedding of shape ``(B, time_emb_dim)`` used to produce the + scale/shift applied in the first inner block. + + Returns: + torch.Tensor: Output tensor of shape ``(B, dim_out, D, H, W)``. + """ scale_shift = None if self.mlp is not None and time_emb is not None: time_emb = self.mlp(time_emb) @@ -188,7 +270,16 @@ def forward(self, x, time_emb=None): return h + self.res_conv(x) -def default(val, d): +def default(val: object, d: object) -> object: + """Return ``val`` if it is not None, otherwise a default value. + + Args: + val: The candidate value. + d: The fallback value, or a zero-argument callable that produces it. + + Returns: + ``val`` if not None; otherwise ``d()`` when ``d`` is callable, else ``d``. + """ if val is not None: return val return d() if isfunction(d) else d diff --git a/spineps/architectures_new/__init__.py b/spineps/architectures_new/__init__.py index e69de29..2f38ed1 100644 --- a/spineps/architectures_new/__init__.py +++ b/spineps/architectures_new/__init__.py @@ -0,0 +1 @@ +"""New U-Net architectures, the Lightning wrapper and the Dice loss for SPINEPS segmentation.""" diff --git a/spineps/architectures_new/dice.py b/spineps/architectures_new/dice.py index 5d14dc3..bf65b65 100644 --- a/spineps/architectures_new/dice.py +++ b/spineps/architectures_new/dice.py @@ -1,3 +1,5 @@ +"""Memory-efficient soft Dice loss for multi-class segmentation.""" + from __future__ import annotations import torch @@ -5,7 +7,24 @@ class MemoryEfficientSoftDiceLoss(nn.Module): + """Soft Dice computed without materializing a full one-hot target when the prediction already matches its shape. + + Returns the mean soft Dice coefficient over classes (and optionally the batch). The caller typically uses + ``1 - loss`` as the actual loss term. + """ + def __init__(self, apply_nonlin=None, batch_dice: bool = False, do_bg: bool = True, smooth: float = 1.0, ddp: bool = True): + """Configure the soft Dice computation. + + Args: + apply_nonlin (Callable | None): Optional non-linearity (e.g. softmax) applied to ``x`` before the Dice + computation. + batch_dice (bool): If ``True``, accumulate the statistics over the whole batch before dividing; otherwise + compute the Dice per sample. + do_bg (bool): If ``True``, include the background class (channel 0); otherwise drop it. + smooth (float): Smoothing constant added to the denominator for numerical stability. + ddp (bool): Flag retained for distributed-data-parallel all-gather (currently unused in this implementation). + """ super().__init__() self.do_bg = do_bg @@ -15,6 +34,17 @@ def __init__(self, apply_nonlin=None, batch_dice: bool = False, do_bg: bool = Tr self.ddp = ddp def forward(self, x, y, loss_mask=None): + """Compute the mean soft Dice coefficient between predictions and targets. + + Args: + x (torch.Tensor): Prediction tensor of shape ``(b, c, ...)``; the non-linearity is applied first if set. + y (torch.Tensor): Target tensor, either class indices of shape ``(b, 1, ...)`` / ``(b, ...)`` or a one-hot + encoding matching the shape of ``x``. + loss_mask (torch.Tensor | None): Optional mask multiplied into the statistics to ignore certain voxels. + + Returns: + torch.Tensor: Scalar mean soft Dice coefficient. + """ shp_x, shp_y = x.shape, y.shape if self.apply_nonlin is not None: diff --git a/spineps/architectures_new/pl_unet.py b/spineps/architectures_new/pl_unet.py index 9897195..636918f 100644 --- a/spineps/architectures_new/pl_unet.py +++ b/spineps/architectures_new/pl_unet.py @@ -1,3 +1,5 @@ +"""PyTorch Lightning wrapper training a 2D or 3D U-Net for spine segmentation.""" + from __future__ import annotations from argparse import Namespace @@ -16,15 +18,46 @@ def softmax_helper_dim1(x: torch.Tensor) -> torch.Tensor: + """Apply softmax over dimension 1 (the channel/class dimension). + + Args: + x (torch.Tensor): Logits tensor with the classes on dimension 1. + + Returns: + torch.Tensor: Softmax probabilities of the same shape as ``x``. + """ return torch.softmax(x, 1) def _tb_logger(module: pl.LightningModule) -> TensorBoardLogger: + """Return the module's logger cast to :class:`TensorBoardLogger`. + + Args: + module (pl.LightningModule): Lightning module whose ``logger`` is a TensorBoard logger. + + Returns: + TensorBoardLogger: The module's logger typed as a TensorBoard logger. + """ return cast(TensorBoardLogger, module.logger) class PLNet(pl.LightningModule): + """LightningModule training a 2D or 3D U-Net with a combined cross-entropy, Dice and L2 loss. + + Wraps :class:`Unet2D` or :class:`Unet3D` and handles the training/validation loops, loss computation, + Dice metric logging and optimizer configuration. + """ + def __init__(self, opt: Namespace, do2D: bool = False, *args: Any, **kwargs: Any) -> None: # noqa: ARG002 + """Build the network and configure losses, metrics and training hyperparameters. + + Args: + opt (Namespace): Configuration namespace providing ``channelwise``, ``n_epoch``, ``lr``, + ``lr_end_factor``, ``l2_reg_w`` and ``dsc_loss_w``. + do2D (bool): If ``True``, use the 2D U-Net; otherwise the 3D U-Net. + *args (Any): Unused positional arguments. + **kwargs (Any): Unused keyword arguments. + """ super().__init__() self.save_hyperparameters() @@ -58,6 +91,7 @@ def __init__(self, opt: Namespace, do2D: bool = False, *args: Any, **kwargs: Any self.val_step_outputs: dict[str, list] = {} def on_fit_start(self): + """Register custom TensorBoard scalar layouts grouping the train/val losses and Dice metrics.""" tb = _tb_logger(self).experiment layout = { "loss_split": { @@ -84,9 +118,25 @@ def on_fit_start(self): tb.add_custom_scalars(layout) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the wrapped U-Net on the input. + + Args: + x (torch.Tensor): Input image/volume tensor. + + Returns: + torch.Tensor: Per-class logits produced by the network. + """ return self.network(x) def training_step(self, batch): + """Run a single training step: compute losses, log them and accumulate metrics. + + Args: + batch (dict): Batch with the input image under ``"target"`` and the ground-truth labels under ``"class"``. + + Returns: + torch.Tensor: The combined scalar training loss to back-propagate. + """ target, gt = batch["target"], batch["class"] losses, gt, pred_cls = self._shared_step(target, gt) loss = self._merge_losses(losses) @@ -104,6 +154,7 @@ def training_step(self, batch): return loss def on_train_epoch_end(self) -> None: + """Aggregate the accumulated training metrics, log mean/foreground Dice and clear the buffers.""" if self.train_step_outputs: metrics = self._aggregate_metrics(self.train_step_outputs) self.log("dice/train_dice", metrics["dice"], on_epoch=True) @@ -116,6 +167,12 @@ def on_train_epoch_end(self) -> None: self.train_step_outputs.clear() def validation_step(self, batch, _): + """Run a single validation step, computing losses and metrics and accumulating them for the epoch end. + + Args: + batch (dict): Batch with the input image under ``"target"`` and the ground-truth labels under ``"class"``. + _ : Batch index (unused). + """ target, gt = batch["target"], batch["class"] losses, gt, pred_cls = self._shared_step(target, gt) loss = self._merge_losses(losses).detach().cpu() @@ -126,6 +183,7 @@ def validation_step(self, batch, _): self._append_metrics(metrics, self.val_step_outputs) def on_validation_epoch_end(self): + """Aggregate the accumulated validation metrics, log losses and Dice scores and clear the buffers.""" if self.val_step_outputs: metrics = self._aggregate_metrics(self.val_step_outputs) for k, v in metrics.items(): @@ -145,6 +203,11 @@ def on_validation_epoch_end(self): self.val_step_outputs.clear() def configure_optimizers(self): + """Configure the Adam optimizer and a linear learning-rate decay schedule. + + Returns: + dict: Mapping with the ``"optimizer"`` and its ``"lr_scheduler"``. + """ optimizer = Adam(self.parameters(), lr=self.start_lr) scheduler = lr_scheduler.LinearLR( optimizer=optimizer, @@ -155,16 +218,43 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": scheduler} def _compute_losses(self, logits: torch.Tensor, gt: torch.Tensor) -> dict[str, torch.Tensor]: + """Compute the cross-entropy, (weighted) Dice and L2 regularization losses. + + Args: + logits (torch.Tensor): Per-class logits of shape ``(b, n_classes, ...)``. + gt (torch.Tensor): Ground-truth labels of shape ``(b, 1, ...)``. + + Returns: + dict[str, torch.Tensor]: Dictionary with keys ``"ce_loss"``, ``"dc_loss"`` and ``"l2_reg_loss"``. + """ ce_loss = self.CEL(logits, gt.squeeze(1)) dc_loss = (1 - self.DC(logits, gt)) * self.dsc_loss_w l2_reg = torch.stack([p.norm() for p in self.parameters()]).sum() * self.l2_reg_w return {"ce_loss": ce_loss, "dc_loss": dc_loss, "l2_reg_loss": l2_reg} def _merge_losses(self, losses: dict[str, torch.Tensor]) -> torch.Tensor: + """Sum the three loss components into a single scalar loss. + + Args: + losses (dict[str, torch.Tensor]): Dictionary with exactly three loss tensors. + + Returns: + torch.Tensor: The sum of the three loss values. + """ vals = list(losses.values()) return vals[0] + vals[1] + vals[2] def _shared_step(self, target: torch.Tensor, gt: torch.Tensor): + """Forward the input, compute losses and produce CPU predictions and ground truth. + + Args: + target (torch.Tensor): Input image/volume tensor. + gt (torch.Tensor): Ground-truth labels. + + Returns: + tuple: ``(losses, gt, pred_cls)`` where ``losses`` is the loss dict, ``gt`` is the ground truth moved + to CPU and ``pred_cls`` is the per-voxel arg-max class prediction on CPU. + """ logits = self.forward(target) losses = self._compute_losses(logits, gt) @@ -175,6 +265,17 @@ def _shared_step(self, target: torch.Tensor, gt: torch.Tensor): return losses, gt, pred_cls def _compute_metrics(self, loss: torch.Tensor, pred_cls: torch.Tensor, gt: torch.Tensor) -> dict: + """Compute per-class, mean and foreground Dice scores alongside the given loss. + + Args: + loss (torch.Tensor): Scalar loss value to carry through into the metrics dict. + pred_cls (torch.Tensor): Predicted class indices. + gt (torch.Tensor): Ground-truth class indices. + + Returns: + dict: Dictionary with keys ``"loss"``, ``"dice"`` (mean over all classes), ``"diceFG"`` (mean over + foreground classes) and ``"dice_p_cls"`` (per-class Dice). + """ dice_p_cls = mF.dice(pred_cls, gt, average=None, num_classes=self.n_classes) return { "loss": loss, @@ -184,11 +285,31 @@ def _compute_metrics(self, loss: torch.Tensor, pred_cls: torch.Tensor, gt: torch } def _append_metrics(self, metrics: dict, outputs: dict): + """Append each metric value to the matching list in the per-epoch output buffer. + + Args: + metrics (dict): Metrics computed for a single step. + outputs (dict): Accumulator mapping each metric name to a list of per-step values; mutated in place. + """ for k, v in metrics.items(): outputs.setdefault(k, []).append(v) def _aggregate_metrics(self, outputs: dict) -> dict: + """Average the accumulated per-step metrics over an epoch. + + Args: + outputs (dict): Accumulator mapping each metric name to a list of per-step tensors. + + Returns: + dict: Mapping from metric name to its mean; ``"dice_p_cls"`` is averaged per class (``dim=0``) while + all other metrics are reduced to a scalar. + """ return {k: torch.mean(torch.stack(v)) if k != "dice_p_cls" else torch.mean(torch.stack(v), dim=0) for k, v in outputs.items()} def __str__(self): + """Return a short name indicating whether the wrapped network is 2D or 3D. + + Returns: + str: ``"Unet_2D"`` or ``"Unet_3D"``. + """ return f"Unet_{'2D' if self.do2D else '3D'}" diff --git a/spineps/architectures_new/unet2D.py b/spineps/architectures_new/unet2D.py index e594e81..61021d3 100644 --- a/spineps/architectures_new/unet2D.py +++ b/spineps/architectures_new/unet2D.py @@ -1,3 +1,5 @@ +"""2D U-Net architecture with time, label and embedding conditioning for diffusion-style models.""" + from __future__ import annotations import itertools @@ -10,6 +12,15 @@ def default(val, d): + """Return ``val`` if it is not ``None``, otherwise the default ``d``. + + Args: + val: The value to use when it is not ``None``. + d: The fallback value, or a zero-argument callable producing the fallback. + + Returns: + ``val`` when it is not ``None``; otherwise ``d()`` if ``d`` is a function, else ``d``. + """ from inspect import isfunction if val is not None: @@ -21,11 +32,26 @@ def default(val, d): class SinusoidalPosEmb(nn.Module): + """Fixed sinusoidal positional embedding used to encode the diffusion time step.""" + def __init__(self, dim): + """Initialize the embedding. + + Args: + dim (int): Output embedding dimension. Half is used for the sine and half for the cosine components. + """ super().__init__() self.dim = dim def forward(self, x): + """Compute the sinusoidal embedding for the given scalar values. + + Args: + x (torch.Tensor): Tensor of shape ``(b,)`` with the values (e.g. time steps) to embed. + + Returns: + torch.Tensor: Embedding of shape ``(b, dim)``. + """ device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) @@ -41,12 +67,28 @@ class LearnedSinusoidalPosEmb(nn.Module): """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ def __init__(self, dim): + """Initialize the learned embedding. + + Args: + dim (int): Output embedding dimension; must be even. ``dim // 2`` learnable frequencies are used. + + Raises: + AssertionError: If ``dim`` is not even. + """ super().__init__() assert (dim % 2) == 0 half_dim = dim // 2 self.weights = nn.Parameter(torch.randn(half_dim)) # type: ignore def forward(self, x): + """Compute the learned Fourier embedding for the given scalar values. + + Args: + x (torch.Tensor): Tensor of shape ``(b,)`` with the values (e.g. time steps) to embed. + + Returns: + torch.Tensor: Embedding of shape ``(b, dim + 1)`` concatenating the input with its sine and cosine features. + """ x = rearrange(x, "b -> b 1") freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) @@ -55,46 +97,116 @@ def forward(self, x): class PreNorm(nn.Module): + """Apply layer normalization to the input before passing it through a wrapped module.""" + def __init__(self, dim, fn): + """Initialize the pre-normalization wrapper. + + Args: + dim (int): Number of channels to normalize over. + fn (nn.Module): Module applied to the normalized input. + """ super().__init__() self.fn = fn self.norm = LayerNorm(dim) def forward(self, x, *args, **kwargs): + """Normalize ``x`` and forward it (with any extra arguments) through the wrapped module. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, c, h, w)``. + *args: Positional arguments forwarded to the wrapped module. + **kwargs: Keyword arguments forwarded to the wrapped module. + + Returns: + torch.Tensor: Output of the wrapped module applied to the normalized input. + """ x = self.norm(x) return self.fn(x, *args, **kwargs) class Residual(nn.Module): + """Add a skip connection around a wrapped module.""" + def __init__(self, fn): + """Initialize the residual wrapper. + + Args: + fn (nn.Module): Module whose output is added to its input. + """ super().__init__() self.fn = fn def forward(self, x, *args, **kwargs): + """Forward ``x`` through the wrapped module and add the input as a residual. + + Args: + x (torch.Tensor): Input tensor. + *args: Positional arguments forwarded to the wrapped module. + **kwargs: Keyword arguments forwarded to the wrapped module. + + Returns: + torch.Tensor: ``fn(x) + x``. + """ return self.fn(x, *args, **kwargs) + x class LayerNorm(nn.Module): + """Channel-wise layer normalization for 4D ``(b, c, h, w)`` tensors with learnable scale and bias.""" + def __init__(self, dim, eps=1e-5): + """Initialize the layer normalization. + + Args: + dim (int): Number of channels to normalize over. + eps (float): Small constant added to the variance for numerical stability. + """ super().__init__() self.eps = eps self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) # type: ignore self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) # type: ignore def forward(self, x): + """Normalize ``x`` over the channel dimension and apply the learnable scale and bias. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, c, h, w)``. + + Returns: + torch.Tensor: Normalized tensor of the same shape as ``x``. + """ var = torch.var(x, dim=1, unbiased=False, keepdim=True) mean = torch.mean(x, dim=1, keepdim=True) return (x - mean) / (var + self.eps).sqrt() * self.g + self.b class Block(nn.Module): + """Convolution, group normalization and SiLU activation, with optional FiLM-style scale/shift modulation.""" + def __init__(self, dim, dim_out, groups=8): + """Initialize the block. + + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + groups (int): Number of groups for group normalization. + """ super().__init__() self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.SiLU() def forward(self, x, scale_shift=None): + """Apply convolution, normalization, optional modulation and activation. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, dim, h, w)``. + scale_shift (tuple[torch.Tensor, torch.Tensor] | None): Optional ``(scale, shift)`` tensors used to + modulate the normalized features as ``x * (scale + 1) + shift``. + + Returns: + torch.Tensor: Output tensor of shape ``(b, dim_out, h, w)``. + """ x = self.proj(x) x = self.norm(x) @@ -107,7 +219,18 @@ def forward(self, x, scale_shift=None): class ResnetBlock(nn.Module): + """Residual block of two convolutional blocks with optional time-embedding conditioning.""" + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + """Initialize the residual block. + + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + time_emb_dim (int | None): Dimension of the time embedding. If given, an MLP produces per-channel + scale and shift values; if ``None``, no time conditioning is applied. + groups (int): Number of groups for the group normalization in each block. + """ super().__init__() self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if time_emb_dim is not None else None @@ -116,6 +239,16 @@ def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb=None): + """Apply the two blocks with optional time conditioning and a residual connection. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, dim, h, w)``. + time_emb (torch.Tensor | None): Optional time embedding of shape ``(b, time_emb_dim)`` used to derive + the scale/shift modulation of the first block. + + Returns: + torch.Tensor: Output tensor of shape ``(b, dim_out, h, w)``. + """ scale_shift = None if self.mlp is not None and time_emb is not None: time_emb = self.mlp(time_emb) @@ -130,7 +263,16 @@ def forward(self, x, time_emb=None): class LinearAttention(nn.Module): + """Multi-head linear attention over spatial positions with linear complexity in the number of pixels.""" + def __init__(self, dim, heads=4, dim_head=32): + """Initialize the linear attention module. + + Args: + dim (int): Number of input and output channels. + heads (int): Number of attention heads. + dim_head (int): Channel dimension per attention head. + """ super().__init__() self.scale = dim_head**-0.5 self.heads = heads @@ -140,6 +282,14 @@ def __init__(self, dim, heads=4, dim_head=32): self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) def forward(self, x): + """Apply linear attention to the spatial feature map. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, dim, h, w)``. + + Returns: + torch.Tensor: Output tensor of shape ``(b, dim, h, w)``. + """ _b, _c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = (rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads) for t in qkv) @@ -156,7 +306,16 @@ def forward(self, x): class Attention(nn.Module): + """Standard multi-head softmax self-attention over spatial positions.""" + def __init__(self, dim, heads=4, dim_head=32): + """Initialize the attention module. + + Args: + dim (int): Number of input and output channels. + heads (int): Number of attention heads. + dim_head (int): Channel dimension per attention head. + """ super().__init__() self.scale = dim_head**-0.5 self.heads = heads @@ -165,6 +324,14 @@ def __init__(self, dim, heads=4, dim_head=32): self.to_out = nn.Conv2d(hidden_dim, dim, 1) def forward(self, x): + """Apply full self-attention to the spatial feature map. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, dim, h, w)``. + + Returns: + torch.Tensor: Output tensor of shape ``(b, dim, h, w)``. + """ _b, _c, h, w = x.shape qkv = self.to_qkv(x).chunk(3, dim=1) q, k, v = (rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads) for t in qkv) @@ -180,6 +347,15 @@ def forward(self, x): class Unet2D(nn.Module): + """2D U-Net with residual blocks, (linear) attention and time/label/embedding conditioning. + + The network down-samples through a configurable number of resolution stages, applies a bottleneck with + full attention and up-samples again using skip connections. The diffusion time step is encoded via a + sinusoidal (or learned-sinusoidal) embedding, and the model can additionally be conditioned on a class + label and/or an external embedding vector. Optional patching folds spatial patches into the channel + dimension to reduce the spatial resolution processed by the network. + """ + def __init__( self, dim, @@ -196,6 +372,24 @@ def __init__( conditional_embedding_size=0, patch_size=1, # Improving Diffusion Model Efficiency Through Patching https://arxiv.org/abs/2207.04316; 1 means deactivated (Note: Increases Training difficulty by a lot!) ): + """Build the 2D U-Net layers. + + Args: + dim (int): Base feature dimension used to derive the channel widths and the time embedding size. + init_dim (int | None): Channels produced by the initial convolution. Defaults to ``dim``. + out_dim (int | None): Number of output channels before patch unfolding. Defaults to ``channels`` + (doubled when ``learned_variance`` is set). + dim_mults (tuple): Channel multipliers applied to ``dim`` for the successive resolution stages. + channels (int): Number of image channels. + conditional_dimensions (int): Number of additional conditioning channels concatenated to the input. + resnet_block_groups (int): Number of groups for the group normalization inside the residual blocks. + learned_variance (bool): If ``True``, double the default output channels to also predict a variance. + learned_sinusoidal_cond (bool): If ``True``, use a learned sinusoidal time embedding instead of a fixed one. + learned_sinusoidal_dim (int): Dimension of the learned sinusoidal embedding when enabled. + conditional_label_size (int): Number of classes for label conditioning; 0 disables it. + conditional_embedding_size (int): Size of an external embedding concatenated to the time embedding; 0 disables it. + patch_size (int): Spatial patch size folded into channels; 1 disables patching. + """ super().__init__() self.patch_size = patch_size self.learned_variance = learned_variance @@ -282,6 +476,14 @@ def __init__( # Improving Diffusion Model Efficiency Through Patching https://arxiv.org/abs/2207.04316 (Note: Increases Training difficulty by a lot!) def to_patches(self, x): + """Fold ``patch_size x patch_size`` spatial patches into the channel dimension. + + Args: + x (torch.Tensor): Input tensor of shape ``(B, C, H, W)`` with ``H`` and ``W`` divisible by ``patch_size``. + + Returns: + torch.Tensor: Tensor of shape ``(B, C * patch_size**2, H // patch_size, W // patch_size)``. + """ p = self.patch_size B, C, H, W = x.shape x = x.permute(0, 2, 3, 1).reshape(B, H, W // p, C * p) @@ -289,6 +491,14 @@ def to_patches(self, x): return x.permute(0, 3, 2, 1) def from_patches(self, x): + """Invert :meth:`to_patches`, unfolding the channel dimension back into spatial patches. + + Args: + x (torch.Tensor): Patched tensor of shape ``(B, C, H, W)``. + + Returns: + torch.Tensor: Tensor of shape ``(B, C // patch_size**2, H * patch_size, W * patch_size)``. + """ p = self.patch_size B, C, H, W = x.shape @@ -297,6 +507,22 @@ def from_patches(self, x): return x.permute(0, 3, 1, 2) def forward(self, x, time=None, label=None, embedding=None) -> torch.Tensor: + """Run the U-Net forward pass. + + Args: + x (torch.Tensor): Input image tensor of shape ``(b, channels (+ conditional_dimensions), h, w)``. + time (torch.Tensor | None): Diffusion time steps of shape ``(b,)``. Defaults to a tensor of ones. + label (torch.Tensor | None): Class labels of shape ``(b,)``; required if the model was built with + ``conditional_label_size != 0``. + embedding (torch.Tensor | None): External conditioning embedding; required if the model was built + with ``conditional_embedding_size != 0``. + + Returns: + torch.Tensor: Output tensor of shape ``(b, out_dim, h, w)``. + + Raises: + AssertionError: If a required ``label`` or ``embedding`` is not provided. + """ if self.patch_size != 1: x = self.to_patches(x) diff --git a/spineps/architectures_new/unet3D.py b/spineps/architectures_new/unet3D.py index eb42bd8..970b6be 100644 --- a/spineps/architectures_new/unet3D.py +++ b/spineps/architectures_new/unet3D.py @@ -1,3 +1,5 @@ +"""3D U-Net architecture built from 3D residual convolutional blocks.""" + from __future__ import annotations import itertools @@ -8,28 +10,69 @@ class Block3D(nn.Module): + """3D convolution followed by group normalization and a LeakyReLU activation.""" + def __init__(self, dim: int, dim_out: int, groups: int = 8): + """Initialize the block. + + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + groups (int): Number of groups for group normalization. + """ super().__init__() self.proj = nn.Conv3d(dim, dim_out, 3, padding=1) self.norm = nn.GroupNorm(groups, dim_out) self.act = nn.LeakyReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply convolution, normalization and activation. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, dim, d, h, w)``. + + Returns: + torch.Tensor: Output tensor of shape ``(b, dim_out, d, h, w)``. + """ return self.act(self.norm(self.proj(x))) class ResnetBlock3D(nn.Module): + """Residual block stacking two :class:`Block3D` modules with a skip connection.""" + def __init__(self, dim: int, dim_out: int, *, groups: int = 8): + """Initialize the residual block. + + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + groups (int): Number of groups for the group normalization in each block. + """ super().__init__() self.block1 = Block3D(dim, dim_out, groups=groups) self.block2 = Block3D(dim_out, dim_out, groups=groups) self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the two blocks and add the (optionally projected) input as a residual. + + Args: + x (torch.Tensor): Input tensor of shape ``(b, dim, d, h, w)``. + + Returns: + torch.Tensor: Output tensor of shape ``(b, dim_out, d, h, w)``. + """ return self.block2(self.block1(x)) + self.res_conv(x) class Unet3D(nn.Module): + """3D U-Net with residual blocks for volumetric segmentation. + + The volume is down-sampled through a configurable number of resolution stages, processed by a bottleneck + of two residual blocks and up-sampled again. Skip connections from the encoder are averaged with the + decoder features at each stage, and the initial features are concatenated before the final convolution. + """ + def __init__( self, dim: int, @@ -40,6 +83,17 @@ def __init__( conditional_dimensions: int = 0, resnet_block_groups: int = 8, ): + """Build the 3D U-Net layers. + + Args: + dim (int): Base feature dimension used to derive the channel widths. + init_dim (int | None): Channels produced by the initial convolution. Defaults to ``dim``. + out_dim (int | None): Number of output channels. Defaults to ``channels``. + dim_mults (tuple): Channel multipliers applied to ``dim`` for the successive resolution stages. + channels (int): Number of input image channels. + conditional_dimensions (int): Number of additional conditioning channels concatenated to the input. + resnet_block_groups (int): Number of groups for the group normalization inside the residual blocks. + """ super().__init__() self.channels = channels @@ -87,6 +141,18 @@ def __init__( self.final_conv = nn.Conv3d(dim, self.out_dim, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run the U-Net forward pass over a volume. + + Args: + x (torch.Tensor): Input volume of shape ``(b, channels (+ conditional_dimensions), d, h, w)``. Each + spatial dimension must be divisible by ``2 ** (num_down_stages - 1)``. + + Returns: + torch.Tensor: Output tensor of shape ``(b, out_dim, d, h, w)``. + + Raises: + AssertionError: If any spatial dimension is not divisible by the required down-sampling factor. + """ down_factor = 2 ** (len(self.downs) - 1) for i in (-1, -2, -3): assert x.shape[i] % down_factor == 0, f"Spatial dim {x.shape[i]} not divisible by {down_factor}, input shape={x.shape}" diff --git a/spineps/entrypoint.py b/spineps/entrypoint.py index dd3ecae..0349962 100755 --- a/spineps/entrypoint.py +++ b/spineps/entrypoint.py @@ -1,3 +1,5 @@ +"""Command-line interface for SPINEPS, wiring CLI arguments to single-image and whole-dataset processing.""" + from __future__ import annotations import argparse @@ -26,6 +28,17 @@ # TODO replace with Class_to_ArgParse and then load only from config files! def parser_arguments(parser: argparse.ArgumentParser): + """Add the shared SPINEPS processing options to an argument parser. + + Registers flags common to both the ``sample`` and ``dataset`` subcommands, such as derivatives naming, + override toggles, debug/saving options, cropping, n4 bias correction and device selection. + + Args: + parser (argparse.ArgumentParser): The parser (or subparser) to add the arguments to. + + Returns: + argparse.ArgumentParser: The same parser with the shared arguments registered. + """ parser.add_argument("-der_name", "-dn", type=str, default="derivatives_seg", metavar="", help="Name of the derivatives folder") parser.add_argument("-save_debug", "-sd", action="store_true", help="Saves a lot of debug data and intermediate results") # parser.add_argument("-save_unc_img", "-sui", action="store_true", help="Saves a uncertainty image from the subreg prediction") @@ -83,6 +96,14 @@ def parser_arguments(parser: argparse.ArgumentParser): @citation_reminder def entry_point(): + """Parse command-line arguments and dispatch to the ``sample`` or ``dataset`` workflow. + + Builds the top-level parser with the ``sample`` and ``dataset`` subcommands, parses ``sys.argv`` and + calls :func:`run_sample` or :func:`run_dataset` accordingly. + + Raises: + NotImplementedError: If an unrecognized subcommand is supplied. + """ ########################### ########################### main_parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -201,6 +222,22 @@ def entry_point(): @citation_reminder def run_sample(opt: Namespace): + """Run the full segmentation pipeline on a single input NIfTI file. + + Loads the requested semantic, instance and (optional) labeling models, wraps the input as a + ``BIDS_FILE`` and calls :func:`process_img_nii`, optionally under a cProfiler. + + Args: + opt (Namespace): Parsed CLI arguments from the ``sample`` subcommand (input path, model ids/paths, + override and saving flags, device and verbosity options). + + Returns: + int: ``1`` on completion. + + Raises: + AssertionError: If the input path's parent directory is missing, only a filename was given, or the + input file does not exist. + """ input_path = Path(opt.input).absolute() dataset = str(input_path.parent) assert os.path.exists(dataset), f"-input parent does not exist, got {dataset}" # noqa: PTH110 @@ -275,6 +312,21 @@ def run_sample(opt: Namespace): @citation_reminder def run_dataset(opt: Namespace): + """Run the segmentation pipeline over a whole (preferably BIDS) dataset directory. + + Resolves the semantic, instance and (optional) labeling models (``"auto"`` defers model selection to the + pipeline), then calls :func:`process_dataset`, optionally under a cProfiler. + + Args: + opt (Namespace): Parsed CLI arguments from the ``dataset`` subcommand (dataset directory, rawdata and + derivatives names, model ids/paths, override/compatibility/saving flags, device and verbosity options). + + Returns: + int: ``1`` on completion. + + Raises: + AssertionError: If the directory does not exist, is not a directory, or no instance model is resolved. + """ input_dir = Path(opt.directory) assert input_dir.exists(), f"-input does not exist, {input_dir}" assert input_dir.is_dir(), f"-input is not a directory, got {input_dir}" diff --git a/spineps/example/get_gpu.py b/spineps/example/get_gpu.py index f2462d4..cfe1a7a 100644 --- a/spineps/example/get_gpu.py +++ b/spineps/example/get_gpu.py @@ -1,4 +1,6 @@ -from __future__ import annotations # noqa: INP001 +"""Helpers for selecting idle GPUs and thread-aware logging when running SPINEPS in parallel.""" # noqa: INP001 + +from __future__ import annotations import time @@ -9,6 +11,16 @@ def get_gpu(verbose: bool = False, max_load: float = 0.3, max_memory: float = 0.4): + """Return the IDs of currently available GPUs below the given load and memory thresholds. + + Args: + verbose (bool): If ``True``, print the current GPU utilization before querying. + max_load (float): Maximum allowed compute load (0-1) for a GPU to count as available. + max_memory (float): Maximum allowed memory usage (0-1) for a GPU to count as available. + + Returns: + list[int]: Up to four available GPU IDs ordered by load. + """ GPUtil.showUtilization() if verbose else None device_ids = GPUtil.getAvailable( order="load", @@ -23,10 +35,33 @@ def get_gpu(verbose: bool = False, max_load: float = 0.3, max_memory: float = 0. def intersection(lst1, lst2): + """Return the set intersection of two iterables. + + Args: + lst1: First iterable. + lst2: Second iterable. + + Returns: + set: Elements present in both ``lst1`` and ``lst2``. + """ return set(lst1).intersection(lst2) def get_free_gpus(blocked_gpus=None, max_load: float = 0.3, max_memory: float = 0.4): + """Poll the GPUs repeatedly and return those consistently free and not explicitly blocked. + + Availability is sampled 15 times (a short sleep between samples) and intersected so that only GPUs that stay + idle across all samples are returned. + + Args: + blocked_gpus (dict[int, bool] | None): Mapping of GPU ID to a blocked flag; a GPU is excluded when its flag + is not ``False``. Defaults to ``{0: False, 1: False, 2: False, 3: False}``. + max_load (float): Maximum allowed compute load (0-1) for the initial availability query. + max_memory (float): Maximum allowed memory usage (0-1) for the initial availability query. + + Returns: + list[int]: IDs of GPUs that are consistently available and not blocked. + """ # print("get_free_gpus") if blocked_gpus is None: blocked_gpus = {0: False, 1: False, 2: False, 3: False} @@ -41,4 +76,10 @@ def get_free_gpus(blocked_gpus=None, max_load: float = 0.3, max_memory: float = def thread_print(fold, *text): + """Print a message prefixed with the fold identifier of the calling thread. + + Args: + fold: Identifier of the fold/thread used as the message prefix. + *text: Values to print after the prefix. + """ logger.print(f"Fold [{fold}]: ", *text) diff --git a/spineps/example/helper_parallel.py b/spineps/example/helper_parallel.py index bc1fe16..064df95 100755 --- a/spineps/example/helper_parallel.py +++ b/spineps/example/helper_parallel.py @@ -1,4 +1,6 @@ -from __future__ import annotations # noqa: INP001 +"""CLI entry point running the SPINEPS pipeline on a single image, intended to be launched in parallel.""" # noqa: INP001 + +from __future__ import annotations import sys from pathlib import Path diff --git a/spineps/example/template_roll_out.py b/spineps/example/template_roll_out.py index 60c4f08..4d1be94 100755 --- a/spineps/example/template_roll_out.py +++ b/spineps/example/template_roll_out.py @@ -1,4 +1,6 @@ -from __future__ import annotations # noqa: INP001 +"""Template script for batch-running the SPINEPS pipeline over a BIDS dataset (edit the TODO markers).""" # noqa: INP001 + +from __future__ import annotations import sys from pathlib import Path @@ -38,6 +40,17 @@ def injection_function(seg_nii: NII): + """Post-process the semantic segmentation mask before instance segmentation (placeholder hook). + + Passed as ``lambda_semantic`` to the pipeline; customize it to modify the semantic mask. By default it is + a no-op returning the mask unchanged. + + Args: + seg_nii (NII): Semantic segmentation mask produced by the pipeline. + + Returns: + NII: The (optionally modified) semantic segmentation mask. + """ # TODO do something with semantic mask return seg_nii diff --git a/spineps/get_models.py b/spineps/get_models.py index b5d908e..8940982 100755 --- a/spineps/get_models.py +++ b/spineps/get_models.py @@ -1,3 +1,5 @@ +"""Discovery, lookup and instantiation of SPINEPS segmentation and labeling models from disk or remote URLs.""" + from __future__ import annotations import os @@ -17,92 +19,86 @@ logger = No_Logger() logger.prefix = "Models" +# Shown when no model of a given kind could be found in the configured models directory. +_NO_MODELS_AVAILABLE_MSG = ( + "Found no available {kind} models. Did you set one up by downloading the model weights and " + "putting them into the folder specified by the env variable, or did you want to specify an " + "absolute path instead?" +) -def get_semantic_model(model_name: str, **kwargs) -> Segmentation_Model: - """Finds and returns a semantic model by name - Args: - model_name (str): _description_ +def _get_model_by_name( + model_name: str, + modelid2folder: dict[str, Path | str], + phase: SpinepsPhase, + kind: str, + **kwargs, +) -> Segmentation_Model | VertLabelingClassifier: + """Looks up a model by name in a model-id-to-folder map and instantiates it. - Returns: - Segmentation_Model: _description_ + Shared implementation behind get_semantic_model / get_instance_model / get_labeling_model. """ model_name = model_name.lower() - _modelid2folder_subreg = modelid2folder_semantic() - possible_keys = list(_modelid2folder_subreg.keys()) - + possible_keys = list(modelid2folder.keys()) if len(possible_keys) == 0: - logger.print( - "Found no available semantic models. Did you set one up by downloading model weights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?", - Log_Type.FAIL, - ) + logger.print(_NO_MODELS_AVAILABLE_MSG.format(kind=kind), Log_Type.FAIL) raise KeyError(model_name) if model_name not in possible_keys: logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL) raise KeyError(model_name) - config_path = _modelid2folder_subreg[model_name] + config_path = modelid2folder[model_name] if str(config_path).startswith("http"): # Resolve HTTP - config_path = download_if_missing(model_name, config_path, phase=SpinepsPhase.SEMANTIC) + config_path = download_if_missing(model_name, config_path, phase=phase) return get_actual_model(config_path, **kwargs) -def get_instance_model(model_name: str, **kwargs) -> Segmentation_Model: - """Finds and returns an instance model by name +def get_semantic_model(model_name: str, **kwargs) -> Segmentation_Model: + """Finds and returns a semantic (subregion) model by name. Args: - model_name (str): _description_ + model_name (str): Id of the semantic model to load (case-insensitive). + **kwargs: Extra keyword arguments forwarded to the model constructor. Returns: - Segmentation_Model: _description_ + Segmentation_Model: The instantiated semantic model. + + Raises: + KeyError: If no model with the given name is available. """ - model_name = model_name.lower() - _modelid2folder_vert = modelid2folder_instance() - possible_keys = list(_modelid2folder_vert.keys()) - if len(possible_keys) == 0: - logger.print( - "Found no available instance models. Did you set one up by downloading modelweights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?", - Log_Type.FAIL, - ) - raise KeyError(model_name) - if model_name not in possible_keys: - logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL) - raise KeyError(model_name) - config_path = _modelid2folder_vert[model_name] - if str(config_path).startswith("http"): - # Resolve HTTP - config_path = download_if_missing(model_name, config_path, phase=SpinepsPhase.INSTANCE) + return _get_model_by_name(model_name, modelid2folder_semantic(), SpinepsPhase.SEMANTIC, "semantic", **kwargs) - return get_actual_model(config_path, **kwargs) + +def get_instance_model(model_name: str, **kwargs) -> Segmentation_Model: + """Finds and returns an instance (vertebra) model by name. + + Args: + model_name (str): Id of the instance model to load (case-insensitive). + **kwargs: Extra keyword arguments forwarded to the model constructor. + + Returns: + Segmentation_Model: The instantiated instance model. + + Raises: + KeyError: If no model with the given name is available. + """ + return _get_model_by_name(model_name, modelid2folder_instance(), SpinepsPhase.INSTANCE, "instance", **kwargs) def get_labeling_model(model_name: str, **kwargs) -> VertLabelingClassifier: - """Finds and returns an instance model by name + """Finds and returns a vertebra-labeling model by name. Args: - model_name (str): _description_ + model_name (str): Id of the labeling model to load (case-insensitive). + **kwargs: Extra keyword arguments forwarded to the model constructor. Returns: - Segmentation_Model: _description_ - """ - model_name = model_name.lower() - _modelid2folder_labeling = modelid2folder_labeling() - possible_keys = list(_modelid2folder_labeling.keys()) - if len(possible_keys) == 0: - logger.print( - "Found no available labeling models. Did you set one up by downloading model weights and putting them into the folder specified by the env variable or did you want to specify with an absolute path instead?", - Log_Type.FAIL, - ) - raise KeyError(model_name) - if model_name not in possible_keys: - logger.print(f"Model with name {model_name} does not exist, options are {possible_keys}", Log_Type.FAIL) - raise KeyError(model_name) - config_path = _modelid2folder_labeling[model_name] - if str(config_path).startswith("http"): - # Resolve HTTP - config_path = download_if_missing(model_name, config_path, phase=SpinepsPhase.LABELING) + VertLabelingClassifier: The instantiated labeling classifier. - return get_actual_model(config_path, **kwargs) + Raises: + KeyError: If no model with the given name is available. + """ + return _get_model_by_name(model_name, modelid2folder_labeling(), SpinepsPhase.LABELING, "labeling", **kwargs) _modelid2folder_semantic: Optional[dict[str, Union[Path, str]]] = None @@ -111,10 +107,12 @@ def get_labeling_model(model_name: str, **kwargs) -> VertLabelingClassifier: def modelid2folder_semantic() -> dict[str, Path | str]: - """Returns the dictionary mapping semantic model ids to their corresponding path + """Returns the dictionary mapping semantic model ids to their corresponding path. + + Uses the cached mapping if available, otherwise scans the configured models directory. Returns: - _type_: _description_ + dict[str, Path | str]: Mapping from semantic model id to its folder path or download URL. """ if _modelid2folder_semantic is not None: return _modelid2folder_semantic @@ -123,10 +121,12 @@ def modelid2folder_semantic() -> dict[str, Path | str]: def modelid2folder_instance() -> dict[str, Path | str]: - """Returns the dictionary mapping instance model ids to their corresponding path + """Returns the dictionary mapping instance model ids to their corresponding path. + + Uses the cached mapping if available, otherwise scans the configured models directory. Returns: - _type_: _description_ + dict[str, Path | str]: Mapping from instance model id to its folder path or download URL. """ if _modelid2folder_instance is not None: return _modelid2folder_instance @@ -135,10 +135,12 @@ def modelid2folder_instance() -> dict[str, Path | str]: def modelid2folder_labeling() -> dict[str, Path | str]: - """Returns the dictionary mapping labeling model ids to their corresponding path + """Returns the dictionary mapping labeling model ids to their corresponding path. + + Uses the cached mapping if available, otherwise scans the configured models directory. Returns: - _type_: _description_ + dict[str, Path | str]: Mapping from labeling model id to its folder path or download URL. """ if _modelid2folder_labeling is not None: return _modelid2folder_labeling @@ -149,14 +151,23 @@ def modelid2folder_labeling() -> dict[str, Path | str]: def check_available_models( models_folder: str | Path, verbose: bool = False ) -> tuple[dict[str, Path | str], dict[str, Path | str], dict[str, Path | str]]: - """Searches through the specified directories and finds models, sorting them into the dictionaries mapping to instance or semantic models + """Searches the given directory for models and sorts them into semantic, instance and labeling id-to-folder maps. + + Recursively finds all inference_config.json files, loads each config and assigns the model to the labeling map + (classifier), the instance map (segmentation input modality) or the semantic map (everything else). The results are + cached in module-level globals. Models whose config fails to load are skipped. Args: - models_folder (str | Path): The folder to be analyzed for models - verbose (bool, optional): _description_. Defaults to False. + models_folder (str | Path): The folder to be analyzed for models. + verbose (bool, optional): If true, logs models that were skipped because their config could not be loaded. + Defaults to False. Returns: - tuple[dict[str, Path], dict[str, Path]]: modelid2folder_semantic, modelid2folder_instance + tuple[dict[str, Path | str], dict[str, Path | str], dict[str, Path | str]]: The semantic, instance and labeling + id-to-folder maps. + + Raises: + AssertionError: If models_folder does not exist. """ logger.print("Check available models...") if isinstance(models_folder, str): @@ -186,17 +197,17 @@ def check_available_models( return _modelid2folder_semantic, _modelid2folder_instance, _modelid2folder_labeling -def modeltype2class(modeltype: ModelType): - """Maps ModelType to actual Segmentation_Model Subclass +def modeltype2class(modeltype: ModelType) -> type: + """Maps a ModelType to the corresponding model class. Args: - type (ModelType): _description_ + modeltype (ModelType): The model type from the inference config. Raises: - NotImplementedError: _description_ + NotImplementedError: If the model type is not supported. Returns: - _type_: _description_ + type: The class to instantiate (Segmentation_Model_NNunet, Segmentation_Model_Unet3D or VertLabelingClassifier). """ if modeltype == ModelType.nnunet: return Segmentation_Model_NNunet @@ -213,13 +224,22 @@ def get_actual_model( use_cpu: bool = False, **kwargs, ) -> Segmentation_Model | VertLabelingClassifier: - """Creates the Model class from given path + """Creates and returns the appropriate model from a given inference config path. + + Accepts either a path to an inference_config.json file or a folder containing exactly one such file (searched + recursively). Loads the config, picks the matching model class and instantiates it. Args: - in_config (str | Path): Path to the models inference config file + in_config (str | Path): Path to the model's inference config file, or to a folder containing it. + use_cpu (bool, optional): If true, runs inference on CPU instead of GPU. Defaults to False. + **kwargs: Extra keyword arguments forwarded to the model constructor. Returns: - Segmentation_Model: The returned model + Segmentation_Model | VertLabelingClassifier: The instantiated model. + + Raises: + FileNotFoundError: If no inference_config.json is found in the given folder. + AssertionError: If more than one inference_config.json is found in the given folder. """ # if isinstance(in_config, MODELS): # in_dir = filepath_model(in_config.value, model_dir=None) diff --git a/spineps/lab_model.py b/spineps/lab_model.py index 625c45d..8769b81 100755 --- a/spineps/lab_model.py +++ b/spineps/lab_model.py @@ -1,3 +1,5 @@ +"""Vertebra-labeling classifier: crops vertebra patches and predicts their anatomical labels.""" + from __future__ import annotations import math @@ -18,6 +20,9 @@ logger = No_Logger(prefix="VertLabelingClassifier") +# Default spatial size (voxels) of the cropped patch fed to the vertebra-labeling classifier. +DEFAULT_CLASSIFIER_INPUT_SIZE = (152, 168, 32) + def unit_vector(vector): """Returns the unit vector of the vector.""" @@ -47,19 +52,17 @@ def angle_between(v1, v2, signed=True): def rotate_patch_sagitally(patch: np.ndarray, angle: float, msk: bool = False, cval: int = 0) -> np.ndarray: - """ - Rotates a patch sagitally given an angle (Assuming the patch is in (I,P,L) orientation) - - Parameters: - ---------- - patch: np.ndarray - a numpy array with (I,P,L) orientation - angle: float - angle of rotation in degrees - msk: bool, optional - flag to determine interpolation type. Interpolation order os 0 if input is a mask. - Output: - np.ndarray: rotated patch + """Rotates a patch sagittally by a given angle (assuming the patch is in (I, P, L) orientation). + + Args: + patch (np.ndarray): A numpy array in (I, P, L) orientation. + angle (float): Angle of rotation in degrees. + msk (bool, optional): If true, treats the patch as a mask and uses nearest-neighbour interpolation (order 0); + otherwise uses cubic interpolation (order 3). Defaults to False. + cval (int, optional): Constant value used to fill regions outside the rotated patch. Defaults to 0. + + Returns: + np.ndarray: The rotated patch with the same shape as the input. """ if msk: cval = 0 @@ -71,6 +74,21 @@ def rotate_patch_sagitally(patch: np.ndarray, angle: float, msk: bool = False, c class VertLabelingClassifier(Segmentation_Model): + """Classifier that assigns anatomical labels to individual vertebrae. + + For each vertebra a patch is cropped around its center of mass, optionally rotated to align with the spine axis, + normalized and center-cropped to a fixed size, then passed through a DenseNet (PLClassifier) that outputs per-head + softmax predictions. Although it subclasses Segmentation_Model to reuse config loading, it does not perform voxel + segmentation (run/segment_scan are not implemented). + + Attributes: + device (torch.device): Device the classifier runs on. + final_size (tuple[int, int, int]): Spatial size (voxels) the cropped patch is reduced to before inference. + cutout_size (tuple[int, int, int]): Patch size used when cutting out a vertebra, set from the loaded model. + totensor (ToTensor): Transform converting numpy arrays to tensors. + transform (Compose): Intensity normalization and center-crop transform applied to each patch. + """ + def __init__( self, model_folder: str | Path, @@ -79,11 +97,24 @@ def __init__( default_verbose: bool = False, default_allow_tqdm: bool = True, ): + """Initializes the vertebra-labeling classifier and its preprocessing transforms. + + Args: + model_folder (str | Path): Path to the classifier's model folder. + inference_config (Segmentation_Inference_Config | None, optional): Inference config; if None, loads it from the + model folder. Defaults to None. + use_cpu (bool, optional): If true, runs inference on CPU instead of GPU. Defaults to False. + default_verbose (bool, optional): If true, prints more information when used. Defaults to False. + default_allow_tqdm (bool, optional): If true, shows a progress bar while predicting. Defaults to True. + + Raises: + AssertionError: If the inference config expects more than one input. + """ super().__init__(model_folder, inference_config, use_cpu, default_verbose, default_allow_tqdm) assert len(self.inference_config.expected_inputs) == 1, "Unet3D cannot expect more than one input" # self.model: PLClassifier = model self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - self.final_size: tuple[int, int, int] = (152, 168, 32) + self.final_size: tuple[int, int, int] = DEFAULT_CLASSIFIER_INPUT_SIZE self.totensor = ToTensor() self.transform = Compose( [ @@ -93,6 +124,17 @@ def __init__( ) def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 + """Loads the classifier checkpoint and updates the preprocessing transform to the model's input size. + + Args: + folds (tuple[str, ...] | None, optional): Unused; present for interface compatibility. Defaults to None. + + Returns: + Self: This classifier with its predictor loaded and moved to the selected device. + + Raises: + AssertionError: If no matching checkpoint file is found in the model folder. + """ assert os.path.exists(self.model_folder) # noqa: PTH110 chktpath = search_path(self.model_folder, "**/*val_f1=*valf1-weights.ckpt") @@ -120,18 +162,53 @@ def run( input_nii: list[NII], verbose: bool = False, ) -> dict[OutputType, NII | None]: + """Not implemented: the classifier does not perform voxel segmentation. + + Args: + input_nii (list[NII]): Unused. + verbose (bool, optional): Unused. Defaults to False. + + Raises: + NotImplementedError: Always, since running it as a segmentation model is not meaningful. + """ raise NotImplementedError("Doesnt make sense") def segment_scan(*args, **kwargs): + """Not implemented: the classifier does not perform voxel segmentation. + + Raises: + NotImplementedError: Always, since segmenting with this model is not meaningful. + """ raise NotImplementedError("Doesnt make sense") @classmethod def from_modelfolder(cls, model_folder: str | Path): + """Not implemented: construction directly from a model folder. + + Args: + model_folder (str | Path): Path to the model folder. + + Raises: + NotImplementedError: Always; use from_checkpoint_path instead. + """ raise NotImplementedError() # find checkpoint yourself, then load from checkpoitn path @classmethod - def from_checkpoint_path(cls, checkpoint_path: str | Path): + def from_checkpoint_path(cls, checkpoint_path: str | Path) -> VertLabelingClassifier: + """Constructs a classifier from a checkpoint file path. + + Resolves the model folder as the grandparent of the checkpoint file and instantiates the classifier from it. + + Args: + checkpoint_path (str | Path): Path to the checkpoint (.ckpt) file. + + Returns: + VertLabelingClassifier: The constructed classifier. + + Raises: + AssertionError: If the checkpoint path does not exist. + """ if isinstance(checkpoint_path, str): checkpoint_path = Path(checkpoint_path) assert checkpoint_path.exists(), f"Checkpoint path does not exist: {checkpoint_path}" @@ -142,7 +219,17 @@ def from_checkpoint_path(cls, checkpoint_path: str | Path): logger.print("Model loaded from", checkpoint_path, verbose=True) return d - def run_all_position_instances(self, img: NII, com_list: list[tuple[int, int, int]]): + def run_all_position_instances(self, img: NII, com_list: list[tuple[int, int, int]]) -> dict[int, dict[str, np.ndarray]]: + """Runs the classifier on patches cropped around a list of center-of-mass positions. + + Args: + img (NII): The intensity image (reoriented in place to the default orientation). + com_list (list[tuple[int, int, int]]): Center-of-mass voxel positions, ordered top-to-bottom, one per vertebra. + + Returns: + dict[int, dict[str, np.ndarray]]: Mapping from list index to a dict with "soft" (softmax outputs) and + "pred" (argmax class) per classifier head. + """ img.reorient_() # assert coms are in PIR? # assert coms are in order top-to-bottom @@ -153,6 +240,19 @@ def run_all_position_instances(self, img: NII, com_list: list[tuple[int, int, in return predictions def run_all_seg_instances(self, img: NII, seg: NII) -> dict[int, dict[str, np.ndarray]]: + """Runs the classifier on every vertebra instance present in a segmentation mask. + + For each label in the mask, computes the patch rotation angle from the neighbouring vertebra centers of mass (to + align with the spine axis) and runs the classifier on the corresponding patch. + + Args: + img (NII): The intensity image. + seg (NII): The vertebra instance segmentation mask. + + Returns: + dict[int, dict[str, np.ndarray]]: Mapping from vertebra label to a dict with "soft" (softmax outputs) and + "pred" (argmax class) per classifier head. + """ img = img.reorient() seg = seg.reorient() # TODO assert order of seg labels are order from top to bottom @@ -175,7 +275,23 @@ def run_all_seg_instances(self, img: NII, seg: NII) -> dict[int, dict[str, np.nd predictions[v] = {"soft": logits_soft, "pred": pred_cls} return predictions - def run_given_seg_pos(self, img: NII, seg: NII, vert_label: int | None = None, angle: float | None = None): + def run_given_seg_pos( + self, img: NII, seg: NII, vert_label: int | None = None, angle: float | None = None + ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """Runs the classifier on the patch centered on a single vertebra defined by a segmentation. + + Selects the given vertebra label (or binarizes the mask if multiple labels are present), computes the center of its + bounding box and runs the classifier there. + + Args: + img (NII): The intensity image. + seg (NII): The segmentation mask defining the vertebra location. + vert_label (int | None, optional): Label of the vertebra to use; if None, the whole mask is used. Defaults to None. + angle (float | None, optional): Rotation angle (degrees) to align the patch with the spine axis. Defaults to None. + + Returns: + tuple[dict, dict]: The softmax outputs and argmax class predictions per classifier head. + """ if vert_label is not None: seg = seg.extract_label(vert_label) elif len(seg.unique()) > 1: @@ -188,7 +304,24 @@ def run_given_seg_pos(self, img: NII, seg: NII, vert_label: int | None = None, a center_of_crop.append(crop[i].start + (size_t // 2)) return self.run_given_center_pos(img, seg, center_of_crop, angle=angle) # type: ignore - def run_given_center_pos(self, img: NII, seg: NII, center_pos: tuple[int, int, int], angle: float | None = None): + def run_given_center_pos( + self, img: NII, seg: NII, center_pos: tuple[int, int, int], angle: float | None = None + ) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """Crops image and segmentation patches around a center point, optionally rotates them, and runs the classifier. + + Cuts out a patch larger than the final size (with extra padding for rotation), reorients to (I, P, L), optionally + rotates sagittally by the given angle, crops back to the cutout size and runs the classifier on the patch. + + Args: + img (NII): The intensity image (or a raw array). + seg (NII): The segmentation mask used as the second channel. + center_pos (tuple[int, int, int]): Voxel position to center the patch on. + angle (float | None, optional): Rotation angle (degrees) to align the patch with the spine axis; no rotation if + None or 0. Defaults to None. + + Returns: + tuple[dict, dict]: The softmax outputs and argmax class predictions per classifier head. + """ extra_rotation_padding = 64 extra_rotation_padding_halfed = extra_rotation_padding // 2 # @@ -230,11 +363,28 @@ def run_given_center_pos(self, img: NII, seg: NII, center_pos: tuple[int, int, i return self._run_array(img_v.get_array(), seg_v.get_seg_array()) # sem_cut def _run_nii(self, img_nii: NII): + """Runs the classifier on the raw array of an NII patch. + + Args: + img_nii (NII): The patch image to classify. + + Returns: + tuple[dict, dict]: The softmax outputs and argmax class predictions per classifier head. + """ # TODO check resolution # TODO check size return self._run_array(img_nii.get_array()) def run_all_arrays(self, img_arrays: dict[int, np.ndarray]) -> dict[int, dict[str, np.ndarray]]: + """Runs the classifier on a set of pre-cut image patches. + + Args: + img_arrays (dict[int, np.ndarray]): Mapping from vertebra id to its 3D image patch. + + Returns: + dict[int, dict[str, np.ndarray]]: Mapping from vertebra id to a dict with "soft" (softmax outputs) and + "pred" (argmax class) per classifier head. + """ # TODO assert order of seg labels are order from top to bottom predictions = {} for v, arr in img_arrays.items(): @@ -243,6 +393,22 @@ def run_all_arrays(self, img_arrays: dict[int, np.ndarray]) -> dict[int, dict[st return predictions def _run_array(self, img_arr: np.ndarray, seg_arr: np.ndarray | None | torch.Tensor = None): # , seg_arr: np.ndarray): + """Applies preprocessing and runs the classifier forward pass on a single image patch. + + Converts the patch (and optional segmentation) to tensors, applies intensity normalization and center cropping, + adds the channel/batch dimensions and runs the network, returning per-head softmax probabilities and argmax classes. + + Args: + img_arr (np.ndarray): The 3D image patch. + seg_arr (np.ndarray | None | torch.Tensor, optional): Optional segmentation patch; if None, a copy of the image + is used. Defaults to None. + + Returns: + tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: Per-head softmax probabilities and per-head argmax classes. + + Raises: + AssertionError: If img_arr is not 3-dimensional. + """ assert img_arr.ndim == 3, f"Dimension mismatch, {img_arr.shape}, expected 3 dimensions" # img_arr = self.totensor(img_arr) diff --git a/spineps/phase_instance.py b/spineps/phase_instance.py index d3f82d6..9d3b4c9 100755 --- a/spineps/phase_instance.py +++ b/spineps/phase_instance.py @@ -1,3 +1,5 @@ +"""Instance phase: turn the subregion semantic mask into per-vertebra instance labels via cutout prediction and merging.""" + from __future__ import annotations # from utils.predictor import nnUNetPredictor @@ -20,8 +22,29 @@ from spineps.seg_enums import ErrCode, OutputType from spineps.seg_model import Segmentation_Model -from spineps.seg_pipeline import logger +from spineps.seg_pipeline import IVD_LABEL_OFFSET, logger from spineps.utils.proc_functions import clean_cc_artifacts +from spineps.utils.resolution import ( + INFERIOR_AXIS_PIR, + REFERENCE_VOXEL_VOLUME_MM3, + REFERENCE_ZOOM, + mm3_to_voxels, + mm_to_voxels_axis, +) + +# --- Instance-segmentation geometry and merged-corpus heuristics --- +# Physical margin kept around the segmentation when cropping before instance prediction. +INSTANCE_CROP_MARGIN_MM = 5 * min(REFERENCE_ZOOM) +# Lower bound (physical volume) for the resolution-scaled corpus/vertebra cleaning thresholds. +MIN_CLEANING_VOLUME_MM3 = 40 * REFERENCE_VOXEL_VOLUME_MM3 +# Two neighboring structures at nearly the same height (within this physical distance) are merged. +SAME_HEIGHT_MERGE_THRESHOLD_MM = 10 * REFERENCE_ZOOM[INFERIOR_AXIS_PIR] +# Relative index window of neighbors inspected when fixing a merged corpus (excludes self). +NEIGHBOR_OFFSETS = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] +# A merged corpus is only split when there are more than this many neighbors to estimate volume from. +MIN_NEIGHBORS_FOR_VOLUME_CHECK = 2 +# A corpus is considered merged when its volume exceeds the neighbor average by this ratio. +MERGED_CORPUS_VOLUME_RATIO = 1.5 def predict_instance_mask( @@ -36,14 +59,28 @@ def predict_instance_mask( proc_inst_largest_k_cc: int = 0, verbose: bool = False, ) -> tuple[NII | None, ErrCode]: - """Based on subregion segmentation, feeds individual arcus coms to a network to get the vertebra body segmentations + """Build a per-vertebra instance mask from a subregion semantic segmentation. + + Reorients and rescales the input to the model's recommended zoom, crops to the segmentation, derives one + cutout per corpus center of mass, runs the instance model on each cutout, merges the overlapping per-vertebra + predictions into a single label map, optionally cleans and fills it, then uncrops back to the original space. Args: - seg_nii (NII): _description_ - cutout_size (tuple[int, int, int], optional): _description_. Defaults to (128, 88, 32). + seg_nii (NII): Subregion (semantic) segmentation mask used as input. + model (Segmentation_Model): Instance model producing the per-vertebra-body cutout predictions. + debug_data (dict): Dictionary for collecting intermediate results across the pipeline. + pad_size (int, optional): Edge padding added before processing and removed afterwards. Defaults to 0. + proc_inst_fill_3d_holes (bool, optional): Whether to fill 3D holes in the final vertebra mask. Defaults to True. + proc_detect_and_solve_merged_corpi (bool, optional): Whether to detect and split merged vertebral bodies. Defaults to True. + proc_corpus_clean (bool, optional): Whether to clean small corpus connected-component artifacts. Defaults to True. + proc_inst_clean_small_cc_artifacts (bool, optional): Whether to delete small instance artifacts. Defaults to True. + proc_inst_largest_k_cc (int, optional): Keep only the largest k connected components per cutout label; 0 disables. Defaults to 0. + verbose (bool, optional): Emit additional progress logging. Defaults to False. Returns: - tuple[NII | None, ErrCode]: whole_vert_nii, errcode + tuple[NII | None, ErrCode]: The vertebra instance mask in the input space and an error code. Returns + ``(None, ErrCode.EMPTY)`` if no corpus labels are present, ``(None, ErrCode.UNKNOWN)`` if no predictions + are produced, or ``(None, errcode)`` if merging fails. """ logger.print("Predict instance mask", Log_Type.STAGE) with logger: @@ -81,7 +118,7 @@ def predict_instance_mask( debug_data["inst_uncropped_Subreg_nii_b_zms"] = seg_nii_uncropped.copy() uncropped_vert_mask = np.zeros(seg_nii_uncropped.shape, dtype=seg_nii_uncropped.dtype) logger.print("Vertebra uncropped_vert_mask empty", uncropped_vert_mask.shape, verbose=verbose) - crop = seg_nii_rdy.compute_crop(dist=5) + crop = seg_nii_rdy.compute_crop(dist=INSTANCE_CROP_MARGIN_MM / min(seg_nii_rdy.zoom)) # logger.print("Crop", crop, verbose=verbose) seg_nii_rdy.apply_crop_(crop) logger.print(f"Crop down from {uncropped_vert_mask.shape} to {seg_nii_rdy.shape}", verbose=verbose) @@ -91,11 +128,12 @@ def predict_instance_mask( # # make threshold in actual mm corpus_border_threshold = int(corpus_border_threshold / expected_zms[1]) - corpus_size_cleaning = max(int(corpus_size_cleaning / (expected_zms[0] * expected_zms[1] * expected_zms[2])), 40) - vert_size_threshold = max(int(vert_size_threshold / (expected_zms[0] * expected_zms[1] * expected_zms[2])), 40) + min_cleaning_voxels = mm3_to_voxels(MIN_CLEANING_VOLUME_MM3, expected_zms) + corpus_size_cleaning = max(int(corpus_size_cleaning / (expected_zms[0] * expected_zms[1] * expected_zms[2])), min_cleaning_voxels) + vert_size_threshold = max(int(vert_size_threshold / (expected_zms[0] * expected_zms[1] * expected_zms[2])), min_cleaning_voxels) seg_labels = seg_nii_rdy.unique() - if 49 not in seg_labels: + if Location.Vertebra_Corpus_border.value not in seg_labels: logger.print(f"no corpus ({Location.Vertebra_Corpus_border.value}) labels in this segmentation, cannot proceed", Log_Type.FAIL) return None, ErrCode.EMPTY # get all the 3vert predictions @@ -173,6 +211,26 @@ def get_corpus_coms( process_detect_and_solve_merged_corpi: bool = True, verbose: bool = False, ) -> list | None: + """Compute the center of mass of every vertebral corpus, optionally splitting merged bodies. + + Extracts the corpus region (using the dense corpus label when available, otherwise the eroded and cleaned + corpus-border label) and returns one center of mass per corpus, sorted from bottom to top (plus the dens + when present). When ``process_detect_and_solve_merged_corpi`` is set, it cross-checks the vertebra/IVD + height alternation: neighbors at the same height are merged, and a corpus whose volume exceeds the neighbor + average by ``MERGED_CORPUS_VOLUME_RATIO`` is split via a separating plane. + + Args: + seg_nii (NII): Subregion semantic mask in ("P", "I", "R") orientation. + corpus_size_cleaning (int): Voxel threshold for removing small corpus-border artifacts; 0 disables cleaning. + process_detect_and_solve_merged_corpi (bool, optional): Whether to detect and split merged corpora. Defaults to True. + verbose (bool, optional): Emit additional progress logging. Defaults to False. + + Returns: + list | None: Corpus center-of-mass coordinates ordered from bottom to top, or None if no corpus is found. + + Raises: + AssertionError: If ``seg_nii`` is not in ("P", "I", "R") orientation. + """ seg_nii.assert_affine(orientation=("P", "I", "R")) dense = False # Extract Corpus region and try to find all coms naively (some skips should not matter) @@ -225,7 +283,7 @@ def get_corpus_coms( seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) has_ivd: bool = Location.Vertebra_Disc.value in seg_sem.unique() subreg_cc: NII = seg_sem.get_connected_components(labels=Location.Vertebra_Disc.value) - subreg_cc[subreg_cc > 0] += 100 + subreg_cc[subreg_cc > 0] += IVD_LABEL_OFFSET # offset IVD CC labels to keep them distinct from corpus CC labels subreg_cc_n = len(subreg_cc.unique()) logger.print(f"Found {subreg_cc_n} IVD ccs (naively)", verbose=verbose) coms = subreg_cc.center_of_masses() @@ -262,7 +320,7 @@ def get_corpus_coms( verbose=verbose, ) # check if same heigh, then just merge ivd label - if abs(neighborheight - selfheight) < 10: + if abs(neighborheight - selfheight) < mm_to_voxels_axis(SAME_HEIGHT_MERGE_THRESHOLD_MM, seg_nii.zoom, INFERIOR_AXIS_PIR): logger.print("Same height, just merge") stats_by_height.pop(vl) stats_by_height = dict(sorted(stats_by_height.items(), key=lambda x: x[1][0])) @@ -272,7 +330,7 @@ def get_corpus_coms( logger.print("Merged corpi, try to fix it", verbose=verbose) neighbor_verts = { stats_by_height_keys[idx + i]: stats_by_height[stats_by_height_keys[idx + i]] - for i in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] + for i in NEIGHBOR_OFFSETS if (idx + i) < len(stats_by_height_keys) and (idx + i) >= 0 and stats_by_height_keys[idx + i] < 99 } @@ -280,15 +338,15 @@ def get_corpus_coms( if len(neighbor_verts) == 0: logger.print("Got no neighbor vert labels to fix", Log_Type.FAIL) continue - neighbor_volumes = [k[2] for nv, k in neighbor_verts.items()] + neighbor_volumes = [k[2] for k in neighbor_verts.values()] logger.print("neighbor_volumes", neighbor_volumes, verbose=verbose) n_neighbors_without_target = len(neighbor_volumes) - 1 - if n_neighbors_without_target > 2: + if n_neighbors_without_target > MIN_NEIGHBORS_FOR_VOLUME_CHECK: argmax = np.argmax(neighbor_volumes) - n_avg_volume = sum([neighbor_volumes[i] for i in range(len(neighbor_volumes)) if i != argmax]) / n_neighbors_without_target + n_avg_volume = sum(neighbor_volumes[i] for i in range(len(neighbor_volumes)) if i != argmax) / n_neighbors_without_target diff_volume = neighbor_volumes[argmax] / n_avg_volume - if diff_volume > 1.5: + if diff_volume > MERGED_CORPUS_VOLUME_RATIO: logger.print( f"Volume difference detected in label {vl}, diff = {diff_volume}, volume = {neighbor_volumes[argmax]}, neighbor_avg = {n_avg_volume}", Log_Type.STRANGE, @@ -317,55 +375,30 @@ def get_corpus_coms( return corpus_coms -def get_separating_components(segvert: np.ndarray, max_iter: int = 10, connectivity: int = 3): - """ - Attempts to split a binary volumetric segmentation into two spatially separate components (S and T) - by iterative erosion and connected component analysis. - - This function is designed for cases where an initial segmentation is a single connected component, - but the goal is to identify two meaningful subregions. It uses morphological erosion to find a - splitting point and then recovers the two regions through dilation. - - Parameters - ---------- - segvert : np.ndarray - A 3D binary (or labeled) numpy array representing the segmented volume to split. - max_iter : int, optional - Maximum number of erosion iterations allowed to find separable components. Default is 10. - connectivity : int, optional - Connectivity used for morphological operations (e.g., 1=6-connectivity, 2=18, 3=26). Default is 3. - - Returns - ------- - spart : np.ndarray - Binary mask of the first separated component (S). - tpart : np.ndarray - Binary mask of the second separated component (T). - spart_dil : np.ndarray - Dilated version of spart until contact with tpart. - tpart_dil : np.ndarray - Dilated version of tpart until contact with spart. - stpart : np.ndarray - Combined map of dilated S and T, with values: - - 0: background - - 1: spart_dil only - - 2: tpart_dil only - - 3: overlapping region between spart_dil and tpart_dil - - Raises - ------ - Exception - If the input volume cannot be split into two parts within the allowed number of iterations, - or if resulting parts are empty. - IndentationError - If the maximum number of iterations is reached without successful separation. - - Notes - ----- - - The function assumes that `np_erode_msk`, `np_dilate_msk`, `np_connected_components`, - `np_volume`, `np_filter_connected_components` are available in the environment. - - This method is particularly useful for anatomical structures that are initially connected - (e.g., left and right organs) but should be separated for downstream analysis. +def get_separating_components( + segvert: np.ndarray, max_iter: int = 10, connectivity: int = 3 +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Split a binary volume into two spatially separate components (S and T) via erosion and dilation. + + Designed for a segmentation that is a single connected component but should be separated into two + meaningful subregions. Morphological erosion is applied until the volume breaks into multiple connected + components (the splitting point), then the two regions are recovered through dilation. Useful for + anatomical structures that are initially connected (e.g., two merged vertebral bodies). + + Args: + segvert (np.ndarray): 3D binary (or labeled) array representing the volume to split. + max_iter (int, optional): Maximum number of erosion iterations allowed to find separable components. Defaults to 10. + connectivity (int, optional): Connectivity for morphological operations (1=6-, 2=18-, 3=26-connectivity). Defaults to 3. + + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: A 5-tuple ``(spart, tpart, spart_dil, + tpart_dil, stpart)`` where ``spart``/``tpart`` are the two separated binary components, ``spart_dil``/``tpart_dil`` + are their dilations grown until contact, and ``stpart`` is their combined map (0=background, 1=spart_dil only, + 2=tpart_dil only, 3=overlap of the two dilations). + + Raises: + Exception: If the volume cannot be split into two parts within the iterations, or if a resulting part is empty. + IndentationError: If the maximum number of erosion iterations is reached without successful separation. """ check_connectivity = 3 vol = segvert.copy() @@ -445,51 +478,29 @@ def get_plane_split( tpart: np.ndarray, spart_dil: np.ndarray, tpart_dil: np.ndarray, -): - """ - Computes an approximate separating plane between two regions (spart and tpart) - based on their dilated overlap and returns it as a NIfTI image. - - This function determines the collision region between the dilated versions - of two separated segmentation components. It then estimates a plane orthogonal - to the vector between their centers of mass and passing through the point - of contact. The resulting binary plane is filled in the axial direction - and returned in NIfTI format for visualization or further processing. - - Parameters - ---------- - segvert : np.ndarray - Original 3D binary or labeled segmentation volume. - compare_nii : NII - NIfTI image object used as a reference for orientation and spatial metadata. - spart : np.ndarray - Binary mask of the first component (S). - tpart : np.ndarray - Binary mask of the second component (T). - spart_dil : np.ndarray - Dilated mask of spart. - tpart_dil : np.ndarray - Dilated mask of tpart. - - Returns - ------- - plane_filled_nii : NII - A NIfTI image containing a filled binary plane separating spart and tpart, - reoriented to match the input NIfTI image. If no collision is detected, - returns an empty image. - - Notes - ----- - - The function uses the collision area between the dilated masks to find - a center of mass and constructs a plane orthogonal to the vector between - the COMs of spart and tpart. - - Filling the plane ensures better visualization and compatibility with downstream tasks. - - If the dilated masks do not overlap, an empty volume is returned and a warning is logged. - - TODO - ---- - - Improve accuracy by using the line connecting both COMs and projecting the collision - point onto this vector, rather than relying on the COM of the overlap region. +) -> NII: + """Compute an approximate separating plane between two regions and return it as an NII. + + Determines the collision region between the dilated versions of the two components, takes its center of mass + as a point on the plane, and builds a plane orthogonal to the vector between the centers of mass of ``spart`` + and ``tpart``. The plane is then filled along the superior axis and returned as an NII matching the input + orientation. If the dilated masks do not touch, an empty volume is returned and a warning is logged. + + Note: + TODO: Improve accuracy by projecting the collision point onto the line connecting both COMs rather than + relying on the COM of the overlap region. + + Args: + segvert (np.ndarray): Original 3D binary or labeled segmentation volume. + compare_nii (NII): Reference image providing orientation and spatial metadata for the output. + spart (np.ndarray): Binary mask of the first component (S). + tpart (np.ndarray): Binary mask of the second component (T). + spart_dil (np.ndarray): Dilated mask of ``spart``. + tpart_dil (np.ndarray): Dilated mask of ``tpart``. + + Returns: + NII: A filled binary plane separating ``spart`` and ``tpart``, in the input orientation; an empty image + if no collision between the dilated masks is detected. """ s_dilint = spart_dil.astype(np.uint8) @@ -545,7 +556,19 @@ def get_plane_split( def split_by_plane( segvert: np.ndarray, plane_filled_nii: NII, -): +) -> np.ndarray: + """Split a vertebra volume into two parts using a filled separating plane. + + Masks the filled plane (whose values are 1 on one side of the plane and 2 on the other) to the foreground + of ``segvert``, yielding a label map that partitions the vertebra into the two sides. + + Args: + segvert (np.ndarray): Binary vertebra volume to split. + plane_filled_nii (NII): Filled separating plane (1 above, 2 below) as produced by ``get_plane_split``. + + Returns: + np.ndarray: The vertebra voxels labeled 1 (above the plane) and 2 (below the plane); background stays 0. + """ plane_filled_arr = plane_filled_nii.get_array() # 1 above, 2 below plane_filled_arr[segvert == 0] = 0 @@ -559,13 +582,36 @@ def collect_vertebra_predictions( seg_nii: NII, model: Segmentation_Model, corpus_size_cleaning: int, - cutout_size, + cutout_size: tuple[int, int, int], debug_data: dict, proc_inst_largest_k_cc: int = 0, process_detect_and_solve_merged_corpi: bool = True, proc_inst_fill_holes: bool = False, verbose: bool = False, ) -> tuple[np.ndarray | None, list[str], int]: + """Run the instance model on a cutout around each corpus center of mass and collect per-label predictions. + + Computes corpus centers of mass, and for each one extracts a ``cutout_size`` window (nudged inferiorly until + it lands on segmentation), relabels it to the model's expected labels, runs the model, post-processes the + cutout, and stores each predicted label (1/2/3, the three-vertebra hierarchy) as a binary map placed back + into the full-volume frame. + + Args: + seg_nii (NII): Subregion semantic mask used for the cutouts. + model (Segmentation_Model): Instance model producing the three-vertebra-body cutout predictions. + corpus_size_cleaning (int): Voxel threshold for cleaning small corpus artifacts when finding coms; 0 disables. + cutout_size: Cutout window size (per axis) extracted around each center of mass. + debug_data (dict): Dictionary for collecting per-cutout intermediate results. + proc_inst_largest_k_cc (int, optional): Keep only the largest k connected components per cutout label; 0 disables. Defaults to 0. + process_detect_and_solve_merged_corpi (bool, optional): Whether to detect and split merged corpora. Defaults to True. + proc_inst_fill_holes (bool, optional): Whether to fill holes in each cutout prediction. Defaults to False. + verbose (bool, optional): Emit additional progress logging. Defaults to False. + + Returns: + tuple[np.ndarray | None, list[str], int]: A hierarchical prediction array of shape + ``(n_corpus_coms, 3, *seg_shape)``, a list of ``"comidx_label"`` identifiers for the predictions actually + produced, and the number of corpus centers of mass. Returns ``(None, [], 0)`` if no corpus is found. + """ corpus_coms = get_corpus_coms( seg_nii, corpus_size_cleaning=corpus_size_cleaning, @@ -588,7 +634,9 @@ def collect_vertebra_predictions( seg_nii.shape[2], ) hierarchical_existing_predictions = [] - hierarchical_predictions = np.zeros((n_corpus_coms, 3, *shp), dtype=seg_nii.dtype) + # Holds only binary {0, 1} per-label masks, so uint8 is sufficient (the source dtype can be wider, + # which would needlessly inflate this n_coms x 3 x volume array and slow the Dice comparisons below). + hierarchical_predictions = np.zeros((n_corpus_coms, 3, *shp), dtype=np.uint8) # print("hierarchical_predictions", hierarchical_predictions.shape) vert_predict_template = np.zeros(shp, dtype=np.uint16) # print("vert_predict_template", vert_predict_template.shape) @@ -601,9 +649,10 @@ def collect_vertebra_predictions( logger.print("Vertebra collect in", seg_nii.zoom, seg_nii.orientation, seg_nii.shape, verbose=verbose) + # seg_nii_for_cut is constant across the loop; read its array once instead of copying it per centroid. + seg_arr_c = seg_nii_for_cut.get_seg_array() # iterate over sorted coms and segment vertebra from subreg for com_idx, com in enumerate(tqdm(corpus_coms, desc=logger._get_logger_prefix() + " Vertebra Body predictions")): - seg_arr_c = seg_nii_for_cut.get_seg_array() # Shift the com until there is a segmentation there (to account for mishaps in the com calculation) seg_at_com = seg_arr_c[int(com[0])][int(com[1])][int(com[2])] != 0 orig_com = (com[0], com[1], com[2]) @@ -673,7 +722,18 @@ def post_process_single_3vert_prediction( labels: list[int] | None = None, largest_cc: int = 0, fill_holes: bool = False, -): +) -> NII: + """Post-process a single three-vertebra cutout prediction by filtering components and filling holes. + + Args: + vert_nii (NII): The cutout prediction mask to clean. + labels (list[int] | None, optional): Labels to restrict the connected-component filter to. Defaults to None (all labels). + largest_cc (int, optional): Keep only the largest ``largest_cc`` connected components per label; 0 disables. Defaults to 0. + fill_holes (bool, optional): Whether to fill holes in the prediction. Defaults to False. + + Returns: + NII: The post-processed cutout prediction mask. + """ if largest_cc > 0: # 5 seems like a good number (if three, then at least center must be fully visible) vert_nii = vert_nii.filter_connected_components(max_count_component=largest_cc, labels=labels, keep_label=True) if fill_holes: @@ -682,7 +742,16 @@ def post_process_single_3vert_prediction( return vert_nii -def str_id_com_label(com_idx: int, label: int): +def str_id_com_label(com_idx: int, label: int) -> str: + """Build the string identifier for a single (corpus-com, label) prediction. + + Args: + com_idx (int): Index of the corpus center of mass. + label (int): Label index within that center's three-vertebra prediction. + + Returns: + str: The identifier ``"{com_idx}_{label}"``. + """ return str(com_idx) + "_" + str(label) @@ -695,6 +764,24 @@ def from_vert3_predictions_make_vert_mask( proc_inst_clean_small_cc_artifacts: bool = True, verbose: bool = False, ) -> tuple[NII, dict, ErrCode]: + """Merge the hierarchical three-vertebra predictions into a single vertebra instance mask. + + Each per-label prediction looks among neighboring predictions (center index -2 to +2, all three labels) for + its most-agreeing partners (by Dice), forming prediction couples. The couples are then merged into one + instance label map, optionally cleaning small connected-component artifacts. + + Args: + seg_nii (NII): Reference segmentation providing shape and spatial metadata. + vert_predictions (np.ndarray): Hierarchical predictions of shape ``(com_idx, label, *shape)``. + hierarchical_existing_predictions (list[str]): Identifiers of the predictions that were actually produced. + vert_size_threshold (int): Voxel threshold for removing small instance artifacts. + debug_data (dict): Dictionary for collecting intermediate results. + proc_inst_clean_small_cc_artifacts (bool, optional): Whether to delete small instance artifacts. Defaults to True. + verbose (bool, optional): Emit additional progress logging. Defaults to False. + + Returns: + tuple[NII, dict, ErrCode]: The merged vertebra instance mask, the (updated) debug data, and an error code. + """ # instance approach: each 1/2/3 pred finds it most agreeing partner in surrounding predictions (com idx -2 to +2 all three pred) # Then sort by agreement, and segment each (this would be able to add more vertebra than input coms if one is skipped) # each one that has been used for fixing a segmentation cannot be used again (so object loose their partners if too weird) @@ -722,17 +809,32 @@ def create_prediction_couples( hierarchical_predictions: np.ndarray, hierarchical_existing_predictions, verbose: bool = False, -): +) -> dict: + """Form and rank prediction couples across all hierarchical predictions. + + For every (center index, label) prediction, finds its best-agreeing partners and groups them into a couple, + averaging the agreement scores of duplicate couples. The result is sorted so that larger, higher-agreement + couples come first (key = ``(len(couple) + 1) * mean_agreement``). + + Args: + hierarchical_predictions (np.ndarray): Hierarchical predictions of shape ``(com_idx, label, *shape)``. + hierarchical_existing_predictions (list[str]): Identifiers of the predictions that were actually produced. + verbose (bool, optional): Emit additional progress logging. Defaults to False. + + Returns: + dict: Mapping from each couple (a tuple of ``(com_idx, label)`` members) to its mean agreement score, + ordered by descending size-weighted agreement. + """ n_predictions = hierarchical_predictions.shape[0] + # Set for O(1) membership in the inner candidate search (called 3 * n_predictions times). + existing_predictions = set(hierarchical_existing_predictions) coupled_predictions = {} # TODO try to calculate list of candidates here, take the predictions and then parallelize the find_prediction_couple for idx in range(n_predictions): for pred in range(3): - couple, agreement = find_prediction_couple( - idx, pred, hierarchical_predictions, hierarchical_existing_predictions, n_predictions, verbose - ) + couple, agreement = find_prediction_couple(idx, pred, hierarchical_predictions, existing_predictions, n_predictions, verbose) if couple is None: continue if couple not in coupled_predictions: @@ -750,7 +852,17 @@ def create_prediction_couples( return coupled_predictions -def parallel_dice(anchor, pred, cand_loc): +def parallel_dice(anchor, pred, cand_loc: tuple) -> tuple[float, tuple]: + """Compute the Dice score between two masks, tagged with a candidate location. + + Args: + anchor (np.ndarray): Anchor prediction mask. + pred (np.ndarray): Candidate prediction mask to compare against. + cand_loc: Candidate location identifier carried through unchanged. + + Returns: + tuple[float, Any]: The Dice score between ``anchor`` and ``pred`` and the passed-through ``cand_loc``. + """ return float(np_dice(anchor, pred)), cand_loc @@ -761,7 +873,25 @@ def find_prediction_couple( hierarchical_existing_predictions, n_predictions, verbose: bool = False, -): +) -> tuple[tuple | None, float]: + """Find the best-agreeing partner predictions for one anchor prediction. + + Considers candidate predictions within +/-2 of the anchor's center index (all three labels, excluding the + anchor itself), ranks them by Dice with the anchor, and keeps up to the two best whose Dice exceeds 0.3. + The anchor itself is appended, and the members are returned sorted by center index. + + Args: + idx (int): Center-of-mass index of the anchor prediction. + pred (int): Label index of the anchor prediction. + hierarchical_predictions (np.ndarray): Hierarchical predictions of shape ``(com_idx, label, *shape)``. + hierarchical_existing_predictions (list[str]): Identifiers of the predictions that were actually produced. + n_predictions (int): Total number of corpus centers of mass. + verbose (bool, optional): Emit additional progress logging. Defaults to False. + + Returns: + tuple[tuple | None, float]: The couple (a sorted tuple of ``(com_idx, label)`` members including the + anchor) and its mean partner agreement. Returns ``(None, 0)`` if the anchor prediction does not exist. + """ if str_id_com_label(idx, pred) not in hierarchical_existing_predictions: logger.print(f"{str_id_com_label(idx, pred)} not in predictions {hierarchical_existing_predictions}", verbose=verbose) return None, 0 @@ -826,6 +956,26 @@ def merge_coupled_predictions( vert_size_threshold: int = 0, verbose: bool = False, ) -> tuple[NII, dict, ErrCode]: + """Assemble the final vertebra instance mask from ranked prediction couples. + + Iterates over the couples in priority order, summing their member maps and thresholding by voxel agreement + (requiring overlap from at least two members unless the couple is small or low-agreement). Each accepted + couple is written as a new instance label into voxels not yet claimed; couples overlapping established + vertebrae by more than 60% are skipped. Small connected-component artifacts are optionally cleaned afterwards. + + Args: + seg_nii (NII): Reference segmentation providing shape and spatial metadata. + coupled_predictions (dict): Mapping from couple to mean agreement, ordered by priority. + hierarchical_predictions (np.ndarray): Hierarchical predictions of shape ``(com_idx, label, *shape)``. + debug_data (dict): Dictionary for collecting intermediate results. + proc_clean_small_cc_artifacts (bool, optional): Whether to delete small instance artifacts. Defaults to True. + vert_size_threshold (int, optional): Voxel threshold for removing small instance artifacts. Defaults to 0. + verbose (bool, optional): Emit additional progress logging. Defaults to False. + + Returns: + tuple[NII, dict, ErrCode]: The vertebra instance mask, the (updated) debug data, and an error code + (``ErrCode.OK`` on success, ``ErrCode.EMPTY`` if a couple produces an empty mask or the result is empty). + """ whole_vert_nii = seg_nii.copy() whole_vert_arr = np.zeros(whole_vert_nii.shape, dtype=np.uint16) # this is fixed segmentations from vert diff --git a/spineps/phase_labeling.py b/spineps/phase_labeling.py index 4b6bc9c..0223ff6 100644 --- a/spineps/phase_labeling.py +++ b/spineps/phase_labeling.py @@ -1,3 +1,5 @@ +"""Vertebra-labeling phase: turns top-to-bottom vertebra instances into anatomical vertebra labels via a classifier and path search.""" + from __future__ import annotations import numpy as np @@ -15,7 +17,13 @@ ) from spineps.get_models import get_actual_model from spineps.lab_model import VertLabelingClassifier -from spineps.utils.find_min_cost_path import find_most_probably_sequence +from spineps.utils.find_min_cost_path import ( + DEFAULT_REGION_STARTS, + L5_CLASS_IDX, + T11_CLASS_IDX, + T12_CLASS_IDX, + find_most_probably_sequence, +) logger = No_Logger(prefix="LabelingPhase") @@ -26,6 +34,16 @@ DIVIDE_BY_ZERO_OFFSET = 1e-8 +# Cost-matrix class indices (0-based, matching VertExact) of anatomically special vertebrae. +# T11/T12/L5 and the region starts (DEFAULT_REGION_STARTS) are imported from find_min_cost_path, +# their canonical home (the path solver that consumes them). +C1_CLASS_IDX = 0 +C2_CLASS_IDX = 1 +# Post-processing label for the (anomalous) T13 vertebra; it has no VertExact class. +T13_LABEL = 28 +# Crop margin in millimeters kept around the vertebrae before labeling. +LABELING_CROP_MARGIN_MM = 128 + def perform_labeling_step( model: VertLabelingClassifier, @@ -34,7 +52,24 @@ def perform_labeling_step( subreg_nii: NII | None = None, proc_lab_force_no_tl_anomaly: bool = False, disable_c1: bool = True, -): +) -> NII: + """Assign anatomical vertebra labels to a vertebra instance mask using the labeling classifier. + + Runs the labeling classifier on each vertebra instance, derives a globally consistent label sequence, and relabels the + instance mask accordingly. If a subregion mask is given, the classifier only sees the vertebra corpus (not the whole vertebra). + Optionally adds a missing C1 label and zeroes out any instances that could not be matched. + + Args: + model (VertLabelingClassifier): Classifier used to predict per-instance vertebra labels. + img_nii (NII): Input MRI image. + vert_nii (NII): Vertebra instance segmentation mask to be relabeled. + subreg_nii (NII | None): Subregion semantic mask; if given, vertebrae are masked to their corpus before classification. + proc_lab_force_no_tl_anomaly (bool): If True, disallow thoracolumbar (T13) transitional-vertebra anomalies. + disable_c1 (bool): If True, do not predict/add a C1 label. + + Returns: + NII: The vertebra instance mask relabeled with anatomical vertebra labels (unmatched instances set to 0). + """ if model.predictor is None: model.load() @@ -76,14 +111,36 @@ def run_model_for_vert_labeling( verbose: bool = False, proc_lab_force_no_tl_anomaly: bool = False, disable_c1: bool = True, -): +) -> tuple[dict[int, int], float, list[int], list[int], list, list, dict]: + """Run the labeling classifier over a whole image/instance pair and resolve a vertebra label sequence. + + Reorients, crops around the vertebrae, rescales to the model's recommended zoom, runs the classifier on every vertebra + instance, and uses the cheapest-cost path search to turn per-instance predictions into a consistent anatomical sequence. + + Args: + model (VertLabelingClassifier): Classifier used to predict per-instance vertebra labels. + img_nii (NII): Input MRI image. + vert_nii (NII): Vertebra instance segmentation mask. + verbose (bool): If True, print intermediate weighting/path information. + proc_lab_force_no_tl_anomaly (bool): If True, disallow thoracolumbar (T13) transitional-vertebra anomalies. + disable_c1 (bool): If True, do not predict a C1 label. + + Returns: + tuple: ``(labelmap, fcost, fpath, fpath_post, costlist, min_costs_path, predictions)`` where ``labelmap`` maps each + original instance label to its assigned vertebra label, ``fcost`` is the total path cost, ``fpath``/``fpath_post`` + are the raw and post-processed label sequences, ``costlist`` is the cost matrix as a list, ``min_costs_path`` is the + per-step minimum cost path, and ``predictions`` are the raw classifier outputs. + + Raises: + AssertionError: If the number of original instances does not match the resolved path length. + """ # reorient img = img_nii.reorient(model.inference_config.model_expected_orientation, verbose=False) vert = vert_nii.reorient(model.inference_config.model_expected_orientation, verbose=False) zms_pir = img.zoom # crop - crop = vert.compute_crop(dist=128 / min(img.zoom)) + crop = vert.compute_crop(dist=LABELING_CROP_MARGIN_MM / min(img.zoom)) img.apply_crop_(crop) vert.apply_crop_(crop) @@ -118,7 +175,27 @@ def run_model_for_vert_labeling_cutouts( boost_c2: float = 3.0, allow_cervical_skip: bool = True, verbose: bool = True, -): +) -> tuple[dict[int, int], float, list[int], list[int], list, list, dict]: + """Run the labeling classifier on precomputed per-instance image cutouts and resolve a vertebra label sequence. + + Like :func:`run_model_for_vert_labeling`, but skips reorienting/cropping/rescaling and instead consumes already-prepared + image arrays keyed by instance label. + + Args: + model (VertLabelingClassifier): Classifier used to predict per-instance vertebra labels. + img_arrays (dict[int, np.ndarray]): Mapping of vertebra instance label to its cropped image array. + disable_c1 (bool): If True, do not predict a C1 label. + boost_c2 (float): Multiplicative boost applied to a prediction whose argmax is C2. + allow_cervical_skip (bool): If True, allow the path search to skip a class within the cervical region. + verbose (bool): If True, print intermediate weighting/path information. + + Returns: + tuple: ``(labelmap, fcost, fpath, fpath_post, costlist, min_costs_path, predictions)`` (see + :func:`run_model_for_vert_labeling`). + + Raises: + AssertionError: If the number of input arrays does not match the resolved path length. + """ # reorient # img = img_nii.reorient(model.inference_config.model_expected_orientation, verbose=False) # vert = vert_nii.reorient(model.inference_config.model_expected_orientation, verbose=False) @@ -147,7 +224,15 @@ def run_model_for_vert_labeling_cutouts( return labelmap, fcost, fpath, fpath_post, costlist, min_costs_path, predictions -def region_to_vert(region_softmax_values: np.ndarray): # shape(1,3) +def region_to_vert(region_softmax_values: np.ndarray) -> np.ndarray: # shape(1,3) + """Broadcast a 3-region (cervical, thoracic, lumbar) softmax into a per-vertebra-class vector. + + Args: + region_softmax_values (np.ndarray): Length-3 region softmax values ordered cervical, thoracic, lumbar. + + Returns: + np.ndarray: Length-``VERT_CLASSES`` vector with each region's value broadcast across that region's vertebra classes. + """ vert_prediction_values = np.zeros(VERT_CLASSES) vert_prediction_values[CERV] = region_softmax_values[0] vert_prediction_values[THOR] = region_softmax_values[1] @@ -160,7 +245,20 @@ def prepare_vert( gaussian_sigma: float = 0.85, gaussian_radius: int = 2, gaussian_regionwise: bool = True, -): +) -> np.ndarray: + """Smooth and normalize a per-vertebra-class softmax vector. + + Optionally applies a 1-D Gaussian filter (either per spinal region or across all classes) and then normalizes to sum to 1. + + Args: + vert_softmax_values (np.ndarray): Length-``VERT_CLASSES`` per-class softmax values. + gaussian_sigma (float): Gaussian smoothing sigma; 0 disables smoothing. + gaussian_radius (int): Half-width of the Gaussian kernel. + gaussian_regionwise (bool): If True, smooth each spinal region independently instead of across the whole vector. + + Returns: + np.ndarray: The smoothed, sum-normalized per-class vector. + """ # gaussian region-wise softmax_values = vert_softmax_values.copy() if gaussian_sigma > 0.0: @@ -178,7 +276,21 @@ def prepare_vertgrp( gaussian_sigma: float = 0.85, gaussian_radius: int = 2, gaussian_regionwise: bool = True, -): +) -> np.ndarray: + """Expand a vertebra-group softmax to per-vertebra classes, then smooth and normalize it. + + Distributes each vertebra-group probability onto its member vertebra classes (via ``vert_group_idx_to_exact_idx_dict``), + optionally applies a 1-D Gaussian filter (per region or globally), and normalizes to sum to 1. + + Args: + vertgrp_softmax_values (np.ndarray): Per-vertebra-group softmax values. + gaussian_sigma (float): Gaussian smoothing sigma; 0 disables smoothing. + gaussian_radius (int): Half-width of the Gaussian kernel. + gaussian_regionwise (bool): If True, smooth each spinal region independently instead of across the whole vector. + + Returns: + np.ndarray: The expanded, smoothed, sum-normalized per-class vector. + """ # gaussian region-wise softmax_values = np.zeros(VERT_CLASSES) for i, g in vert_group_idx_to_exact_idx_dict.items(): @@ -193,11 +305,26 @@ def prepare_vertgrp( return softmax_values -def prepare_visible(predictions: dict, visible_w: float = 1.0, gaussian_sigma: float = 0.8, gaussian_radius: int = 2): +def prepare_visible(predictions: dict, visible_w: float = 1.0, gaussian_sigma: float = 0.8, gaussian_radius: int = 2) -> np.ndarray: + """Build a per-instance confidence-weighting chain from the classifier's "fully visible" head. + + For each instance, reads the probability of being fully visible (if the ``FULLYVISIBLE`` head is present, else assumes 1), + optionally Gaussian-smooths it along the instance axis, and converts it into a multiplicative weight in ``[0, 1]`` that + down-weights partially visible (cropped) vertebrae according to ``visible_w``. + + Args: + predictions (dict): Per-instance classifier outputs, each holding a ``"soft"`` dict of head softmax arrays. + visible_w (float): Strength of the visibility down-weighting; 0 disables it. + gaussian_sigma (float): Gaussian smoothing sigma along the instance axis; 0 disables smoothing. + gaussian_radius (int): Half-width of the Gaussian kernel. + + Returns: + np.ndarray: Per-instance multiplicative weights clipped to ``[0, 1]``. + """ # has soft and FULLYVISIBLE key predict_keys = list(predictions[list(predictions.keys())[0]]["soft"].keys()) # noqa: RUF015 if "FULLYVISIBLE" in predict_keys: - visible_chain = np.asarray([k["soft"]["FULLYVISIBLE"][1] for v, k in predictions.items()]) + visible_chain = np.asarray([k["soft"]["FULLYVISIBLE"][1] for k in predictions.values()]) else: visible_chain = np.ones(len(predictions)) if gaussian_sigma > 0.0: @@ -212,7 +339,17 @@ def prepare_visible(predictions: dict, visible_w: float = 1.0, gaussian_sigma: f return visible_chain -def prepare_region(region_softmax_values: np.ndarray, gaussian_sigma: float = 0.75, gaussian_radius: int = 1): +def prepare_region(region_softmax_values: np.ndarray, gaussian_sigma: float = 0.75, gaussian_radius: int = 1) -> np.ndarray: + """Broadcast a region softmax to per-vertebra classes, then smooth and normalize it. + + Args: + region_softmax_values (np.ndarray): Length-3 region softmax values (cervical, thoracic, lumbar). + gaussian_sigma (float): Gaussian smoothing sigma; 0 disables smoothing. + gaussian_radius (int): Half-width of the Gaussian kernel. + + Returns: + np.ndarray: The broadcast, smoothed, sum-normalized per-class vector. + """ softmax_values = region_to_vert(region_softmax_values) if gaussian_sigma > 0.0 and np.sum(softmax_values) > 0.0: softmax_values = gaussian_filter1d(softmax_values, sigma=gaussian_sigma, mode="nearest", radius=gaussian_radius) @@ -220,7 +357,20 @@ def prepare_region(region_softmax_values: np.ndarray, gaussian_sigma: float = 0. return softmax_values -def prepare_vertrel_columns(vertrel_matrix: np.ndarray, gaussian_sigma: float = 0.75, gaussian_radius: int = 1): +def prepare_vertrel_columns(vertrel_matrix: np.ndarray, gaussian_sigma: float = 0.75, gaussian_radius: int = 1) -> np.ndarray: + """Smooth and column-normalize the relative-position (VertRel) cost matrix. + + For each VertRel label (column, skipping the first), optionally Gaussian-smooths the values along the instance axis and + normalizes the column so its values stay bounded (divides by the column sum when it exceeds 1, otherwise by ``1 + sum``). + + Args: + vertrel_matrix (np.ndarray): Matrix of shape ``(n_instances, len(VertRel))`` of relative-position softmax values. + gaussian_sigma (float): Gaussian smoothing sigma along the instance axis; 0 disables smoothing. + gaussian_radius (int): Half-width of the Gaussian kernel. + + Returns: + np.ndarray: The smoothed, column-normalized relative-position matrix (modified in place and returned). + """ for i in range(1, min(len(VertRel), vertrel_matrix.shape[1])): if gaussian_sigma > 0.0 and np.sum(vertrel_matrix) > 0.0: vertrel_matrix[:, i] = gaussian_filter1d(vertrel_matrix[:, i], sigma=gaussian_sigma, mode="nearest", radius=gaussian_radius) @@ -233,14 +383,34 @@ def prepare_vertrel_columns(vertrel_matrix: np.ndarray, gaussian_sigma: float = return vertrel_matrix -def prepare_vertt13_columns(vertt13_matrix: np.ndarray): +def prepare_vertt13_columns(vertt13_matrix: np.ndarray) -> np.ndarray: + """Column-normalize the T13-anomaly (VertT13) cost matrix. + + Normalizes each VertT13 label (column, skipping the first) so it sums to 1 along the instance axis. + + Args: + vertt13_matrix (np.ndarray): Matrix of shape ``(n_instances, len(VertT13))`` of T13-anomaly softmax values. + + Returns: + np.ndarray: The column-normalized matrix (modified in place and returned). + """ for i in range(1, min(len(VertT13), vertt13_matrix.shape[1])): # normalize per column / label in this case vertt13_matrix[:, i] = vertt13_matrix[:, i] / (np.sum(vertt13_matrix[:, i]) + DIVIDE_BY_ZERO_OFFSET) return vertt13_matrix -def prepare_vertrel(vertrel_softmax_values: np.ndarray, gaussian_sigma: float = 0.75, gaussian_radius: int = 1): +def prepare_vertrel(vertrel_softmax_values: np.ndarray, gaussian_sigma: float = 0.75, gaussian_radius: int = 1) -> np.ndarray: + """Optionally Gaussian-smooth a relative-position (VertRel) softmax vector. + + Args: + vertrel_softmax_values (np.ndarray): Relative-position softmax values for a single instance. + gaussian_sigma (float): Gaussian smoothing sigma; 0 disables smoothing. + gaussian_radius (int): Half-width of the Gaussian kernel. + + Returns: + np.ndarray: The (optionally smoothed) relative-position vector; not re-normalized. + """ softmax_values = vertrel_softmax_values.copy() if gaussian_sigma > 0.0: softmax_values = gaussian_filter1d(softmax_values, sigma=gaussian_sigma, mode="nearest", radius=gaussian_radius) @@ -279,7 +449,55 @@ def find_vert_path_from_predictions( proc_lab_force_no_tl_anomaly: bool = False, # verbose: bool = False, -): +) -> tuple[float, list[int], list[int], list, list, dict]: + """Combine the classifier's prediction heads into a cost matrix and solve for the most probable vertebra label sequence. + + Builds a per-instance / per-class cost matrix by weighting and summing the available prediction heads (VERT, VERTGRP, + REGION), down-weighting by the "fully visible" chain, optionally boosting C2, and adding separate relative-position + (VertRel) and T13-anomaly (VertT13) cost terms. The cheapest monotonically increasing label path is then found with + :func:`find_most_probably_sequence` (unless ``argmax_combined_cost_matrix_instead_of_path_algorithm`` is set, which falls + back to a plain per-instance argmax). Special transitional vertebrae (T11 skip, T12/L5 repeats) and per-region skips are + permitted via the corresponding flags. Finally the path is post-processed (see :func:`fpath_post_processing`). + + Args: + predictions (dict): Per-instance classifier outputs, each holding a ``"soft"`` dict of per-head softmax arrays. + visible_w (float): Weight of the "fully visible" down-weighting (must be in ``[0, 1]``). + vert_w (float): Weight of the per-vertebra (VERT) head. + vertgrp_w (float): Weight of the vertebra-group (VERTGRP) head. + region_w (float): Weight of the spinal-region (REGION) head. + vertrel_w (float): Weight of the relative-position (VERTREL) cost term. + vertt13_w (float): Weight of the T13-anomaly (VERTT13) cost term. + disable_c1 (bool): If True, the path may not start at C1 (starts at C2 instead). + boost_c2 (float): Multiplicative boost applied to a prediction whose argmax is C2; 0 disables it. + allow_cervical_skip (bool): If True, allow skipping a class within the cervical region. + allow_thoracic_skip (bool): If True, allow skipping a class within the thoracic region. + allow_lumbar_skip (bool): If True, allow skipping a class within the lumbar region. + punish_multiple_sequence (float): Extra cost for repeating an allowed-multiple class. + punish_skip_sequence (float): Extra cost for skipping an allowed-skip class. + punish_skip_at_region_sequence (float): Extra cost for skipping at a region boundary. + region_gaussian_sigma (float): Gaussian sigma for the region head; 0 disables smoothing. + vert_gaussian_sigma (float): Gaussian sigma for the vertebra head; 0 disables smoothing. + vert_gaussian_regionwise (bool): If True, smooth the vertebra head per region. + vertgrp_gaussian_sigma (float): Gaussian sigma for the vertebra-group head; 0 disables smoothing. + vertgrp_gaussian_regionwise (bool): If True, smooth the vertebra-group head per region. + vertrel_column_norm (bool): If True, column-normalize the relative-position matrix. + vertrel_gaussian_sigma (float): Gaussian sigma used when column-normalizing the relative-position matrix. + focus_tl_gap (bool): If True, focus on the T11/T13 thoracolumbar gap (reserved for the refinement pass). + argmax_combined_cost_matrix_instead_of_path_algorithm (bool): If True, take a plain per-instance argmax instead of the + path search. + proc_lab_force_no_tl_anomaly (bool): If True, disallow T13 transitional-vertebra anomalies (no T11 skip / no T12 repeat). + verbose (bool): If True, print the active head weights. + + Returns: + tuple: ``(fcost, fpath, fpath_post, cost_matrix_list, min_costs_path, args)`` where ``fcost`` is the total path cost, + ``fpath`` is the raw class path, ``fpath_post`` is the post-processed (1-based, T13-aware) label sequence, + ``cost_matrix_list`` is the combined cost matrix as a nested list, ``min_costs_path`` is the per-step minimum cost + path, and ``args`` is a snapshot of the call arguments. + + Raises: + AssertionError: If a weight is negative, ``visible_w`` exceeds 1, or no vital classification head (VERT/VERTEXACT/ + VERTGRP) is present in the predictions. + """ args = locals() assert 0 <= visible_w, visible_w # noqa: SIM300 assert visible_w <= 1.0, f"visible_w must be <= 1.0, got {visible_w}" @@ -290,26 +508,26 @@ def find_vert_path_from_predictions( # n_vert = len(predictions) # - cost_matrix = np.zeros((n_vert, 24)) # TODO 24 fix? - relative_cost_matrix = np.zeros((n_vert, 6)) # TODO 6 fix? + cost_matrix = np.zeros((n_vert, VERT_CLASSES)) + relative_cost_matrix = np.zeros((n_vert, len(VertRel))) visible_chain = prepare_visible(predictions, visible_w) # print(visible_chain) predict_keys = list(predictions[list(predictions.keys())[0]]["soft"].keys()) # noqa: RUF015 - assert "VERT" in predict_keys or "VERTEXACT" in predict_keys or "VERTGRP" in predict_keys, ( + assert "VERT" in predict_keys or "VERTEXACT" in predict_keys or "VERTEX" in predict_keys or "VERTGRP" in predict_keys, ( f"No vital classification head found, got {predict_keys}" ) # VertRel normalize over labels if "VERTREL" in predict_keys: - vertrel_matrix = np.asarray([k["soft"]["VERTREL"] for v, k in predictions.items()]) + vertrel_matrix = np.asarray([k["soft"]["VERTREL"] for k in predictions.values()]) else: vertrel_matrix = np.zeros((n_vert, len(VertRel))) if vertrel_column_norm: vertrel_matrix = prepare_vertrel_columns(vertrel_matrix, gaussian_sigma=vertrel_gaussian_sigma) if "VERTT13" in predict_keys: - vertt13_softmax_output = np.asarray([k["soft"]["VERTT13"] for v, k in predictions.items()]) + vertt13_softmax_output = np.asarray([k["soft"]["VERTT13"] for k in predictions.values()]) else: vertt13_softmax_output = np.zeros((n_vert, len(VertT13))) vertt13_values = np.multiply( @@ -371,7 +589,7 @@ def find_vert_path_from_predictions( # normalize final_vert_pred /= np.sum(final_vert_pred) + DIVIDE_BY_ZERO_OFFSET # boost c2 if enabled - if boost_c2 > 0.0 and np.argmax(final_vert_pred) == 1: + if boost_c2 > 0.0 and np.argmax(final_vert_pred) == C2_CLASS_IDX: final_vert_pred = np.multiply(final_vert_pred, boost_c2) # then multiply with visible factor final_vert_pred = np.multiply(final_vert_pred, visible_chain[idx]) @@ -394,8 +612,8 @@ def find_vert_path_from_predictions( min_costs_path = [[]] fpath = list(np.argmax(cost_matrix, axis=1)) else: - allow_multiple_at_class = [18, 23] if not proc_lab_force_no_tl_anomaly else [23] # T12 and L5 - allow_skip_at_class = [17] if not proc_lab_force_no_tl_anomaly else [] # T11 + allow_multiple_at_class = [T12_CLASS_IDX, L5_CLASS_IDX] if not proc_lab_force_no_tl_anomaly else [L5_CLASS_IDX] + allow_skip_at_class = [T11_CLASS_IDX] if not proc_lab_force_no_tl_anomaly else [] allow_skip_at_region = [] if allow_cervical_skip: allow_skip_at_region.append(0) @@ -406,7 +624,7 @@ def find_vert_path_from_predictions( fcost, fpath, min_costs_path = find_most_probably_sequence( # input cost_matrix, - min_start_class=0 if not disable_c1 else 1, + min_start_class=C1_CLASS_IDX if not disable_c1 else C2_CLASS_IDX, region_rel_cost=relative_cost_matrix, vertt13_cost=vertt13_values, invert_cost=True, @@ -414,9 +632,9 @@ def find_vert_path_from_predictions( punish_multiple_sequence=punish_multiple_sequence, punish_skip_sequence=punish_skip_sequence, # no touch - regions=[0, 7, 19], - allow_multiple_at_class=allow_multiple_at_class, # T12 and L5 - allow_skip_at_class=allow_skip_at_class, # T11 + regions=list(DEFAULT_REGION_STARTS), + allow_multiple_at_class=allow_multiple_at_class, + allow_skip_at_class=allow_skip_at_class, # allow_skip_at_region=allow_skip_at_region, punish_skip_at_region_sequence=punish_skip_at_region_sequence, @@ -428,26 +646,49 @@ def find_vert_path_from_predictions( def fpath_post_processing(fpath) -> list[int]: + """Post-process a raw 0-based class path into the final 1-based vertebra label sequence. + + Resolves transitional-vertebra anomalies (two consecutive T12 become T12 + T13; a trailing double L5 becomes L5 + L6) and + shifts every class index by 1 to the final label convention, leaving the special T13 label untouched. + + Args: + fpath (list[int]): Raw 0-based class path from the cost/path search. + + Returns: + list[int]: The post-processed 1-based vertebra label sequence (with T13/L6 anomalies applied). + """ fpath_post = fpath[:] # Two T12 -> T12 + T13 if VertExact.T12.value in fpath_post: tidx = fpath_post.index(VertExact.T12.value) if tidx != 0 and fpath_post[tidx - 1] == VertExact.T12.value: - fpath_post[tidx] = 28 + fpath_post[tidx] = T13_LABEL elif tidx != len(fpath_post) - 1 and fpath_post[tidx + 1] == VertExact.T12.value: - fpath_post[tidx + 1] = 28 + fpath_post[tidx + 1] = T13_LABEL # Two L5 -> L5, L6 if (VertExact.L5.value in fpath_post and len(fpath_post) >= 2) and ( fpath_post[-1] == VertExact.L5.value and fpath_post[-2] == VertExact.L5.value ): fpath_post[-1] += 1 - fpath_post = [f + 1 if f != 28 else 28 for f in fpath_post] + fpath_post = [f + 1 if f != T13_LABEL else T13_LABEL for f in fpath_post] return fpath_post def is_valid_vertebra_sequence(sequence: list[VertExact] | list[int]) -> bool: + """Check whether a vertebra label sequence is anatomically contiguous top-to-bottom. + + A sequence is valid if each label follows the previous one by exactly 1, or forms one of the allowed transitional jumps at + the thoracolumbar junction (T13->L1, i.e. 28->20, and T12->L1, i.e. 18->20). ``VertExact`` inputs are first converted via + :func:`fpath_post_processing`. + + Args: + sequence (list[VertExact] | list[int]): The vertebra label sequence, either as ``VertExact`` enums or 1-based ints. + + Returns: + bool: True if the sequence is a valid contiguous vertebra run, otherwise False. + """ sequence2: list[int] = fpath_post_processing([s.value for s in sequence]) if isinstance(sequence[0], VertExact) else sequence # type: ignore # must be sequence of vertebrae for i in range(1, len(sequence2)): diff --git a/spineps/phase_post.py b/spineps/phase_post.py index 3223de7..71f48af 100644 --- a/spineps/phase_post.py +++ b/spineps/phase_post.py @@ -1,3 +1,5 @@ +"""Post-processing phase: clean and reconcile the semantic and vertebra-instance masks and attach IVD/endplate instance labels.""" + from __future__ import annotations # from utils.predictor import nnUNetPredictor @@ -21,9 +23,32 @@ ) from spineps.phase_labeling import VertLabelingClassifier, perform_labeling_step -from spineps.seg_pipeline import logger, vertebra_subreg_labels +from spineps.seg_pipeline import ENDPLATE_LABEL_OFFSET, IVD_LABEL_OFFSET, logger, vertebra_subreg_labels from spineps.utils.compat import zip_strict from spineps.utils.proc_functions import fix_wrong_posterior_instance_label +from spineps.utils.resolution import REFERENCE_VOXEL_VOLUME_MM3, REFERENCE_ZOOM, isotropic_area_to_voxels + +# --- Label-id conventions for combined post-processing --- +# Intervertebral discs (IVDs) and vertebral endplates reuse their parent vertebra's instance label, +# shifted by IVD_LABEL_OFFSET / ENDPLATE_LABEL_OFFSET (imported from seg_pipeline, their canonical home). +# The dens (odontoid process) anatomically belongs to the C2 vertebra (instance label 2). +C2_INSTANCE_LABEL = 2 +# Raw vertebra instance labels stay below this bound; anything above is an IVD/endplate/derived label. +INSTANCE_LABEL_LIMIT = 40 + +# --- Heuristic thresholds for combined post-processing --- +# Physical margin kept around the segmentation when cropping before processing. +POSTPROCESS_CROP_MARGIN_MM = 2 * min(REFERENCE_ZOOM) +# Warn when a vertebra's unmatched semantic volume exceeds this fraction of an average vertebra. +UNMATCHED_VOLUME_WARN_FRACTION = 0.5 +# Endplate splitting dilates iteratively with radius 1 up to (but excluding) this value. +MAX_ENDPLATE_DILATION = 15 +# Two stacked vertebrae are merged only if the smaller is below this fraction of the larger... +MERGED_VERTEBRA_SIZE_RATIO = 0.5 +# ...and the two masks share at least this much contact area (orientation-agnostic). +MERGED_VERTEBRA_MIN_CONTACT_MM2 = 20 * REFERENCE_VOXEL_VOLUME_MM3 ** (2.0 / 3.0) +# An articular substructure CC is reassigned when its largest overlap dominates the second by this ratio. +ARTICULAR_DOMINANCE_RATIO = 0.5 def phase_postprocess_combined( @@ -45,6 +70,37 @@ def phase_postprocess_combined( disable_c1=True, sacrum_ids=(v_name2idx["S1"],), ) -> tuple[NII, NII]: + """Run the combined semantic/instance post-processing pipeline and return cleaned, anatomically labeled masks. + + Crops both masks to the segmentation, optionally fixes superior/inferior articular inconsistencies, reconciles the + instance and semantic masks (reassigning or deleting unmatched connected components), splits accidentally merged vertebrae, + fixes mislabeled posterior elements, labels instances top-to-bottom, optionally runs the anatomical labeling classifier, + forces the sacrum and dens labels, attaches IVD and endplate instance labels (splitting endplates into superior/inferior), + and finally un-crops back to the original field of view. + + Args: + img_nii (NII): Input MRI image. + seg_nii (NII): Subregion semantic segmentation mask. + vert_nii (NII): Vertebra instance segmentation mask. + model_labeling (VertLabelingClassifier | None): Anatomical labeling classifier; if None, instances keep their + top-to-bottom labels. + debug_data (dict | None): Optional dict that intermediate results are written into; created if None. + labeling_offset (int): Offset added to the top-to-bottom instance labels. + proc_lab_force_no_tl_anomaly (bool): If True, disallow T13 transitional-vertebra anomalies during labeling. + proc_assign_missing_cc (bool): If True, reassign semantic connected components not covered by the instance mask. + proc_assign_missing_cc_fast (bool): If True, use the faster infect-based missing-CC assignment. + proc_clean_inst_by_sem (bool): If True, mask the instance mask by the semantic mask before processing. + n_vert_bodies (int | None): Number of vertebra bodies; inferred from the instance mask if None. + process_merge_vertebra (bool): If True, detect and merge accidentally split adjacent vertebrae. + proc_vertebra_inconsistency (bool): If True, reassign inconsistent articular substructures by instance overlap. + proc_assign_posterior_instance_label (bool): If True, fix wrongly labeled posterior instance elements. + verbose (bool): If True, print verbose progress. + disable_c1 (bool): If True (and ``labeling_offset >= 1``), do not predict a C1 label. + sacrum_ids (tuple): Semantic label id(s) treated as sacrum and mapped to the S1 instance label. + + Returns: + tuple[NII, NII]: The cleaned ``(seg_uncropped, vert_uncropped)`` semantic and vertebra-instance masks. + """ logger.print("Post process", Log_Type.STAGE) with logger: img_nii.assert_affine(other=seg_nii) @@ -59,7 +115,7 @@ def phase_postprocess_combined( if proc_clean_inst_by_sem: vert_nii.apply_mask(seg_nii, inplace=True) - crop_slices = seg_nii.compute_crop(dist=2) + crop_slices = seg_nii.compute_crop(dist=POSTPROCESS_CROP_MARGIN_MM / min(seg_nii.zoom)) # Save uncropped to uncrop later vert_uncropped = vert_nii.copy() @@ -110,7 +166,7 @@ def phase_postprocess_combined( logger.print("seg_nii", seg_nii_cleaned.unique()) whole_vert_nii_cleaned[seg_nii_cleaned.extract_label(sacrum_ids).get_seg_array() == 1] = v_name2idx["S1"] - whole_vert_nii_cleaned[seg_nii_cleaned == Location.Dens_axis.value] = 2 + whole_vert_nii_cleaned[seg_nii_cleaned == Location.Dens_axis.value] = C2_INSTANCE_LABEL vert_arr_cleaned, seg_arr_cleaned = add_ivd_ep_vert_label(whole_vert_nii_cleaned, seg_nii_cleaned) # # @@ -138,6 +194,27 @@ def mask_cleaning_other( proc_assign_missing_cc_fast=False, verbose: bool = False, ) -> tuple[NII, NII]: + """Reconcile the vertebra instance mask with the vertebra portion of the semantic mask. + + Extracts the vertebra subregions from the semantic mask and (optionally) reassigns semantic connected components that the + instance mask missed, either via a fast infect pass or via :func:`assign_missing_cc`; deleted components are removed from the + semantic mask. Logs a warning when the unmatched vertebra volume between the two masks is anomalously large. + + Args: + whole_vert_nii (NII): Vertebra instance segmentation mask. + seg_nii (NII): Subregion semantic segmentation mask. + n_vert_bodies (int): Number of vertebra bodies, used to scale the unmatched-volume warning. + proc_assign_missing_cc (bool): If True, reassign missed semantic components via :func:`assign_missing_cc`. + proc_assign_missing_cc_fast (bool): If True, additionally run a fast infect-based assignment first. + verbose (bool): If True, print verbose progress. + + Returns: + tuple[NII, NII]: The cleaned ``(whole_vert_nii, seg_nii)`` instance and semantic masks. + + Raises: + AssertionError: If, with ``proc_assign_missing_cc`` enabled, the instance mask still has more vertebra voxels than the + semantic mask (which should be impossible). + """ subreg_vert_nii = seg_nii.extract_label(vertebra_subreg_labels) if proc_assign_missing_cc_fast: @@ -175,7 +252,7 @@ def mask_cleaning_other( f"A volume of {n_vert_pixels_rel_diff} * avg_vertebra_volume in vertebra not matched in semantic mask, set proc_assign_missing_cc=TRUE to circumvent this", Log_Type.WARNING, ) - elif n_vert_pixels_rel_diff > 0.5: + elif n_vert_pixels_rel_diff > UNMATCHED_VOLUME_WARN_FRACTION: logger.print(f"A volume of {n_vert_pixels_rel_diff} * avg_vertebra_volume in subreg not matched by vertebra mask", Log_Type.WARNING) return whole_vert_nii.set_array(vert_arr_cleaned), seg_nii.set_array(subreg_arr) @@ -186,7 +263,29 @@ def assign_missing_cc( reference_arr: np.ndarray, verbose: bool = False, verbose_deletion: bool = False, + proc_assign_missing_dilate_first: bool = True, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Assign reference-mask connected components that the target mask does not cover to a neighboring target label. + + Finds connected components of ``reference_arr`` (e.g. the vertebra semantic mask) that have no overlap with ``target_arr`` + (the instance mask). Optionally dilates the target mask once first to absorb thin gaps. Each remaining component is dilated + locally and assigned to the most common neighboring target label; components with no labeled neighbor are deleted from the + reference mask and recorded in a deletion map. + + Args: + target_arr (np.ndarray): Instance label array that components are assigned to. + reference_arr (np.ndarray): Reference (semantic) label array whose uncovered components are processed. + verbose (bool): If True, log each assignment. + verbose_deletion (bool): If True, log each deletion even when ``verbose`` is False. + proc_assign_missing_dilate_first (bool): If True, dilate the target mask once before searching for uncovered components. + + Returns: + tuple[np.ndarray, np.ndarray, np.ndarray]: ``(target_arr, reference_arr, deletion_map)`` with the updated instance and + reference arrays and a binary map of voxels removed from the reference mask. + + Raises: + AssertionError: If ``target_arr`` and ``reference_arr`` do not share the same shape. + """ assert target_arr.shape == reference_arr.shape deletion_map = np.zeros_like(reference_arr, dtype=np.uint8) @@ -199,6 +298,27 @@ def assign_missing_cc( logger.print("No CC had to be assigned", Log_Type.OK, verbose=verbose) return target_arr, reference_arr, deletion_map + # dilate once first + if proc_assign_missing_dilate_first: + target_arr_ = np_dilate_msk( + target_arr, + None, + n_pixel=2, + connectivity=1, + mask=reference_arr, + use_crop=False, + ) + subreg_arr_vert_rest = reference_arr.copy() + subreg_arr_vert_rest[target_arr_ != 0] = 0 + deletion_map = np.zeros(reference_arr.shape) + + label_rest = np_unique(subreg_arr_vert_rest) + if len(label_rest) == 1 and label_rest[0] == 0: + logger.print("No CC had to be assigned", Log_Type.OK, verbose=verbose) + return target_arr_, reference_arr, deletion_map + + target_arr = target_arr_ + # subreg_arr_vert_rest is not hit pixels bei vertebra prediction subreg_cc = np_connected_components_per_label(subreg_arr_vert_rest, connectivity=2) loop_counts = 0 @@ -246,7 +366,24 @@ def assign_missing_cc( return target_arr, reference_arr, deletion_map -def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): +def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True) -> tuple[np.ndarray, np.ndarray]: + """Attach intervertebral-disc and endplate instance labels and split endplates into superior/inferior. + + Reorients both masks to PIR, computes each vertebra corpus center of mass along the inferior-superior axis, then assigns + every IVD and endplate connected component to the nearest lower vertebra. IVD voxels are written into the instance array + with ``IVD_LABEL_OFFSET`` added; endplate voxels with ``ENDPLATE_LABEL_OFFSET`` added. Endplates are further divided into + inferior/superior plates by iteratively dilating each vertebra into the endplate region. Logs the number of assigned + components and restores the original orientation before returning. + + Args: + whole_vert_nii (NII): Vertebra instance segmentation mask. + seg_nii (NII): Subregion semantic segmentation mask (must contain disc/endplate/corpus labels). + verbose (bool): If True, print endplate-detection progress. + + Returns: + tuple[np.ndarray, np.ndarray]: ``(vert_arr, seg_arr)`` arrays in the original orientation: the instance array with IVD + and endplate instance labels added, and the semantic array with endplates split into superior/inferior plates. + """ # PIR orientation = whole_vert_nii.orientation vert_t = whole_vert_nii.reorient() @@ -306,8 +443,8 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): subreg_ivd = subreg_cc.copy() n_ivd_unique = len(np.unique(to_mapped_labels)) subreg_ivd = np_map_labels(subreg_ivd, label_map=mapping_cc_to_vert_label) - subreg_ivd += 100 - subreg_ivd[subreg_ivd == 100] = 0 + subreg_ivd += IVD_LABEL_OFFSET + subreg_ivd[subreg_ivd == IVD_LABEL_OFFSET] = 0 vert_arr[subreg_arr == Location.Vertebra_Disc.value] = subreg_ivd[subreg_arr == Location.Vertebra_Disc.value] n_eps = 0 n_eps_unique = 0 @@ -331,8 +468,8 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): subreg_ep = ep_cc.copy() n_eps_unique = len(np.unique(list(mapping_ep_cc_to_vert_label.values()))) subreg_ep = np_map_labels(subreg_ep, label_map=mapping_ep_cc_to_vert_label) - subreg_ep += 200 - subreg_ep[subreg_ep == 200] = 0 + subreg_ep += ENDPLATE_LABEL_OFFSET + subreg_ep[subreg_ep == ENDPLATE_LABEL_OFFSET] = 0 vert_arr[subreg_arr == Location.Endplate.value] = subreg_ep[subreg_arr == Location.Endplate.value] vert_t.set_array_(vert_arr) @@ -340,10 +477,13 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): out = seg_t * 0 pref = 1 old_vol = -1 - for dil in range(1, 15): + # seg_t and vert_t are not modified in this loop, so compute these invariants once. + endplate_nii = seg_t.extract_label(Location.Endplate.value) + total = endplate_nii.sum() + vert_labels_to_split = vert_t.unique() + for dil in range(1, MAX_ENDPLATE_DILATION): curr = out.extract_label([Location.Vertebral_Body_Endplate_Inferior.value, Location.Vertebral_Body_Endplate_Superior.value]) new_vol = curr.sum() - total = seg_t.extract_label(Location.Endplate.value).sum() logger.print(rf"{new_vol / total * 100:.2f}% endplates detected", end="\r") if verbose else None if old_vol == new_vol and old_vol != 0: break @@ -351,18 +491,18 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): if total == new_vol: logger.print("Found all Endplates ") break - for i in vert_t.unique(): - if i >= 40: + for i in vert_labels_to_split: + if i >= INSTANCE_LABEL_LIMIT: break curr = out.extract_label([Location.Vertebral_Body_Endplate_Inferior.value, Location.Vertebral_Body_Endplate_Superior.value]) v = vert_t.extract_label(i).dilate_msk(dil, verbose=False) - end = seg_t.extract_label(Location.Endplate.value) * v + end = endplate_nii * v end *= -curr + 1 # type: ignore plates = vert_t * end plates.map_labels_( { - i + 200: Location.Vertebral_Body_Endplate_Inferior.value, - pref + 200: Location.Vertebral_Body_Endplate_Superior.value, + i + ENDPLATE_LABEL_OFFSET: Location.Vertebral_Body_Endplate_Inferior.value, + pref + ENDPLATE_LABEL_OFFSET: Location.Vertebral_Body_Endplate_Superior.value, }, verbose=False, ) @@ -381,14 +521,35 @@ def add_ivd_ep_vert_label(whole_vert_nii: NII, seg_nii: NII, verbose=True): return vert_t.set_array_(vert_arr).reorient_(orientation).get_seg_array(), seg_t.reorient_(orientation).get_seg_array() -def find_nearest_lower(seq, x): +def find_nearest_lower(seq, x) -> float: + """Return the largest element of ``seq`` strictly smaller than ``x``, or the minimum if none exists. + + Args: + seq (Sequence[float]): Values to search. + x (float): Reference value. + + Returns: + float: The greatest element below ``x``, or ``min(seq)`` when no element is below ``x``. + """ values_lower = [item for item in seq if item < x] if len(values_lower) == 0: return min(seq) return max(values_lower) -def label_instance_top_to_bottom(vert_nii: NII, labeling_offset: int = 0): +def label_instance_top_to_bottom(vert_nii: NII, labeling_offset: int = 0) -> tuple[NII, np.ndarray]: + """Relabel vertebra instances consecutively from top to bottom by center-of-mass height. + + Reorients to PIR, sorts the instances by their center of mass along the inferior-superior axis, and assigns consecutive + labels (``1 + labeling_offset`` upward) from top to bottom, then restores the original orientation. + + Args: + vert_nii (NII): Vertebra instance segmentation mask (modified in place). + labeling_offset (int): Offset added to the consecutive labels. + + Returns: + tuple[NII, np.ndarray]: The relabeled instance mask and its array of unique labels. + """ ori = vert_nii.orientation vert_nii.reorient_() vert_arr = vert_nii.get_seg_array() @@ -411,7 +572,21 @@ def assign_vertebra_inconsistency( Location.Inferior_Articular_Left, Location.Inferior_Articular_Right, ), -): +) -> None: + """Reassign articular-process components to the vertebra instance they most overlap with. + + For each given articular subregion location, finds its connected components in the semantic mask and, for each component, + reassigns its instance label to the vertebra whose overlap volume dominates the second-largest by ``ARTICULAR_DOMINANCE_RATIO``. + Updates ``vert_nii`` in place. + + Args: + vert_nii (NII): Vertebra instance segmentation mask (modified in place). + seg_nii (NII): Subregion semantic segmentation mask. + locations (tuple[Location, ...]): Articular subregion locations to reconcile. + + Returns: + None: ``vert_nii`` is modified in place. + """ seg_nii.assert_affine(shape=vert_nii.shape) seg_arr = seg_nii.get_seg_array() vert_arr = vert_nii.get_seg_array() @@ -445,7 +620,7 @@ def assign_vertebra_inconsistency( # print(biggest_volume, second_volume) - if biggest_volume[1] * 0.50 > second_volume[1]: + if biggest_volume[1] * ARTICULAR_DOMINANCE_RATIO > second_volume[1]: to_label = biggest_volume[0] - 1 # int(list(gt_volume.keys())[argmax] - 1) vert_arr[cc_map == 1] = to_label @@ -456,7 +631,20 @@ def assign_vertebra_inconsistency( vert_nii.set_array_(vert_arr) -def detect_and_solve_merged_vertebra(seg_nii: NII, vert_nii: NII): +def detect_and_solve_merged_vertebra(seg_nii: NII, vert_nii: NII) -> tuple[NII, NII]: + """Detect and merge a vertebra (typically C2) that was split into two stacked instances. + + Builds a height-sorted list of IVD components and vertebra instances. If the two topmost entries are both vertebrae, the + upper one is significantly smaller (below ``MERGED_VERTEBRA_SIZE_RATIO`` of the other), and the two masks touch over more + than ``MERGED_VERTEBRA_MIN_CONTACT_MM2`` of area, the smaller instance is merged into the larger one in ``vert_nii``. + + Args: + seg_nii (NII): Subregion semantic segmentation mask. + vert_nii (NII): Vertebra instance segmentation mask (modified in place when a merge occurs). + + Returns: + tuple[NII, NII]: The ``(seg_nii, vert_nii)`` masks (``vert_nii`` possibly with two instances merged). + """ seg_sem = seg_nii.map_labels({Location.Endplate.value: Location.Vertebra_Disc.value}, verbose=False) # get all ivd CCs from seg_sem @@ -484,14 +672,14 @@ def detect_and_solve_merged_vertebra(seg_nii: NII, vert_nii: NII): first_stats, second_stats = stats_by_height[first_key], stats_by_height[second_key] if first_stats[1] is False and second_stats[1] is False: # noqa: SIM102 # both vertebra - if first_stats[2] < 0.5 * second_stats[2]: + if first_stats[2] < MERGED_VERTEBRA_SIZE_RATIO * second_stats[2]: # first is significantly smaller than second and they are close in height # how many pixels are touching vert_firsttwo_arr = vert_nii.extract_label(first_key).get_seg_array() vert_firsttwo_arr2 = vert_nii.extract_label(second_key).get_seg_array() vert_firsttwo_arr += vert_firsttwo_arr2 + 1 contacts = np_contacts(vert_firsttwo_arr, connectivity=3) - if contacts[(1, 2)] > 20: + if contacts[(1, 2)] > isotropic_area_to_voxels(MERGED_VERTEBRA_MIN_CONTACT_MM2, vert_nii.zoom): logger.print("Found first two instance weird, will merge", Log_Type.STRANGE) vert_nii.map_labels_({first_key: second_key}, verbose=False) diff --git a/spineps/phase_pre.py b/spineps/phase_pre.py index 51da9cf..641ac8a 100644 --- a/spineps/phase_pre.py +++ b/spineps/phase_pre.py @@ -1,23 +1,62 @@ +"""Pre-processing phase: crop, N4 bias correction, and intensity normalization of the input MRI before segmentation.""" + from __future__ import annotations import inspect # from utils.predictor import nnUNetPredictor from time import perf_counter -from typing import Literal +from typing import TYPE_CHECKING, Literal from TPTBox import NII, Log_Type, to_nii +if TYPE_CHECKING: + from pathlib import Path + from spineps.seg_enums import ErrCode from spineps.seg_pipeline import logger from spineps.utils.proc_functions import n4_bias +from spineps.utils.resolution import REFERENCE_ZOOM + +# Input intensities are rescaled into this range before segmentation. +NORMALIZE_MIN_VALUE = 0 +NORMALIZE_MAX_VALUE = 1500 +# Physical margin kept around the detected spine when cropping a Vibe scan. +VIBE_CROP_MARGIN_MM = 25 * min(REFERENCE_ZOOM) def _has_logger_arg(func) -> bool: + """Check whether a callable accepts a ``logger`` keyword argument. + + Args: + func (Callable): The function whose signature is inspected. + + Returns: + bool: True if ``logger`` is among the function's parameters, else False. + """ return "logger" in inspect.signature(func).parameters -def compute_crop(nii: NII, out_file, dataset_id=100, ddevice: Literal["cpu", "cuda", "mps"] = "cuda", gpu=0, max_folds=None, logger=None): +def compute_crop( + nii: NII, out_file: str | Path, dataset_id=100, ddevice: Literal["cpu", "cuda", "mps"] = "cuda", gpu=0, max_folds=None, logger=None +) -> tuple[slice, slice, slice]: + """Run the Vibe whole-body segmentation and compute a crop region around the spine. + + Segments the input with ``run_vibeseg``, keeps only the spine-relevant labels (IVD, vertebra body, + vertebra posterior elements, and sacrum), and returns a bounding-box crop expanded by ``VIBE_CROP_MARGIN_MM``. + + Args: + nii (NII): Input MRI image to segment and crop. + out_file: Path where the Vibe segmentation output is written. + dataset_id (int, optional): Vibe model/dataset identifier passed to ``run_vibeseg``. Defaults to 100. + ddevice (Literal["cpu", "cuda", "mps"], optional): Compute device for inference. Defaults to "cuda". + gpu (int, optional): GPU index used when running on CUDA. Defaults to 0. + max_folds (int | None, optional): Maximum number of model folds to ensemble. Defaults to None (all folds). + logger (optional): Logger forwarded to ``run_vibeseg`` when that version supports it. Defaults to None. + + Returns: + tuple[slice, slice, slice]: The crop slices around the segmented spine, with a ``VIBE_CROP_MARGIN_MM`` margin. + """ from TPTBox.core.vert_constants import Full_Body_Instance_Vibe from TPTBox.segmentation import run_vibeseg @@ -34,7 +73,7 @@ def compute_crop(nii: NII, out_file, dataset_id=100, ddevice: Literal["cpu", "cu Full_Body_Instance_Vibe.sacrum, ] ) - return seg.compute_crop(0, dist=25) + return seg.compute_crop(0, dist=VIBE_CROP_MARGIN_MM / min(seg.zoom)) def preprocess_input( @@ -46,6 +85,25 @@ def preprocess_input( proc_crop_input: bool = True, verbose: bool = False, ) -> tuple[NII | None, ErrCode]: + """Pre-process an input MRI for segmentation: normalize, crop, N4-correct, and pad. + + Optionally rescales intensities to ``[NORMALIZE_MIN_VALUE, NORMALIZE_MAX_VALUE]``, crops away empty + background to speed up computation, applies N4 bias field correction on the crop, re-normalizes, + writes the processed crop back into the full image, and finally pads the volume by ``pad_size`` on every side. + + Args: + mri_nii (NII): Input grayscale MRI image. + debug_data (dict): Dictionary for collecting intermediate results (unused here, reserved for parity). + pad_size (int, optional): Number of voxels of edge padding added on each side per axis. Defaults to 4. + proc_normalize_input (bool, optional): Whether to rescale intensities into the normalization range. Defaults to True. + proc_do_n4_bias_correction (bool, optional): Whether to apply N4 bias field correction. Defaults to True. + proc_crop_input (bool, optional): Whether to crop away background before processing. Defaults to True. + verbose (bool, optional): Emit additional progress logging. Defaults to False. + + Returns: + tuple[NII | None, ErrCode]: The padded, pre-processed image and ``ErrCode.OK``; or ``(None, ErrCode.EMPTY)`` + if the input image is empty. + """ logger.print("Prepare input image", Log_Type.STAGE) mri_nii = mri_nii.copy() with logger: @@ -53,7 +111,7 @@ def preprocess_input( try: # Enforce to range [0, 1500] if proc_normalize_input: - mri_nii.normalize_to_range_(min_value=0, max_value=1500, verbose=False) + mri_nii.normalize_to_range_(min_value=NORMALIZE_MIN_VALUE, max_value=NORMALIZE_MAX_VALUE, verbose=False) crop = mri_nii.compute_crop(dist=0) if proc_crop_input else (slice(None, None), slice(None, None), slice(None, None)) else: crop = ( @@ -74,9 +132,9 @@ def preprocess_input( cropped_nii, _ = n4_bias(cropped_nii) # PIR logger.print(f"N4 Bias field correction done in {perf_counter() - n4_start} sec", verbose=True) - # Enforce to range [0, 1500] + # Enforce to range [NORMALIZE_MIN_VALUE, NORMALIZE_MAX_VALUE] if proc_normalize_input: - cropped_nii.normalize_to_range_(min_value=0, max_value=1500, verbose=logger) + cropped_nii.normalize_to_range_(min_value=NORMALIZE_MIN_VALUE, max_value=NORMALIZE_MAX_VALUE, verbose=logger) # Uncrop again # uncropped_input[crop] = cropped_nii.get_array() diff --git a/spineps/phase_semantic.py b/spineps/phase_semantic.py index c034869..1e3605b 100755 --- a/spineps/phase_semantic.py +++ b/spineps/phase_semantic.py @@ -1,3 +1,5 @@ +"""Semantic phase: predict and post-process the subregion (semantic) segmentation mask of the spine.""" + from __future__ import annotations # from utils.predictor import nnUNetPredictor @@ -8,6 +10,18 @@ from spineps.seg_model import Segmentation_Model from spineps.seg_pipeline import fill_holes_labels, logger from spineps.utils.proc_functions import clean_cc_artifacts +from spineps.utils.resolution import REFERENCE_VOXEL_VOLUME_MM3, REFERENCE_ZOOM, mm3_to_voxels, mm_to_voxels + +# Connected-component artifacts smaller than this physical volume are removed from the semantic mask. +SMALL_CC_SIZE_THRESHOLD_MM3 = 30 * REFERENCE_VOXEL_VOLUME_MM3 +# Vertical (inferior) extent in millimeters kept around the spinal canal. +CANAL_HEIGHT_MARGIN_MM = 64 +# Semantic label of S1, i.e. the sacrum. +SACRUM_LABEL = 26 +# More connected components than this in the semantic mask is unexpected and gets logged. +MAX_EXPECTED_SEMANTIC_CC = 3 +# Physical margin used when cropping around connected components in the bounding-box clean step. +CC_BBOX_MARGIN_MM = 4 * min(REFERENCE_ZOOM) def predict_semantic_mask( @@ -20,19 +34,27 @@ def predict_semantic_mask( proc_clean_small_cc_artifacts: bool = True, verbose: bool = False, ) -> tuple[NII | None, NII | None, ErrCode]: - """Predicts the semantic mask, takes care of rescaling, and back + """Predict the semantic (subregion) segmentation mask and run post-processing on it. + + Runs the model on the input MRI (resampling to the model's recommended zoom), then optionally removes + structures beyond the spinal-canal height, cleans small connected-component artifacts, restricts the mask + to the largest bounding box of connected components, and fills 3D holes. Args: - mri_nii (NII): input mri image (grayscal, must be in range 0 -> ?) - model (Segmentation_Model): Model to semantically segment with - do_n4 (bool, optional): Wherever to apply n4 bias field correction. Defaults to True. - fill_holes (bool, optional): Whether to fill holes in the output mask. Defaults to True. - clean_artifacts (bool, optional): Whether to try and clean possible artifacts. Defaults to True. - do_crop (bool, optional): Whether to apply cropping in order to speedup computation (min value in scan must be 0!). Defaults to True. - verbose (bool, optional): If you want some more infos on whats happening. Defaults to False. + mri_nii (NII): Input grayscale MRI image (intensities must start at 0). + model (Segmentation_Model): Model used to produce the semantic segmentation. + debug_data (dict): Dictionary for collecting intermediate results (e.g. the raw segmentation). + proc_fill_3d_holes (bool, optional): Whether to fill 3D holes in the output mask. Defaults to True. + proc_clean_beyond_largest_bounding_box (bool, optional): Whether to keep only connected components within + the largest bounding box. Defaults to True. + proc_remove_inferior_beyond_canal (bool, optional): Whether to remove non-sacrum structures below the + spinal-canal height. Defaults to False. + proc_clean_small_cc_artifacts (bool, optional): Whether to delete small connected-component artifacts. Defaults to True. + verbose (bool, optional): Emit additional progress logging. Defaults to False. Returns: - tuple[NII | None, NII | None, NII | None, np.ndarray, ErrCode]: seg_nii, seg_nii_modelres, unc_nii, softmax_logits, ErrCode + tuple[NII | None, NII | None, ErrCode]: The post-processed semantic mask, the softmax logits, and an + error code (``ErrCode.OK`` on success, ``ErrCode.EMPTY`` if the predicted mask is empty). """ logger.print("Predict Semantic Mask", Log_Type.STAGE) with logger: @@ -79,7 +101,7 @@ def predict_semantic_mask( ], only_delete=True, ignore_missing_labels=True, - cc_size_threshold=30, # [ + cc_size_threshold=mm3_to_voxels(SMALL_CC_SIZE_THRESHOLD_MM3, seg_nii.zoom), ), verbose=verbose, ) @@ -104,33 +126,65 @@ def predict_semantic_mask( return seg_nii, softmax_logits, ErrCode.OK -def remove_nonsacrum_beyond_canal_height(seg_nii: NII): +def remove_nonsacrum_beyond_canal_height(seg_nii: NII) -> NII: + """Remove non-sacrum labels that lie above or below the spinal-canal extent. + + Computes the inferior-axis (I) extent of the spinal canal/cord, expanded by ``CANAL_HEIGHT_MARGIN_MM``, + and zeroes out everything outside that range. The sacrum (``SACRUM_LABEL``) is kept regardless of position. + If no canal/cord is present, the mask is returned unchanged. + + Args: + seg_nii (NII): Semantic segmentation mask in ("P", "I", "R") orientation. + + Returns: + NII: The segmentation mask with structures beyond the canal height removed (sacrum preserved). + + Raises: + AssertionError: If ``seg_nii`` is not in ("P", "I", "R") orientation. + """ seg_nii.assert_affine(orientation=("P", "I", "R")) canal_nii = seg_nii.extract_label([Location.Spinal_Canal.value, Location.Spinal_Cord.value]) if canal_nii.sum() == 0: return seg_nii - crop_i = canal_nii.compute_crop(dist=64 / seg_nii.zoom[1])[1] + crop_i = canal_nii.compute_crop(dist=CANAL_HEIGHT_MARGIN_MM / seg_nii.zoom[1])[1] seg_arr = seg_nii.get_seg_array() - sacrum_arr = seg_nii.extract_label(26).get_seg_array() + sacrum_arr = seg_nii.extract_label(SACRUM_LABEL).get_seg_array() seg_arr[:, 0 : crop_i.start, :] = 0 seg_arr[:, crop_i.stop :, :] = 0 - seg_arr[sacrum_arr == 1] = 26 + seg_arr[sacrum_arr == 1] = SACRUM_LABEL return seg_nii.set_array_(seg_arr) -def semantic_bounding_box_clean(seg_nii: NII): +def semantic_bounding_box_clean(seg_nii: NII) -> NII: + """Keep only connected components that fall within the spine's growing bounding box. + + Binarizes the mask and labels its connected components. Starting from the largest component's bounding box + (expanded by ``CC_BBOX_MARGIN_MM``), it iteratively merges in any other component whose bounding box overlaps + the current region in all three axes (with extra inferior margin to tolerate gaps in the spine). Voxels + outside the resulting region, and any non-incorporated components, are removed. Components are dropped if the + binary mask has more than ``MAX_EXPECTED_SEMANTIC_CC`` parts (logged as strange). + + Args: + seg_nii (NII): Semantic segmentation mask to clean. + + Returns: + NII: The cleaned segmentation mask, restored to its original orientation. + """ ori = seg_nii.orientation seg_binary = seg_nii.reorient_().extract_label(list(seg_nii.unique())) # whole thing binary + # Resolution-aware bounding-box margin (mm -> voxels at the current zoom). + bbox_margin_dist = CC_BBOX_MARGIN_MM / min(seg_nii.zoom) + bbox_margin_vox = mm_to_voxels(CC_BBOX_MARGIN_MM, seg_nii.zoom) seg_bin_largest_k_cc_nii: NII = seg_binary.filter_connected_components( max_count_component=None, labels=1, connectivity=3, keep_label=False ) max_k = int(seg_bin_largest_k_cc_nii.max()) - if max_k > 3: + if max_k > MAX_EXPECTED_SEMANTIC_CC: logger.print(f"Found {max_k} unique connected components in semantic mask", Log_Type.STRANGE) # PIR largest_nii = seg_bin_largest_k_cc_nii.extract_label(1) # width fixed, and heigh include all connected components within bounding box, then repeat - p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=4) + p_slice, i_slice, r_slice = largest_nii.compute_crop(dist=bbox_margin_dist) bboxes = [(p_slice, i_slice, r_slice)] # PIR -> fixed, extendable, extendable @@ -140,12 +194,12 @@ def semantic_bounding_box_clean(seg_nii: NII): changed = False for k in [l for l in range(2, max_k + 1) if l not in incorporated]: k_nii = seg_bin_largest_k_cc_nii.extract_label(k) - p, i, r = k_nii.compute_crop(dist=4) + p, i, r = k_nii.compute_crop(dist=bbox_margin_dist) for bbox in bboxes: i_slice_compare = slice( - max(bbox[1].start - 4, 0), bbox[1].stop + 4 - ) # more margin in inferior direction (allows for gaps of size 15 in spine) + max(bbox[1].start - bbox_margin_vox, 0), bbox[1].stop + bbox_margin_vox + ) # more margin in inferior direction (allows for gaps in spine) if overlap_slice(bbox[0], p) and overlap_slice(i_slice_compare, i) and overlap_slice(bbox[2], r): # extend bbox bboxes.append((p, i, r)) @@ -175,12 +229,15 @@ def semantic_bounding_box_clean(seg_nii: NII): return seg_nii -def overlap_slice(slice1: slice, slice2: slice): - """checks if two ranges defined by slices overlapping (including border!) +def overlap_slice(slice1: slice, slice2: slice) -> bool: + """Check whether two ranges defined by slices overlap (borders inclusive). Args: - slice1 (slice): _description_ - slice2 (slice): _description_ + slice1 (slice): First range, using its ``start`` and ``stop`` bounds. + slice2 (slice): Second range, using its ``start`` and ``stop`` bounds. + + Returns: + bool: True if the two ranges overlap or touch at a border, else False. """ slice1s = slice1.start slice1e = slice1.stop diff --git a/spineps/seg_enums.py b/spineps/seg_enums.py index e07d38f..ad4c0da 100755 --- a/spineps/seg_enums.py +++ b/spineps/seg_enums.py @@ -1,3 +1,5 @@ +"""Enumerations describing image modalities, acquisitions, model types and pipeline phases used across SPINEPS.""" + from __future__ import annotations from enum import Enum, EnumMeta, auto @@ -6,16 +8,40 @@ class MetaEnum(EnumMeta): - def __contains__(cls, item): + """Enum metaclass enabling ``item in EnumClass`` membership tests by member name.""" + + def __contains__(cls, item: object) -> bool: + """Return whether ``item`` names a member of the enum. + + Args: + item: Candidate member name to test for membership. + + Returns: + bool: True if ``item`` is a valid member name of the enum, False otherwise. + """ try: cls[item] - except ValueError: + except (KeyError, ValueError): return False return True class Enum_Compare(Enum, metaclass=MetaEnum): + """Base enum that compares equal to other enums by name/value and to plain strings by name. + + Provides string-friendly equality, hashing and representation so members can be compared + against and interchanged with their string names throughout the pipeline. + """ + def __eq__(self, __value: object) -> bool: # noqa: PYI063 + """Compare this member against another enum or a string. + + Args: + __value (object): Another enum member or a string holding a member name. + + Returns: + bool: True if the other enum matches by name and value, or the string matches this member's name. + """ if isinstance(__value, Enum): return self.name == __value.name and self.value == __value.value elif isinstance(__value, str): @@ -24,26 +50,36 @@ def __eq__(self, __value: object) -> bool: # noqa: PYI063 return False def __str__(self) -> str: + """Return the member as ``ClassName.MEMBER``. + + Returns: + str: Human-readable identifier of the member. + """ return f"{type(self).__name__}.{self.name}" def __repr__(self) -> str: + """Return the same string as :meth:`__str__`. + + Returns: + str: Human-readable identifier of the member. + """ return str(self) def __hash__(self) -> int: + """Return the member's integer value as its hash. + + Returns: + int: The member's value, used for hashing. + """ return self.value class Modality(Enum_Compare): - """Describes image modality + """Image modality of an input scan. - Args: - Enum_Compare (_type_): _description_ - - Raises: - NotImplementedError: _description_ - - Returns: - _type_: _description_ + Members cover the MRI sequences and other image types SPINEPS can handle, e.g. T2-weighted (T2w), + T1-weighted (T1w), Vibe/Dixon, CT, an existing segmentation (SEG), multi-planar reconstruction (MPR), + proton density (PD) and FLAIR. """ T2w = auto() @@ -53,9 +89,21 @@ class Modality(Enum_Compare): SEG = auto() MPR = auto() PD = auto() + FLAIR = auto() @classmethod def format_keys(cls, modalities: Self | list[Self]) -> list[str]: + """Map modality members to the BIDS/file-name string keys that denote them. + + Args: + modalities (Self | list[Self]): A single modality member or a list of them. + + Returns: + list[str]: All file-name/format keys associated with the given modalities. + + Raises: + NotImplementedError: If a modality has no associated keys defined. + """ if not isinstance(modalities, list): modalities = [modalities] result = [] @@ -78,16 +126,10 @@ def format_keys(cls, modalities: Self | list[Self]) -> list[str]: class Acquisition(Enum_Compare): - """Describes Acquisition (sag = sagittal, cor = coronal, ax = axial) - - Args: - Enum_Compare (_type_): _description_ + """Acquisition plane of a scan. - Raises: - NotImplementedError: _description_ - - Returns: - _type_: _description_ + Members denote the imaging plane: ``sag`` (sagittal), ``cor`` (coronal), ``ax`` (axial) + and ``iso`` (isotropic / no dominant plane). """ sag = auto() @@ -97,6 +139,17 @@ class Acquisition(Enum_Compare): @classmethod def format_keys(cls, acquisition: Self) -> list[str]: + """Map an acquisition member to the file-name string keys that denote it. + + Args: + acquisition (Self): The acquisition plane member. + + Returns: + list[str]: All file-name/format keys associated with the given acquisition. + + Raises: + NotImplementedError: If the acquisition has no associated keys defined. + """ if acquisition == Acquisition.ax: return ["axial", "ax", "axl"] elif acquisition == Acquisition.cor: @@ -110,18 +163,28 @@ def format_keys(cls, acquisition: Self) -> list[str]: class SpinepsPhase(Enum_Compare): + """Stage of the SPINEPS pipeline: semantic segmentation, vertebra instance segmentation or labeling.""" + SEMANTIC = auto() INSTANCE = auto() LABELING = auto() class ModelType(Enum_Compare): + """Kind of model backing an inference config: an nnU-Net, a plain U-Net or a classifier.""" + nnunet = auto() unet = auto() classifier = auto() class InputType(Enum_Compare): + """Type of input channel fed to a model. + + ``img`` is the default image input and ``seg`` a segmentation input. The remaining members + are the Dixon/Vibe channels: in-phase (``ip``), out-of-phase (``oop``), ``water`` and ``fat``. + """ + img = auto() # default: image input seg = auto() # segmentation input # For Vibe @@ -132,6 +195,8 @@ class InputType(Enum_Compare): class OutputType(Enum_Compare): + """Type of model output: a segmentation (``seg``), softmax logits or an uncertainty map (``unc``).""" + seg = auto() # seg_modelres = auto() softmax_logits = auto() @@ -139,6 +204,13 @@ class OutputType(Enum_Compare): class ErrCode(Enum_Compare): + """Status/error codes returned by pipeline steps. + + Indicates success (``OK``), that outputs already exist (``ALL_DONE``), a model/input compatibility + problem (``COMPATIBILITY``), an unknown failure (``UNKNOWN``), an empty mask or input (``EMPTY``) + or mismatched shapes (``SHAPE``). + """ + OK = auto() ALL_DONE = auto() # outputs are already there COMPATIBILITY = auto() # compatibility issue between model and input diff --git a/spineps/seg_model.py b/spineps/seg_model.py index ef99022..6159bdc 100755 --- a/spineps/seg_model.py +++ b/spineps/seg_model.py @@ -1,3 +1,5 @@ +"""Segmentation model abstractions: the abstract Segmentation_Model and its nnU-Net and Unet3D subclasses.""" + from __future__ import annotations import os @@ -21,12 +23,28 @@ threads_started = False +# Two zoom vectors are considered identical if every axis differs by less than this (mm). +ZOOM_MATCH_TOLERANCE = 1e-4 +# Legacy single-channel Unet3D divided the input label ids by this value to scale them to ~[0, 1]. +LEGACY_LABEL_NORMALIZATION = 9 -class Segmentation_Model(ABC): - """Abstract Segmentation Model class - Args: - ABC (_type_): _description_ +class Segmentation_Model(ABC): + """Abstract base class wrapping a segmentation network together with its inference configuration. + + Subclasses implement load() and run() for a concrete backend (e.g. nnU-Net or Unet3D). The class handles input + preparation (reorientation, rescaling to the recommended zoom, padding), running the model and mapping the output back + into the input space. + + Attributes: + name (str): Optional human-readable model name. + logger (No_Logger): Logger used for all model output. + use_cpu (bool): If true, runs inference on CPU instead of GPU. + inference_config (Segmentation_Inference_Config): Configuration describing expected inputs, resolution range and labels. + default_verbose (bool): Default verbosity for printing. + default_allow_tqdm (bool): Whether a progress bar is shown during segmentation by default. + model_folder (str): Path to the model's folder on disk. + predictor: The loaded backend predictor, or None until load() is called. """ def __init__( @@ -37,13 +55,18 @@ def __init__( default_verbose: bool = False, default_allow_tqdm: bool = True, ): - """Initializes the segmentation model, finding and loading the corresponding inference config for that model + """Initializes the segmentation model, finding and loading the corresponding inference config for that model. Args: - model_folder (str | Path): Path to that model's folder - inference_config (Segmentation_Inference_Config | None, optional): Path to the inference config (if different from model folder). Defaults to None. - default_verbose (bool): If true, will spam a lot more when using. Defaults to True. - default_allow_tqdm (bool, optional): If true, will showcase a progress bar while segmenting. Defaults to True. + model_folder (str | Path): Path to that model's folder. + inference_config (Segmentation_Inference_Config | None, optional): Inference config to use; if None, loads + "inference_config.json" from the model folder. Defaults to None. + use_cpu (bool, optional): If true, runs inference on CPU instead of GPU. Defaults to False. + default_verbose (bool, optional): If true, prints more information when used. Defaults to False. + default_allow_tqdm (bool, optional): If true, shows a progress bar while segmenting. Defaults to True. + + Raises: + AssertionError: If model_folder does not exist. """ self.name: str = "" assert Path(model_folder).exists(), f"model_folder does not exist, got {model_folder}" @@ -68,21 +91,28 @@ def __init__( @abstractmethod def load(self, folds: tuple[str, ...] | None = None) -> Self: - """Loads the weights from disk + """Loads the model weights from disk. + + Args: + folds (tuple[str, ...] | None, optional): Which folds to load; if None, uses the folds from the inference config. + Defaults to None. Returns: - Self: Segmentation_Model, but with loaded weights + Self: This model with its predictor loaded. """ return self def calc_recommended_resampling_zoom(self, input_zoom: ZOOMS) -> ZOOMS: - """Calculates the resolution a corresponding input should be resampled to for this model + """Calculates the resolution a corresponding input should be resampled to for this model. + + If the inference config defines a (min, max) resolution range, each axis of the input zoom is clamped into that + range; otherwise the fixed configured resolution is returned. Args: - input_zoom (ZOOMS): _description_ + input_zoom (ZOOMS): Voxel spacing (mm) of the input image, per axis. Returns: - ZOOMS: _description_ + ZOOMS: Recommended voxel spacing (mm) to resample the input to before inference. """ if len(self.inference_config.resolution_range) != 2: return self.inference_config.resolution_range @@ -96,9 +126,18 @@ def calc_recommended_resampling_zoom(self, input_zoom: ZOOMS) -> ZOOMS: return output_zoom def same_modelzoom_as_model(self, model: Self, input_zoom: ZOOMS) -> bool: + """Checks whether another model would resample a given input to the same resolution as this model. + + Args: + model (Self): The other segmentation model to compare against. + input_zoom (ZOOMS): Voxel spacing (mm) of the input image, per axis. + + Returns: + bool: True if both models' recommended resampling zooms agree on every axis within ZOOM_MATCH_TOLERANCE. + """ self_zms = self.calc_recommended_resampling_zoom(input_zoom=input_zoom) model_zms = model.calc_recommended_resampling_zoom(input_zoom=self_zms) - match: bool = bool(np.all([self_zms[i] - model_zms[i] < 1e-4 for i in range(3)])) + match: bool = bool(np.all([abs(self_zms[i] - model_zms[i]) < ZOOM_MATCH_TOLERANCE for i in range(3)])) return match @citation_reminder @@ -111,17 +150,25 @@ def segment_scan( resample_output_to_input_space: bool = True, verbose: bool = False, ) -> dict[OutputType, NII | None]: - """Segments a given input with this model + """Segments a given input with this model. + + Prepares each expected input (optional padding, reorientation to the model orientation and rescaling to the + recommended zoom), runs the model and maps the outputs back into the input space. Args: - input (Image_Reference | dict[InputType, Image_Reference]): input - pad_size (int, optional): Padding in each dimension (times two more pixels in each dim). Defaults to 4. - step_size (float | None, optional): _description_. Defaults to None. - resample_to_recommended (bool, optional): _description_. Defaults to True. - verbose (bool, optional): _description_. Defaults to False. + input_image (Image_Reference | dict[InputType, Image_Reference]): A single image, or a mapping from InputType to + image for multi-input models. + pad_size (int, optional): Padding added in each dimension (this many extra voxels on each side per axis), removed + again from the output. Defaults to 0. + step_size (float | None, optional): Sliding-window tile step size; if None, uses the config default. Defaults to None. + resample_to_recommended (bool, optional): If true, rescales each input to the model's recommended zoom. Defaults to True. + resample_output_to_input_space (bool, optional): If true, resamples and pads the outputs back to the original input + space. Defaults to True. + verbose (bool, optional): If true, prints verbose information. Defaults to False. Returns: - dict[OutputType, NII]: _description_ + dict[OutputType, NII | None]: Mapping of output type to result NII (e.g. the segmentation mask, optionally softmax + logits). """ if self.predictor is None: self.load() @@ -183,7 +230,8 @@ def segment_scan( for k, v in result.items(): if isinstance(v, NII): # and k != OutputType.seg_modelres: if resample_output_to_input_space: - v.rescale_(zms_pir, verbose=self.logger).reorient_(orientation, verbose=self.logger) + v.resample_from_to_(inputdict[self.inference_config.expected_inputs[0]]) + # v.rescale_(zms_pir, verbose=self.logger).reorient_(orientation, verbose=self.logger) v.pad_to(orig_shape, inplace=True) if k == OutputType.seg: v.map_labels_(self.inference_config.segmentation_labels, verbose=self.logger) @@ -197,36 +245,58 @@ def segment_scan( return result def modalities(self) -> list[Modality]: - """Returns the modalities this model supports + """Returns the modalities this model supports. Returns: - list[Modality]: _description_ + list[Modality]: Modalities the model was trained for, as listed in its inference config. """ return self.inference_config.modalities def acquisition(self) -> Acquisition: - """Returns the acquisition this model supports + """Returns the acquisition this model supports. Returns: - Acquisition: _description_ + Acquisition: Acquisition plane/type the model expects, as listed in its inference config. """ return self.inference_config.acquisition @abstractmethod def run(self, input_nii: list[NII], verbose: bool = False) -> dict[OutputType, NII | None]: - pass + """Runs the backend predictor on the prepared inputs. + + Args: + input_nii (list[NII]): Inputs already reoriented and rescaled to the model's expectation, in the configured order. + verbose (bool, optional): If true, prints verbose information. Defaults to False. + + Returns: + dict[OutputType, NII | None]: Mapping of output type to result NII produced by the model. + """ - def print(self, *text, verbose: bool | None = None): + def print(self, *text: object, verbose: bool | None = None): + """Logs text via the model's logger. + + Args: + *text: Items to print. + verbose (bool | None, optional): Overrides the default verbosity; if None, uses default_verbose. Defaults to None. + """ if verbose is None: verbose = self.default_verbose self.logger.print(*text, verbose=verbose) def print_self(self): - """Prints own model id""" + """Prints the model id and its inference config.""" self.print(self.modelid(include_log_name=False), verbose=True) self.print("Config:", self.inference_config, verbose=True) - def modelid(self, include_log_name: bool = False): + def modelid(self, include_log_name: bool = False) -> str: + """Returns an identifier string for this model. + + Args: + include_log_name (bool, optional): If true and a name is set, appends the config log name. Defaults to False. + + Returns: + str: The model name, or the inference config's log name if no name is set. + """ name: str = str(self.name) if name != "": if include_log_name: @@ -234,7 +304,12 @@ def modelid(self, include_log_name: bool = False): return name return self.inference_config.log_name - def dict_representation(self): + def dict_representation(self) -> dict[str, str]: + """Builds a summary dictionary describing this model. + + Returns: + dict[str, str]: Model id, model path, modalities, acquisition and resolution range as strings. + """ info = { "name": self.modelid(), # self.inference_config.__repr__() "model_path": str(self.model_folder), @@ -247,14 +322,26 @@ def dict_representation(self): # info["resolution_processed"] = str(proc_zms) return info - def __str__(self): + def __str__(self) -> str: + """Returns the model id together with its inference config representation. + + Returns: + str: Human-readable description of the model. + """ return self.modelid(include_log_name=True) + "\nConfig: " + self.inference_config.__repr__() def __repr__(self) -> str: + """Returns the same representation as __str__. + + Returns: + str: Human-readable description of the model. + """ return str(self) class Segmentation_Model_NNunet(Segmentation_Model): + """Segmentation_Model backed by an nnU-Net predictor.""" + def __init__( self, model_folder: str | Path, @@ -263,19 +350,38 @@ def __init__( default_verbose: bool = False, default_allow_tqdm: bool = True, ): + """Initializes an nnU-Net-backed segmentation model. + + Args: + model_folder (str | Path): Path to the nnU-Net model folder. + inference_config (Segmentation_Inference_Config | None, optional): Inference config; if None, loads it from the + model folder. Defaults to None. + use_cpu (bool, optional): If true, runs inference on CPU instead of GPU. Defaults to False. + default_verbose (bool, optional): If true, prints more information when used. Defaults to False. + default_allow_tqdm (bool, optional): If true, shows a progress bar while segmenting. Defaults to True. + """ super().__init__(model_folder, inference_config, use_cpu, default_verbose, default_allow_tqdm) def load(self, folds: tuple[str, ...] | None = None) -> Self: + """Loads the nnU-Net predictor and its ensemble folds from the model folder. + + Args: + folds (tuple[str, ...] | None, optional): Folds to load; if None, uses the folds from the inference config. + Defaults to None. + + Returns: + Self: This model with its nnU-Net predictor loaded. + """ global threads_started # noqa: PLW0603 if not os.path.exists(self.model_folder): # noqa: PTH110 self.print(f"Model weights not found in {self.model_folder}", Log_Type.FAIL) conf_folds = self.inference_config.available_folds if isinstance(conf_folds, int): - conf_folds = tuple([str(i) for i in range(conf_folds)]) + conf_folds = tuple(str(i) for i in range(conf_folds)) elif isinstance(conf_folds, str): conf_folds = (conf_folds,) else: - conf_folds = tuple([str(i) for i in conf_folds]) + conf_folds = tuple(str(i) for i in conf_folds) self.predictor = load_inf_model( model_folder=self.model_folder, step_size=self.inference_config.default_step_size, @@ -297,6 +403,16 @@ def run( input_nii: list[NII], verbose: bool = False, ) -> dict[OutputType, NII | None]: + """Runs nnU-Net inference on the prepared inputs. + + Args: + input_nii (list[NII]): Inputs in the model's expected orientation and resolution, in the configured order. + verbose (bool, optional): If true, prints verbose information. Defaults to False. + + Returns: + dict[OutputType, NII | None]: The segmentation mask under OutputType.seg and the softmax logits under + OutputType.softmax_logits. + """ self.print("Segmenting...") seg_nii, softmax_logits = run_inference(input_nii, self.predictor) self.print("Segmentation done!") @@ -305,6 +421,12 @@ def run( class Segmentation_Model_Unet3D(Segmentation_Model): + """Segmentation_Model backed by a single-input 3D U-Net (PyTorch Lightning PLNet). + + Used as the instance (vertebra) model: it takes a segmentation mask as input and refines it into the vertebra instance + output. Supports both the current multi-channel network and a legacy single-channel network. + """ + def __init__( self, model_folder: str | Path, @@ -313,10 +435,34 @@ def __init__( default_verbose: bool = False, default_allow_tqdm: bool = True, ): + """Initializes a 3D U-Net-backed segmentation model. + + Args: + model_folder (str | Path): Path to the model folder containing the checkpoint. + inference_config (Segmentation_Inference_Config | None, optional): Inference config; if None, loads it from the + model folder. Defaults to None. + use_cpu (bool, optional): If true, runs inference on CPU instead of GPU. Defaults to False. + default_verbose (bool, optional): If true, prints more information when used. Defaults to False. + default_allow_tqdm (bool, optional): If true, shows a progress bar while segmenting. Defaults to True. + + Raises: + AssertionError: If the inference config expects more than one input. + """ super().__init__(model_folder, inference_config, use_cpu, default_verbose, default_allow_tqdm) assert len(self.inference_config.expected_inputs) == 1, "Unet3D cannot expect more than one input" def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 + """Loads the 3D U-Net checkpoint, trying the current then the legacy PLNet implementation. + + Args: + folds (tuple[str, ...] | None, optional): Unused; present for interface compatibility. Defaults to None. + + Returns: + Self: This model with its 3D U-Net predictor loaded and moved to the selected device. + + Raises: + AssertionError: If exactly one checkpoint file is not found in the model folder. + """ assert os.path.exists(self.model_folder) # noqa: PTH110 chktpath = search_path(self.model_folder, "**/*weights*.ckpt") @@ -334,6 +480,21 @@ def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 return self def run(self, input_nii: list[NII], verbose: bool = False) -> dict[OutputType, NII | None]: + """Runs the 3D U-Net on a single input segmentation mask. + + Converts the input mask to a network tensor (one-hot encoded for the multi-channel network, or intensity-normalized + for the legacy single-channel network), runs the forward pass and returns the per-voxel argmax class as a mask. + + Args: + input_nii (list[NII]): A single-element list containing the input segmentation mask. + verbose (bool, optional): If true, prints verbose information. Defaults to False. + + Returns: + dict[OutputType, NII | None]: The predicted segmentation mask under OutputType.seg. + + Raises: + AssertionError: If more than one input is provided. + """ assert len(input_nii) == 1, "Unet3D does not support more than one input" input_nii_ = input_nii[0] @@ -355,7 +516,7 @@ def run(self, input_nii: list[NII], verbose: bool = False) -> dict[OutputType, N else: # legacy version target = target.to(torch.float32) - target /= 9 + target /= LEGACY_LABEL_NORMALIZATION target = target.unsqueeze(0) target = target.unsqueeze(0) logits = self.predictor.forward(target.to(self.device)) diff --git a/spineps/seg_pipeline.py b/spineps/seg_pipeline.py index e4f2a6a..4a1f3b4 100755 --- a/spineps/seg_pipeline.py +++ b/spineps/seg_pipeline.py @@ -1,3 +1,5 @@ +"""Segmentation-pipeline helpers: shared logger, subregion label sets, centroid computation, and pipeline version reporting.""" + from __future__ import annotations # from utils.predictor import nnUNetPredictor @@ -13,6 +15,16 @@ logger = No_Logger(prefix="SPINEPS") +# IVD and endplate instances are stored as (vertebra label + offset). These offsets are the canonical +# home for the convention; other modules import them from here. +IVD_LABEL_OFFSET = 100 +ENDPLATE_LABEL_OFFSET = 200 +# Number of derived label ids reserved per type; the ranges below cover all IVD/endplate labels and are +# stripped before centroid computation. +_MAX_DERIVED_LABELS_PER_TYPE = 34 +IVD_LABEL_RANGE = range(IVD_LABEL_OFFSET, IVD_LABEL_OFFSET + _MAX_DERIVED_LABELS_PER_TYPE) +ENDPLATE_LABEL_RANGE = range(ENDPLATE_LABEL_OFFSET, ENDPLATE_LABEL_OFFSET + _MAX_DERIVED_LABELS_PER_TYPE) + fill_holes_labels = [ Location.Vertebra_Corpus_border.value, Location.Spinal_Canal.value, @@ -41,28 +53,30 @@ def predict_centroids_from_both( seg_nii: NII, models: list[Segmentation_Model | None], parameter: dict[str, Any], -): - """Calculates the centroids of each vertebra corpus by using both semantic and instance mask +) -> poi.POI: + """Calculate the centroids of each vertebra corpus using both the semantic and instance masks. + + Strips the IVD and endplate derived instance labels from the instance mask, computes the per-vertebra centroids from the + instance and semantic masks, adds an S1 corpus centroid when sacrum is present, and records pipeline metadata (model + descriptions, version, revision, timestamp, and the given parameters) on the result. Args: - vert_nii_cleaned (NII): _description_ - seg_nii (NII): _description_ - models (list[Segmentation_Model]): _description_ - input_zms_pir (ZOOMS | None, optional): _description_. Defaults to None. + vert_nii_cleaned (NII): Cleaned vertebra instance segmentation mask. + seg_nii (NII): Subregion semantic segmentation mask. + models (list[Segmentation_Model | None]): Models used in the pipeline, recorded in the centroid metadata. + parameter (dict[str, Any]): Pipeline parameters to record on the centroid metadata. Returns: - _type_: _description_ + POI: The computed point-of-interest / centroid object with pipeline metadata attached. """ vert_nii_4_centroids = vert_nii_cleaned.copy() - labelmap = dict.fromkeys(range(100, 134), 0) - for i in range(200, 234): - labelmap[i] = 0 + labelmap = dict.fromkeys([*IVD_LABEL_RANGE, *ENDPLATE_LABEL_RANGE], 0) vert_nii_4_centroids.map_labels_(labelmap, verbose=False) ctd = poi.calc_poi_from_subreg_vert(vert_nii_4_centroids, seg_nii, verbose=logger) if v_name2idx["S1"] in vert_nii_cleaned.unique(): - s1_nii = vert_nii_cleaned.extract_label(26, inplace=False) + s1_nii = vert_nii_cleaned.extract_label(v_name2idx["S1"], inplace=False) ctd[v_name2idx["S1"], 50] = center_of_mass(s1_nii.get_seg_array()) models_repr = {} @@ -81,7 +95,12 @@ def predict_centroids_from_both( return ctd -def pipeline_version(): +def pipeline_version() -> str: + """Return the pipeline version string derived from the git commit count on ``main``. + + Returns: + str: A version like ``"v1."``, or ``"Version not found"`` if git is unavailable. + """ try: label = subprocess.check_output(["git", "rev-list", "--count", "main"]).strip() label = str(label).replace("'", "") @@ -92,7 +111,12 @@ def pipeline_version(): return "v1." + str(label) -def pipeline_revision(): +def pipeline_revision() -> str: + """Return the current git revision string for the pipeline. + + Returns: + str: ``"::"``; either part is empty if the corresponding git call fails. + """ label = "" rev = "" try: diff --git a/spineps/seg_run.py b/spineps/seg_run.py index c41b435..01597f9 100755 --- a/spineps/seg_run.py +++ b/spineps/seg_run.py @@ -1,3 +1,5 @@ +"""Top-level SPINEPS pipeline orchestration for running spine segmentation over datasets and single niftys.""" + from __future__ import annotations import math @@ -70,42 +72,61 @@ def process_dataset( log_inference_time: bool = True, verbose: bool = False, ): - """Runs the SPINEPS framework over a whole BIDS-conform dataset + """Runs the SPINEPS framework over a whole BIDS-conform dataset. + + Iterates over every subject in the BIDS dataset, queries the matching scans for each requested modality pair and runs + process_img_nii on each, producing semantic (subregion), vertebra (instance) and centroid outputs plus a snapshot. Args: - dataset_path (Path): Path to the dataset - model_instance (Segmentation_Model): Model for the vertebra segmentation - model_semantic (list[Segmentation_Model] | Segmentation_Model | None, optional): Models for the subregion segmentation. If none, will attempt to find the correct one. Defaults to None. + dataset_path (Path): Path to the BIDS dataset. + model_instance (Segmentation_Model): Model for the vertebra (instance) segmentation. + model_semantic (list[Segmentation_Model] | Segmentation_Model | None, optional): Models for the subregion (semantic) + segmentation, one per modality pair. If None, attempts to find a matching model for each modality. Defaults to None. + model_labeling (VertLabelingClassifier | None, optional): Classifier used to label the vertebra instances. Defaults to None. rawdata_name (str, optional): Name of the rawdata folder. Defaults to "rawdata". - derivative_name (str, optional): Name of the derivatives folder. Defaults to "derivatives_seg". - modalities (list[Modality_Pair] | Modality_Pair, optional): List of modalities you want to segment in the dataset. Defaults to [(Modality.T2w, Acquisition.sag)]. - - save_debug_data (bool, optional): If true, saves debug data. Increases space usage! Defaults to False. - #save_uncertainty_image (bool, optional): If true, saves a uncertainty image for the semantic segmentation. Defaults to False. - save_modelres_mask (bool, optional): If true, will additionally save the semantic mask in the resolution of the model. Defaults to False. - save_softmax_logits (bool, optional): If true, additionally saves the softmax logits (averaged over folds) as an npz. Defaults to False. - save_log_data (bool, optional): If true, will save the log to a file. Defaults to True. - - override_subreg (bool, optional): If true, will redo existing subregion segmentations. Defaults to False. - override_vert (bool, optional): If true, will redo existing vertebra segmentations. Defaults to False. - override_ctd (bool, optional): If true, will redo existing cetnroid files. Defaults to False. - - snapshot_copy_folder (Path | None | bool, optional): If given a path, will copy all created snapshots in here. Defaults to None. - do_crop_semantic (bool, optional): _description_. Defaults to True. - - proc_n4correction (bool, optional): _description_. Defaults to True. - proc_fillholes (bool, optional): If true, will use fill holes in postprocessing step. Defaults to True. - proc_clean (bool, optional): If true, will use CC cleaning in postprocessing step. Defaults to True. - proc_corpus_clean (bool, optional): _description_. Defaults to True. - proc_cleanvert (bool, optional): If true, will use CC cleaning in vertebra postprocessing. Defaults to True. - proc_assign_missing_cc (bool, optional): _description_. Defaults to True. - proc_largest_cc (int, optional): _description_. Defaults to 0. - - ignore_model_compatibility (bool, optional): If true, will ignore initialization compatibility issues. Defaults to False. - ignore_inference_compatibility (bool, optional): If true, will ignore compatibility issues between models and individual inputs. Defaults to False. - ignore_bids_filter (bool, optional): _description_. Defaults to False. - log_inference_time (bool, optional): If true, will log the inference time for each subject. Defaults to True. - verbose (bool, optional): If true, will spam your terminal with info. Defaults to False. + derivative_name (str, optional): Name of the derivatives output folder. Defaults to "derivatives_seg". + modalities (list[Modality_Pair] | Modality_Pair, optional): Modality/acquisition pairs to segment in the dataset. + Defaults to [(Modality.T2w, Acquisition.sag)]. + save_debug_data (bool, optional): If true, saves intermediate debug data. Increases space usage. Defaults to False. + save_modelres_mask (bool, optional): If true, additionally saves the semantic mask in the resolution of the model. + Defaults to False. + save_softmax_logits (bool, optional): If true, additionally saves the softmax logits (averaged over folds) as an npz. + Defaults to False. + save_log_data (bool, optional): If true, writes the log to a file in the dataset folder. Defaults to True. + override_semantic (bool, optional): If true, redoes existing semantic segmentations. Defaults to False. + override_instance (bool, optional): If true, redoes existing instance segmentations. Defaults to False. + override_postpair (bool, optional): If true, redoes the combined post-processing step. Defaults to False. + override_ctd (bool, optional): If true, redoes existing centroid files. Defaults to False. + snapshot_copy_folder (Path | None | bool, optional): If a path, copies all created snapshots there; if True, uses a + "snaps_seg" subfolder of the dataset; if None/False, no copy is made. Defaults to None. + pad_size (int, optional): Padding added in each dimension before inference. Defaults to 4. + proc_sem_crop_input (bool, optional): If true, crops the input to the foreground before semantic segmentation. Defaults to True. + proc_sem_n4_bias_correction (bool, optional): If true, applies N4 bias field correction before semantic segmentation + (MRI only). Defaults to True. + proc_sem_remove_inferior_beyond_canal (bool, optional): If true, removes semantic structures inferior to and beyond the + spinal canal. Defaults to False. + proc_sem_clean_beyond_largest_bounding_box (bool, optional): If true, removes semantic voxels outside the largest + bounding box. Defaults to True. + proc_sem_clean_small_cc_artifacts (bool, optional): If true, removes small connected-component artifacts from the + semantic mask. Defaults to True. + proc_inst_corpus_clean (bool, optional): If true, cleans the vertebra corpus during instance processing. Defaults to True. + proc_inst_clean_small_cc_artifacts (bool, optional): If true, removes small connected-component artifacts from the + instance mask. Defaults to True. + proc_inst_largest_k_cc (int, optional): If greater than 0, keeps only the largest k connected components of the instance + mask. Defaults to 0. + proc_inst_detect_and_solve_merged_corpi (bool, optional): If true, detects and splits merged vertebra corpi. Defaults to True. + proc_lab_force_no_tl_anomaly (bool, optional): If true, forces the labeling to assume no thoracolumbar transition anomaly. + Defaults to False. + proc_fill_3d_holes (bool, optional): If true, fills 3D holes during post-processing. Defaults to True. + proc_assign_missing_cc (bool, optional): If true, assigns unlabeled connected components to the nearest instance. Defaults to True. + proc_clean_inst_by_sem (bool, optional): If true, cleans the instance mask using the semantic mask. Defaults to True. + proc_vertebra_inconsistency (bool, optional): If true, detects and resolves vertebra labeling inconsistencies. Defaults to True. + ignore_model_compatibility (bool, optional): If true, ignores model/modality initialization compatibility issues. Defaults to False. + ignore_inference_compatibility (bool, optional): If true, ignores compatibility issues between models and individual inputs. + Defaults to False. + ignore_bids_filter (bool, optional): If true, disables the BIDS query filters and processes all niftys found. Defaults to False. + log_inference_time (bool, optional): If true, logs the inference time of each step. Defaults to True. + verbose (bool, optional): If true, prints verbose information. Defaults to False. """ global logger # noqa: PLW0603 logger.print(f"Initialize setup for dataset in {dataset_path}", Log_Type.BOLD) @@ -265,10 +286,10 @@ def process_img_nii( # noqa: C901 proc_normalize_input: bool = True, # Processings # Pre-processing crop - crop=None, + crop: tuple[slice, slice, slice] | None = None, auto_crop_to_spine: bool | Literal["auto"] = "auto", - auto_crop_when_max_res_leq=1.2, - auto_crop_req_crop_min_dim=200, + auto_crop_when_max_res_leq: float = 1.2, + auto_crop_req_crop_min_dim: int = 200, # Semantic proc_sem_crop_input: bool = True, proc_sem_n4_bias_correction: bool = True, @@ -299,51 +320,74 @@ def process_img_nii( # noqa: C901 timing=False, verbose: bool = False, ) -> tuple[dict[str, Path], ErrCode]: - """Runs the SPINEPS framework over one nifty - - Args: - img_ref (BIDS_FILE): input BIDS_FILE - model_instance (Segmentation_Model): Model for the vertebra segmentation - model_semantic (list[Segmentation_Model] | Segmentation_Model | None, optional): Models for the subregion segmentation. If none, will attempt to find the correct one. Defaults to None. - rawdata_name (str, optional): Name of the rawdata folder. Defaults to "rawdata". - derivative_name (str, optional): Name of the derivatives folder. Defaults to "derivatives_seg". - modalities (list[Modality_Pair] | Modality_Pair, optional): List of modalities you want to segment in the dataset. Defaults to [(Modality.T2w, Acquisition.sag)]. - - save_debug_data (bool, optional): If true, saves debug data. Increases space usage! Defaults to False. - #save_uncertainty_image (bool, optional): If true, saves a uncertainty image for the semantic segmentation. Defaults to False. - save_modelres_mask (bool, optional): If true, will additionally save the semantic mask in the resolution of the model. Defaults to False. - save_softmax_logits (bool, optional): If true, additionally saves the softmax logits (averaged over folds) as an npz. Defaults to False. - save_log_data (bool, optional): If true, will save the log to a file. Defaults to True. + """Runs the SPINEPS framework over one nifty. - override_semantic (bool, optional): If true, will redo existing semantic segmentations. Defaults to False. - override_instance (bool, optional): If true, will redo existing instance segmentations. Defaults to False. - override_ctd (bool, optional): If true, will redo existing cetnroid files. Defaults to False. - - snapshot_copy_folder (Path | None | bool, optional): If given a path, will copy all created snapshots in here. Defaults to None. - do_crop_semantic (bool, optional): _description_. Defaults to True. + Runs the full pipeline on a single input image: semantic (subregion) segmentation, vertebra (instance) segmentation, + combined post-processing/labeling, centroid computation and a snapshot. Existing outputs are reused unless overridden. + Args: + img_ref (BIDS_FILE): Input BIDS_FILE referencing the image to segment. + model_semantic (Segmentation_Model): Model for the subregion (semantic) segmentation. + model_instance (Segmentation_Model): Model for the vertebra (instance) segmentation. + model_labeling (VertLabelingClassifier | None, optional): Classifier used to label the vertebra instances. Defaults to None. + derivative_name (str, optional): Name of the derivatives output folder. Defaults to "derivatives_seg". + save_modelres_mask (bool, optional): If true, additionally saves the semantic mask in the resolution of the model. + Defaults to False. + save_softmax_logits (bool, optional): If true, additionally saves the softmax logits (averaged over folds) as an npz. + Defaults to False. + save_debug_data (bool, optional): If true, saves intermediate debug data. Increases space usage. Defaults to False. + save_raw (bool, optional): If true, saves the raw (pre-cleanup) semantic and vertebra masks. Defaults to True. + override_semantic (bool, optional): If true, redoes an existing semantic segmentation. Defaults to False. + override_instance (bool, optional): If true, redoes an existing instance segmentation. Defaults to False. + override_postpair (bool, optional): If true, redoes the combined post-processing step. Defaults to False. + override_ctd (bool, optional): If true, redoes an existing centroid file. Defaults to False. + proc_pad_size (int, optional): Padding added in each dimension before inference. Defaults to 4. + proc_normalize_input (bool, optional): If true, normalizes the input intensities (disabled automatically for CT). Defaults to True. crop: If provided, segment only within the specified crop. - auto_crop_to_spine (bool | "auto"): Speeds up high-resolution models by first predicting the spine with VIBESeg (https://link.springer.com/article/10.1007/s00330-025-12035-9) and cropping to the spine region (works for any MR or CT image). - auto_crop_when_max_res_leq: Enables automatic spine cropping when auto_crop_to_spine="auto" and the largest spacing value of the semantic model is less than or equal to this threshold. + auto_crop_to_spine (bool | "auto"): Speeds up high-resolution models by first predicting the spine with VIBESeg + (https://link.springer.com/article/10.1007/s00330-025-12035-9) and cropping to the spine region (works for any MR or + CT image). + auto_crop_when_max_res_leq: Enables automatic spine cropping when auto_crop_to_spine="auto" and the largest spacing value + of the semantic model is less than or equal to this threshold. auto_crop_req_crop_min_dim: When auto_crop_to_spine="auto", compute the crop only if the image size exceeds this value cubed. - - proc_n4correction (bool, optional): _description_. Defaults to True. - proc_fillholes (bool, optional): If true, will use fill holes in postprocessing step. Defaults to True. - proc_clean (bool, optional): If true, will use CC cleaning in postprocessing step. Defaults to True. - proc_corpus_clean (bool, optional): _description_. Defaults to True. - proc_cleanvert (bool, optional): If true, will use CC cleaning in vertebra postprocessing. Defaults to True. - proc_assign_missing_cc (bool, optional): _description_. Defaults to True. - proc_largest_cc (int, optional): _description_. Defaults to 0. - - ignore_model_compatibility (bool, optional): If true, will ignore initialization compatibility issues. Defaults to False. - ignore_inference_compatibility (bool, optional): If true, will ignore compatibility issues between models and individual inputs. Defaults to False. - ignore_bids_filter (bool, optional): _description_. Defaults to False. - log_inference_time (bool, optional): If true, will log the inference time for each subject. Defaults to True. - timing: log the timing for each step - verbose (bool, optional): If true, will spam your terminal with info. Defaults to False. + proc_sem_crop_input (bool, optional): If true, crops the input to the foreground before semantic segmentation. Defaults to True. + proc_sem_n4_bias_correction (bool, optional): If true, applies N4 bias field correction before semantic segmentation + (MRI only). Defaults to True. + proc_sem_remove_inferior_beyond_canal (bool, optional): If true, removes semantic structures inferior to and beyond the + spinal canal. Defaults to False. + proc_sem_clean_beyond_largest_bounding_box (bool, optional): If true, removes semantic voxels outside the largest + bounding box. Defaults to True. + proc_sem_clean_small_cc_artifacts (bool, optional): If true, removes small connected-component artifacts from the + semantic mask. Defaults to True. + proc_inst_corpus_clean (bool, optional): If true, cleans the vertebra corpus during instance processing. Defaults to True. + proc_inst_clean_small_cc_artifacts (bool, optional): If true, removes small connected-component artifacts from the + instance mask. Defaults to True. + proc_inst_largest_k_cc (int, optional): If greater than 0, keeps only the largest k connected components of the instance + mask. Defaults to 0. + proc_inst_detect_and_solve_merged_corpi (bool, optional): If true, detects and splits merged vertebra corpi. Defaults to True. + vertebra_instance_labeling_offset (int, optional): Offset applied when mapping instance ids to vertebra labels (set to 1 + for CT models that include C1). Defaults to 2. + proc_lab_force_no_tl_anomaly (bool, optional): If true, forces the labeling to assume no thoracolumbar transition anomaly. + Defaults to False. + proc_fill_3d_holes (bool, optional): If true, fills 3D holes during post-processing. Defaults to True. + proc_assign_missing_cc (bool, optional): If true, assigns unlabeled connected components to the nearest instance. Defaults to True. + proc_assign_missing_cc_fast (bool, optional): If true, uses the faster variant of the missing-cc assignment. Defaults to False. + proc_clean_inst_by_sem (bool, optional): If true, cleans the instance mask using the semantic mask. Defaults to True. + proc_vertebra_inconsistency (bool, optional): If true, detects and resolves vertebra labeling inconsistencies. Defaults to True. + lambda_semantic (Callable[[NII], NII] | None, optional): Optional function applied to the semantic mask before saving. + Defaults to None. + snapshot_copy_folder (Path | None, optional): If given, copies the created snapshot there. Defaults to None. + ignore_bids_filter (bool, optional): If true, builds output paths in non-strict mode. Defaults to False. + ignore_compatibility_issues (bool, optional): If true, continues despite input/model incompatibilities. Defaults to False. + log_inference_time (bool, optional): If true, logs the inference time of each step. Defaults to True. + return_output_instead_of_save (bool, optional): If true, returns the result NIIs/centroids instead of saving them. + Defaults to False. + timing (bool, optional): If true, logs the timing of each pipeline step. Defaults to False. + verbose (bool, optional): If true, prints verbose information. Defaults to False. Returns: - ErrCode: Error code depicting whether the operation was successful or not + tuple[dict[str, Path], ErrCode]: Mapping of output names to their file paths and an error code indicating success. + If return_output_instead_of_save is True, instead returns (seg_nii, vert_nii, centroids, ErrCode). """ arguments = locals() input_format = img_ref.format @@ -530,7 +574,7 @@ def process_img_nii( # noqa: C901 if not out_spine.exists() or not out_vert.exists() or done_something or override_postpair: # back to input space # - seg_nii_modelres[seg_nii_modelres == 50] = 49 + seg_nii_modelres[seg_nii_modelres == Location.Vertebra_Corpus.value] = Location.Vertebra_Corpus_border.value if not save_modelres_mask: seg_nii_back = seg_nii_modelres.resample_from_to(input_nii_) whole_vert_nii = whole_vert_nii.resample_from_to(input_nii_) @@ -646,7 +690,23 @@ def output_paths_from_input( snapshot_copy_folder: Path | str | None, input_format: str, non_strict_mode: bool = False, -): +) -> dict[str, Path]: + """Derives all pipeline output file paths for a given input image. + + Builds the BIDS-conform output paths (semantic/vertebra masks, raw masks, centroids, snapshots, logits, debug and + VIBESeg crop) used throughout the pipeline, keyed by a descriptive name. + + Args: + img_ref (BIDS_FILE): Input BIDS_FILE the outputs are derived from. + derivative_name (str): Name of the derivatives output folder. + snapshot_copy_folder (Path | str | None): If given, location to which the snapshot is additionally copied + (used to build out_snap2). + input_format (str): Format string of the input, used to name the debug and raw output subfolders. + non_strict_mode (bool, optional): If true, builds the paths in non-strict BIDS mode. Defaults to False. + + Returns: + dict[str, Path]: Mapping of output names (e.g. "out_spine", "out_vert", "out_ctd", "out_snap") to their file paths. + """ out_spine = img_ref.get_changed_path( bids_format="msk", parent=derivative_name, diff --git a/spineps/seg_utils.py b/spineps/seg_utils.py index 6dda7de..122e313 100755 --- a/spineps/seg_utils.py +++ b/spineps/seg_utils.py @@ -1,3 +1,5 @@ +"""Utilities for matching segmentation models to inputs by modality, acquisition, and resolution compatibility.""" + from __future__ import annotations # from utils.predictor import nnUNetPredictor @@ -17,6 +19,20 @@ def find_best_matching_model( modality_pair: Modality_Pair, expected_resolution: ZOOMS | None, # actual resolution here? ) -> Segmentation_Model: + """Select the segmentation model best matching a modality/acquisition pair and resolution. + + Not yet implemented: intended to iterate over model configs and pick the one best matching the requested resolution. + + Args: + modality_pair (Modality_Pair): The desired ``(modality(ies), acquisition)`` pair. + expected_resolution (ZOOMS | None): The desired voxel resolution, or None. + + Returns: + Segmentation_Model: The best-matching model (once implemented). + + Raises: + NotImplementedError: Always, as this function is not yet implemented; also for an unmapped modality pair. + """ raise NotImplementedError("find_best_matching_model()") logger.print(expected_resolution) # TODO replace with automatic going through model configs to find best matching the resolution @@ -39,16 +55,19 @@ def check_model_modality_acquisition( model: Segmentation_Model, mod_pair: Modality_Pair, verbose: bool = True, -): - """Checks if a model is compatible with a specified Modality_Pair +) -> bool: + """Check whether a model supports a given modality/acquisition pair. + + Compares the model's supported modalities and acquisition against the requested pair and logs a warning describing any + mismatch when ``verbose`` is True. Args: - model (Segmentation_Model): _description_ - mod_pair (Modality_Pair): _description_ - verbose (bool, optional): _description_. Defaults to True. + model (Segmentation_Model): The model to check. + mod_pair (Modality_Pair): The required ``(modality(ies), acquisition)`` pair. + verbose (bool): If True, log a warning when incompatible. Returns: - _type_: _description_ + bool: True if the model supports all required modalities and the acquisition, otherwise False. """ compatible = True @@ -78,10 +97,19 @@ def check_model_modality_acquisition( ignored_text = " (IGNORED)." -len_ignored_text = len(ignored_text) -def add_ignore_text(logger_texts: list[str]): +def add_ignore_text(logger_texts: list[str]) -> None: + """Mark the last accumulated log message as ignored. + + Drops the trailing character of the last message (its period) and appends an "(IGNORED)." suffix in place. + + Args: + logger_texts (list[str]): Accumulated log messages; the last entry is modified in place. + + Returns: + None: ``logger_texts`` is modified in place. + """ logger_texts[-1] = logger_texts[-1][:-1] logger_texts[-1] += ignored_text @@ -94,17 +122,22 @@ def check_input_model_compatibility( ignore_labelkey: bool = False, verbose: bool = True, ) -> bool: - """Checks if a model is compatible with a specified input + """Check whether an input image file is compatible with a model's expected modality, acquisition, and naming. + + Validates the input's format/modality, acquisition plane, and BIDS keys against what the model expects. Individual mismatches + can be tolerated via the ``ignore_*`` flags (annotated as "(IGNORED)" in the log). Debug files are always rejected, and the + image plane must be isotropic or one of the model's allowed acquisitions. Warnings are logged when ``verbose`` is True. Args: - img_ref (BIDS_FILE): _description_ - model (Segmentation_Model): _description_ - ignore_modality (bool, optional): _description_. Defaults to False. - ignore_acquisition (bool, optional): _description_. Defaults to False. - verbose (bool, optional): _description_. Defaults to True. + img_ref (BIDS_FILE): Reference to the input image file. + model (Segmentation_Model): The model to check against. + ignore_modality (bool): If True, tolerate a modality/format mismatch. + ignore_acquisition (bool): If True, tolerate an acquisition mismatch. + ignore_labelkey (bool): If True, tolerate an unexpected ``label`` key in the filename. + verbose (bool): If True, log warnings describing incompatibilities. Returns: - bool: _description_ + bool: True if the input is compatible with the model (after applying the ignore flags), otherwise False. """ model_modalities = model.modalities() model_acquisition = model.acquisition() diff --git a/spineps/utils/__init__.py b/spineps/utils/__init__.py index 905aa3b..ffe13f4 100755 --- a/spineps/utils/__init__.py +++ b/spineps/utils/__init__.py @@ -1 +1,3 @@ +"""Utility subpackage with helpers for file paths, model configs, downloads and segmentation post-processing.""" + from spineps.utils.filepaths import filepath_model diff --git a/spineps/utils/auto_download.py b/spineps/utils/auto_download.py index 4542d59..835f1f1 100644 --- a/spineps/utils/auto_download.py +++ b/spineps/utils/auto_download.py @@ -1,3 +1,5 @@ +"""Automatic download and extraction of pretrained SPINEPS model weights from the GitHub releases.""" + from __future__ import annotations import shutil @@ -26,7 +28,10 @@ SpinepsPhase.SEMANTIC.name + "_ct": current_highest_ct_version, } -instances: dict[str, Union[Path, str]] = {"instance": link + current_instance_highest_version + "/instance.zip"} +instances: dict[str, Union[Path, str]] = { + "instance": link + current_instance_highest_version + "/instance.zip", + "ct_instance": link + current_highest_ct_version + "/CT_instance.zip", +} semantic: dict[str, Union[Path, str]] = { "t2w": link + current_highest_version + "/t2w.zip", "t1w": link + current_highest_version + "/t1w.zip", @@ -40,6 +45,7 @@ download_names = { "instance": "instance_sagittal", + "ct_instance": "CT_instance", "t2w": "T2w_semantic", "t1w": "T1w_semantic", "vibe": "Vibe_semantic", @@ -49,8 +55,20 @@ } -def download_if_missing(key, url, phase: SpinepsPhase): +def download_if_missing(key: str, url: Union[Path, str], phase: SpinepsPhase) -> Path: + """Return the local model folder for a model, downloading and extracting its weights if absent. + + The target folder name combines the model's download name with the version resolved for its phase (and the + phase/key-specific override when one exists, e.g. CT models). + + Args: + key: Model key identifying the model within its phase (e.g. ``"t2w"``, ``"instance"``). + url: Release URL of the model's weights zip archive. + phase (SpinepsPhase): Pipeline phase the model belongs to. + Returns: + Path: Path to the local model folder containing the (possibly just downloaded) weights. + """ version = phase_to_version.get(f"{phase}_{key}", phase_to_version[phase.name]) out_path = Path(get_mri_segmentor_models_dir(), download_names[key] + "_" + version) if not out_path.exists(): @@ -59,7 +77,20 @@ def download_if_missing(key, url, phase: SpinepsPhase): return out_path -def download_weights(weights_url, out_path) -> None: +def download_weights(weights_url: Union[Path, str], out_path: Union[Path, str]) -> None: + """Download a weights zip archive, extract it into ``out_path`` and remove the archive. + + Shows a progress bar during download. If the extracted archive nests its contents in an extra subfolder + (no ``inference_config.json`` at the top level), the inner contents are moved up one level. Returns early + without raising if the initial size request fails. + + Args: + weights_url: URL of the weights zip archive to download. + out_path: Destination folder for the extracted weights (the archive is downloaded next to it as ``.zip``). + + Raises: + AssertionError: If the nested archive layout is detected but the extra entry is not a directory. + """ out_path = Path(out_path) logger = Print_Logger() try: diff --git a/spineps/utils/citation_reminder.py b/spineps/utils/citation_reminder.py index 2ad44ae..7160443 100644 --- a/spineps/utils/citation_reminder.py +++ b/spineps/utils/citation_reminder.py @@ -1,3 +1,5 @@ +"""Citation reminder utilities that prompt users to cite SPINEPS when the package is used.""" + from __future__ import annotations import atexit @@ -20,13 +22,13 @@ def wrapper(*args, **kwargs): if not has_reminded_citation and os.environ.get("SPINEPS_TURN_OF_CITATION_REMINDER", "FALSE") != "TRUE": print_citation_reminder() has_reminded_citation = True - func_result = func(*args, **kwargs) - return func_result + return func(*args, **kwargs) return wrapper def print_citation_reminder(): + """Print a formatted reminder with the SPINEPS GitHub and ArXiv links asking users to cite the work.""" console = Console() console.rule("Thank you for using [bold]SPINEPS[/bold]") console.print( diff --git a/spineps/utils/compat.py b/spineps/utils/compat.py index c1e74d5..489353e 100644 --- a/spineps/utils/compat.py +++ b/spineps/utils/compat.py @@ -1,7 +1,9 @@ from __future__ import annotations +from collections.abc import Iterable -def zip_strict(*iterables): + +def zip_strict(*iterables: Iterable) -> zip: """ A strict version of zip that raises a ValueError if the input iterables have different lengths. diff --git a/spineps/utils/filepaths.py b/spineps/utils/filepaths.py index 41b9a24..89380f7 100755 --- a/spineps/utils/filepaths.py +++ b/spineps/utils/filepaths.py @@ -1,3 +1,5 @@ +"""File-path helpers for locating the SPINEPS model weights directory and individual model folders.""" + from __future__ import annotations import os diff --git a/spineps/utils/find_min_cost_path.py b/spineps/utils/find_min_cost_path.py index d56666f..93c762f 100644 --- a/spineps/utils/find_min_cost_path.py +++ b/spineps/utils/find_min_cost_path.py @@ -1,3 +1,5 @@ +"""Min-cost path solver that assigns the most probable vertebra label sequence from a per-vertebra cost matrix.""" + from __future__ import annotations import sys @@ -7,8 +9,27 @@ import numpy as np from TPTBox import Log_Type, No_Logger +# Default softmax temperature used when cost smoothing is enabled. +DEFAULT_SOFTMAX_TEMP = 0.2 +# Class indices of anatomically special vertebrae within the labeling cost matrix. +T11_CLASS_IDX = 17 # a single skip is permitted at this class +T12_CLASS_IDX = 18 # transitional vertebra; may appear twice in a path +L5_CLASS_IDX = 23 # transitional vertebra; may appear twice in a path +# Region start indices (cervical, thoracic, lumbar) along the class axis. +DEFAULT_REGION_STARTS = (0, 7, 19) +# A class flagged as "multiple-allowed" may appear at most this many times in a path. +MAX_REPEATS_PER_CLASS = 2 + + +def argmin(lst: list) -> tuple[int, ...]: + """Return the index and value of the smallest element in a list. -def argmin(lst): + Args: + lst: A non-empty sequence supporting ``min`` and ``index``. + + Returns: + tuple: ``(index, value)`` of the minimum element. + """ m = min(lst) return lst.index(m), m @@ -18,14 +39,31 @@ def softmax_T(x, temp): return np.exp(np.divide(x, temp)) / np.sum(np.exp(np.divide(x, temp)), axis=0) -def c_to_region_idx(c: int, regions: list[int]): +def c_to_region_idx(c: int, regions: list[int]) -> int: + """Map a class index to the index of the spinal region it falls into. + + Args: + c (int): Class (label) index along the cost matrix's class axis. + regions (list[int]): Sorted region start indices (e.g. cervical, thoracic, lumbar). + + Returns: + int: Index of the region containing class ``c``. + """ for idx, r in enumerate(regions): if c < r: return idx - 1 return len(regions) - 1 -def internal_to_real_path(p): +def internal_to_real_path(p: list) -> list: + """Convert an internal ``(row, class)`` path into the ordered list of class indices. + + Args: + p: Iterable of ``(row, class)`` tuples representing path nodes. + + Returns: + list: Class indices ordered by ascending row (vertebra) index. + """ pat = sorted(p, key=lambda x: x[0]) pat = [x[1] for x in pat] return pat @@ -43,7 +81,7 @@ def find_most_probably_sequence( # noqa: C901 invert_cost: bool = True, # softmax_cost: bool = False, - softmax_temp: float = 0.2, + softmax_temp: float = DEFAULT_SOFTMAX_TEMP, # allow_multiple_at_class: list[int] | None = None, # T12 and L5 punish_multiple_sequence: float = 0.0, @@ -56,17 +94,56 @@ def find_most_probably_sequence( # noqa: C901 # verbose: bool = False, ) -> tuple[float, list[int], list]: + """Find the most probable vertebra-label sequence as a min-cost monotone path through a cost matrix. + + Each matrix row corresponds to a detected vertebra (top to bottom) and each column to a candidate label + class. The path moves one row down per step, normally advancing one class (diagonal). Special constraints + model spinal anatomy: certain transitional classes (e.g. T12, L5) may repeat, certain classes/regions allow + a single skip, and optional region- and T13-related transition costs adjust the path. Extra moves incur the + configured penalties; classes flagged as repeatable may appear at most ``MAX_REPEATS_PER_CLASS`` times. + + Args: + cost (np.ndarray | list[int]): 2D cost matrix of shape ``(n_vertebrae, n_classes)``. + min_start_class (int, optional): Smallest class index the path may start at. Defaults to 0. + region_rel_cost (np.ndarray | list[int] | None, optional): Per-vertebra costs for being the first/last + vertebra of each region; enables region-transition costs when given. Defaults to None. + vertt13_cost (np.ndarray | list[int] | None, optional): Per-vertebra cost contribution for the T13/T12 + (class 18) repeat case. Defaults to None. + regions (list[int] | None, optional): Region start indices along the class axis. Defaults to + ``DEFAULT_REGION_STARTS``. + invert_cost (bool, optional): Negate the cost so that high input scores are preferred. Defaults to True. + softmax_cost (bool, optional): Apply a softmax over the cost columns (deprecated path). Defaults to False. + softmax_temp (float, optional): Temperature for the softmax. Defaults to ``DEFAULT_SOFTMAX_TEMP``. + allow_multiple_at_class (list[int] | None, optional): Classes allowed to repeat (e.g. T12 and L5). + Defaults to ``[T12_CLASS_IDX, L5_CLASS_IDX]``. + punish_multiple_sequence (float, optional): Extra cost added for repeating a class. Defaults to 0.0. + allow_skip_at_class (list[int] | None, optional): Classes after which a single class may be skipped (e.g. + T11). Defaults to ``[T11_CLASS_IDX]``. + punish_skip_sequence (float, optional): Extra cost added for a class-level skip. Defaults to 0.0. + allow_skip_at_region (list[int] | None, optional): Regions in which a single skip is permitted. Defaults + to ``[0]``. + punish_skip_at_region_sequence (float, optional): Extra cost added for a region-level skip. Defaults to 0.2. + verbose (bool, optional): Enable verbose logging of the recursion. Defaults to False. + + Returns: + tuple[float, list[int], list]: The total path cost, the chosen class index per vertebra (top to bottom), + and the internal memoization table of best ``(cost, path)`` per ``(row, class)`` cell. + + Raises: + AssertionError: If ``min_start_class`` is not less than the number of classes, or if a provided + ``region_rel_cost`` does not have the expected number of columns. + """ logger = No_Logger() logger.default_verbose = verbose # default mutable arguments if allow_skip_at_region is None: allow_skip_at_region = [0] if allow_skip_at_class is None: - allow_skip_at_class = [17] + allow_skip_at_class = [T11_CLASS_IDX] if allow_multiple_at_class is None: - allow_multiple_at_class = [18, 23] + allow_multiple_at_class = [T12_CLASS_IDX, L5_CLASS_IDX] if regions is None: - regions = [0, 7, 19] + regions = list(DEFAULT_REGION_STARTS) # convert to np arrays if isinstance(cost, list): cost = np.asarray(cost) @@ -157,7 +234,7 @@ def minCostAlgo(r, c): # allow two subsequent of same class if c in allow_multiple_at_class: cost_add = punish_multiple_sequence - if c == 18: + if c == T12_CLASS_IDX: cost_add += t13_cost_single(r + 1, c) with logger: add_option_path(options, r + 1, c, cost_add) @@ -182,7 +259,7 @@ def minCostAlgo(r, c): cost_value += rel_cost(r, c, pnext, region_cur) # constraint: cannot have more than 2 T12 and L5 for amac in allow_multiple_at_class: - if amac in cnt and cnt[amac] > 2: + if amac in cnt and cnt[amac] > MAX_REPEATS_PER_CLASS: cost_value = sys.maxsize break # setting to memory diff --git a/spineps/utils/generate_disc_labels.py b/spineps/utils/generate_disc_labels.py index 3ab496e..7e91ed2 100644 --- a/spineps/utils/generate_disc_labels.py +++ b/spineps/utils/generate_disc_labels.py @@ -43,9 +43,11 @@ } -def get_parser(): - """ - Parser to generate discs labels +def get_parser() -> argparse.ArgumentParser: + """Build the command-line argument parser for disc-label generation. + + Returns: + argparse.ArgumentParser: Parser accepting the input vertebrae label path and the optional output path. """ # parse command line arguments parser = argparse.ArgumentParser(description="Generate discs labels from spineps' vertebrae segmentation.") @@ -67,8 +69,10 @@ def get_parser(): def main(): - """ - Main function to extract discs labels + """Run the disc-label generation CLI. + + Parses arguments, loads the SPINEPS vertebrae segmentation, derives single-voxel disc labels from it and + writes the result to the chosen (or default) output path. """ # Load parser parser = get_parser() @@ -96,9 +100,15 @@ def main(): print("-" * 80) -def default_name_discs(path_in, suffix="_label-discs_dlabel"): - """ - Generate default discs label name +def default_name_discs(path_in: Path | str, suffix="_label-discs_dlabel") -> Path: + """Derive the default output path for disc labels by swapping in a disc suffix. + + Args: + path_in: Path to the input vertebrae label file (may include compound extensions like ``.nii.gz``). + suffix (str, optional): Suffix inserted before the extension. Defaults to ``"_label-discs_dlabel"``. + + Returns: + Path: The default output path with the disc suffix applied. """ # Fetch suffixes path_obj = Path(path_in) @@ -109,9 +119,19 @@ def default_name_discs(path_in, suffix="_label-discs_dlabel"): return path_out -def extract_discs_label(label, mapping): - """ - Extract discs from mapping +def extract_discs_label(label: Image, mapping: dict) -> Image: + """Derive single-voxel disc labels from a vertebrae segmentation. + + Remaps vertebra label values to disc values, locates each disc's posterior tip by shifting a centerline + (interpolated through the disc centroids) posteriorly and picking the closest segmented voxel, inserts disc 2 + between discs 1 and 3 when both are present, and writes one labeled voxel per disc into the image. + + Args: + label (Image): Vertebrae segmentation image; its data is replaced in place with the disc labels. + mapping (dict): Mapping from vertebra label values to disc label values. + + Returns: + Image: The image holding the disc labels, restored to its original orientation. """ # Store input orientation orig_orientation = label.orientation @@ -175,10 +195,15 @@ def extract_discs_label(label, mapping): return label.change_orientation(orig_orientation) -def extract_centroids_3d(arr): - """ - Extract centroids and bouding boxes from a 3D numpy array - :param arr: 3D numpy array +def extract_centroids_3d(arr: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Extract connected-component centroids and bounding boxes from a 3D array, sorted along the vertical axis. + + Args: + arr (np.ndarray): 3D label array (assumed RSP orientation, so axis 1 is the superior-inferior axis). + + Returns: + tuple[np.ndarray, np.ndarray]: Integer centroid coordinates and the matching bounding boxes, both sorted + by the vertical (axis-1) coordinate, with the background component removed. """ stats = cc3d.statistics(cc3d.connected_components(arr)) centroids = stats["centroids"][1:] # Remove backgroud <0> @@ -192,15 +217,17 @@ def extract_centroids_3d(arr): return centroids_sorted.astype(int), bb_sorted -def project_point_on_line(point, line): - """ - Project the input point on the referenced line by finding the minimal distance +def project_point_on_line(point: np.ndarray, line: np.ndarray) -> tuple[np.ndarray, float]: + """Project a point onto a polyline by finding the closest line point. - :param point: coordinates of a point and its value: point = numpy.array([x y z]) - :param line: list of points coordinates which composes the line - :returns: closest coordinate to the referenced point on the line: - projected_point = numpy.array([X Y Z]) - Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox + Copied from https://github.com/spinalcordtoolbox/spinalcordtoolbox. + + Args: + point (np.ndarray): Coordinates of the point, ``numpy.array([x, y, z])``. + line (np.ndarray): Coordinates of the points composing the line. + + Returns: + tuple[np.ndarray, float]: The closest point on the line and the squared distance to it. """ # Calculate distances between the referenced point and the line then keep the closest point dist = np.sum((line - point) ** 2, axis=1) @@ -208,9 +235,16 @@ def project_point_on_line(point, line): return line[np.argmin(dist)], np.min(dist) -def closest_point_seg_to_line(discs_seg, centerline, bounding_boxes): - """ - Find closest point from segmentation to a line +def closest_point_seg_to_line(discs_seg: np.ndarray, centerline: np.ndarray, bounding_boxes: np.ndarray) -> np.ndarray: + """Find, per disc, the segmented voxel closest to a reference centerline. + + Args: + discs_seg (np.ndarray): Disc-labeled segmentation array. + centerline (np.ndarray): Coordinates of the points composing the reference line. + bounding_boxes (np.ndarray): Bounding box (slice tuple) for each disc, used to isolate it. + + Returns: + np.ndarray: Array of ``[x, y, z, disc_value]`` rows, one per disc, giving the closest voxel and its label. """ discs_list = [] for x, y, z in bounding_boxes: diff --git a/spineps/utils/proc_functions.py b/spineps/utils/proc_functions.py index 4fd5f20..3d9758a 100755 --- a/spineps/utils/proc_functions.py +++ b/spineps/utils/proc_functions.py @@ -1,3 +1,5 @@ +"""Segmentation post-processing helpers: n4 bias correction, connected-component cleaning and instance fixes.""" + from __future__ import annotations import cc3d @@ -15,6 +17,9 @@ ) from tqdm import tqdm +# Vertebra instance labels span 1..25 (cervical, thoracic and lumbar); 26 denotes the sacrum. +MAX_VERTEBRA_INSTANCE_LABEL = 25 + def n4_bias( nii: NII, @@ -22,18 +27,24 @@ def n4_bias( spline_param: int = 100, dtype2nii: bool = False, norm: int = -1, -): - """Applies n4 bias field correction to a nifty +) -> tuple[NII, NII]: + """Apply N4 bias field correction to a NIfTI image. + + Builds a foreground mask by thresholding (and filling its bounding box), runs N4 correction restricted to + that mask, optionally rescales the result to a target maximum and optionally casts back to the input dtype. Args: - nii (NII): Input nifty - threshold (int, optional): Threshold to use for masking, every input value < threshold is used. Defaults to 60. - spline_param (int, optional): _description_. Defaults to 200. - dtype2nii (bool, optional): _description_. Defaults to False. - norm (int, optional): _description_. Defaults to -1. + nii (NII): Input image to correct. + threshold (int, optional): Intensity threshold for the foreground mask; voxels below it are excluded. + Defaults to 60. + spline_param (int, optional): Spline distance parameter passed to the N4 correction. Defaults to 100. + dtype2nii (bool, optional): If True, cast the corrected image back to the input image's dtype. Defaults + to False. + norm (int, optional): If not -1, rescale the corrected image so its maximum equals this value. Defaults + to -1. Returns: - _type_: _description_ + tuple[NII, NII]: The bias-corrected image and the binary foreground mask used for correction. """ from ants.utils.convert_nibabel import from_nibabel # they keep renaming that thing. (version 0.4.2) @@ -63,20 +74,31 @@ def clean_cc_artifacts( only_delete: bool = False, ignore_missing_labels: bool = False, ) -> np.ndarray: - """Cleans artifacts based on connected components analysis + """Clean small connected-component artifacts in a segmentation mask. + + For each requested label, finds connected components below the size threshold and either deletes them or, if + they border enough other foreground voxels, relabels them by majority vote of their dilated neighborhood. Args: - mask (NII | np.ndarray): input segmentation mask - logger (Logger_Interface): logger - labels (list[int], optional): labels to analyze in the input. Defaults to [1, 2, 3]. - cc_size_threshold (int | list[int], optional): threshold on which to clean, can be a number for all labels or a list of values for each different label. Defaults to 100. - neighbor_factor_2_delete (float, optional): Percentage of existing neighbor pixels to not just delete the CC. Defaults to 0.1. - verbose (bool, optional): _description_. Defaults to True. - only_delete (bool, optional): If set, will delete each analyse CC. Defaults to False. - ignore_missing_labels (bool, optional): If true, will not crash if some labels are not found. Defaults to False. + mask (NII | np.ndarray): Input segmentation mask. + logger (Logger_Interface): Logger for progress and cleaning reports. + labels (list[int], optional): Labels to analyze. Defaults to [1, 2, 3]. + cc_size_threshold (int | list[int], optional): Minimum component size in voxels; a single value applies to + all labels, or one value per label. Defaults to 100. + neighbor_factor_2_delete (float, optional): Fraction of neighboring foreground voxels below which a + component is deleted instead of relabeled. Defaults to 0.1. + verbose (bool, optional): If True, log per-component details and show a progress bar. Defaults to True. + only_delete (bool, optional): If True, delete every analyzed component without majority-vote relabeling. + Defaults to False. + ignore_missing_labels (bool, optional): If True, skip labels not present instead of asserting. Defaults to + False. Returns: - np.ndarray: _description_ + np.ndarray: The cleaned segmentation array. + + Raises: + AssertionError: If requested labels are missing (when ``ignore_missing_labels`` is False) or the length of + ``cc_size_threshold`` does not match the number of labels. """ mask_arr = mask.get_seg_array() if isinstance(mask, NII) else mask.copy() result_arr = mask_arr.copy() @@ -160,15 +182,17 @@ def clean_cc_artifacts( def connected_components_3d(mask_image: np.ndarray, connectivity: int = 3, verbose: bool = False) -> tuple[dict, dict]: # noqa: ARG001 - """Applies 3d connected components + """Compute 3D connected components per label together with their statistics. Args: - mask_image: input mask - connectivity: in range [1,3]. For 2D images, 2 and 3 is the same. - verbose: + mask_image (np.ndarray): Input (multi-label) mask. + connectivity (int, optional): Voxel connectivity in range [1, 3]. For 2D images 2 and 3 are equivalent. + Defaults to 3. + verbose (bool, optional): Currently unused. Defaults to False. Returns: - + tuple[dict, dict]: A dict mapping each label to its connected-component array, and a dict mapping each + label to its ``cc3d`` component statistics. """ subreg_cc = np_connected_components_per_label( mask_image, @@ -178,7 +202,24 @@ def connected_components_3d(mask_image: np.ndarray, connectivity: int = 3, verbo return subreg_cc, subreg_cc_stats -def fix_wrong_posterior_instance_label(seg_sem: NII, seg_inst: NII, logger) -> NII: +def fix_wrong_posterior_instance_label(seg_sem: NII, seg_inst: NII, logger: Logger_Interface) -> NII: + """Reassign misattributed posterior vertebra fragments to the correct instance label. + + For every vertebra instance that splits into multiple connected components, each extra component consisting + only of posterior elements (arcus vertebrae and/or spinous process) is relabeled to the single neighboring + instance it touches, if any. Operates on copies and restores the original orientation before returning. + + Args: + seg_sem (NII): Semantic segmentation (subregion labels) aligned with ``seg_inst``. + seg_inst (NII): Vertebra instance segmentation to correct. + logger: Logger used to report each relabeling decision. + + Returns: + NII: The corrected instance segmentation in the original orientation. + + Raises: + AssertionError: If ``seg_sem`` and ``seg_inst`` do not share the same affine. + """ seg_sem = seg_sem.copy() seg_inst = seg_inst.copy() orientation = seg_sem.orientation @@ -188,7 +229,7 @@ def fix_wrong_posterior_instance_label(seg_sem: NII, seg_inst: NII, logger) -> N seg_inst_arr_proc = seg_inst.get_seg_array() - instance_labels = [i for i in seg_inst.unique() if 1 <= i <= 25] + instance_labels = [i for i in seg_inst.unique() if 1 <= i <= MAX_VERTEBRA_INSTANCE_LABEL] for vert in instance_labels: inst_vert = seg_inst.extract_label(vert) diff --git a/spineps/utils/resolution.py b/spineps/utils/resolution.py new file mode 100644 index 0000000..4747232 --- /dev/null +++ b/spineps/utils/resolution.py @@ -0,0 +1,91 @@ +"""Resolution-aware thresholding: convert physical (mm) thresholds to voxels at the actual image zoom. + +SPINEPS was originally tuned on T2w MR images, so several processing thresholds were hard-coded as +voxel counts that implicitly assumed the model resolution. To support inputs at other resolutions +(e.g. CT), those thresholds are now expressed in physical millimetres and converted back to voxels at +runtime using the zoom of the image being processed. The constants are derived from +``REFERENCE_ZOOM`` so that, at the reference resolution, the converted voxel counts equal the original +hard-coded values (i.e. existing T2w results are preserved exactly). +""" + +from __future__ import annotations + +import numpy as np +from TPTBox import ZOOMS + +# Canonical voxel spacing (mm) of the SPINEPS T2w semantic and instance models +# (sagittal: 0.75 mm in-plane, 1.65 mm superior-inferior). All physical thresholds in the pipeline +# are derived from this reference so behaviour is unchanged at this resolution. +REFERENCE_ZOOM: ZOOMS = (0.75, 0.75, 1.65) + +# Physical volume (mm^3) of a single voxel at the reference resolution. +REFERENCE_VOXEL_VOLUME_MM3: float = float(np.prod(REFERENCE_ZOOM)) + +# In a ``P, I, R`` oriented image, axis 1 is the superior-inferior height axis. +INFERIOR_AXIS_PIR: int = 1 + + +def mm3_to_voxels(threshold_mm3: float, zoom: ZOOMS, minimum: int = 1) -> int: + """Convert a volume threshold in mm^3 to a voxel count for the given voxel spacing. + + Args: + threshold_mm3 (float): Volume threshold in cubic millimetres. + zoom (ZOOMS): Voxel spacing (mm) of the image being processed, per axis. + minimum (int, optional): Lower bound on the returned voxel count. Defaults to 1. + + Returns: + int: The threshold expressed as a number of voxels at ``zoom`` (at least ``minimum``). + """ + voxel_volume = float(np.prod(zoom)) + return max(round(threshold_mm3 / voxel_volume), minimum) + + +def isotropic_area_to_voxels(threshold_mm2: float, zoom: ZOOMS, minimum: int = 1) -> int: + """Convert an area threshold in mm^2 to an (orientation-agnostic) voxel-adjacency count. + + Uses the geometric-mean voxel face area (``prod(zoom) ** (2/3)``) so the result does not depend on + which axes span the contact surface; suitable for voxel-contact counts whose orientation is unknown. + + Args: + threshold_mm2 (float): Area threshold in square millimetres. + zoom (ZOOMS): Voxel spacing (mm) of the image being processed, per axis. + minimum (int, optional): Lower bound on the returned voxel count. Defaults to 1. + + Returns: + int: The threshold expressed as a number of voxels at ``zoom`` (at least ``minimum``). + """ + mean_voxel_area = float(np.prod(zoom)) ** (2.0 / 3.0) + return max(round(threshold_mm2 / mean_voxel_area), minimum) + + +def mm_to_voxels(threshold_mm: float, zoom: ZOOMS, minimum: int = 0) -> int: + """Convert a distance threshold in mm to an (isotropic) voxel count using the finest spacing. + + Uses the smallest voxel spacing so the result matches the original voxel-isotropic behaviour at + the reference resolution; this is the same convention used for the labeling crop margin. + + Args: + threshold_mm (float): Distance threshold in millimetres. + zoom (ZOOMS): Voxel spacing (mm) of the image being processed, per axis. + minimum (int, optional): Lower bound on the returned voxel count. Defaults to 0. + + Returns: + int: The threshold expressed as a number of voxels at ``zoom`` (at least ``minimum``). + """ + return max(round(threshold_mm / float(min(zoom))), minimum) + + +def mm_to_voxels_axis(threshold_mm: float, zoom: ZOOMS, axis: int) -> float: + """Convert a distance threshold in mm to a (possibly fractional) voxel distance along one axis. + + Returns a float so it can be passed straight to ``NII.compute_crop(dist=...)``. + + Args: + threshold_mm (float): Distance threshold in millimetres. + zoom (ZOOMS): Voxel spacing (mm) of the image being processed, per axis. + axis (int): The axis along which the distance is measured. + + Returns: + float: The threshold expressed as a voxel distance along ``axis`` at ``zoom``. + """ + return threshold_mm / float(zoom[axis]) diff --git a/spineps/utils/seg_modelconfig.py b/spineps/utils/seg_modelconfig.py index fb70442..7ce3b8b 100755 --- a/spineps/utils/seg_modelconfig.py +++ b/spineps/utils/seg_modelconfig.py @@ -1,3 +1,5 @@ +"""Inference configuration model: parses and holds the per-model settings stored in inference_config.json files.""" + from __future__ import annotations import json @@ -8,6 +10,20 @@ from spineps.seg_enums import Acquisition, InputType, Modality, ModelType +# Number of spatial dimensions of a volumetric (3D) image. +SPATIAL_DIMS = 3 + +# Default voxel geometry and post-processing cleaning thresholds. The voxel-count +# thresholds are multiplied by the resolution scaling factor at runtime. +DEFAULT_CUTOUT_SIZE = (248, 304, 64) +DEFAULT_SACRUM_IDS = (26,) +DEFAULT_CORPUS_SIZE_CLEANING = 100 # minimum corpus component size in voxels +DEFAULT_CORPUS_BORDER_THRESHOLD = 10 +DEFAULT_VERT_SIZE_THRESHOLD = 250 # minimum vertebra size in voxels + +# Default remapping of raw model label ids onto canonical SPINEPS label ids. +DEFAULT_LABEL_MAPPING = {41: 1, 42: 2, 43: 3, 44: 4, 45: 5, 46: 6, 47: 7, 48: 8, 49: 9, 50: 9, Location.Dens_axis.value: 9, 26: 0} + class Segmentation_Inference_Config: """Bucket for saving Inference Config data""" @@ -28,17 +44,55 @@ def __init__( expected_inputs: list[InputType | str] = [InputType.img], # noqa: B006 has_c1=False, needs_corp=False, - sacrum_ids=(26,), - cutout_size=(248, 304, 64), # (264, 304, 64) # (248, 304, 64) # (264, 304, 64) - corpus_size_cleaning=100, - corpus_border_threshold=10, - vert_size_threshold=250, + sacrum_ids=DEFAULT_SACRUM_IDS, + cutout_size=DEFAULT_CUTOUT_SIZE, + corpus_size_cleaning=DEFAULT_CORPUS_SIZE_CLEANING, + corpus_border_threshold=DEFAULT_CORPUS_BORDER_THRESHOLD, + vert_size_threshold=DEFAULT_VERT_SIZE_THRESHOLD, mapping=None, **kwargs, ): - scaling_factor = np.prod(resolution_range) if len(resolution_range) == 3 else np.prod(resolution_range[0]) + """Build an inference config from raw (typically JSON-decoded) values. + + String fields are resolved to the corresponding enum members and the label dictionaries are converted to + integer label ids. Voxel-count cleaning thresholds are scaled by the resolution's voxel volume so they stay + physically meaningful across resolutions. + + Args: + logger (Logger_Interface | None): Logger for diagnostics; unknown extra kwargs are reported through it. + log_name (str): Name used as the logger prefix and to identify this config. + modality (str | tuple[str]): One or more modality names (see :class:`Modality`). + acquisition (str): Acquisition plane name (see :class:`Acquisition`). + modeltype (str): Model type name (see :class:`ModelType`). + model_expected_orientation (AX_CODES): Axis-code orientation the model expects its input in. + available_folds (int | str | tuple[str] | tuple[int]): Folds available for inference/ensembling. + inference_augmentation (bool): Whether to apply test-time augmentation during inference. + resolution_range (ZOOMS | tuple[ZOOMS, ZOOMS]): Target voxel spacing, either a single zoom or a + (min, max) range. + default_step_size (float): Default sliding-window step size used during inference. + labels (dict): Mapping of raw label keys to label names/ids resolved via ``Location``/``v_name2idx``. + expected_inputs (list[InputType | str], optional): Input channels the model expects. Defaults to + ``[InputType.img]``. + has_c1 (bool, optional): Whether the model segments the C1 vertebra. Defaults to False. + needs_corp (bool, optional): Whether the model needs the vertebral corpus present. Defaults to False. + sacrum_ids (tuple, optional): Label ids treated as sacrum. Defaults to ``DEFAULT_SACRUM_IDS``. + cutout_size (tuple, optional): Crop/cutout size in voxels. Defaults to ``DEFAULT_CUTOUT_SIZE``. + corpus_size_cleaning (int, optional): Minimum corpus component size in voxels before resolution + scaling. Defaults to ``DEFAULT_CORPUS_SIZE_CLEANING``. + corpus_border_threshold (int, optional): Border distance threshold for corpus cleaning. Defaults to + ``DEFAULT_CORPUS_BORDER_THRESHOLD``. + vert_size_threshold (int, optional): Minimum vertebra size in voxels before resolution scaling. + Defaults to ``DEFAULT_VERT_SIZE_THRESHOLD``. + mapping (dict | None, optional): Remapping of raw model label ids onto canonical ids. Defaults to a + copy of ``DEFAULT_LABEL_MAPPING``. + **kwargs: Ignored extra configuration keys, reported via ``logger``. + + Raises: + KeyError: If a label name in ``labels`` cannot be resolved to a known label id. + """ + scaling_factor = np.prod(resolution_range) if len(resolution_range) == SPATIAL_DIMS else np.prod(resolution_range[0]) if mapping is None: - mapping = {41: 1, 42: 2, 43: 3, 44: 4, 45: 5, 46: 6, 47: 7, 48: 8, 49: 9, 50: 9, Location.Dens_axis.value: 9, 26: 0} + mapping = dict(DEFAULT_LABEL_MAPPING) if not isinstance(modality, (list, tuple)): modality = [modality] @@ -75,7 +129,16 @@ def __init__( for k in kwargs: logger.print(f"Ignored inference config argument {k}", Log_Type.STRANGE) - def str_representation(self, short: bool = False): + def str_representation(self, short: bool = False) -> str: + """Render the config's attributes as a comma-separated ``'key'=value`` string. + + Args: + short (bool, optional): If True, include only the modalities, acquisition and resolution range. + Defaults to False (all attributes except ``log_name``). + + Returns: + str: The formatted representation of the selected attributes. + """ to_print = self.__dict__ if not short else ["modalities", "acquisition", "resolution_range"] sb = [] for key in self.__dict__: @@ -87,14 +150,33 @@ def str_representation(self, short: bool = False): return ", ".join(sb) - def __str__(self): + def __str__(self) -> str: + """Return the full string representation. + + Returns: + str: All attributes formatted via :meth:`str_representation`. + """ return self.str_representation() - def __repr__(self): + def __repr__(self) -> str: + """Return the short string representation. + + Returns: + str: The key attributes formatted via :meth:`str_representation` with ``short=True``. + """ return self.str_representation(short=True) -def load_inference_config(json_dir: str | Path, logger: Logger_Interface | None = None): +def load_inference_config(json_dir: str | Path, logger: Logger_Interface | None = None) -> Segmentation_Inference_Config: + """Load an inference configuration from a JSON file. + + Args: + json_dir (str | Path): Path to the ``inference_config.json`` file. + logger (Logger_Interface | None, optional): Logger forwarded to the config for diagnostics. Defaults to None. + + Returns: + Segmentation_Inference_Config: The config built from the file's contents. + """ with open(str(json_dir), encoding="utf-8") as json_file: inference_config = json.load(json_file) return Segmentation_Inference_Config(**inference_config, logger=logger) diff --git a/unit_tests/test_find_min_cost_path.py b/unit_tests/test_find_min_cost_path.py new file mode 100644 index 0000000..f05e4c6 --- /dev/null +++ b/unit_tests/test_find_min_cost_path.py @@ -0,0 +1,378 @@ +# Call 'python -m unittest' on this folder # noqa: INP001 +# coverage run -m unittest +# coverage report +# coverage html +from __future__ import annotations + +import unittest + +import numpy as np + +from spineps.utils.find_min_cost_path import ( + argmin, + c_to_region_idx, + find_most_probably_sequence, + internal_to_real_path, + softmax_T, +) + + +class Test_Argmin(unittest.TestCase): + def test_normal_case(self): + idx, val = argmin([3, 1, 2]) + self.assertEqual(idx, 1) + self.assertEqual(val, 1) + + def test_min_at_start_and_end(self): + self.assertEqual(argmin([0, 5, 9]), (0, 0)) + self.assertEqual(argmin([9, 5, 0]), (2, 0)) + + def test_tie_returns_first_index(self): + # When the minimum value appears multiple times the first index wins. + idx, val = argmin([1, 0, 0, 2]) + self.assertEqual(idx, 1) + self.assertEqual(val, 0) + + def test_single_element(self): + self.assertEqual(argmin([5]), (0, 5)) + + def test_negative_values(self): + idx, val = argmin([2.0, -3.5, -1.0, 4.0]) + self.assertEqual(idx, 1) + self.assertAlmostEqual(val, -3.5) + + +class Test_SoftmaxT(unittest.TestCase): + def test_columns_sum_to_one(self): + x = np.array([[1.0, 2.0], [2.0, 1.0], [3.0, 0.0]]) + s = softmax_T(x, 1.0) + self.assertEqual(s.shape, (3, 2)) + col_sums = s.sum(axis=0) + for v in col_sums: + self.assertAlmostEqual(v, 1.0) + + def test_ordering_preserved(self): + # softmax is monotone, so the largest score keeps the largest probability. + x = np.array([[1.0], [2.0], [3.0]]) + s = softmax_T(x, 1.0) + col = s[:, 0] + self.assertTrue(np.all(np.diff(col) > 0)) + # Largest input (row 2) maps to the largest output. + self.assertEqual(int(np.argmax(col)), 2) + self.assertEqual(int(np.argmin(col)), 0) + + def test_all_probabilities_in_unit_interval(self): + x = np.array([[0.0, 5.0], [5.0, 0.0], [2.5, 2.5]]) + s = softmax_T(x, 0.5) + self.assertTrue(np.all(s > 0.0)) + self.assertTrue(np.all(s < 1.0)) + + def test_lower_temperature_sharpens(self): + # A lower temperature pushes mass toward the max -> higher peak probability. + x = np.array([[1.0], [2.0], [4.0]]) + hot = softmax_T(x, 2.0) + cold = softmax_T(x, 0.25) + self.assertGreater(cold[2, 0], hot[2, 0]) + + +class Test_CToRegionIdx(unittest.TestCase): + def test_default_region_starts(self): + regions = [0, 7, 19] + # cervical region (0) + self.assertEqual(c_to_region_idx(0, regions), 0) + self.assertEqual(c_to_region_idx(6, regions), 0) + # thoracic region (1) begins at 7 + self.assertEqual(c_to_region_idx(7, regions), 1) + self.assertEqual(c_to_region_idx(18, regions), 1) + # lumbar region (2) begins at 19 + self.assertEqual(c_to_region_idx(19, regions), 2) + self.assertEqual(c_to_region_idx(25, regions), 2) + + def test_boundaries(self): + regions = [0, 3, 6] + # Just below / at / above each boundary. + self.assertEqual(c_to_region_idx(2, regions), 0) + self.assertEqual(c_to_region_idx(3, regions), 1) + self.assertEqual(c_to_region_idx(5, regions), 1) + self.assertEqual(c_to_region_idx(6, regions), 2) + + def test_class_below_first_start_returns_minus_one(self): + # If the first region start is > 0 then classes before it resolve to -1. + self.assertEqual(c_to_region_idx(0, [2, 5]), -1) + self.assertEqual(c_to_region_idx(1, [2, 5]), -1) + self.assertEqual(c_to_region_idx(2, [2, 5]), 0) + + +class Test_InternalToRealPath(unittest.TestCase): + def test_sorts_by_row_and_returns_classes(self): + p = [(2, "c"), (0, "a"), (1, "b")] + self.assertEqual(internal_to_real_path(p), ["a", "b", "c"]) + + def test_numeric_classes(self): + p = [(3, 30), (1, 10), (0, 5), (2, 20)] + self.assertEqual(internal_to_real_path(p), [5, 10, 20, 30]) + + def test_already_sorted_is_unchanged(self): + p = [(0, 9), (1, 8), (2, 7)] + self.assertEqual(internal_to_real_path(p), [9, 8, 7]) + + def test_single_node(self): + self.assertEqual(internal_to_real_path([(0, 42)]), [42]) + + +class Test_FindMostProbableSequence(unittest.TestCase): + @staticmethod + def _strong_diagonal(n_rows: int, n_cols: int, value: float = 10.0) -> np.ndarray: + """Build a cost matrix whose obvious optimum is the main diagonal.""" + cost = np.zeros((n_rows, n_cols), dtype=float) + for i in range(n_rows): + cost[i, i] = value + return cost + + def test_strong_diagonal_invert_cost(self): + # With invert_cost the highest scores are preferred -> follow the diagonal. + cost = self._strong_diagonal(4, 6, 10.0) + fcost, fpath, min_costs_path = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + ) + self.assertEqual(fpath, [0, 1, 2, 3]) + # Each of the four chosen diagonal cells contributes -10 after inversion. + self.assertAlmostEqual(fcost, -40.0) + # Path covers every row exactly once. + self.assertEqual(len(fpath), cost.shape[0]) + # The returned memo table mirrors the matrix shape. + self.assertEqual(len(min_costs_path), cost.shape[0]) + self.assertEqual(len(min_costs_path[0]), cost.shape[1]) + + def test_no_invert_prefers_low_cost(self): + # Without inversion the solver minimises raw cost: 0 on the diagonal, 10 elsewhere. + cost = np.full((4, 6), 10.0) + for i in range(4): + cost[i, i] = 0.0 + fcost, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=False, + allow_skip_at_region=[], + ) + self.assertEqual(fpath, [0, 1, 2, 3]) + self.assertAlmostEqual(fcost, 0.0) + + def test_invert_symmetry(self): + # invert_cost=True on a matrix equals invert_cost=False on its negation. + cost = self._strong_diagonal(4, 6, 7.0) + fcost_a, fpath_a, _ = find_most_probably_sequence(cost, invert_cost=True, allow_skip_at_region=[]) + fcost_b, fpath_b, _ = find_most_probably_sequence(-cost, invert_cost=False, allow_skip_at_region=[]) + self.assertEqual(fpath_a, fpath_b) + self.assertAlmostEqual(fcost_a, fcost_b) + + def test_min_start_class_shifts_path(self): + # The diagonal of high scores starts at column 2; min_start_class must allow it. + cost = np.zeros((4, 8), dtype=float) + for i in range(4): + cost[i, i + 2] = 5.0 + fcost, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + min_start_class=2, + ) + self.assertEqual(fpath, [2, 3, 4, 5]) + self.assertAlmostEqual(fcost, -20.0) + + def test_first_label_respects_min_start_class(self): + # Even when an earlier column looks attractive, the path may not start before min_start_class. + cost = np.zeros((4, 8), dtype=float) + cost[0, 0] = 100.0 # very attractive but forbidden as a start + for i in range(4): + cost[i, i + 3] = 5.0 + fcost, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + min_start_class=3, + ) + self.assertGreaterEqual(fpath[0], 3) + self.assertEqual(len(fpath), cost.shape[0]) + self.assertTrue(np.isfinite(fcost)) + + def test_list_input_is_accepted(self): + # A plain nested list must be handled identically to an ndarray. + cost = [ + [7, 0, 0, 0, 0, 0], + [0, 7, 0, 0, 0, 0], + [0, 0, 7, 0, 0, 0], + [0, 0, 0, 7, 0, 0], + ] + fcost, fpath, _ = find_most_probably_sequence(cost, invert_cost=True, allow_skip_at_region=[]) + self.assertEqual(fpath, [0, 1, 2, 3]) + self.assertAlmostEqual(fcost, -28.0) + + def test_allow_multiple_at_class_enables_repeat(self): + # Class 1 is strongly preferred for two consecutive rows; repeating it captures both. + cost = np.zeros((4, 5), dtype=float) + cost[0, 0] = 10.0 + cost[1, 1] = 10.0 + cost[2, 1] = 10.0 # second consecutive class-1 vertebra + cost[3, 2] = 10.0 + fcost, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + allow_multiple_at_class=[1], + punish_multiple_sequence=0.0, + ) + self.assertEqual(fpath, [0, 1, 1, 2]) + self.assertAlmostEqual(fcost, -40.0) + + def test_without_multiple_forces_diagonal(self): + # The same matrix, but repeats disallowed -> must advance every step. + cost = np.zeros((4, 5), dtype=float) + cost[0, 0] = 10.0 + cost[1, 1] = 10.0 + cost[2, 1] = 10.0 + cost[3, 2] = 10.0 + fcost, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + allow_multiple_at_class=[], + punish_multiple_sequence=0.0, + ) + self.assertEqual(fpath, [0, 1, 2, 3]) + # Only rows 0 and 1 hit their high-score cells. + self.assertAlmostEqual(fcost, -20.0) + + def test_repeat_capped_when_penalty_high(self): + # A large repeat penalty makes the diagonal cheaper than repeating class 1. + cost = np.zeros((4, 5), dtype=float) + cost[0, 0] = 10.0 + cost[1, 1] = 10.0 + cost[2, 1] = 10.0 + cost[3, 2] = 10.0 + _, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + allow_multiple_at_class=[1], + punish_multiple_sequence=100.0, + ) + self.assertEqual(fpath, [0, 1, 2, 3]) + + def test_allow_skip_at_class_enables_jump(self): + # After class 0 we may skip class 1 and land on class 2. + cost = np.zeros((4, 6), dtype=float) + cost[0, 0] = 10.0 + cost[1, 2] = 10.0 # reached by skipping class 1 + cost[2, 3] = 10.0 + cost[3, 4] = 10.0 + fcost, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + allow_skip_at_class=[0], + punish_skip_sequence=0.0, + ) + self.assertEqual(fpath, [0, 2, 3, 4]) + self.assertAlmostEqual(fcost, -40.0) + + def test_without_skip_no_jump(self): + # The same matrix, but skipping disallowed -> the path shifts to a pure diagonal. + cost = np.zeros((4, 6), dtype=float) + cost[0, 0] = 10.0 + cost[1, 2] = 10.0 + cost[2, 3] = 10.0 + cost[3, 4] = 10.0 + fcost, fpath, _ = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + allow_skip_at_class=[], + punish_skip_sequence=0.0, + ) + self.assertEqual(fpath, [1, 2, 3, 4]) + self.assertAlmostEqual(fcost, -30.0) + + def test_region_rel_cost_pulls_region_start(self): + # An all-zero cost matrix means only the region transition cost matters. + # Region 1 starts at class index 3; making "first of region 1" attractive at + # vertebra 2 forces vertebra 2 onto class 3 -> path [1, 2, 3, 4]. + cost = np.zeros((4, 5), dtype=float) + rel = np.zeros((4, 4), dtype=float) # columns: nothing, last0, first1, last2 + rel[2, 2] = -8.0 # "first of region 1" reward at vertebra 2 + fcost, fpath, _ = find_most_probably_sequence( + cost, + region_rel_cost=rel, + regions=[0, 3], + invert_cost=True, + ) + self.assertEqual(fpath, [1, 2, 3, 4]) + self.assertAlmostEqual(fcost, -8.0) + + def test_region_rel_cost_matches_existing_simple_case(self): + # Mirrors the documented simple scenario: a strong column plus region rewards. + cost = np.array( + [ + [0, 10, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ], + dtype=float, + ) + rel = -np.array( + [ + [0, 0, 0, 0], + [0, 10, 0, 0], + [0, 0, 0, 0], + [0, 0, 11, 0], + ], + dtype=float, + ) + fcost, fpath, _ = find_most_probably_sequence( + cost, + region_rel_cost=rel, + regions=[0, 3], + ) + self.assertEqual(fpath, [1, 2, 3, 4]) + # Mean cost per vertebra, as asserted in the existing path test. + self.assertAlmostEqual(fcost / len(fpath), -5.0) + + def test_region_rel_cost_wrong_shape_raises(self): + # region_rel_cost must have (n_regions * 2) columns for the given regions. + cost = np.zeros((4, 5), dtype=float) + bad_rel = np.zeros((4, 3), dtype=float) # should be 4 columns for regions [0, 3] + with self.assertRaises(AssertionError): + find_most_probably_sequence(cost, region_rel_cost=bad_rel, regions=[0, 3]) + + def test_min_start_class_out_of_range_raises(self): + cost = np.zeros((4, 5), dtype=float) + with self.assertRaises(AssertionError): + find_most_probably_sequence(cost, min_start_class=5, allow_skip_at_region=[]) + + def test_path_properties_on_random_matrix(self): + # Robust structural properties that must hold for any valid solution. + rng = np.random.default_rng(0) + cost = rng.random((5, 8)) + fcost, fpath, min_costs_path = find_most_probably_sequence( + cost, + invert_cost=True, + allow_skip_at_region=[], + ) + # One label per vertebra (row). + self.assertEqual(len(fpath), cost.shape[0]) + # Cost is finite. + self.assertTrue(np.isfinite(fcost)) + # Labels are valid column indices and weakly increasing (monotone path). + for c in fpath: + self.assertGreaterEqual(c, 0) + self.assertLess(c, cost.shape[1]) + self.assertTrue(all(fpath[i] <= fpath[i + 1] for i in range(len(fpath) - 1))) + # Memo table shape mirrors the cost matrix. + self.assertEqual(len(min_costs_path), cost.shape[0]) + self.assertEqual(len(min_costs_path[0]), cost.shape[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/test_generate_disc_labels_extra.py b/unit_tests/test_generate_disc_labels_extra.py new file mode 100644 index 0000000..8e12df1 --- /dev/null +++ b/unit_tests/test_generate_disc_labels_extra.py @@ -0,0 +1,175 @@ +# Call 'python -m unittest' on this folder # noqa: INP001 +# coverage run -m unittest +# coverage report +# coverage html +from __future__ import annotations + +import unittest +from pathlib import Path + +import numpy as np +import numpy.testing as npt + +from spineps.utils.generate_disc_labels import ( + DISCS_MAP, + closest_point_seg_to_line, + default_name_discs, + extract_centroids_3d, + project_point_on_line, +) + + +class Test_ProjectPointOnLine(unittest.TestCase): + def test_point_on_axis_aligned_line(self): + # Line along the x-axis; an off-line point projects to the nearest vertex. + line = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [3.0, 0.0, 0.0]]) + point = np.array([1.4, 5.0, 0.0]) + closest, dist = project_point_on_line(point, line) + # Nearest vertex is x=1 (1.4 rounds toward 1), distance^2 = 0.4^2 + 5.0^2 = 25.16 + npt.assert_allclose(closest, np.array([1.0, 0.0, 0.0])) + self.assertAlmostEqual(dist, 25.16) + + def test_point_exactly_on_vertex(self): + line = np.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]) + point = np.array([2.0, 0.0, 0.0]) + closest, dist = project_point_on_line(point, line) + npt.assert_allclose(closest, np.array([2.0, 0.0, 0.0])) + self.assertAlmostEqual(dist, 0.0) + + def test_snaps_to_nearest_vertex_not_interpolated(self): + # The function returns the closest *vertex*, it does not interpolate along the segment. + line = np.array([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 2.0, 0.0], [0.0, 3.0, 0.0]]) + point = np.array([0.0, 1.6, 0.0]) + closest, dist = project_point_on_line(point, line) + # 1.6 is closer to vertex 2 than vertex 1 -> [0, 2, 0], distance^2 = 0.4^2 = 0.16 + npt.assert_allclose(closest, np.array([0.0, 2.0, 0.0])) + self.assertAlmostEqual(dist, 0.16) + + def test_returns_squared_distance(self): + # Verify the returned distance is the squared euclidean distance (not the root). + line = np.array([[0.0, 0.0, 0.0]]) + point = np.array([3.0, 4.0, 0.0]) + closest, dist = project_point_on_line(point, line) + npt.assert_allclose(closest, np.array([0.0, 0.0, 0.0])) + self.assertAlmostEqual(dist, 25.0) # 3^2 + 4^2, not 5 + + +class Test_ExtractCentroids3d(unittest.TestCase): + def test_two_components_sorted_by_vertical_axis(self): + arr = np.zeros((10, 10, 10), dtype=int) + arr[1:3, 1:3, 1:3] = 5 # lower along axis 1 (S-I axis in RSP) + arr[1:3, 6:8, 1:3] = 7 # higher along axis 1 + centroids, _bounding_boxes = extract_centroids_3d(arr) + + self.assertEqual(len(centroids), 2) + # Centroid of a 2x2x2 block starting at (1,1,1) is (1,1,1) after int truncation; same for (1,6,1). + npt.assert_array_equal(centroids, np.array([[1, 1, 1], [1, 6, 1]])) + # Sorted ascending along axis 1. + self.assertTrue(np.all(np.diff(centroids[:, 1]) >= 0)) + # Integer dtype is guaranteed by the implementation. + self.assertTrue(np.issubdtype(centroids.dtype, np.integer)) + + def test_background_component_removed(self): + arr = np.zeros((6, 6, 6), dtype=int) + arr[2:4, 2:4, 2:4] = 3 + centroids, bounding_boxes = extract_centroids_3d(arr) + # Only one foreground component, background (label 0) must be dropped. + self.assertEqual(len(centroids), 1) + self.assertEqual(len(bounding_boxes), 1) + npt.assert_array_equal(centroids[0], np.array([2, 2, 2])) + + def test_sorting_independent_of_insertion_order(self): + arr = np.zeros((12, 12, 4), dtype=int) + arr[0:2, 8:10, 0:2] = 1 # high axis 1 + arr[0:2, 0:2, 0:2] = 2 # low axis 1 + arr[0:2, 4:6, 0:2] = 3 # mid axis 1 + centroids, _ = extract_centroids_3d(arr) + # Regardless of which label was written first, output is sorted by axis-1 coord. + npt.assert_array_equal(centroids[:, 1], np.array([0, 4, 8])) + + def test_bounding_boxes_match_components(self): + arr = np.zeros((8, 8, 8), dtype=int) + arr[1:3, 1:3, 1:3] = 4 + _, bounding_boxes = extract_centroids_3d(arr) + self.assertEqual(len(bounding_boxes), 1) + # cc3d returns slice tuples; the bounding box must contain exactly the block we placed. + bb = bounding_boxes[0] + sub = arr[bb[0], bb[1], bb[2]] + self.assertEqual(sub.shape, (2, 2, 2)) + self.assertTrue(np.all(sub == 4)) + + +class Test_ClosestPointSegToLine(unittest.TestCase): + def test_picks_closest_voxel_and_preserves_label(self): + arr = np.zeros((10, 10, 10), dtype=int) + arr[1:3, 1:3, 1:3] = 5 + arr[1:3, 6:8, 1:3] = 7 + _, bounding_boxes = extract_centroids_3d(arr) + + # Centerline far in the +z direction -> nearest voxel of each disc is the one with max z. + centerline = np.array([[1.0, 1.0, 100.0], [1.0, 6.0, 100.0]]) + result = closest_point_seg_to_line(arr, centerline, bounding_boxes) + + self.assertEqual(result.shape, (2, 4)) + # Each row is [x, y, z, disc_value]; z must be 2 (top of the [1,3) z-range), labels preserved. + npt.assert_array_equal(result, np.array([[1, 1, 2, 5], [1, 6, 2, 7]])) + + def test_label_value_in_last_column(self): + arr = np.zeros((6, 6, 6), dtype=int) + arr[2:4, 2:4, 2:4] = 9 + _, bounding_boxes = extract_centroids_3d(arr) + centerline = np.array([[0.0, 0.0, 0.0]]) + result = closest_point_seg_to_line(arr, centerline, bounding_boxes) + self.assertEqual(result.shape, (1, 4)) + # Closest voxel to the origin within the [2,4) block is (2,2,2), value 9. + npt.assert_array_equal(result[0], np.array([2, 2, 2, 9])) + + def test_single_voxel_disc(self): + arr = np.zeros((5, 5, 5), dtype=int) + arr[3, 1, 4] = 11 # a single labelled voxel + _, bounding_boxes = extract_centroids_3d(arr) + centerline = np.array([[0.0, 0.0, 0.0], [10.0, 10.0, 10.0]]) + result = closest_point_seg_to_line(arr, centerline, bounding_boxes) + npt.assert_array_equal(result, np.array([[3, 1, 4, 11]])) + + +class Test_DefaultNameDiscs(unittest.TestCase): + def test_default_suffix_with_compound_extension(self): + out = default_name_discs("/data/sub-amu_T2w_dseg.nii.gz") + self.assertEqual(out, Path("/data/sub-amu_T2w_dseg_label-discs_dlabel.nii.gz")) + + def test_custom_suffix(self): + out = default_name_discs(Path("/data/foo.nii.gz"), suffix="_disc") + self.assertEqual(out, Path("/data/foo_disc.nii.gz")) + + def test_single_extension(self): + out = default_name_discs("/data/foo.mha") + self.assertEqual(out, Path("/data/foo_label-discs_dlabel.mha")) + + def test_accepts_path_object_input(self): + out = default_name_discs(Path("/data/scan.nii")) + self.assertIsInstance(out, Path) + self.assertEqual(out.name, "scan_label-discs_dlabel.nii") + + +class Test_DiscsMap(unittest.TestCase): + def test_mapping_known_values(self): + # Spot-check the static vertebra->disc remapping table. + self.assertEqual(DISCS_MAP[2], 1) + self.assertEqual(DISCS_MAP[102], 3) + self.assertEqual(DISCS_MAP[124], 25) + + def test_mapping_is_consecutive_for_thoracolumbar_block(self): + # Keys 102..124 map to consecutive disc values 3..25. + block_keys = list(range(102, 125)) + values = [DISCS_MAP[k] for k in block_keys] + self.assertEqual(values, list(range(3, 26))) + + def test_disc_value_2_is_not_directly_mapped(self): + # Disc 2 is inserted between 1 and 3 by extract_discs_label, so it is absent from the map values. + self.assertNotIn(2, DISCS_MAP.values()) + self.assertEqual(len(DISCS_MAP), 24) + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/test_inference_mocked.py b/unit_tests/test_inference_mocked.py new file mode 100644 index 0000000..78d2b4c --- /dev/null +++ b/unit_tests/test_inference_mocked.py @@ -0,0 +1,352 @@ +# Call 'python -m unittest' on this folder # noqa: INP001 +# coverage run -m unittest +# coverage report +# coverage html +"""Tests for the GPU/model-inference code paths, emulating the networks with unittest.mock. + +These tests reuse the established dummy-model + MagicMock pattern from test_semantic.py so the +surrounding orchestration (input preparation, cutout collection, prediction merging, labeling and +output mapping) is exercised without any real model weights or a GPU. +""" + +from __future__ import annotations + +import unittest +from pathlib import Path +from typing import ClassVar +from unittest.mock import MagicMock + +import numpy as np +import torch +from TPTBox import NII, No_Logger +from TPTBox.tests.test_utils import get_test_mri +from typing_extensions import Self + +from spineps.lab_model import VertLabelingClassifier +from spineps.phase_instance import predict_instance_mask +from spineps.phase_labeling import VERT_CLASSES, perform_labeling_step, run_model_for_vert_labeling +from spineps.seg_enums import ErrCode, OutputType +from spineps.seg_model import Segmentation_Inference_Config, Segmentation_Model + +logger = No_Logger() + + +class DummyPredictor: + """Minimal stand-in for a loaded network predictor.""" + + def __init__(self) -> None: + pass + + +def _dummy_seg_config(cutout_size: tuple[int, int, int] = (48, 48, 32)) -> Segmentation_Inference_Config: + """Build an inference config for a dummy segmentation model (no weights needed).""" + return Segmentation_Inference_Config( + logger=No_Logger(), + modality=["T2w", "SEG", "T1w"], + acquisition="sag", + log_name="DummySegModel", + modeltype="unet", + model_expected_orientation=("P", "I", "R"), + available_folds=1, + inference_augmentation=False, + resolution_range=[1.5, 1.5, 1.5], # equals the test fixture zoom -> no rescaling + default_step_size=0.5, + labels={1: 1}, + cutout_size=cutout_size, + ) + + +class Segmentation_Model_Dummy(Segmentation_Model): + """A Segmentation_Model whose load() installs a dummy predictor instead of real weights.""" + + def __init__(self, cutout_size: tuple[int, int, int] = (48, 48, 32)) -> None: + self.logger = No_Logger() + super().__init__(__file__, _dummy_seg_config(cutout_size), default_verbose=False, default_allow_tqdm=False) + + def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 + self.predictor = DummyPredictor() + return self + + def run(self, input_nii: list[NII], verbose: bool = False) -> dict[OutputType, NII | None]: # noqa: ARG002 + return {OutputType.seg: input_nii[0], OutputType.softmax_logits: None} + + +class Labeling_Model_Dummy(VertLabelingClassifier): + """A VertLabelingClassifier whose load() installs a dummy predictor instead of real weights.""" + + def __init__(self) -> None: + self.logger = No_Logger() + config = Segmentation_Inference_Config( + logger=self.logger, + modality=["T2w", "SEG", "T1w"], + acquisition="sag", + log_name="DummyLabelModel", + modeltype="classifier", + model_expected_orientation=("P", "I", "R"), + available_folds=1, + inference_augmentation=False, + resolution_range=[1.0, 1.0, 1.0], + default_step_size=0.5, + labels={1: 1}, + ) + super().__init__(__file__, config, default_verbose=False, default_allow_tqdm=False) + + def load(self, folds: tuple[str, ...] | None = None) -> Self: # noqa: ARG002 + self.predictor = DummyPredictor() + return self + + +def _vert_softmax(peak_class: int) -> np.ndarray: + """Return a length-VERT_CLASSES softmax-like vector peaked at ``peak_class``.""" + arr = np.full(VERT_CLASSES, 0.01, dtype=float) + arr[min(max(peak_class, 0), VERT_CLASSES - 1)] = 0.89 + return arr / arr.sum() + + +def _fake_run_all_seg_instances(img: NII, seg: NII, *args, **kwargs): # noqa: ARG001 + """Emulate VertLabelingClassifier.run_all_seg_instances with a deterministic VERT head. + + Returns one prediction per unique vertebra label in ``seg`` (in ascending order), each peaked at a + consecutive cervical class so the resulting path is a valid, increasing sequence. + """ + labels = [int(v) for v in seg.unique() if v != 0] + predictions: dict[int, dict] = {} + for offset, v in enumerate(sorted(labels)): + soft = _vert_softmax(1 + offset) # C2, C3, C4, ... + predictions[v] = {"soft": {"VERT": soft}, "pred": {"VERT": int(np.argmax(soft))}} + return predictions + + +class Test_Labeling_Inference_Mocked(unittest.TestCase): + def test_run_model_for_vert_labeling(self): + mri, _subreg, vert, _label = get_test_mri() + model = Labeling_Model_Dummy().load() + model.run_all_seg_instances = MagicMock(side_effect=_fake_run_all_seg_instances) + + labelmap, _fcost, _fpath, fpath_post, _costlist, _mcp, predictions = run_model_for_vert_labeling(model, mri, vert) + # One prediction and one labelmap entry per input vertebra (5, 6, 7). + self.assertEqual(len(predictions), 3) + self.assertEqual(len(labelmap), 3) + self.assertEqual(len(fpath_post), 3) + model.run_all_seg_instances.assert_called() + + def test_perform_labeling_step_relabels(self): + mri, subreg, vert, _label = get_test_mri() + model = Labeling_Model_Dummy().load() + model.run_all_seg_instances = MagicMock(side_effect=_fake_run_all_seg_instances) + + out = perform_labeling_step(model, mri, vert.copy(), subreg_nii=subreg) + self.assertIsInstance(out, NII) + # Same spatial frame as the input vertebra mask. + self.assertTrue(out.assert_affine(other=vert)) + # The instance labels were remapped to the labeling model's output classes. + self.assertGreater(len(out.unique()), 0) + + +class Test_Segment_Scan_Mocked(unittest.TestCase): + def test_segment_scan_padding_round_trip(self): + mri, _subreg, _vert, _label = get_test_mri() + model = Segmentation_Model_Dummy() + # run() echoes its (padded, reoriented) input back as the segmentation. + model.run = MagicMock(side_effect=lambda input_nii, verbose=False: {OutputType.seg: input_nii[0], OutputType.softmax_logits: None}) # noqa: ARG005 + + result = model.segment_scan( + mri, + pad_size=3, + resample_to_recommended=False, + resample_output_to_input_space=True, + verbose=False, + ) + seg = result[OutputType.seg] + self.assertIsInstance(seg, NII) + # Padding added before inference is removed again -> output matches the input shape. + self.assertEqual(seg.shape, mri.shape) + model.run.assert_called_once() + + def test_segment_scan_without_resample_back(self): + mri, subreg, _vert, _label = get_test_mri() + model = Segmentation_Model_Dummy() + model.run = MagicMock(return_value={OutputType.seg: subreg.copy(), OutputType.softmax_logits: None}) + + result = model.segment_scan( + mri, + pad_size=0, + resample_to_recommended=False, + resample_output_to_input_space=False, + verbose=False, + ) + self.assertIn(OutputType.seg, result) + self.assertIsInstance(result[OutputType.seg], NII) + + +class Test_Instance_Inference_Mocked(unittest.TestCase): + @staticmethod + def _fake_segment_scan(cut_nii: NII, **kwargs): # noqa: ARG004 + """Emulate the instance model: split the cutout's corpus into a 1/2/3 three-vertebra hierarchy.""" + arr = cut_nii.get_seg_array() + out = np.zeros_like(arr) + corpus = np.argwhere(arr != 0) + if len(corpus) > 0: + # Split along the axis with the largest extent into thirds (above=1, center=2, below=3). + extents = corpus.max(axis=0) - corpus.min(axis=0) + axis = int(np.argmax(extents)) + coords = corpus[:, axis] + lo, hi = coords.min(), coords.max() + 1 + third = max((hi - lo) / 3.0, 1.0) + for c in corpus: + bucket = int((c[axis] - lo) / third) + out[c[0], c[1], c[2]] = min(bucket, 2) + 1 + return {OutputType.seg: cut_nii.set_array(out), OutputType.softmax_logits: None} + + def test_predict_instance_mask_runs(self): + _mri, subreg, _vert, _label = get_test_mri() + model = Segmentation_Model_Dummy(cutout_size=(48, 48, 32)) + model.segment_scan = MagicMock(side_effect=self._fake_segment_scan) + + whole_vert_nii, errcode = predict_instance_mask( + subreg.copy(), + model, + debug_data={}, + proc_corpus_clean=False, + proc_inst_clean_small_cc_artifacts=False, + verbose=False, + ) + self.assertEqual(errcode, ErrCode.OK) + self.assertIsInstance(whole_vert_nii, NII) + self.assertEqual(whole_vert_nii.shape, subreg.shape) + # At least one vertebra instance was produced and the model was queried per corpus cutout. + self.assertGreater(len([v for v in whole_vert_nii.unique() if v != 0]), 0) + model.segment_scan.assert_called() + + def test_predict_instance_mask_empty_without_corpus(self): + _mri, subreg, _vert, _label = get_test_mri() + # Remove the corpus-border label (49) so the instance phase has nothing to work with. + no_corpus = subreg.copy() + no_corpus[no_corpus == 49] = 0 + model = Segmentation_Model_Dummy() + model.segment_scan = MagicMock(side_effect=self._fake_segment_scan) + + result, errcode = predict_instance_mask(no_corpus, model, debug_data={}, proc_corpus_clean=False) + self.assertIsNone(result) + self.assertEqual(errcode, ErrCode.EMPTY) + + +class FakeClassifierPredictor: + """Stand-in for a loaded PLClassifier: a multi-head network returning deterministic logits. + + Implements the small surface that VertLabelingClassifier._run_array uses: ``eval``, ``to``, + ``forward`` (returning a per-head logits dict) and ``softmax``. Zero logits give a uniform, + deterministic softmax, which is all the surrounding code needs. + """ + + HEADS: ClassVar[dict[str, int]] = {"VERT": VERT_CLASSES, "VERTGRP": 12, "REGION": 3, "VERTREL": 6, "VERTT13": 2, "FULLYVISIBLE": 2} + + def eval(self) -> FakeClassifierPredictor: + return self + + def to(self, device) -> FakeClassifierPredictor: # noqa: ARG002 + return self + + def softmax(self, v: torch.Tensor) -> torch.Tensor: + return torch.softmax(v, dim=1) + + def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: + batch = x.shape[0] + return {name: torch.zeros((batch, n_classes)) for name, n_classes in self.HEADS.items()} + + __call__ = forward + + +def _make_classifier_with_fake_predictor(cutout: tuple[int, int, int] = (32, 32, 16)) -> Labeling_Model_Dummy: + """Build a labeling model wired to a FakeClassifierPredictor, running entirely on CPU.""" + from monai.transforms import CenterSpatialCropd, Compose, NormalizeIntensityd + + model = Labeling_Model_Dummy() + model.device = torch.device("cpu") + model.final_size = cutout + model.cutout_size = cutout # normally set from the checkpoint in load() + model.transform = Compose( + [ + NormalizeIntensityd(keys=["img"], nonzero=True, channel_wise=False), + CenterSpatialCropd(keys=["img", "seg"], roi_size=cutout), + ] + ) + model.predictor = FakeClassifierPredictor() + return model + + +class Test_Classifier_Forward_Mocked(unittest.TestCase): + def test_run_array(self): + model = _make_classifier_with_fake_predictor() + img_arr = np.arange(32 * 32 * 16, dtype=np.float32).reshape(32, 32, 16) + seg_arr = (img_arr > img_arr.mean()).astype(np.float32) + + logits_soft, pred_cls = model._run_array(img_arr, seg_arr) + # One entry per network head, with the VERT head holding a full class vector. + self.assertEqual(set(logits_soft.keys()), set(FakeClassifierPredictor.HEADS.keys())) + self.assertEqual(logits_soft["VERT"].shape, (VERT_CLASSES,)) + self.assertEqual(set(pred_cls.keys()), set(FakeClassifierPredictor.HEADS.keys())) + self.assertEqual(np.asarray(pred_cls["VERT"]).ndim, 0) + + def test_run_array_without_seg(self): + model = _make_classifier_with_fake_predictor() + img_arr = np.ones((32, 32, 16), dtype=np.float32) + # seg defaults to a clone of the image when omitted. + logits_soft, _pred_cls = model._run_array(img_arr) + self.assertEqual(logits_soft["VERT"].shape, (VERT_CLASSES,)) + + def test_run_all_arrays(self): + model = _make_classifier_with_fake_predictor() + arrays = {5: np.ones((32, 32, 16), dtype=np.float32), 6: np.ones((32, 32, 16), dtype=np.float32)} + predictions = model.run_all_arrays(arrays) + self.assertEqual(set(predictions.keys()), {5, 6}) + for entry in predictions.values(): + self.assertIn("soft", entry) + self.assertIn("pred", entry) + self.assertEqual(entry["soft"]["VERT"].shape, (VERT_CLASSES,)) + + def test_run_all_seg_instances_full_path(self): + mri, _subreg, vert, _label = get_test_mri() + model = _make_classifier_with_fake_predictor() + # Drives reorient -> per-instance cutout -> _run_array for every label in the mask. + predictions = model.run_all_seg_instances(mri, vert) + expected = [int(v) for v in vert.unique() if v != 0] + self.assertEqual(sorted(predictions.keys()), sorted(expected)) + for entry in predictions.values(): + self.assertEqual(entry["soft"]["VERT"].shape, (VERT_CLASSES,)) + + +class Test_Same_Modelzoom(unittest.TestCase): + @staticmethod + def _model_with_zoom(zoom: tuple[float, float, float]) -> Segmentation_Model_Dummy: + # A fixed (length-3) resolution_range makes calc_recommended_resampling_zoom return it verbatim. + model = Segmentation_Model_Dummy() + model.inference_config.resolution_range = tuple(zoom) + return model + + def test_same_resolution_matches(self): + a = self._model_with_zoom((1.0, 1.0, 1.0)) + b = self._model_with_zoom((1.0, 1.0, 1.0)) + self.assertTrue(a.same_modelzoom_as_model(b, (1.0, 1.0, 1.0))) + + def test_coarser_other_model_does_not_match(self): + # Regression: model_zms > self_zms yields a negative per-axis difference that must NOT count + # as a match (the bug was comparing the signed difference against the tolerance). + a = self._model_with_zoom((1.0, 1.0, 1.0)) + b = self._model_with_zoom((2.0, 2.0, 2.0)) + self.assertFalse(a.same_modelzoom_as_model(b, (1.0, 1.0, 1.0))) + + def test_finer_other_model_does_not_match(self): + a = self._model_with_zoom((2.0, 2.0, 2.0)) + b = self._model_with_zoom((1.0, 1.0, 1.0)) + self.assertFalse(a.same_modelzoom_as_model(b, (1.0, 1.0, 1.0))) + + def test_single_coarser_axis_does_not_match(self): + # Only the inferior axis differs and the other model is coarser there (negative diff). + a = self._model_with_zoom((1.0, 1.0, 1.0)) + b = self._model_with_zoom((1.0, 1.0, 2.0)) + self.assertFalse(a.same_modelzoom_as_model(b, (1.0, 1.0, 1.0))) + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/test_phase_labeling.py b/unit_tests/test_phase_labeling.py new file mode 100644 index 0000000..c1fd0ee --- /dev/null +++ b/unit_tests/test_phase_labeling.py @@ -0,0 +1,310 @@ +# Call 'python -m unittest' on this folder # noqa: INP001 +# coverage run -m unittest +# coverage report +# coverage html +from __future__ import annotations + +import unittest + +import numpy as np + +from spineps.architectures.read_labels import VertExact, vert_group_idx_to_exact_idx_dict +from spineps.phase_labeling import ( + CERV, + LUMB, + T13_LABEL, + THOR, + VERT_CLASSES, + fpath_post_processing, + is_valid_vertebra_sequence, + prepare_region, + prepare_vert, + prepare_vertgrp, + prepare_vertrel, + prepare_vertrel_columns, + prepare_vertt13_columns, + prepare_visible, + region_to_vert, +) + + +class Test_region_to_vert(unittest.TestCase): + def test_shape_and_broadcast(self): + out = region_to_vert(np.array([0.2, 0.5, 0.3])) + self.assertEqual(out.shape, (VERT_CLASSES,)) + # cervical slice all == region[0] + self.assertTrue(np.allclose(out[CERV], 0.2)) + # thoracic slice all == region[1] + self.assertTrue(np.allclose(out[THOR], 0.5)) + # lumbar slice all == region[2] + self.assertTrue(np.allclose(out[LUMB], 0.3)) + + def test_slice_lengths(self): + out = region_to_vert(np.array([1.0, 2.0, 3.0])) + # CERV = slice(None, 7) -> 7 classes + self.assertEqual(out[CERV].shape[0], 7) + # THOR = slice(7, 19) -> 12 classes + self.assertEqual(out[THOR].shape[0], 12) + # LUMB = slice(19, None) -> 5 classes + self.assertEqual(out[LUMB].shape[0], 5) + # every entry filled (no zeros remain for nonzero regions) + self.assertEqual(np.count_nonzero(out), VERT_CLASSES) + + def test_zero_region_leaves_zeros(self): + out = region_to_vert(np.array([0.0, 1.0, 0.0])) + self.assertTrue(np.allclose(out[CERV], 0.0)) + self.assertTrue(np.allclose(out[LUMB], 0.0)) + self.assertTrue(np.allclose(out[THOR], 1.0)) + + +class Test_prepare_vert(unittest.TestCase): + def test_no_smoothing_normalizes(self): + v = np.zeros(VERT_CLASSES) + v[3] = 2.0 + v[10] = 1.0 + out = prepare_vert(v, gaussian_sigma=0.0) + self.assertEqual(out.shape, (VERT_CLASSES,)) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + # with no smoothing the ratio of the two peaks is preserved (2:1) + self.assertAlmostEqual(float(out[3]), 2.0 / 3.0, places=5) + self.assertAlmostEqual(float(out[10]), 1.0 / 3.0, places=5) + + def test_does_not_mutate_input(self): + v = np.zeros(VERT_CLASSES) + v[3] = 2.0 + _ = prepare_vert(v, gaussian_sigma=0.0) + self.assertEqual(float(v[3]), 2.0) + + def test_smoothing_regionwise_sums_to_one(self): + v = np.zeros(VERT_CLASSES) + v[3] = 1.0 + out = prepare_vert(v, gaussian_sigma=0.85, gaussian_radius=2, gaussian_regionwise=True) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + # smoothing spreads the single peak across neighbouring cervical classes + self.assertGreater(np.count_nonzero(out > 1e-6), 1) + + def test_smoothing_regionwise_does_not_leak_across_regions(self): + # peak at last cervical class (index 6); region-wise smoothing must not + # leak probability into the thoracic region (index >= 7). + v = np.zeros(VERT_CLASSES) + v[6] = 1.0 + out = prepare_vert(v, gaussian_sigma=0.85, gaussian_radius=2, gaussian_regionwise=True) + self.assertTrue(np.allclose(out[THOR], 0.0)) + self.assertTrue(np.allclose(out[LUMB], 0.0)) + + def test_smoothing_global_can_leak_across_regions(self): + # global smoothing of the same peak DOES leak into the thoracic region. + v = np.zeros(VERT_CLASSES) + v[6] = 1.0 + out = prepare_vert(v, gaussian_sigma=0.85, gaussian_radius=2, gaussian_regionwise=False) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + self.assertGreater(float(out[7]), 0.0) + + +class Test_prepare_vertgrp(unittest.TestCase): + def test_group_expanded_to_member_classes(self): + # group index 0 (C12) maps to exact classes [0, 1] + g = np.zeros(len(vert_group_idx_to_exact_idx_dict)) + g[0] = 1.0 + out = prepare_vertgrp(g, gaussian_sigma=0.0) + self.assertEqual(out.shape, (VERT_CLASSES,)) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + nonzero = set(np.nonzero(out)[0].tolist()) + self.assertEqual(nonzero, set(vert_group_idx_to_exact_idx_dict[0])) + # the group value is copied onto each member class, then the whole vector + # is normalized; with one group of two members each ends up at 0.5. + self.assertAlmostEqual(float(out[0]), 0.5, places=5) + self.assertAlmostEqual(float(out[1]), 0.5, places=5) + + def test_multiple_groups_normalize(self): + g = np.zeros(len(vert_group_idx_to_exact_idx_dict)) + g[0] = 1.0 # group C12 -> two member classes [0, 1], each gets value 1.0 + g[11] = 3.0 # group L56 -> single exact class [23], gets value 3.0 + out = prepare_vertgrp(g, gaussian_sigma=0.0) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + # raw assigned mass before normalization is 1.0 + 1.0 + 3.0 = 5.0 + # group 11's single member carries 3/5 of the mass + self.assertAlmostEqual(float(out[vert_group_idx_to_exact_idx_dict[11][0]]), 3.0 / 5.0, places=5) + # each member of group 0 carries 1/5 of the mass + self.assertAlmostEqual(float(out[0]), 1.0 / 5.0, places=5) + self.assertAlmostEqual(float(out[1]), 1.0 / 5.0, places=5) + + def test_smoothing_global_branch_sums_to_one(self): + g = np.zeros(len(vert_group_idx_to_exact_idx_dict)) + g[5] = 1.0 + out = prepare_vertgrp(g, gaussian_sigma=0.85, gaussian_regionwise=False) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + + +class Test_prepare_region(unittest.TestCase): + def test_no_smoothing_normalizes(self): + out = prepare_region(np.array([0.2, 0.5, 0.3]), gaussian_sigma=0.0) + self.assertEqual(out.shape, (VERT_CLASSES,)) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + + def test_smoothing_normalizes(self): + out = prepare_region(np.array([0.2, 0.5, 0.3]), gaussian_sigma=0.75, gaussian_radius=1) + self.assertAlmostEqual(float(out.sum()), 1.0, places=5) + + def test_all_zero_input_stays_zero(self): + # guard `np.sum(...) > 0` skips smoothing and division leaves zeros + out = prepare_region(np.array([0.0, 0.0, 0.0]), gaussian_sigma=0.75) + self.assertEqual(float(out.sum()), 0.0) + + +class Test_prepare_vertrel(unittest.TestCase): + def test_no_smoothing_returns_copy_unchanged(self): + vr = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + out = prepare_vertrel(vr, gaussian_sigma=0.0) + self.assertTrue(np.allclose(out, vr)) + # returned object is a copy, not the same array + self.assertIsNot(out, vr) + + def test_smoothing_preserves_shape(self): + vr = np.array([0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) + out = prepare_vertrel(vr, gaussian_sigma=0.75, gaussian_radius=1) + self.assertEqual(out.shape, vr.shape) + # smoothing spreads the single peak to neighbours + self.assertGreater(np.count_nonzero(out > 1e-6), 1) + + +class Test_prepare_vertrel_columns(unittest.TestCase): + def test_column_zero_untouched(self): + m = np.zeros((3, 6)) + m[:, 0] = [5.0, 5.0, 5.0] + out = prepare_vertrel_columns(m.copy(), gaussian_sigma=0.0) + # the first column (NOTHING) is skipped by the loop -> unchanged + self.assertTrue(np.allclose(out[:, 0], [5.0, 5.0, 5.0])) + + def test_column_sum_greater_one_normalizes_to_one(self): + m = np.zeros((3, 6)) + m[:, 1] = [0.5, 0.5, 0.5] # sum 1.5 > 1 -> divide by sum + out = prepare_vertrel_columns(m.copy(), gaussian_sigma=0.0) + self.assertAlmostEqual(float(out[:, 1].sum()), 1.0, places=5) + + def test_column_sum_less_one_divides_by_one_plus_sum(self): + m = np.zeros((3, 6)) + m[:, 2] = [0.1, 0.1, 0.1] # sum 0.3 < 1 -> divide by (1 + sum) + out = prepare_vertrel_columns(m.copy(), gaussian_sigma=0.0) + expected = 0.1 / (1.0 + 0.3) + self.assertTrue(np.allclose(out[:, 2], expected, atol=1e-6)) + # the resulting column sum is below 1 + self.assertLess(float(out[:, 2].sum()), 1.0) + + def test_returns_same_array_object(self): + m = np.zeros((3, 6)) + m[:, 1] = [0.5, 0.5, 0.5] + out = prepare_vertrel_columns(m, gaussian_sigma=0.0) + self.assertIs(out, m) + + +class Test_prepare_vertt13_columns(unittest.TestCase): + def test_column_zero_untouched_rest_normalized(self): + m = np.array([[0.9, 0.1], [0.8, 0.4], [0.7, 0.5]], dtype=float) + out = prepare_vertt13_columns(m.copy()) + # first column untouched + self.assertTrue(np.allclose(out[:, 0], [0.9, 0.8, 0.7])) + # second column normalized to sum 1 + self.assertAlmostEqual(float(out[:, 1].sum()), 1.0, places=5) + # relative proportions preserved within the normalized column + self.assertAlmostEqual(float(out[0, 1]), 0.1 / 1.0, places=5) + + def test_returns_same_array_object(self): + m = np.array([[0.9, 0.1], [0.8, 0.4]], dtype=float) + out = prepare_vertt13_columns(m) + self.assertIs(out, m) + + +class Test_prepare_visible(unittest.TestCase): + def _make_preds(self, visible_pairs): + return {idx: {"soft": {"FULLYVISIBLE": pair, "VERT": [0.0] * 24}} for idx, pair in enumerate(visible_pairs)} + + def test_fullyvisible_present_weight_one(self): + preds = self._make_preds([[0.1, 0.9], [0.2, 0.8]]) + # with visible_w=1 and no smoothing, weight == FULLYVISIBLE[1] + out = prepare_visible(preds, visible_w=1.0, gaussian_sigma=0.0) + self.assertEqual(out.shape, (2,)) + self.assertTrue(np.allclose(out, [0.9, 0.8])) + # weights are clipped into [0, 1] + self.assertTrue(np.all(out >= 0.0)) + self.assertTrue(np.all(out <= 1.0)) + + def test_fullyvisible_absent_returns_ones(self): + preds = {0: {"soft": {"VERT": [0.0] * 24}}, 1: {"soft": {"VERT": [0.0] * 24}}} + out = prepare_visible(preds, visible_w=1.0, gaussian_sigma=0.0) + self.assertTrue(np.allclose(out, [1.0, 1.0])) + + def test_visible_weight_zero_disables_downweighting(self): + preds = self._make_preds([[0.6, 0.4], [0.7, 0.3]]) + # visible_w=0 -> 1 - (1 - x) * 0 = 1 for every instance + out = prepare_visible(preds, visible_w=0.0, gaussian_sigma=0.0) + self.assertTrue(np.allclose(out, [1.0, 1.0])) + + def test_partial_weight_between(self): + preds = self._make_preds([[0.5, 0.5]]) + # weight = 1 - (1 - 0.5) * 0.5 = 0.75 + out = prepare_visible(preds, visible_w=0.5, gaussian_sigma=0.0) + self.assertAlmostEqual(float(out[0]), 0.75, places=3) + + +class Test_fpath_post_processing(unittest.TestCase): + def test_plus_one_offset(self): + # plain class indices simply shift by +1 (0-based -> 1-based) + self.assertEqual(fpath_post_processing([0, 1, 2]), [1, 2, 3]) + + def test_returns_new_list(self): + src = [0, 1, 2] + out = fpath_post_processing(src) + self.assertEqual(src, [0, 1, 2]) + self.assertIsNot(out, src) + + def test_double_t12_second_becomes_t13(self): + # [T11, T12, T12] -> the second T12 turns into the special T13 label (28), + # the rest is offset by +1. T13_LABEL is left untouched by the offset. + out = fpath_post_processing([17, VertExact.T12.value, VertExact.T12.value]) + self.assertEqual(out, [18, 19, T13_LABEL]) + self.assertIn(T13_LABEL, out) + + def test_double_t12_at_start(self): + # T12 at index 0 then T12 -> next index becomes T13 + out = fpath_post_processing([VertExact.T12.value, VertExact.T12.value, 19]) + self.assertEqual(out, [19, T13_LABEL, 20]) + + def test_trailing_double_l5_becomes_l5_l6(self): + # trailing [..., L5, L5] -> the last L5 is bumped by +1 (to L6 slot 24), + # then everything offsets by +1. + out = fpath_post_processing([21, 22, VertExact.L5.value, VertExact.L5.value]) + self.assertEqual(out, [22, 23, 24, 25]) + + def test_single_trailing_l5_just_offsets(self): + out = fpath_post_processing([22, VertExact.L5.value]) + self.assertEqual(out, [23, 24]) + + +class Test_is_valid_vertebra_sequence(unittest.TestCase): + def test_valid_consecutive_ints(self): + self.assertTrue(is_valid_vertebra_sequence([1, 2, 3, 4])) + + def test_invalid_with_gap(self): + self.assertFalse(is_valid_vertebra_sequence([1, 2, 5])) + + def test_valid_t13_to_l1_jump(self): + # special allowed jump: T13 (28) -> L1 (20) + self.assertTrue(is_valid_vertebra_sequence([28, 20])) + + def test_valid_t12_to_l1_jump(self): + # special allowed jump: T12 (18) -> L1 (20) + self.assertTrue(is_valid_vertebra_sequence([18, 20])) + + def test_invalid_backwards(self): + self.assertFalse(is_valid_vertebra_sequence([5, 4, 3])) + + def test_vertexact_input_valid(self): + self.assertTrue(is_valid_vertebra_sequence([VertExact.C1, VertExact.C2, VertExact.C3])) + + def test_vertexact_input_invalid_skip(self): + self.assertFalse(is_valid_vertebra_sequence([VertExact.L1, VertExact.L3])) + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/test_pure_helpers.py b/unit_tests/test_pure_helpers.py new file mode 100644 index 0000000..c7edd61 --- /dev/null +++ b/unit_tests/test_pure_helpers.py @@ -0,0 +1,167 @@ +# Call 'python -m unittest' on this folder # noqa: INP001 +# coverage run -m unittest +# coverage report +# coverage html +from __future__ import annotations + +import unittest + +from spineps.phase_post import find_nearest_lower +from spineps.phase_semantic import overlap_slice +from spineps.seg_enums import Acquisition, Modality, ModelType +from spineps.seg_utils import add_ignore_text + + +class Test_Modality_format_keys(unittest.TestCase): + def test_single_modalities(self): + self.assertEqual(Modality.format_keys(Modality.CT), ["CT", "ct"]) + self.assertEqual(Modality.format_keys(Modality.SEG), ["msk", "seg"]) + self.assertEqual(Modality.format_keys(Modality.T1w), ["T1w", "t1", "T1", "T1c"]) + self.assertEqual( + Modality.format_keys(Modality.T2w), + ["T2w", "dixon", "mr", "t2", "T2", "T2c"], + ) + self.assertEqual( + Modality.format_keys(Modality.Vibe), + ["t1dixon", "vibe", "mevibe", "GRE"], + ) + self.assertEqual(Modality.format_keys(Modality.MPR), ["mpr", "MPR", "Mpr"]) + + def test_list_of_modalities(self): + # A list of modalities concatenates the per-modality keys in order. + self.assertEqual( + Modality.format_keys([Modality.CT, Modality.SEG]), + ["CT", "ct", "msk", "seg"], + ) + self.assertEqual( + Modality.format_keys([Modality.T1w, Modality.MPR]), + ["T1w", "t1", "T1", "T1c", "mpr", "MPR", "Mpr"], + ) + + def test_single_equals_singleton_list(self): + # Passing a single member is equivalent to passing a one-element list. + self.assertEqual( + Modality.format_keys(Modality.Vibe), + Modality.format_keys([Modality.Vibe]), + ) + + def test_not_implemented_modalities(self): + with self.assertRaises(NotImplementedError): + Modality.format_keys(Modality.PD) + with self.assertRaises(NotImplementedError): + Modality.format_keys(Modality.FLAIR) + # Also raises when an unsupported modality appears inside a list. + with self.assertRaises(NotImplementedError): + Modality.format_keys([Modality.CT, Modality.PD]) + + +class Test_Acquisition_format_keys(unittest.TestCase): + def test_defined_acquisitions(self): + self.assertEqual(Acquisition.format_keys(Acquisition.sag), ["sagittal", "sag"]) + self.assertEqual(Acquisition.format_keys(Acquisition.cor), ["coronal", "cor"]) + self.assertEqual(Acquisition.format_keys(Acquisition.ax), ["axial", "ax", "axl"]) + self.assertEqual(Acquisition.format_keys(Acquisition.iso), ["iso", "ISO"]) + + def test_all_four_defined_do_not_raise(self): + # NotImplementedError is not applicable for the 4 defined members. + for acq in (Acquisition.sag, Acquisition.cor, Acquisition.ax, Acquisition.iso): + keys = Acquisition.format_keys(acq) + self.assertIsInstance(keys, list) + self.assertGreater(len(keys), 0) + + +class Test_Enum_Compare(unittest.TestCase): + def test_equality_to_string_name(self): + self.assertEqual(ModelType.nnunet, "nnunet") + self.assertEqual(ModelType.unet, "unet") + self.assertEqual(ModelType.classifier, "classifier") + self.assertEqual(ModelType.nnunet.name, "nnunet") + + def test_equality_to_self(self): + self.assertEqual(ModelType.nnunet, ModelType.nnunet) + + def test_inequality_across_members(self): + self.assertNotEqual(ModelType.nnunet, ModelType.unet) + self.assertNotEqual(ModelType.nnunet, "unet") + self.assertNotEqual(ModelType.unet, "nope") + # Comparison against an unrelated type is not equal. + self.assertNotEqual(ModelType.nnunet, 5.0) + + def test_hash_usable_as_dict_key(self): + # hash() works and members are usable as dict keys. + self.assertEqual(hash(ModelType.nnunet), ModelType.nnunet.value) + mapping = { + ModelType.nnunet: "a", + ModelType.unet: "b", + ModelType.classifier: "c", + } + self.assertEqual(len(mapping), 3) + self.assertEqual(mapping[ModelType.unet], "b") + + def test_str_and_repr_format(self): + self.assertEqual(str(ModelType.nnunet), "ModelType.nnunet") + self.assertEqual(repr(ModelType.nnunet), "ModelType.nnunet") + self.assertEqual(str(Modality.T2w), "Modality.T2w") + self.assertEqual(repr(Acquisition.sag), "Acquisition.sag") + + +class Test_MetaEnum_membership(unittest.TestCase): + def test_member_name_in_enum(self): + self.assertTrue("nnunet" in ModelType) + self.assertTrue("unet" in ModelType) + self.assertTrue("classifier" in ModelType) + + def test_non_member_membership(self): + # MetaEnum.__contains__ returns False for names that are not members + # (it catches both KeyError and ValueError from the member lookup). + self.assertFalse("nope" in ModelType) + self.assertFalse("NNUNET" in ModelType) # case-sensitive: not a member + + +class Test_overlap_slice(unittest.TestCase): + def test_overlapping(self): + self.assertTrue(overlap_slice(slice(0, 10), slice(5, 15))) + self.assertTrue(overlap_slice(slice(5, 15), slice(0, 10))) + # One range fully contained in the other. + self.assertTrue(overlap_slice(slice(0, 20), slice(5, 10))) + + def test_touching_at_border(self): + # Borders are inclusive, so touching at a single point counts as overlap. + self.assertTrue(overlap_slice(slice(0, 10), slice(10, 20))) + self.assertTrue(overlap_slice(slice(10, 20), slice(0, 10))) + + def test_disjoint(self): + self.assertFalse(overlap_slice(slice(0, 5), slice(10, 15))) + self.assertFalse(overlap_slice(slice(10, 15), slice(0, 5))) + + +class Test_find_nearest_lower(unittest.TestCase): + def test_returns_largest_element_below_x(self): + self.assertEqual(find_nearest_lower([1, 5, 10, 20], 12), 10) + self.assertEqual(find_nearest_lower([1, 5, 10, 20], 20), 10) + self.assertEqual(find_nearest_lower([3, 1, 2], 3), 2) + + def test_returns_min_when_none_lower(self): + # No element strictly below x -> falls back to min(seq). + self.assertEqual(find_nearest_lower([10, 20, 30], 5), 10) + self.assertEqual(find_nearest_lower([10, 20, 30], 10), 10) + + +class Test_add_ignore_text(unittest.TestCase): + def test_mutates_last_element(self): + texts = ["First.", "Second."] + add_ignore_text(texts) + # Last char dropped from the last entry, then the ignore marker appended. + self.assertEqual(texts[-1], "Second (IGNORED).") + # Earlier entries are untouched. + self.assertEqual(texts[0], "First.") + + def test_returns_none_and_mutates_in_place(self): + texts = ["only."] + result = add_ignore_text(texts) + self.assertIsNone(result) + self.assertEqual(texts[0], "only (IGNORED).") + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/test_resolution.py b/unit_tests/test_resolution.py new file mode 100644 index 0000000..c7bf7ed --- /dev/null +++ b/unit_tests/test_resolution.py @@ -0,0 +1,73 @@ +# Call 'python -m unittest' on this folder # noqa: INP001 +# coverage run -m unittest +# coverage report +# coverage html +"""Tests for the resolution-aware mm<->voxel threshold helpers.""" + +from __future__ import annotations + +import unittest + +from spineps.utils.resolution import ( + REFERENCE_VOXEL_VOLUME_MM3, + REFERENCE_ZOOM, + isotropic_area_to_voxels, + mm3_to_voxels, + mm_to_voxels, + mm_to_voxels_axis, +) + +CT_ZOOM = (0.8, 0.8, 0.8) +COARSE_ZOOM = (1.5, 1.5, 1.5) + + +class Test_Resolution_Helpers(unittest.TestCase): + def test_reference_volume_roundtrips(self): + # A threshold defined as N voxels at the reference resolution must convert back to N voxels there. + for n in (1, 30, 40, 100): + mm3 = n * REFERENCE_VOXEL_VOLUME_MM3 + self.assertEqual(mm3_to_voxels(mm3, REFERENCE_ZOOM), n) + + def test_reference_distance_roundtrips(self): + for n in (2, 4, 5, 25): + mm = n * min(REFERENCE_ZOOM) + self.assertEqual(mm_to_voxels(mm, REFERENCE_ZOOM), n) + + def test_reference_axis_roundtrips(self): + # Distance along the inferior axis (1.65 mm) reproduces the voxel count there. + for n in (5, 10, 64): + mm = n * REFERENCE_ZOOM[1] + self.assertAlmostEqual(mm_to_voxels_axis(mm, REFERENCE_ZOOM, 1), n) + + def test_reference_area_roundtrips(self): + for n in (10, 20, 50): + mm2 = n * REFERENCE_VOXEL_VOLUME_MM3 ** (2.0 / 3.0) + self.assertEqual(isotropic_area_to_voxels(mm2, REFERENCE_ZOOM), n) + + def test_finer_resolution_needs_more_voxels(self): + # The same physical volume spans more voxels at a finer (smaller) spacing. + mm3 = 30 * REFERENCE_VOXEL_VOLUME_MM3 + at_ct = mm3_to_voxels(mm3, CT_ZOOM) # 0.8 mm iso -> finer than 1.65 axis + at_coarse = mm3_to_voxels(mm3, COARSE_ZOOM) # 1.5 mm iso -> coarser + self.assertGreater(at_ct, 30) + self.assertLess(at_coarse, 30) + + def test_ct_volume_value(self): + # 30 voxels at the reference is ~27.84 mm^3 -> 54 voxels at 0.8 mm iso (0.512 mm^3/voxel). + mm3 = 30 * REFERENCE_VOXEL_VOLUME_MM3 + self.assertEqual(mm3_to_voxels(mm3, CT_ZOOM), 54) + + def test_minimum_floor(self): + self.assertEqual(mm3_to_voxels(0.0, REFERENCE_ZOOM, minimum=1), 1) + self.assertEqual(mm3_to_voxels(0.0, REFERENCE_ZOOM, minimum=5), 5) + self.assertEqual(mm_to_voxels(0.0, REFERENCE_ZOOM, minimum=0), 0) + self.assertEqual(isotropic_area_to_voxels(0.0, REFERENCE_ZOOM, minimum=1), 1) + + def test_return_types_are_int(self): + self.assertIsInstance(mm3_to_voxels(50.0, CT_ZOOM), int) + self.assertIsInstance(mm_to_voxels(5.0, CT_ZOOM), int) + self.assertIsInstance(isotropic_area_to_voxels(20.0, CT_ZOOM), int) + + +if __name__ == "__main__": + unittest.main() diff --git a/unit_tests/test_semantic.py b/unit_tests/test_semantic.py index 73e991c..3e5dea4 100644 --- a/unit_tests/test_semantic.py +++ b/unit_tests/test_semantic.py @@ -101,7 +101,7 @@ def test_compatibility(self): def test_phase_preprocess(self): mri, _subreg, _vert, _label = get_test_mri() for pad_size in range(7): - origin_diff = max([d * float(pad_size) for d in mri.zoom]) + 1e-4 + origin_diff = max(d * float(pad_size) for d in mri.zoom) + 1e-4 # print(origin_diff) preprossed_input, errcode = preprocess_input(mri, debug_data={}, pad_size=pad_size, verbose=True) print(mri)