diff --git a/.gitignore b/.gitignore index b455cc3..4a38f5d 100644 --- a/.gitignore +++ b/.gitignore @@ -185,4 +185,6 @@ notebooks/data/*/* # ignore the data test_images tests/aerial/data -dev/ \ No newline at end of file +dev/ +*.local* +.claude \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..3000d04 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,105 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +**HIPP (Historical Image Pre-Processing)** is a Python library for preprocessing scanned historical aerial and satellite images for photogrammetric analysis (Structure from Motion). It supports: +- Aerial images with fiducial markers (USGS, NAGAP datasets) +- Declassified US reconnaissance satellite images (KH-9 Hexagon panoramic/mapping camera, KH-4/4A/4B Corona) + +## Commands + +The project uses **Hatch** as the project manager. + +```bash +hatch shell dev # Enter development environment +hatch run dev:check # Type check (mypy --strict) + lint (ruff) +hatch run dev:pytest # Run tests +hatch run dev:lab # Start Jupyter Lab on port 8333 +hatch run dev:kernel # Install IPython kernel +``` + +Direct commands (inside `hatch shell dev`): +```bash +pytest # All tests +pytest tests/aerial/test_core.py # Single test file +ruff check . # Lint +mypy src/ --strict --ignore-missing-imports --no-warn-unused-ignores --allow-untyped-calls +pre-commit install # Install git hooks (run once after cloning) +``` + +Line length is 120 characters. Pre-commit hooks run ruff + mypy on every commit. + +## Architecture + +### Package Layout (`src/hipp/`) + +``` +hipp/ +├── image.py # Low-level image I/O, CLAHE, resizing +├── math.py # Geometric transforms, matrix ops +├── intrinsics.py # Intrinsics class (camera calibration parameters) +├── tools.py # GUI point picking, archive extraction, quickviews +├── aerial/ # Fiducial-based aerial image preprocessing +│ ├── core.py # Main pipeline: template creation → detection → restitution +│ ├── fiducials.py # Fiducial marker detection, matching, transformation +│ └── quality_control.py +├── kh9pc/ # KH-9 panoramic camera preprocessing +│ ├── pipeline.py # End-to-end orchestration (PipelineStep, KH9Pipeline, PipelineConfig) +│ ├── image_mosaic.py # ORB keypoint matching, RANSAC, image stitching (ImageAlignment) +│ ├── batch.py # Batch join_images +│ ├── quality_control.py +│ ├── utils.py +│ └── restitution/ # Image rectification +│ ├── types.py # StepResult, StrategyAttempt data classes +│ ├── strategy.py # RectificationStrategy + Collimation/Poly/Flat strategies +│ ├── vertical.py # VerticalDetector (collimation line detection) +│ └── plotters.py +└── dataquery/ # USGS/NAGAP data download +``` + +### Data Flow + +**Aerial pipeline** (`hipp.aerial.core`): +1. `create_fiducial_templates()` — user picks fiducial locations on reference image +2. `iter_detect_fiducials()` — OpenCV template matching on input images +3. `filter_detected_fiducials()` — removes low-confidence matches +4. `compute_transformations()` — estimates affine/similarity transforms +5. `iter_image_restitution()` — crops, applies CLAHE, outputs standardized images + +**KH-9 pipeline** (`hipp.kh9pc.pipeline.KH9Pipeline`): +1. Extract archive → list of TIF scan strips +2. `join_images()` — stitch strips via ORB keypoints + RANSAC affine alignment +3. Restitute (rectify): + - Detect vertical collimation edges (`VerticalDetector`) + - Detect horizontal edges with strategy fallback: `CollimationStrategy` → `PolyStrategy` → `FlatStrategy` + - Apply analytical inverse-map transform (bilinear interpolation between fitted polynomial curves via `build_inverse_map`) +4. Generate QC reports + +Valid `PipelineConfig.steps` names (in order): `extract`, `join_images`, `quickview_mosaic`, `restitution`, `quickview_final`, `qc_report`. + +**CLI** (`python -m hipp.kh9pc`): +```bash +python -m hipp.kh9pc --input scan.tgz --output /out/images/DZB1215.tif --qc-dir /out/qc +python -m hipp.kh9pc --input t1.tif t2.tif t3.tif --output /out/DZB1215.tif --qc-dir /out/qc +python -m hipp.kh9pc --input scan.tgz --output /out/DZB1215.tif --qc-dir /out/qc --config cfg.toml +``` +`PipelineConfig.from_toml()` accepts keys: `overwrite`, `cleanup`, `steps`, and `output_height` (integer, default `22064`). + +### Key Patterns + +- **`PipelineStep`**: declarative step class with `inputs`/`outputs`/`overwrite` — enables skip-if-done logic +- **Strategy pattern** in `kh9pc/restitution/strategy.py`: multiple fallback strategies for edge detection; all inherit `RectificationStrategy` ABC which chains `_fit()` + `make_inverse_map()` + centering translation into `transform(output_height)`; strategies expose `detected_region() → ((col_off, row_top), (width, height))` +- **Pandas Series for fiducials**: coordinate data stored with named keys like `corner_top_left_x`, `midside_left_x` +- **`Intrinsics` class**: wraps focal length, pixel pitch, true fiducial coordinates in mm, principal point +- **3×3 homogeneous matrices** throughout for image transforms +- **Rasterio** for all geospatial raster I/O; **OpenCV** for image operations; `build_inverse_map` in `utils.py` for the analytical curve-interpolation warp +- **Intermediate files persisted as `.joblib`**: `vertical.joblib`, `horizontal.joblib`, `alignments.joblib` — individual steps can be re-run by loading these directly + +### Notebooks + +Practical usage examples live in `notebooks/`: +- `aerial_preprocessing.ipynb` — aerial fiducial workflow +- `kh9pc_preprocessing.ipynb` — full KH-9 pipeline +- `kh9pc_collimation_rectification.ipynb` — detailed rectification walkthrough diff --git a/pyproject.toml b/pyproject.toml index b582045..8fb1c6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,13 @@ dependencies = [ "rasterio", "scikit-image", "scikit-learn", - "requests" + "requests", + "click", ] +[project.scripts] +hipp-kh9pc = "hipp.kh9pc.cli:main" + [project.urls] Documentation = "https://github.com/godinlu/hipp#readme" Issues = "https://github.com/godinlu/hipp/issues" @@ -53,13 +57,15 @@ dependencies = [ "pre-commit", "usgsxplore", "ipykernel", + "jupyter", ] [tool.hatch.envs.dev.scripts] check = [ "ruff check .", - "mypy . --strict --ignore-missing-imports --no-warn-unused-ignores --allow-untyped-calls" + "mypy src/ --strict --ignore-missing-imports --no-warn-unused-ignores --allow-untyped-calls" ] kernel = "python -m ipykernel install --user --name hipp --display-name 'Python (hipp)'" +lab = "jupyter lab --no-browser --ip=127.0.0.1 --port=8333" [tool.coverage.run] diff --git a/src/hipp/image.py b/src/hipp/image.py index 9d1c82a..3a04650 100644 --- a/src/hipp/image.py +++ b/src/hipp/image.py @@ -3,6 +3,7 @@ Description: some function for the image processing """ +import logging import warnings from pathlib import Path from typing import Callable @@ -12,13 +13,38 @@ import rasterio from numpy.typing import NDArray from rasterio.errors import NotGeoreferencedWarning -from rasterio.warp import Resampling, reproject +from rasterio.warp import Resampling from rasterio.windows import Window from scipy.interpolate import RectBivariateSpline from tqdm import tqdm warnings.filterwarnings("ignore", category=NotGeoreferencedWarning) +logger = logging.getLogger(__name__) + + +class LogProgressBar: + """Log-friendly progress bar that emits one line per character of width.""" + + def __init__(self, label: str, total: int, log: logging.Logger, width: int = 5) -> None: + self._label = label + self._total = total + self._log = log + self._width = width + self._last_step = -1 + + def update(self, i: int) -> None: + step = i * self._width // self._total + if step != self._last_step: + self._last_step = step + filled = i * self._width // self._total + bar = "#" * filled + "." * (self._width - filled) + self._log.info("%s [%s]", self._label, bar) + + def close(self) -> None: + bar = "#" * self._width + self._log.info("%s [%s]", self._label, bar) + def apply_clahe( image: cv2.typing.MatLike, @@ -78,11 +104,24 @@ def resize_img( def generate_quickview( - raster_filepath: str, - output_path: str, + raster_filepath: str | Path, + output_path: str | Path | None = None, scale_factor: float = 0.2, interpolation: int = Resampling.average, + jpeg_quality: int = 95, + overwrite: bool = False, ) -> None: + raster_filepath = Path(raster_filepath) + + if output_path is None: + output_path = raster_filepath.parent / "quickviews" / raster_filepath.with_suffix(".jpg").name + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + if output_path.exists() and not overwrite: + logger.info("Skipping generate_quickview: %s (already exists, overwrite=False)", str(output_path)) + return + with rasterio.open(raster_filepath) as src: width = int(src.width * scale_factor) height = int(src.height * scale_factor) @@ -95,14 +134,17 @@ def generate_quickview( if count == 1: img_cv2 = qv_img[0] # 2D array for single band else: - # transpose (b, H, W) -> (H, W, b) and transforme rgb to bgr + # transpose (b, H, W) -> (H, W, b) and convert rgb to bgr img_cv2 = cv2.cvtColor(np.transpose(qv_img, (1, 2, 0)), cv2.COLOR_RGB2BGR) - # If single band, make sure dtype is uint8 if img_cv2.dtype != np.uint8: img_cv2 = img_cv2.astype(np.uint8) - cv2.imwrite(output_path, img_cv2) + encode_params = [] + if output_path.suffix.lower() in (".jpg", ".jpeg"): + encode_params = [cv2.IMWRITE_JPEG_QUALITY, jpeg_quality] + + cv2.imwrite(str(output_path), img_cv2, encode_params) def resize_raster_blockwise( @@ -243,8 +285,6 @@ def remap_tif_blockwise( output_size: tuple[int, int] | None = None, block_size: int = 256, interpolation: int = cv2.INTER_CUBIC, - pbar: bool = True, - pbar_desc: str = "Remaping tif", padding: int = 2, lowres_step: int | None = None, ) -> None: @@ -276,10 +316,6 @@ def remap_tif_blockwise( interpolation : int, default=cv2.INTER_CUBIC OpenCV interpolation flag (e.g., `cv2.INTER_LINEAR`, `cv2.INTER_CUBIC`, `cv2.INTER_NEAREST`). Defines how pixel values are interpolated during remapping. - pbar : bool, default=True - Whether to display a tqdm progress bar during processing. - pbar_desc : str, default="Remaping tif" - Description text displayed in the tqdm progress bar. padding : int, default=2 Number of extra pixels to read around the computed source window, to reduce border artifacts caused by bicubic interpolation. @@ -333,10 +369,11 @@ def remap_tif_blockwise( for dst_x0 in range(0, output_size[0], block_size) for dst_y0 in range(0, output_size[1], block_size) ] - # Wrap block iterator with a tqdm progress bar if enabled - iterator = tqdm(blocks, desc=pbar_desc, unit="block") if pbar else blocks + n_blocks = len(blocks) + pbar = LogProgressBar(f"remapping {Path(input_path).name}", n_blocks, logger) - for dst_x0, dst_y0 in iterator: + for block_idx, (dst_x0, dst_y0) in enumerate(blocks): + pbar.update(block_idx) dst_x1 = min(dst_x0 + block_size, output_size[0]) dst_y1 = min(dst_y0 + block_size, output_size[1]) @@ -373,8 +410,8 @@ def remap_tif_blockwise( tf_ygrid = points_transformed[:, 1].reshape(ygrid.shape).astype(np.float32) # Compute source window bounds with a padding to avoid artefact on edge caused of bicubic interpolation - src_x0 = int(np.floor(tf_xgrid.min())) - padding - src_y0 = int(np.floor(tf_ygrid.min())) - padding + src_x0 = max(0, int(np.floor(tf_xgrid.min())) - padding) + src_y0 = max(0, int(np.floor(tf_ygrid.min())) - padding) src_x1 = int(np.ceil(tf_xgrid.max())) + padding src_y1 = int(np.ceil(tf_ygrid.max())) + padding @@ -401,132 +438,124 @@ def remap_tif_blockwise( dst_window = Window(col_off=dst_x0, row_off=dst_y0, width=dst_x1 - dst_x0, height=dst_y1 - dst_y0) dst.write(remapped_block, 1, window=dst_window) + pbar.close() -def warp_raster_pixels( - raster_filepath: str | Path, - output_raster_filepath: str | Path, - transformation_matrix: cv2.typing.MatLike, - output_size: None | tuple[int, int] = None, - max_workers: int = 5, - resampling: int = Resampling.cubic, - band_idx: int = 1, -) -> None: - """ - Apply a pixel-wise affine warp to a raster band and save the result to a new file. - The function reprojects the selected raster band using a custom affine transformation - (e.g., translation, rotation, scaling) provided as a 2D transformation matrix. - The pixel grid of the output raster is updated to reflect the transformation, while - preserving the original spatial reference system. +def match_multiple_templates( + image: cv2.typing.MatLike, + templates: list[cv2.typing.MatLike], + threshold: float = 0.8, + method: int = cv2.TM_CCOEFF_NORMED, + nms_threshold: float = 0.3, +) -> tuple[list[tuple[int, int, int, int]], list[float], list[int]]: + """ + Match multiple templates and return filtered detections. Parameters ---------- - raster_filepath : str - Path to the input raster file. - output_raster_filepath : str - Path where the warped raster will be written. - transformation_matrix : cv2.typing.MatLike - A 2×3 affine-like transformation matrix (as used in OpenCV) defining the warp - to apply in pixel space. - output_size : tuple[int, int] or None, optional - Dimensions (width, height) of the output raster. If None (default), the input - raster dimensions are used. - max_workers : int, default 5 - Number of threads to use during reprojection. - resampling : int, default rasterio.warp.Resampling.cubic - Resampling method applied during the warp (e.g., nearest, bilinear, cubic). - band_idx : int, default 1 - Index of the raster band (1-based) to process. + image : np.ndarray + Input image. + templates : list of np.ndarray + List of templates. + threshold : float + Matching score threshold. + method : int + OpenCV template matching method. + nms_threshold : float + IoU threshold for NMS. Returns ------- - None - The warped raster is written to `output_raster_filepath`. - - Notes - ----- - - Only one band is processed at a time. For multi-band rasters, call the function - once per band or extend it accordingly. - - The output transform is temporarily updated to apply the warp, then reset to the - original transform to keep the spatial reference consistent. - - No CRS transformation is performed; warping is done strictly in pixel space. + boxes : list of tuple + (x, y, w, h) + scores : list of float + template_ids : list of int + Index of template for each detection """ - affine_transform = rasterio.Affine(*transformation_matrix[:2].flatten()) - # create the parent output directory if necessary - Path(output_raster_filepath).parent.mkdir(exist_ok=True, parents=True) + all_boxes: list[tuple[int, int, int, int]] = [] + all_scores: list[float] = [] + all_template_ids: list[int] = [] - with rasterio.open(raster_filepath) as src: - output_size = output_size if output_size else (src.width, src.height) - profile = src.profile.copy() - profile.update( - { - "width": output_size[0], - "height": output_size[1], - "transform": src.transform * ~affine_transform, - "compress": "lzw", - "BIGTIFF": "YES", - } - ) - with rasterio.open(output_raster_filepath, "w", **profile) as dst: - reproject( - source=rasterio.band(src, band_idx), - destination=rasterio.band(dst, band_idx), - resampling=resampling, - num_threads=max_workers, - ) - dst.transform = src.transform + # --- Collect detections --- + for tid, template in enumerate(templates): + h, w = template.shape[:2] + if h > image.shape[0] or w > image.shape[1]: + continue + res = cv2.matchTemplate(image, template, method) + ys, xs = np.where(res >= threshold) -def apply_clahe_to_tif_blockwise( - input_tif_path: str, - output_tif_path: str, - block_size: int = 256, - clip_limit: float = 2.0, - tile_grid_size: tuple[int, int] = (8, 8), -) -> None: - """ - Apply CLAHE on a GeoTIFF image block by block and save the result. + for x, y in zip(xs, ys): + all_boxes.append((int(x), int(y), int(w), int(h))) + all_scores.append(float(res[y, x])) + all_template_ids.append(tid) - Args: - input_tif_path (str): Path to input .tif image. - output_tif_path (str): Path to save the output .tif image. - block_size (int): Size of the square block/window to process. - clip_limit (float): CLAHE clip limit parameter. - tile_grid_size (tuple[int, int]): CLAHE tile grid size. + if not all_boxes: + return [], [], [] - """ + # --- Apply global NMS --- + indices = cv2.dnn.NMSBoxes( + all_boxes, + all_scores, + score_threshold=threshold, + nms_threshold=nms_threshold, + ) - # Open source image with rasterio - with rasterio.open(input_tif_path) as src: - profile = src.profile.copy() + if len(indices) == 0: + return [], [], [] - # Update profile for output - profile.update( - dtype=rasterio.uint8, # CLAHE output is uint8 - count=src.count, - compress="lzw", - bigtiff="TRUE", - ) + indices_np = np.array(indices).reshape(-1) - # Create CLAHE object from OpenCV - clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size) + boxes = [all_boxes[i] for i in indices_np] + scores = [all_scores[i] for i in indices_np] + template_ids = [all_template_ids[i] for i in indices_np] - with rasterio.open(output_tif_path, "w", **profile) as dst: - height = src.height - width = src.width + return boxes, scores, template_ids - # Read only the first band (grayscale) - for row_off in range(0, height, block_size): - for col_off in range(0, width, block_size): - # Define window dimensions (may be smaller on edges) - win_width = min(block_size, width - col_off) - win_height = min(block_size, height - row_off) - window = Window(col_off, row_off, win_width, win_height) +class SubImage: + """A windowed view of a rasterio raster with coordinate conversion helpers. + + Reads band 1 of a raster within a given window (optionally resampled) and + provides ``to_global`` / ``to_local`` to convert pixel coordinates between + the sub-image space and the full raster space. + """ + + def __init__( + self, + raster: str | Path | rasterio.DatasetReader, + window: Window | None, + out_shape: tuple[int, int, int] | None = None, + resampling: Resampling = Resampling.average, + ): + if isinstance(raster, rasterio.DatasetReader): + self._setup(raster, window, out_shape, resampling) + else: + with rasterio.open(raster) as src: + self._setup(src, window, out_shape, resampling) + + def _setup( + self, + src: rasterio.DatasetReader, + window: Window | None, + out_shape: tuple[int, int, int] | None, + resampling: Resampling, + ) -> None: + self.window = window or Window(0, 0, src.width, src.height) + self.band = src.read(1, window=self.window, out_shape=out_shape, resampling=resampling) + + actual_shape = self.band.shape # (height, width) after read + self.out_shape = (1, actual_shape[0], actual_shape[1]) + self._scale = np.array( + [self.window.width / actual_shape[1], self.window.height / actual_shape[0]], dtype=np.float64 + ) + self._offset = np.array([self.window.col_off, self.window.row_off], dtype=np.float64) - # Read block from source (as ndarray) - block = src.read(1, window=window) + def to_global(self, pts: NDArray[np.floating]) -> NDArray[np.floating]: + """Convert local sub-image pixel coordinates to global raster coordinates.""" + return pts * self._scale + self._offset - # Write result block to destination - dst.write(clahe.apply(block), 1, window=window) + def to_local(self, pts: NDArray[np.floating]) -> NDArray[np.floating]: + """Convert global raster pixel coordinates to local sub-image coordinates.""" + return (pts - self._offset) / self._scale diff --git a/src/hipp/kh9pc/__init__.py b/src/hipp/kh9pc/__init__.py index de4dfff..76e71fe 100644 --- a/src/hipp/kh9pc/__init__.py +++ b/src/hipp/kh9pc/__init__.py @@ -1,13 +1,26 @@ -from . import collimation_lines, core, image_mosaic from . import quality_control as qc -from .batch import iter_collimation_rectification, join_images, join_images_asp +from .pipeline import batch_preprocess_kh9pc, preprocess_kh9pc +from .mosaic import image_mosaic +from .restitution import ( + CollimationStrategy, + FiducialStrategy, + FlatStrategy, + MixedStrategy, + PolyStrategy, + VerticalDetector, +) +from .restitution.base import DetectionError __all__ = [ - "image_mosaic", - "join_images", - "join_images_asp", - "iter_collimation_rectification", + "batch_preprocess_kh9pc", + "preprocess_kh9pc", "qc", - "collimation_lines", - "core", + "image_mosaic", + "CollimationStrategy", + "DetectionError", + "FiducialStrategy", + "FlatStrategy", + "MixedStrategy", + "PolyStrategy", + "VerticalDetector", ] diff --git a/src/hipp/kh9pc/__main__.py b/src/hipp/kh9pc/__main__.py new file mode 100644 index 0000000..372e316 --- /dev/null +++ b/src/hipp/kh9pc/__main__.py @@ -0,0 +1,3 @@ +from hipp.kh9pc.cli import main + +main() diff --git a/src/hipp/kh9pc/batch.py b/src/hipp/kh9pc/batch.py deleted file mode 100644 index fa30e33..0000000 --- a/src/hipp/kh9pc/batch.py +++ /dev/null @@ -1,184 +0,0 @@ -""" -Copyright (c) 2025 HIPP developers -Description: Functions for applying core preprocessing functions to images batch -""" - -import os -from collections import defaultdict -from pathlib import Path - -# from hipp.image import warp_tif_blockwise_to_dst -from hipp.kh9pc.core import collimation_rectification, image_mosaic -from hipp.kh9pc.image_mosaic import compute_sequential_alignment, mosaic_images - - -def join_images_asp( - images_directory: str, - output_directory: str, - overwrite: bool = False, - threads: int = 0, - cleanup: bool = True, - verbose: bool = True, - dryrun: bool = False, -) -> None: - """ - Groups and mosaics TIF image tiles from a directory by scene ID. - - Each group of images is identified by the prefix before the first underscore in the filename. - Images must be named in a way that ensures alphabetical ordering corresponds to spatial/temporal logic - (e.g., img_a.tif, img_b.tif, etc.). - - Parameters: - images_directory (str): Path to the directory containing .tif image tiles. - output_directory (str): Path where the output mosaicked images will be saved. - overwrite (bool): If False and an output file already exists, it will be skipped. Default is False. - threads (int): Number of threads to use for mosaicking. Default is 0 (auto). - cleanup (bool): If True, temporary log/auxiliary files will be deleted after mosaicking. Default is True. - verbose (bool): If True, prints progress and command details. Default is True. - dryrun (bool): If True, simulates the process without executing commands. Default is False. - - Returns: - None - """ - scene_tiles = defaultdict(list) - - # Group image tiles by scene ID (assumed to be the prefix before the first underscore) - for filename in os.listdir(images_directory): - if filename.endswith(".tif") and "_" in filename: - scene_id = filename.split("_")[0] - scene_tiles[scene_id].append(os.path.join(images_directory, filename)) - - # For each scene group, create a mosaicked image - for scene_id in sorted(scene_tiles): - output_image_path = os.path.join(output_directory, f"{scene_id}.tif") - image_paths = sorted(scene_tiles[scene_id]) - - # Call image_mosaic for each group - # Sort image paths alphabetically to ensure consistent mosaicking order - image_mosaic(image_paths, output_image_path, overwrite, threads, cleanup, verbose, dryrun) - - -def join_images( - images_directory: str, - output_directory: str, - overwrite: bool = False, - verbose: bool = True, - max_workers: int = 5, -) -> None: - """ - Groups and mosaics TIF image tiles from a directory by scene ID. - - Each group of images is identified by the prefix before the first underscore in the filename. - Images must be named in a way that ensures alphabetical ordering corresponds to spatial/temporal logic - (e.g., img_a.tif, img_b.tif, etc.). - """ - scene_tiles = defaultdict(list) - - # Group image tiles by scene ID (assumed to be the prefix before the first underscore) - for filename in os.listdir(images_directory): - if filename.endswith(".tif") and "_" in filename: - scene_id = filename.split("_")[0] - scene_tiles[scene_id].append(os.path.join(images_directory, filename)) - - # For each scene group, create a mosaicked image - for scene_id in sorted(scene_tiles): - output_image_path = os.path.join(output_directory, f"{scene_id}.tif") - image_paths = sorted(scene_tiles[scene_id]) - - if os.path.exists(output_image_path) and not overwrite: - print(f"Skipping {output_image_path}: output already exists") - else: - matrix = compute_sequential_alignment(image_paths, verbose=verbose) - mosaic_images(matrix, output_image_path, max_workers, verbose) - - -def iter_collimation_rectification( - input_dir: str | Path, - output_dir: str | Path, - qc_dir: str | Path, - bg_px_threshold: int = 20, - collimation_line_dist: int = 21770, - transformation: str = "tps", - verbose: bool = True, - overwrite: bool = False, -) -> None: - """ - Apply collimation rectification iteratively to all raster images in a directory. - - This function loops over all `.tif` images in the input directory and applies - the `collimation_rectification()` function to each. The user can choose between - Thin Plate Spline (TPS) or Affine transformations for geometric correction. - Quality control (QC) outputs for each image are stored in the specified QC directory. - - Args: - input_dir (str | Path): - Directory containing the input raster images to rectify. - output_dir (str | Path): - Directory where rectified raster images will be saved. - qc_dir (str | Path): - Directory where quality control plots and intermediate data will be stored. - bg_px_threshold (int, optional): - Minimum pixel intensity difference used to detect vertical edges. Defaults to 20. - collimation_line_dist (int, optional): - Expected distance (in pixels) between the top and bottom collimation lines - in the rectified image. Defaults to 21770. - transformation (str, optional): - Type of geometric transformation to apply. - - "tps": Thin Plate Spline (non-linear, smooth correction) - - "affine": Affine (linear correction) - Defaults to "tps". - verbose (bool, optional): - If True, prints progress updates during processing. Defaults to True. - overwrite (bool, optional): - If False, skips processing for images that already have a rectified output. - If True, overwrites existing rectified images. Defaults to False. - - Returns: - None - - Workflow: - 1. Scan the `input_dir` for all `.tif` files. - 2. For each image: - a. Check if the output file already exists. - b. If not (or if `overwrite=True`), perform collimation rectification using - `collimation_rectification()`. - 3. Store rectified images in `output_dir` and QC data in `qc_dir`. - - Notes: - - This function is designed for batch rectification of multiple raster scenes. - - Each image’s intermediate data (collimation lines, grids, QC plots) - will be organized under its corresponding subdirectories in `qc_dir`. - - The same transformation type (`transformation`) is applied to all images - in the batch for consistency. - - Example: - >>> iter_collimation_rectification( - ... input_dir="raw_scenes/", - ... output_dir="rectified_scenes/", - ... qc_dir="quality_control/", - ... bg_px_threshold=25, - ... collimation_line_dist=21800, - ... transformation="tps", - ... verbose=True, - ... overwrite=False - ... ) - """ - input_dir = Path(input_dir) - output_dir = Path(output_dir) - - for input_raster_path in sorted(input_dir.glob("*.tif")): - output_raster_path = output_dir / input_raster_path.name - - if output_raster_path.exists() and not overwrite: - if verbose: - print(f"Skipping {input_raster_path.name} : output already exists") - else: - collimation_rectification( - input_raster_path, - output_raster_path, - qc_dir, - bg_px_threshold, - collimation_line_dist, - transformation, - verbose, - ) diff --git a/src/hipp/kh9pc/cli.py b/src/hipp/kh9pc/cli.py new file mode 100644 index 0000000..35a156b --- /dev/null +++ b/src/hipp/kh9pc/cli.py @@ -0,0 +1,80 @@ +# mypy: disable-error-code="misc" +import logging +import sys +from pathlib import Path + +import click + +from hipp.kh9pc.pipeline import batch_preprocess_kh9pc, preprocess_kh9pc + + +def _configure_logging(verbosity: int) -> None: + level = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG}.get(verbosity, logging.DEBUG) + logging.basicConfig( + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%H:%M:%S", stream=sys.stdout + ) + logging.getLogger("hipp").setLevel(level) + + +@click.group() +def main() -> None: + """KH-9 Panoramic Camera preprocessing tools.""" + + +@main.command() +@click.option( + "--input", + "-i", + "input_files", + multiple=True, + required=True, + metavar="FILE", + help="Input archive (.tgz) or tile files (.tif)", +) +@click.option( + "--output-dir", "-o", required=True, type=Path, metavar="DIR", help="Output directory for restituted images" +) +@click.option("--overwrite", is_flag=True, help="Overwrite existing outputs") +@click.option("--keep-work", is_flag=True, help="Keep intermediate working files") +@click.option("-v", "--verbose", count=True, help="Increase verbosity (-v INFO, -vv DEBUG)") +def preproc(input_files: tuple[str, ...], output_dir: Path, overwrite: bool, keep_work: bool, verbose: int) -> None: + """Preprocess a single KH-9 PC scan.""" + _configure_logging(verbose) + preprocess_kh9pc( + input=list(input_files) if len(input_files) > 1 else input_files[0], + output_dir=output_dir, + overwrite=overwrite, + keep_work=keep_work, + ) + + +@main.command() +@click.option( + "--input-dir", + "-i", + required=True, + type=Path, + metavar="DIR", + help="Directory containing input archives or tile subdirectories", +) +@click.option( + "--output-dir", "-o", required=True, type=Path, metavar="DIR", help="Output directory for restituted images" +) +@click.option("--n-jobs", "-j", default=1, show_default=True, help="Number of parallel jobs") +@click.option("--overwrite", is_flag=True, help="Overwrite existing outputs") +@click.option("--keep-work", is_flag=True, help="Keep intermediate working files") +@click.option("--dry-run", is_flag=True, help="Log what would be processed without running") +@click.option("-v", "--verbose", count=True, help="Increase verbosity (-v INFO, -vv DEBUG)") +def batch_preproc( + input_dir: Path, output_dir: Path, n_jobs: int, overwrite: bool, keep_work: bool, dry_run: bool, verbose: int +) -> None: + """Batch preprocess multiple KH-9 PC scans.""" + _configure_logging(verbose) + batch_preprocess_kh9pc( + input_dir=input_dir, + output_dir=output_dir, + overwrite=overwrite, + keep_work=keep_work, + n_jobs=n_jobs, + dry_run=dry_run, + ) diff --git a/src/hipp/kh9pc/collimation_lines.py b/src/hipp/kh9pc/collimation_lines.py deleted file mode 100644 index 473ae7c..0000000 --- a/src/hipp/kh9pc/collimation_lines.py +++ /dev/null @@ -1,544 +0,0 @@ -""" -Copyright (c) 2025 HIPP developers -Description: Functions to process lines for KH-9 Panoramic camera images -""" - -from pathlib import Path - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import rasterio -from numpy.typing import NDArray -from rasterio.warp import Resampling -from rasterio.windows import Window -from scipy.signal import find_peaks -from sklearn.base import BaseEstimator, RegressorMixin -from sklearn.linear_model import LinearRegression, RANSACRegressor -from sklearn.metrics import root_mean_squared_error -from sklearn.pipeline import Pipeline, make_pipeline -from sklearn.preprocessing import PolynomialFeatures, StandardScaler - -#################################################################################################################################### -# PUBLIC FUNCTIONS -#################################################################################################################################### - - -def detect_vertical_edges( - raster_filepath: str | Path, - px_threshold: int = 20, - width_fraction: float = 0.05, - stride: tuple[int, int] = (20, 20), - ransac_residual_threshold: float = 100, - ransac_max_trials: int = 100, - plot: bool = True, - output_plot_path: str | Path | None = None, -) -> dict[str, int]: - """ - Detect the left and right vertical edges of a raster image using RANSAC regression. - - This function extracts two vertical bands (left and right) from a raster image, - identifies strong vertical edge points based on pixel intensity changes, and fits - a robust RANSAC regression line to estimate the most probable edge position. - The detected vertical positions (in pixel coordinates) represent the image's - lateral boundaries, which can be used for geometric calibration or alignment tasks. - - Args: - raster_filepath (str | Path): - Path to the raster image file. - px_threshold (int, optional): - Minimum pixel intensity difference used to identify edge points. Defaults to 20. - width_fraction (float, optional): - Fraction of the image width used to define the left and right edge bands. Defaults to 0.05. - stride (tuple[int, int], optional): - Downsampling step (width, height) applied when reading the image to reduce computation. Defaults to (20, 20). - ransac_residual_threshold (float, optional): - Maximum distance for a data point to be classified as an inlier by the RANSAC algorithm. Defaults to 100. - ransac_max_trials (int, optional): - Maximum number of iterations performed by the RANSAC algorithm. Defaults to 100. - plot (bool, optional): - Whether to display the visualization of detected edges and RANSAC fits. Defaults to True. - output_plot_path (str | Path | None, optional): - Path to save the resulting plot as an image file. If None, the plot is not saved. Defaults to None. - - Returns: - dict[str, int]: - A dictionary mapping "left" and "right" to the detected x-coordinate positions (in pixels) - of the corresponding vertical edges. - - Notes: - - This function relies on helper functions `extract_vertical_edge_points()` and `vertical_ransac()`. - - The RANSAC method ensures robustness against noise and false edge detections. - - The detected vertical positions can be used to correct lateral distortions in remote sensing imagery. - """ - res = {} - fig, axes = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True) - - with rasterio.open(raster_filepath) as src: - window_width = int(src.width * width_fraction) - window_left = Window(0, 0, window_width, src.height) - window_right = Window(src.width - window_width, 0, window_width, src.height) - - windows = {"left": window_left, "right": window_right} - for i, (side, window) in enumerate(windows.items()): - out_shape = (1, window.height // stride[1], window.width // stride[0]) - band = src.read(1, window=window, out_shape=out_shape, resampling=Resampling.average) - x_local, y_local = extract_vertical_edge_points(band, px_threshold, side) - ransac_local, stats = vertical_ransac(x_local, y_local, ransac_residual_threshold, ransac_max_trials) - res[side] = int(ransac_local.estimator_.constant_ * stride[0] + window.col_off) - - axes[i].imshow(band, cmap="gray") - - inlier_mask = ransac_local.inlier_mask_ - axes[i].scatter(x_local[inlier_mask], y_local[inlier_mask], s=5, color="green", label="inliers") - axes[i].scatter(x_local[~inlier_mask], y_local[~inlier_mask], s=5, color="red", label="outliers") - - axes[i].axvline(x=ransac_local.estimator_.constant_, color="blue", label="RANSAC line") - stats_str = "\n".join([f"{k}: {v:.2f}" for k, v in stats.items()]) - axes[i].set_title(f"{side} edge detection \n({stats_str})") - - handles, labels = axes[0].get_legend_handles_labels() - fig.legend(handles, labels, loc="lower center", ncol=3) - - if output_plot_path: - Path(output_plot_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(output_plot_path) - if plot: - plt.show() - else: - plt.close() - - return res - - -def detect_collimation_lines( - raster_filepath: str | Path, - height_fraction: float = 0.15, - stride: tuple[int, int] = (256, 10), - polynomial_degree: int = 2, - ransac_residual_threshold: float = 80.0, - collimation_line_dist: int = 21770, - plot: bool = True, - output_plot_path: str | Path | None = None, -) -> dict[str, Pipeline]: - """ - Detects and fits collimation lines in the top and bottom portions of a raster image. - - The function reads the input raster, extracts two horizontal windows (top and bottom), - detects peak positions in each, fits several polynomial RANSAC models, and selects - the best matching pair of top/bottom lines based on their vertical distance consistency. - - Parameters - ---------- - raster_filepath : str or Path - Path to the input raster image. - height_fraction : float, optional - Fraction of the raster height to use for top and bottom windows. - stride : tuple[int, int], optional - Downsampling stride for (x, y) directions. - polynomial_degree : int, optional - Degree of the polynomial model used in RANSAC fitting. - ransac_residual_threshold : float, optional - Maximum residual allowed for inlier detection in RANSAC. - collimation_line_dist : int, optional - Expected distance between top and bottom collimation lines. - plot : bool, optional - If True, display the results interactively. - output_plot_path : str or Path, optional - If provided, save the plot to this path. - - Returns - ------- - dict - Dictionary containing the best polynomial models for the top and bottom lines. - """ - - # Create figure with two subplots (top and bottom) - fig, axes = plt.subplots(2, 1, figsize=(10, 8), sharex=True, constrained_layout=True) - polys_dict, inliers_dict, peaks_dict = {}, {}, {} - - with rasterio.open(raster_filepath) as src: - # Define top and bottom windows based on height fraction - window_height = int(src.height * height_fraction) - window_top = Window(0, 0, src.width, window_height) - window_bottom = Window(0, src.height - window_height, src.width, window_height) - windows = {"top": window_top, "bottom": window_bottom} - - # Process both top and bottom sections - for i, (side, window) in enumerate(windows.items()): - # Read and downsample raster band in the selected window - out_shape = (1, window.height // stride[1], window.width // stride[0]) - band = src.read(1, window=window, out_shape=out_shape, resampling=Resampling.average) - - # Detect peaks (local maxima) in each column - peaks_local = find_column_peaks(band) - - # Convert peak coordinates from local (window) to global raster coordinates - peaks_global = peaks_local * np.array(stride) + np.array([0, window.row_off]) - - # Fit several RANSAC polynomial models to the detected peaks - polys, inlier_masks = fit_iterative_ransac_polynomials( - peaks_global, residual_threshold=ransac_residual_threshold, degree=polynomial_degree - ) - polys_dict[side] = polys - inliers_dict[side] = inlier_masks - peaks_dict[side] = peaks_global - - # Display the raster window with proper spatial extent - extent = [ - window.col_off, - window.col_off + window.width, - window.row_off + window.height, - window.row_off, - ] - axes[i].imshow(band, cmap="gray", extent=extent, aspect="auto") - axes[i].set_title(side.upper()) - - # ---- Select the best matching pair of top/bottom polynomials ---- - x = np.linspace(0, src.width, 100) - best_score, best_pair = np.inf, (0, 0) - - # Compare all combinations of top/bottom models - for i, poly_top in enumerate(polys_dict["top"]): - for j, poly_bottom in enumerate(polys_dict["bottom"]): - y_top = poly_top.predict(x.reshape(-1, 1)) - y_bottom = poly_bottom.predict(x.reshape(-1, 1)) - - # Compute deviation from expected collimation distance - dist = np.abs(collimation_line_dist - np.abs(y_top - y_bottom)) - score = np.mean(dist) + 10 * np.std(dist) - - if score < best_score: - best_score = score - best_pair = (i, j) - - # ---- Plot selected polynomial models and their inliers/outliers ---- - for idx, side in enumerate(["top", "bottom"]): - poly = polys_dict[side][best_pair[idx]] - peaks = peaks_dict[side] - inliers_mask = inliers_dict[side][best_pair[idx]] - - y_pred = poly.predict(x.reshape(-1, 1)) - - # Plot polynomial curve - axes[idx].plot(x, y_pred, color="red", lw=2, label="Best polynomial") - - # Plot inliers and outliers - axes[idx].scatter(peaks[inliers_mask, 0], peaks[inliers_mask, 1], s=8, color="lime", label="Inliers") - axes[idx].scatter(peaks[~inliers_mask, 0], peaks[~inliers_mask, 1], s=8, color="gray", label="Outliers") - - axes[idx].legend(loc="upper right") - - # ---- Plot display and saving ---- - if output_plot_path: - Path(output_plot_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(output_plot_path) - - if plot: - plt.show() - else: - plt.close() - - # Return the best pair of fitted models - return { - "top": polys_dict["top"][best_pair[0]], - "bottom": polys_dict["bottom"][best_pair[1]], - } - - -def compute_source_and_target_grid( - detected_vertical_edges: dict[str, int], - detected_horizontal_ransac: dict[str, RANSACRegressor], - colimation_line_dist: int = 21770, - margin: tuple[int, int] = (0, 147), - grid_shape: tuple[int, int] = (100, 50), -) -> tuple[NDArray[np.generic], NDArray[np.generic], tuple[int, int]]: - """ - Generate source and destination control points for Thin Plate Spline (TPS) rectification - as structured 2D grids. - - This function creates two corresponding 2D grids of control points: - - `src_points`: distorted coordinates from the detected vertical edges and horizontal - RANSAC lines (top and bottom). Shape `(grid_shape[0], grid_shape[1], 2)`. - - `dst_points`: regular target coordinates forming a rectified rectangular grid. - Shape `(grid_shape[0], grid_shape[1], 2)`. - - The grids are generated column-wise: each column of points spans from top to bottom - between the detected/fitted top and bottom edges. - - Parameters - ---------- - detected_vertical_edges : dict[str, int] - Dictionary containing pixel positions of left and right vertical edges. - detected_horizontal_ransac : dict[str, RANSACRegressor] - Dictionary containing RANSAC models for the top and bottom horizontal edges. - colimation_line_dist : int, optional - Target distance (in pixels) between the top and bottom collimation lines in the - rectified frame. Default is 21770. - margin : tuple[int, int], optional - (horizontal, vertical) pixel margins added to all points. Default is (0, 147). - grid_shape : tuple[int, int], optional - Number of control points along (width, height) axes of the grid. Default is (100, 50). - - Returns - ------- - src_points : np.ndarray - Array of distorted source coordinates with shape `(grid_shape[0], grid_shape[1], 2)`. - Each entry contains `[x, y]` coordinates in the original image. - dst_points : np.ndarray - Array of regular destination coordinates with shape `(grid_shape[0], grid_shape[1], 2)`. - Each entry contains `[x, y]` coordinates in the rectified frame. - output_size : tuple[int, int] - Expected size `(width, height)` of the rectified raster including margins. - - Notes - ----- - - The source points are computed by evaluating the RANSAC fits for the top and bottom - horizontal edges and interpolating linearly between them for each column. - - The destination points form a uniform rectangular grid spanning from (0,0) to - (cropped_img_width, colimation_line_dist), shifted by the specified margin. - """ - cropped_img_width = detected_vertical_edges["right"] - detected_vertical_edges["left"] - - # --- Destination points --- - x_dst = np.linspace(0, cropped_img_width, grid_shape[0]) - y_top_dst = np.zeros_like(x_dst) - y_bottom_dst = np.full_like(x_dst, colimation_line_dist) - - dst_points = np.zeros((grid_shape[0], grid_shape[1], 2), dtype=float) - - for i, (xi, yt, yb) in enumerate(zip(x_dst, y_top_dst, y_bottom_dst)): - ys = np.linspace(yt, yb, grid_shape[1]) - xs = np.full_like(ys, xi) - dst_points[i, :, 0] = xs # x coordinates - dst_points[i, :, 1] = ys # y coordinates - - # Apply margin - dst_points += np.array(margin) - - # --- Source points --- - x_src = x_dst + detected_vertical_edges["left"] - y_top_src = detected_horizontal_ransac["top"].predict(x_src.reshape(-1, 1)) - y_bottom_src = detected_horizontal_ransac["bottom"].predict(x_src.reshape(-1, 1)) - - src_points = np.zeros((grid_shape[0], grid_shape[1], 2), dtype=float) - - for i, (xi, yt, yb) in enumerate(zip(x_src, y_top_src, y_bottom_src)): - ys = np.linspace(yt, yb, grid_shape[1]) - xs = np.full_like(ys, xi) - src_points[i, :, 0] = xs # x coordinates - src_points[i, :, 1] = ys # y coordinates - - # --- Output size --- - output_size = (cropped_img_width + 2 * margin[0], colimation_line_dist + 2 * margin[1]) - - return src_points, dst_points, output_size - - -#################################################################################################################################### -# PRIVATE FUNCTIONS -#################################################################################################################################### - - -def find_column_peaks(image: cv2.typing.MatLike, n_peaks: int = 3, distance: float = 0.2) -> NDArray[np.generic]: - """ - Detects the most prominent peaks along each column of an image. - - This function scans each column of the input image as a 1D signal and identifies - up to `n_peaks` local maxima based on their prominence. The detected peaks are - returned as (x, y) coordinates in image space. - - Args: - image (cv2.typing.MatLike): Input grayscale image or 2D array. - n_peaks (int, optional): Maximum number of peaks to keep per column. - Defaults to 3. - distance (float, optional): Minimum vertical separation between peaks - (as a fraction of image height). Defaults to 0.2. - - Returns: - NDArray[np.generic]: Array of shape (N, 2) containing (x, y) coordinates - of detected peaks across all columns. - - Example: - >>> peaks = find_column_peaks(image, n_peaks=2, distance=0.1) - >>> plt.scatter(peaks[:,0], peaks[:,1], s=2, color="red") - """ - peaks_x, peaks_y = [], [] - n_rows, n_cols = image.shape[:2] - distance_px = int(distance * n_rows) - - for col in range(n_cols): - signal = image[:, col].astype(int) - - # --- Detect peaks and compute their prominence --- - peaks, properties = find_peaks(signal, prominence=0, distance=distance_px) - - k = min(n_peaks, len(peaks)) - top_indices = np.argpartition(properties["prominences"], -k)[-k:] - selected_peaks = peaks[top_indices] - - for y in selected_peaks: - peaks_x.append(col) - peaks_y.append(y) - return np.column_stack((peaks_x, peaks_y)) - - -def fit_iterative_ransac_polynomials( - peaks: NDArray[np.generic], residual_threshold: float, n_ransac: int = 3, degree: int = 2 -) -> tuple[list[Pipeline], list[NDArray[np.bool]]]: - """ - Fit multiple polynomial models iteratively using RANSAC regression. - - Each iteration fits a polynomial model on the remaining (non-inlier) points, - removing detected inliers after each successful fit. This allows extraction - of several dominant polynomial trends from a set of (x, y) peak coordinates. - - Args: - peaks (NDArray[np.floating]): Array of shape (n_samples, 2) containing (x, y) points. - residual_threshold (float): Maximum residual for a data point to be classified as an inlier. - n_ransac (int, optional): Maximum number of RANSAC iterations/models to fit. Default is 3. - degree (int, optional): Degree of the polynomial features. Default is 2. - - Returns: - Tuple[List[Pipeline], List[NDArray[np.bool_]]]: - - models: List of fitted polynomial pipelines (StandardScaler + PolyFeatures + LinearRegression). - - inlier_masks: List of boolean masks indicating inliers for each fitted model. - """ - X = peaks[:, 0].reshape(-1, 1) - y = peaks[:, 1] - - # Initialize tracking masks - remaining_mask = np.ones_like(y, dtype=bool) - models: list[Pipeline] = [] - inlier_masks: list[NDArray[np.bool]] = [] - - for _ in range(n_ransac): - # Stop if too few points remain - if np.sum(remaining_mask) < 3: - break - - # Create polynomial regression pipeline - poly_model = make_pipeline( - StandardScaler(), - PolynomialFeatures(degree=degree), - LinearRegression(), - ) - - # Fit RANSAC on remaining points - ransac = RANSACRegressor(poly_model, residual_threshold=residual_threshold, min_samples=3) - ransac.fit(X[remaining_mask], y[remaining_mask]) - - # Store fitted model - models.append(ransac.estimator_) - - # Compute inlier mask in global coordinates - inliers_mask = np.zeros_like(y, dtype=bool) - inliers_mask[remaining_mask] = ransac.inlier_mask_ - inlier_masks.append(inliers_mask) - - # Exclude inliers for next iteration - remaining_mask[inliers_mask] = False - - return models, inlier_masks - - -def extract_vertical_edge_points( - image: cv2.typing.MatLike, px_threshold: int = 20, direction: str = "left" -) -> tuple[NDArray[np.int64], NDArray[np.int64]]: - """ - Extract candidate points corresponding to a vertical edge (left or right) in an image. - - For each image row, this function locates the first (or last) pixel - exceeding a given intensity threshold. These edge points can be used - later to fit a vertical boundary (e.g., using RANSAC). - - Args: - image (cv2.typing.MatLike): Grayscale image array. - px_threshold (int, optional): Pixel intensity threshold used to detect edge pixels. - Default is 20. - direction (str, optional): Edge direction to detect. - Must be either "left" (first pixel above threshold in each row) - or "right" (last pixel above threshold in each row). Default is "left". - - Returns: - tuple[NDArray[np.int64], NDArray[np.int64]]: - - x_coords: 1D array of detected x-coordinates (column indices). - - y_coords: 1D array of corresponding y-coordinates (row indices). - - Raises: - ValueError: If `direction` is not "left" or "right". - - Example: - >>> x, y = extract_vertical_edge_points(image, px_threshold=30, direction="right") - >>> plt.scatter(x, y, s=2, color='red') - """ - mask = image > px_threshold - - if direction == "left": - idx = np.argmax(mask, axis=1) - elif direction == "right": - idx = mask.shape[1] - 1 - np.argmax(mask[:, ::-1], axis=1) - else: - raise ValueError("direction must be 'left' or 'right'") - return idx[mask.any(axis=1)], np.arange(len(idx))[mask.any(axis=1)] - - -def vertical_ransac( - x: NDArray[np.generic], - y: NDArray[np.generic], - residual_threshold: float = 100, - max_trials: int = 1000, -) -> tuple[RANSACRegressor, dict[str, float]]: - """ - Fit a vertical edge model (constant x-value) using RANSAC regression. - - This function estimates the most probable vertical boundary in an image - given a set of (x, y) points that approximately follow a vertical line. - The model fitted is a constant regressor (predicting a fixed x value) - robust to outliers via the RANSAC algorithm. - - Args: - x (NDArray[np.generic]): Array of x-coordinates (column indices). - y (NDArray[np.generic]): Array of y-coordinates (row indices). - residual_threshold (float, optional): Maximum residual allowed for a point - to be classified as an inlier. Default is 100. - max_trials (int, optional): Maximum number of RANSAC iterations. Default is 1000. - - Returns: - tuple[RANSACRegressor, dict[str, float]]: - - ransac: Fitted RANSACRegressor model. - - stats: Dictionary containing: - * "residuals_rmse": Root Mean Squared Error (RMSE) of inliers. - * "inlier_percent": Percentage of inlier points. - - Example: - >>> x_edge, y_edge = extract_vertical_edge_points(image, px_threshold=30) - >>> model, stats = vertical_ransac(x_edge, y_edge, residual_threshold=50) - >>> print(stats) - {'residuals_rmse': 2.1, 'inlier_percent': 93.4} - """ - Y = y.reshape(-1, 1) - - class ConstantRegressor(BaseEstimator, RegressorMixin): # type: ignore[misc] - """Regressor that predicts a constant value (mean of y).""" - - def fit(self, X: NDArray[np.float64], y: NDArray[np.float64]) -> "ConstantRegressor": - self.constant_ = np.median(y) # ou moyenne - return self - - def predict(self, X: NDArray[np.float64]) -> NDArray[np.float64]: - return np.full(shape=(len(X),), fill_value=self.constant_, dtype=float) - - ransac = RANSACRegressor( - estimator=ConstantRegressor(), - max_trials=max_trials, - residual_threshold=residual_threshold, - min_samples=1, - ) - ransac.fit(Y, x) - x_pred = ransac.predict(Y) - stats = { - "residuals_rmse": root_mean_squared_error(x[ransac.inlier_mask_], x_pred[ransac.inlier_mask_]), - "inlier_percent": np.mean(ransac.inlier_mask_) * 100, - } - return ransac, stats diff --git a/src/hipp/kh9pc/core.py b/src/hipp/kh9pc/core.py deleted file mode 100644 index 7b93a00..0000000 --- a/src/hipp/kh9pc/core.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -Copyright (c) 2025 HIPP developers -Description: core functions for the preprocessing of KH-9 PC images -""" - -import glob -import os -import subprocess -from pathlib import Path - -import joblib -from skimage.transform import AffineTransform, ThinPlateSplineTransform - -# import pyvips -from hipp.image import remap_tif_blockwise -from hipp.kh9pc.collimation_lines import ( - compute_source_and_target_grid, - detect_collimation_lines, - detect_vertical_edges, -) -from hipp.kh9pc.quality_control import ( - plot_collimation_gradient, - plot_distance_between_collimation_lines, - plot_src_and_dst_points, -) - -#################################################################################################################################### -# MAIN FUNCTIONS -#################################################################################################################################### - - -def image_mosaic( - image_paths: list[str], - output_image_path: str, - overwrite: bool = False, - threads: int = 0, - cleanup: bool = True, - verbose: bool = True, - dryrun: bool = False, -) -> None: - """ - Mosaics a list of images into a single output image using the external 'image_mosaic' command. - - Parameters: - image_paths (list[str]): List of paths to input image tiles. - output_image_path (str): Path to the output mosaic image. - overwrite (bool): If False and the output file exists, the function will skip processing. Default is False. - threads (int): Number of threads to use for processing. Default is 0 (let the tool decide). - cleanup (bool): Whether to remove temporary log and auxiliary files after processing. Default is True. - verbose (bool): If True, prints detailed progress and command information. Default is True. - dryrun (bool): If True, builds the command but does not execute it. Useful for debugging. Default is False. - - Returns: - None - """ - # Skip processing if the output exists and overwriting is disabled - if os.path.exists(output_image_path) and not overwrite: - if verbose: - print(f"Skipping {output_image_path}: output already exists") - return - - if verbose: - print(f"\nMosaicking {output_image_path} with {len(image_paths)} tiles...\n") - - # Build the command for the external 'image_mosaic' tool - cmd = [ - "image_mosaic", - *image_paths, - "--ot", - "byte", - "--overlap-width", - "3000", - "--threads", - str(threads), - "-o", - output_image_path, - ] - - # Display the constructed command - if verbose: - print(" ".join(cmd)) - - # Execute the command unless in dry run mode - if not dryrun: - try: - subprocess.run( - cmd, - check=True, - stdout=None if verbose else subprocess.DEVNULL, - stderr=None if verbose else subprocess.DEVNULL, - ) - except subprocess.CalledProcessError as e: - print(f"Error while processing {output_image_path}: {e}") - - # Optionally remove temporary log and auxiliary files generated by the tool - if cleanup: - for f in glob.glob(f"{output_image_path}-log-image_mosaic-*.txt") + glob.glob(f"{output_image_path}.aux.xml"): - os.remove(f) - - -def collimation_rectification( - input_raster_path: str | Path, - output_raster_path: str | Path, - qc_dir: str | Path, - bg_px_threshold: int = 20, - collimation_line_dist: int = 21770, - transformation: str = "tps", - verbose: bool = True, -) -> None: - """ - Perform collimation rectification on a raster image using a geometric transformation - (Thin Plate Spline or Affine warping). - - This function detects the horizontal and vertical collimation features in a raster image, - estimates the geometric deformation, and rectifies the image by applying an inverse - geometric transformation. The user can choose between a Thin Plate Spline (TPS) or - Affine transformation model. It also produces several quality control (QC) plots - illustrating each processing step, including line detection, distance consistency, - and transformation effects. - - Args: - input_raster_path (str | Path): - Path to the input raster image to be rectified. - output_raster_path (str | Path): - Path where the geometrically rectified image will be saved. - qc_dir (str | Path): - Directory where quality control plots and intermediate data will be stored. - bg_px_threshold (int, optional): - Minimum pixel intensity difference used to detect vertical edges. Defaults to 20. - collimation_line_dist (int, optional): - Expected distance (in pixels) between the top and bottom collimation lines - in the rectified image. Defaults to 21770. - transformation (str, optional): - Type of geometric transformation to apply. - - "tps": Thin Plate Spline (non-linear, smooth correction) - - "affine": Affine (linear correction) - Defaults to "tps". - verbose (bool, optional): - If True, prints progress updates during processing. Defaults to True. - - Returns: - None - - Workflow: - 1. Detect top and bottom collimation lines using RANSAC polynomial fitting. - 2. Detect left and right vertical edges using robust RANSAC regression. - 3. Estimate and plot the distance between the detected collimation lines. - 4. Compute source and destination grids from detected features. - 5. Estimate the chosen transformation model (TPS or Affine). - 6. Apply the inverse transformation to rectify the image geometry. - 7. Re-detect collimation lines after transformation for validation. - 8. Generate and save all QC plots (line detection, distances, gradients). - - Notes: - - The function assumes that the raster image contains clear collimation marks. - - All intermediate QC results are saved to `qc_dir` for traceability. - - The transformation preserves image size consistency using the computed `output_size`. - - The Thin Plate Spline model provides a smooth, non-linear geometric correction, - while the Affine model applies a simpler linear correction. - - Example: - >>> collimation_rectification( - ... input_raster_path="raw_scene.tif", - ... output_raster_path="rectified_scene.tif", - ... qc_dir="quality_control/", - ... bg_px_threshold=25, - ... collimation_line_dist=21800, - ... transformation="affine", - ... verbose=True - ... ) - """ - # transform to Path every paths - input_raster_path = Path(input_raster_path) - output_raster_path = Path(output_raster_path) - qc_dir = Path(qc_dir) - data_dir = qc_dir / "data" / input_raster_path.stem - data_dir.mkdir(exist_ok=True, parents=True) - - if verbose: - print(f"Collimation rectification for {input_raster_path.name} : ") - - # Detect collimation lines - if verbose: - print("\t-[1/4] Estimation of collimation lines...") - collimation_lines = detect_collimation_lines( - input_raster_path, - plot=False, - output_plot_path=qc_dir / "collimation_lines" / f"{input_raster_path.stem}.png", - ) - joblib.dump(collimation_lines, data_dir / "collimation_lines.pkl") - - # Detect vertical lines - if verbose: - print("\t-[2/4] Detection of vertical lines...") - vertical_edges = detect_vertical_edges( - input_raster_path, - bg_px_threshold, - plot=False, - output_plot_path=qc_dir / "vertical_edges" / f"{input_raster_path.stem}.png", - ) - joblib.dump(vertical_edges, data_dir / "vertical_edges.pkl") - - # make the source and destination points - src_grid, dst_grid, output_size = compute_source_and_target_grid( - vertical_edges, collimation_lines, collimation_line_dist - ) - joblib.dump(src_grid, data_dir / "src_grid.pkl") - joblib.dump(src_grid, data_dir / "dst_grid.pkl") - - src_points = src_grid.reshape(-1, 2) - dst_points = dst_grid.reshape(-1, 2) - - # plot them for quality control - plot_src_and_dst_points( - src_points, - dst_points, - output_size, - plot=False, - output_plot_path=qc_dir / "transformations" / f"{input_raster_path.stem}.png", - ) - - # choose the goood tranformation and set some hyperparamters - inverse_remap: ThinPlateSplineTransform | AffineTransform - if transformation == "tps": - inverse_remap = ThinPlateSplineTransform() - lowres_step = 100 - block_size = 2**13 - elif transformation == "affine": - inverse_remap = AffineTransform() - lowres_step = None - block_size = 256 - else: - raise ValueError(f"{transformation} not supported, support only 'tps' and 'affine'") - - # for remapping the hipp.image.remap_tif_blockwise use the inverse transformation function - # so we estimate our transformation with dst -> src - inverse_remap.estimate(dst_points, src_points) - - # remap the image with the previously computed function inverse_remap_function - if verbose: - print("\t-[3/4] Warping image (can take some times)...") - remap_tif_blockwise( - input_raster_path, - output_raster_path, - inverse_remap, - output_size, - block_size=block_size, - pbar_desc=f"{input_raster_path.name} remapping", - lowres_step=lowres_step, - ) - - # detect collimation lines after the transformation - if verbose: - print("\t-[4/4] Estimation of collimation lines after transformation...") - - collimation_lines_after_transform = detect_collimation_lines( - output_raster_path, - 0.05, - plot=False, - output_plot_path=qc_dir / "collimation_lines_after_transform" / f"{input_raster_path.stem}.png", - ) - joblib.dump(collimation_lines_after_transform, data_dir / "collimation_lines_after_transform.pkl") - - # Plot the distance between collimation lines for quality control - plot_distance_between_collimation_lines( - collimation_lines, - collimation_lines_after_transform, - output_size[0], - collimation_line_dist, - plot=False, - output_plot_path=qc_dir / "distance_between_collimation_lines" / f"{input_raster_path.stem}.png", - ) - - # plot both collimation gradient before and after transform - plot_collimation_gradient( - collimation_lines, - collimation_lines_after_transform, - output_size[0], - plot=False, - output_plot_path=qc_dir / "collimation_gradients" / f"{input_raster_path.stem}.png", - ) - - -#################################################################################################################################### -# PRIVATE FUNCTIONS -#################################################################################################################################### diff --git a/src/hipp/kh9pc/fiducial_patterns.py b/src/hipp/kh9pc/fiducial_patterns.py new file mode 100644 index 0000000..a7874e2 --- /dev/null +++ b/src/hipp/kh9pc/fiducial_patterns.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass + +import numpy as np + +from numpy.typing import NDArray +from typing import Literal + + +PATTERNS = Literal[ + "regulare_sparse", "regulare_mid", "regular_dense", "segmented_mid", "segmented_dense", "serialized_time_word" +] + + +@dataclass +class DetectedPattern: + pattern: PATTERNS + points: NDArray[np.floating] + expected_width: int + score: float + + @property + def count(self) -> int: + return len(self.points) + + +def coverage_score(points: NDArray[np.floating], expected_width: int) -> float: + if len(points) == 0: + return 0.0 + result = float((np.max(points[:, 0]) - np.min(points[:, 0])) / expected_width) + return min(result, 1.0) + + +def compute_spacings(points: NDArray[np.floating]) -> NDArray[np.floating]: + sorted_points = points[np.argsort(points[:, 0])] + return np.hypot(np.diff(sorted_points[:, 0]), np.diff(sorted_points[:, 1])) + + +def compute_intra_segment_spacings(points: NDArray[np.floating]) -> NDArray[np.floating]: + """Return only intra-segment spacings, filtering out inter-segment gaps.""" + spacings = compute_spacings(points) + if len(spacings) == 0: + return spacings + return spacings[spacings < np.median(spacings) * 1.5] # type: ignore[no-any-return] + + +def theorical_spacing_from_pattern(pattern: PATTERNS) -> int: + SPARSE_SPACING: int = 19014 + MID_SPACING: int = round(SPARSE_SPACING / 5) + if "sparse" in pattern: + return SPARSE_SPACING + elif "mid" in pattern: + return MID_SPACING + else: + raise ValueError(f"No theorical spacing exist for the pattern {pattern}") + + +def spacing_lo_hi_from_pattern(pattern: PATTERNS) -> tuple[int, int]: + DENSE_MAX_SPACING: int = 1480 + if pattern == "serialized_time_word": + raise ValueError(f"No spacing lo hi existe for the pattern : {pattern}") + if "dense" in pattern: + return (DENSE_MAX_SPACING - 600, DENSE_MAX_SPACING) + else: + max_delta = 200 if "sparse" in pattern else 100 + spacing = theorical_spacing_from_pattern(pattern) + return (spacing - max_delta, spacing + max_delta) + + +def spacing_score_from_pattern(pattern: PATTERNS, points: NDArray[np.floating]) -> float: + if pattern == "serialized_time_word": + return 0.0 + + if len(points) == 0: + return 0.0 + + spacings = compute_spacings(points) + median_spacing = np.median(spacings) + + # test if the distribution is in bounds of the pattern else return 0 + lo, hi = spacing_lo_hi_from_pattern(pattern) + if not lo <= median_spacing <= hi: + return 0.0 + + # 1.5× sits between 1× and 2× spacing, so it cleanly separates regular from gap spacings + regular_spacings = spacings[spacings < median_spacing * 1.5] + expected_count = int(sum(round(s / median_spacing) for s in spacings)) + detection_rate = len(spacings) / expected_count + return float(coefficient_of_variation_score(regular_spacings) * detection_rate) + + +def evaluate_pattern(pattern: PATTERNS, points: NDArray[np.floating], expected_width: int) -> DetectedPattern: + return DetectedPattern( + pattern=pattern, + points=points, + expected_width=expected_width, + score=float(spacing_score_from_pattern(pattern, points) * coverage_score(points, expected_width)), + ) + + +def compute_global_src_and_dst_points( + top_pattern: DetectedPattern, bottom_pattern: DetectedPattern +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: + # top and bottom fiducials distances + # computed with the median take on multiple images + Y_DIST: int = 23242 + + spacing = theorical_spacing_from_pattern(top_pattern.pattern) + if theorical_spacing_from_pattern(bottom_pattern.pattern) != spacing: + raise ValueError("Both pattern should have the same distribution (mid or sparse).") + + mid_actual = float((np.median(top_pattern.points[:, 1]) + np.median(bottom_pattern.points[:, 1])) / 2) + top_y_dst = mid_actual - Y_DIST / 2 + bottom_y_dst = mid_actual + Y_DIST / 2 + + top_src, top_dst = compute_src_and_dst_points(top_pattern.points, spacing, top_y_dst) + bot_src, bot_dst = compute_src_and_dst_points(bottom_pattern.points, spacing, bottom_y_dst) + + return np.vstack((top_src, bot_src)), np.vstack((top_dst, bot_dst)) + + +def compute_expected_fiducial_count(pattern: PATTERNS, expected_width: int) -> int: + """Return expected number of fiducials across an image of expected_width pixels.""" + spacing = theorical_spacing_from_pattern(pattern) + return round(expected_width / spacing) + 1 + + +######################################################################################## +# UTILS FUNCTIONS +######################################################################################## + + +def centers_xy_from_boxes(boxes: NDArray[np.floating] | NDArray[np.integer]) -> NDArray[np.floating]: + """Return (N, 2) array of box centers from (N, 4) ``[x, y, w, h]`` boxes.""" + return boxes[:, :2] + boxes[:, 2:] * 0.5 + + +def coefficient_of_variation_score(x: NDArray[np.floating]) -> float: + mean = np.mean(x) + + if mean == 0: + return 0.0 + + coefficient_of_variation = np.std(x) / mean + return float(1.0 / (1.0 + coefficient_of_variation)) + + +def compute_src_and_dst_points( + points: NDArray[np.floating], true_distance: float, y_dst: float | None = None +) -> tuple[NDArray[np.floating], NDArray[np.floating]]: + # compute y dst with median if not provideed + y_dst = y_dst or float(np.median(points[:, 1])) + + # compute sorted spacing + sorted_points = points[np.argsort(points[:, 0])] + spacing = np.hypot(np.diff(sorted_points[:, 0]), np.diff(sorted_points[:, 1])) + + # compute the median spacing with a filtering to remove gap between segement + median_spacing = np.median(spacing[spacing < 1.5 * true_distance]) + + idx = np.concatenate(([0], np.round(spacing / median_spacing))) + idx = np.cumulative_sum(idx) + dst_x = sorted_points[0, 0] + idx * true_distance + + dst_points = np.column_stack([dst_x, np.full_like(dst_x, y_dst)]) + + return sorted_points, dst_points diff --git a/src/hipp/kh9pc/image_mosaic.py b/src/hipp/kh9pc/image_mosaic.py deleted file mode 100644 index 214dbf3..0000000 --- a/src/hipp/kh9pc/image_mosaic.py +++ /dev/null @@ -1,637 +0,0 @@ -""" -Copyright (c) 2025 HIPP developers -Description: Functions to recreate in python the image_mosaic function from ASP -""" - -import os - -import cv2 -import numpy as np -import rasterio -import rasterio.transform -import rasterio.warp -from rasterio.windows import Window -from skimage.measure import ransac -from skimage.transform import EuclideanTransform -from tqdm import tqdm - -from hipp.math import transform_coord - - -#################################################################################################################################### -# MAIN FUNCTIONS -#################################################################################################################################### -def compute_sequential_alignment( - images_path: list[str], - overlap_width: int = 3000, - bloc_height: int = 256, - nfeature_per_block: int = 500, - ransac_max_trials: int = 1000, - ransac_residual_threshold: float = 3, - verbose: bool = True, -) -> dict[str, cv2.typing.MatLike]: - """ - Compute sequential geometric alignment transformations between a list of images. - - This function aligns each image in the list to its subsequent image by: - - Extracting matched keypoints from overlapping image regions in blocks. - - Transforming matched points to a global coordinate system. - - Estimating a robust Euclidean transformation using RANSAC to handle outliers. - - Accumulating transformation matrices relative to the first image. - - Args: - images_path (list[str]): List of file paths to input images to align sequentially. - overlap_width (int, optional): Width in pixels of the overlapping area between consecutive images used for keypoint matching. Defaults to 3000. - bloc_height (int, optional): Height of blocks (in pixels) to split the overlap region for local keypoint detection. Defaults to 256. - nfeature_per_block (int, optional): Number of ORB features to detect per block. Defaults to 500. - ransac_max_trials (int, optional): Maximum number of RANSAC iterations for robust transformation estimation. Defaults to 1000. - ransac_residual_threshold (float, optional): Maximum allowed residual to classify a point as an inlier in RANSAC. Defaults to 3. - verbose (bool, optional): Whether to print progress and debug information. Defaults to True. - - Returns: - dict[str, cv2.typing.MatLike]: Dictionary mapping each image path to its cumulative 3x3 homogenous transformation matrix. - The first image is assigned the identity matrix. - """ - # Initialize dictionary with identity transformation for the first image (reference) - transformation_matrixs = {images_path[0]: np.eye(3)} - - # Iterate through consecutive image pairs to compute relative transformations - for i in range(len(images_path) - 1): - if verbose: - print(f"Matching '{images_path[i]}' with '{images_path[i + 1]}' ...") - - # Extract globally matched keypoints from the overlap area between images - points_a, points_b = extract_global_matches_from_overlap( - images_path[i], images_path[i + 1], overlap_width, bloc_height, nfeature_per_block - ) - # Transform matched points from image A to global coordinate system using accumulated transformation - points_a_tf = [transform_coord(coord, transformation_matrixs[images_path[i]]) for coord in points_a] - - # Estimate robust Euclidean transformation using RANSAC to filter out outliers - model_robust, inliers = ransac( - (np.array(points_b, dtype=np.float32), np.array(points_a_tf, dtype=np.float32)), - EuclideanTransform, - min_samples=3, - residual_threshold=ransac_residual_threshold, - max_trials=ransac_max_trials, - ) - if verbose: - print(f"\t- Number of matching points before versus after ransac : {np.sum(inliers)}/{len(points_a)}") - - # Store cumulative transformation for the next image in the sequence - transformation_matrixs[images_path[i + 1]] = model_robust.params - return transformation_matrixs # type: ignore[return-value] - - -def mosaic_images( - transformation_matrixs_dict: dict[str, cv2.typing.MatLike], - output_tif: str, - max_worker: int = 5, - verbose: bool = True, - resampling: int = rasterio.warp.Resampling.cubic, -) -> None: - """ - Mosaic multiple images into a single output GeoTIFF using given pixel transformation matrices. - - This function warps and mosaics a collection of images into one large raster. The pixel - transformation matrices are applied as inverse affine transforms to align each image - into the output raster. The mosaicing process is block-based and supports multithreading. - - Parameters - ---------- - transformation_matrixs_dict : dict[str, cv2.typing.MatLike] - Dictionary mapping image file paths to their corresponding 2D transformation matrices - (affine-like, 3x3 matrices). - output_tif : str - Path where the final mosaiced GeoTIFF will be saved. - max_worker : int, optional - Number of worker threads to use during block reprojecting (default is 5). - verbose : bool, optional - If True, prints progress information (default is True). - resampling : int, optional - Resampling algorithm from `rasterio.warp.Resampling` to use for reprojection - (default is `cubic`). - - Returns - ------- - None - The mosaiced raster is written to `output_tif`. - - Notes - ----- - - The transforms applied here are purely pixel-based and ignore CRS/georeferencing information. - - The output raster is compressed (LZW), tiled (256x256), and saved as BigTIFF if required. - - At the end of processing, the transform metadata is reset to identity to avoid - incorrect geospatial metadata. - """ - # Get the last image path to determine the output size and metadata - last_image_path = next(reversed(transformation_matrixs_dict)) - - # Open the last image to base the output profile on - with rasterio.open(last_image_path) as src: - # Calculate output width by adding translation component from transformation matrix - output_width = src.width + int(transformation_matrixs_dict[last_image_path][0, 2]) - output_height = src.height - - # define the output image profile based on the previously computed width and height - # and add some tif optimization - profile = { - "width": output_width, - "height": output_height, - "compress": "lzw", - "driver": "GTiff", - "BIGTIFF": "YES", - "count": 1, - "tiled": True, - "blockxsize": 256, - "blockysize": 256, - "nodata": 0, - "dtype": "uint8", - } - if verbose: - print("Start the mosaicing...") - - os.makedirs(os.path.dirname(output_tif) or ".", exist_ok=True) - - with rasterio.open(output_tif, "w", **profile) as dst: - for i, (filepath, matrix) in enumerate(transformation_matrixs_dict.items()): - if verbose: - print(f"Warping {filepath} with : \n{matrix}") - - # open the coresponding image - with rasterio.open(filepath) as src: - # here we set the dst transform to the inverse of the given matrix - # Note : the transform here is juste for pixels not for geographic stuffs - dst.transform = ~rasterio.Affine(*matrix.flatten()) - - # use the reproject of rasterio with NO_GEOTRANSFORM option to specify we don't care about CRS. - # here we use this method cause it support big images with block processing and work with multi-threads. - rasterio.warp.reproject( - source=rasterio.band(src, 1), - destination=rasterio.band(dst, 1), - resampling=resampling, - num_threads=max_worker, - SRC_METHOD="NO_GEOTRANSFORM", # important to avoid error of CRS - init_dest_nodata=False, # important to avoid rewriting all the images with no data - ) - - # we remove the transform metadata to avoid let a wrong transform - dst.transform = rasterio.Affine.identity() - - -def mosaic_images_streaming( - transformation_matrixs_dict: dict[str, cv2.typing.MatLike], - output_tif: str, - clipping: int = 30, - max_workers: int = 5, - verbose: bool = True, -) -> None: - # Get the last image path to determine the output size and metadata - last_image_path = next(reversed(transformation_matrixs_dict)) - - root, ext = os.path.splitext(output_tif) - tmp_tif_file = f"{root}.tmp{ext}" - - # Open the last image to base the output profile on - with rasterio.open(last_image_path) as src: - # Calculate output width by adding translation component from transformation matrix - width = src.width + int(np.round(transformation_matrixs_dict[last_image_path][0, 2])) - height = src.height - - profile = { - "width": width, - "height": height, - "transform": rasterio.Affine.identity(), - "compress": "lzw", - "driver": "GTiff", - "BIGTIFF": "YES", - "count": 1, - "tiled": True, - "blockxsize": 256, - "blockysize": 256, - "nodata": 0, - "dtype": "uint8", - } - # Create the first tmp raster - nodata = 0 - if verbose: - print("Start the mosaicing...") - - with rasterio.open(output_tif, "w", **profile) as dst: - profile.update({"compress": None, "tiled": False}) - with rasterio.open(tmp_tif_file, "w+", **profile) as tmp_raster: - for i, (image_path, matrix) in enumerate(transformation_matrixs_dict.items()): - with rasterio.open(image_path) as src: - dst_transform = rasterio.Affine(*matrix.flatten()[:6]) - - if verbose: - print(f"Warping {image_path} with : \n{dst_transform}") - - rasterio.warp.reproject( - source=src.read(1), - destination=rasterio.band(tmp_raster, 1), - src_transform=dst_transform, - src_crs=rasterio.CRS.from_epsg(3857), - dst_crs=rasterio.CRS.from_epsg(3857), - resampling=rasterio.warp.Resampling.cubic, - src_nodata=nodata, - dst_nodata=nodata, - num_threads=max_workers, - ) - x_start = matrix[0, 2] - if i == 0: - window = Window(x_start, 0, src.width, dst.height) - else: - window_width = min(src.width - clipping, dst.width - (x_start + clipping)) - window = Window(x_start + clipping, 0, window_width, dst.height) - dst.write(tmp_raster.read(1, window=window), 1, window=window) - os.remove(tmp_tif_file) - - -def mosaic_images_buffered( - transformation_matrixs_dict: dict[str, cv2.typing.MatLike], - output_tif: str, - clipping: int = 50, - max_workers: int = 5, - qc_output: str | None = None, - verbose: bool = True, -) -> None: - """ - Create a mosaic image by warping and stitching multiple input images using provided transformation matrices. - - The function reads images from `transformation_matrixs_dict`, applies geometric transformations (affine warps) - defined by corresponding matrices, and writes the combined result to a single output GeoTIFF file. - Optionally, it generates quality control (QC) difference images between overlapping mosaicked parts. - - Parameters - ---------- - transformation_matrixs_dict : dict[str, cv2.typing.MatLike] - A dictionary mapping input image file paths (strings) to their associated 2x3 affine transformation matrices - (numpy arrays or OpenCV Mat-like) that describe how each image should be warped into the mosaic coordinate space. - The matrices are assumed to be affine transforms in pixel space. - - output_tif : str - File path for the output mosaic GeoTIFF file. - - clipping : int, optional - Number of pixels to clip on the left edge of images after the first one, to avoid visual artifacts due to warping. - Default is 30 pixels. - - max_workers : int, optional - Number of parallel threads to use for rasterio.warp.reproject calls. Defaults to 5. - - qc_output : str or None, optional - Directory path where quality control (QC) difference images will be saved. - If None (default), no QC images are generated. - - verbose : bool, optional - If True (default), print progress messages during processing. - - Returns - ------- - None - The function writes the mosaic directly to `output_tif` and optionally QC images to `qc_output`. - - Notes - ----- - - The output mosaic raster has a size determined by the last image's dimensions plus the horizontal translation - offset of the last transformation matrix. - - The rasterio profile for the output uses LZW compression, tiling, and a block size of 256x256 pixels. - - Images are warped using rasterio.warp.reproject with cubic resampling. - - The coordinate reference systems (CRS) are not used here because warping is done in pixel space only. - - QC difference images highlight absolute pixel differences in overlapping regions of consecutive warped images. - - The function manages memory by writing only windows corresponding to each warped image fragment. - """ - # Get the last image path to determine the output size and metadata - last_image_path = next(reversed(transformation_matrixs_dict)) - - # Open the last image to base the output profile on - with rasterio.open(last_image_path) as src: - # Calculate output width by adding translation component from transformation matrix - width = src.width + int(transformation_matrixs_dict[last_image_path][0, 2]) - height = src.height - - # define the output image profile based on the previously computed width and height - # and add some tif optimization - profile = { - "width": width, - "height": height, - "transform": rasterio.Affine.identity(), - "compress": "lzw", - "driver": "GTiff", - "BIGTIFF": "YES", - "count": 1, - "tiled": True, - "blockxsize": 256, - "blockysize": 256, - "nodata": 0, - "dtype": "uint8", - } - - if verbose: - print("Start the mosaicing...") - - # define the writing mode depend of the quality control - # cause without qc we don't need to read the output image - mode = "w+" if qc_output else "w" - - if os.path.dirname(output_tif): - os.makedirs(os.path.dirname(output_tif), exist_ok=True) - # open with the good mode and the good profile the output raster - with rasterio.open(output_tif, mode, **profile) as dst: - # create an empty numpy array of the final size where all warped part will be write - dst_array = np.zeros((height, width), dtype=np.uint8) - - # the cursor is used only for the qc part to get the end position in the final image - # of the previous warped fragment - cursor = 0 - - # loop in transformation matrixs - for i, (image_path, matrix) in enumerate(transformation_matrixs_dict.items()): - # open the corresponding image - with rasterio.open(image_path) as src: - if verbose: - print(f"Warping {image_path} with : \n{matrix}") - - # warp the image fragment with it's corresponding matrix in the dst_array - rasterio.warp.reproject( - source=src.read(1), - destination=dst_array, - src_transform=rasterio.Affine(*matrix.flatten()), - dst_transform=rasterio.Affine.identity(), - src_crs=rasterio.CRS.from_epsg(3857), - dst_crs=rasterio.CRS.from_epsg(3857), - resampling=rasterio.warp.Resampling.cubic, - src_nodata=profile["nodata"], - dst_nodata=profile["nodata"], - num_threads=max_workers, - ) - - # calculate the x_start and window_width based on the x translation and apply clipping to the left - # to avoid warping artefacts - if i == 0: - x_start = 0 - window_width = src.width - else: - x_start = int(matrix[0, 2]) + clipping - window_width = min(src.width - clipping, dst.width - x_start) - - # code block for generate all qc images - if i != 0 and qc_output: - overlap_width = cursor - x_start - ref_left_part = dst.read(1, window=Window(x_start, 0, overlap_width, dst.height)) - right_part = dst_array[:, x_start : x_start + overlap_width] - - valid_mask = ref_left_part != profile["nodata"] - - abs_diff = np.zeros_like(ref_left_part, dtype=np.uint8) - abs_diff[valid_mask] = np.abs( - ref_left_part[valid_mask].astype(np.int16) - right_part[valid_mask].astype(np.int16) - ).astype(np.uint8) - - abs_diff_file = os.path.join( - qc_output, f"diff_{chr(ord('a') + i - 1)}_{chr(ord('a') + i)}_{os.path.basename(output_tif)}" - ) - os.makedirs(qc_output, exist_ok=True) - cv2.imwrite(abs_diff_file, abs_diff) - - # write the concern window of dst_array into the final output raster - dst.write( - dst_array[:, x_start : x_start + window_width], - 1, - window=Window(x_start, 0, window_width, dst.height), - ) - cursor = x_start + window_width - - -#################################################################################################################################### -# PRIVATE FUNCTIONS -#################################################################################################################################### - - -def warp_tif_blockwise_to_dst( - input_path: str, - dst: rasterio.io.DatasetWriter, - transformation_matrix: cv2.typing.MatLike, - block_size: int = 256, - interpolation: int = cv2.INTER_CUBIC, - overlap: int = 8, - pbar: bool = True, - pbar_desc: str = "Warping blocks", -) -> None: - """ - Applies a geometric transformation (warping) to a raster image in a memory-efficient, - block-wise manner, with overlap between blocks to avoid seam artifacts and safe - in-place writing to the output dataset to prevent overwriting with invalid pixels. - - The function processes the output image in fixed-size blocks, computing the - corresponding source pixel coordinates for each block using the inverse of the - provided transformation matrix. Each block is extended by an 'overlap' margin - on all sides to ensure smooth transitions between adjacent blocks when warping. - - The warped pixels are then remapped from the source to the destination block using - OpenCV's `remap` function with the specified interpolation method. To handle edges - properly and avoid invalid pixels overwriting valid data, the function reads the - existing destination data, and combines it with the warped block pixels using a mask. - - Notes - ----- - - The function reads blocks from the source raster with a margin ('overlap') to avoid - artifacts at block edges after warping. - - The inverse transformation matrix is used to compute source pixel coordinates for - each destination block's extended region. - - Pixels mapped outside the source image bounds are filled with the source nodata value. - - The function combines newly warped pixels with existing destination pixels to avoid - overwriting valid data with nodata values. - - This method is designed to be memory efficient for processing large rasters that - cannot fit entirely into memory. - """ - out_width, out_height = dst.width, dst.height - - # Compute the inverse transformation matrix to map output coordinates back to source image coordinates. - M_inv = np.linalg.inv( - transformation_matrix # type: ignore[arg-type] - )[0:2, :] # Extract first two rows for 2D affine transform usable by cv2.remap - - with rasterio.open(input_path) as src: - src_dtype = src.dtypes[0] # Data type of source raster, uint8 for grayscale - src_nodata = src.nodata if src.nodata is not None else 0 # NoData value, fallback to 0 if undefined - - # Generate a list of blocks covering the entire output raster by stepping through width and height - blocks = [(x, y) for y in range(0, out_height, block_size) for x in range(0, out_width, block_size)] - # Wrap blocks with a progress bar if enabled - iterator = tqdm(blocks, desc=pbar_desc, unit="block") if pbar else blocks - - # Process each output block independently to limit memory use - for x_out, y_out in iterator: - # Compute block size in x and y (handle edge blocks smaller than block_size) - w = min(block_size, out_width - x_out) - h = min(block_size, out_height - y_out) - - # Extend the block boundaries by 'overlap' pixels on all sides, clipped to image bounds, - # to avoid edge artifacts when warping and enable smooth blending between blocks - x_ext = max(0, x_out - overlap) - y_ext = max(0, y_out - overlap) - w_ext = min(out_width - x_ext, w + 2 * overlap) - h_ext = min(out_height - y_ext, h + 2 * overlap) - - # Create meshgrid of pixel coordinates in the extended output block area - dst_grid_x, dst_grid_y = np.meshgrid(np.arange(x_ext, x_ext + w_ext), np.arange(y_ext, y_ext + h_ext)) - - # Prepare homogeneous coordinates (x, y, 1) for transformation - dst_pts = np.stack([dst_grid_x.ravel(), dst_grid_y.ravel(), np.ones(dst_grid_x.size)], axis=0) - - # Map output coordinates back to source image coordinates using inverse transform - src_pts = (M_inv @ dst_pts).T - - # Reshape source coordinates to 2D grids matching the extended block size - x_src = src_pts[:, 0].reshape(h_ext, w_ext).astype(np.float32) - y_src = src_pts[:, 1].reshape(h_ext, w_ext).astype(np.float32) - - # Determine the bounding box of the source pixels needed to sample for the current block - x_min = int(np.floor(x_src.min())) - x_max = int(np.ceil(x_src.max())) - y_min = int(np.floor(y_src.min())) - y_max = int(np.ceil(y_src.max())) - - # Skip block if it lies completely outside the source image boundaries - if x_max < 0 or y_max < 0 or x_min >= src.width or y_min >= src.height: - continue # Tout est hors champ - - # Clip the read window to source image bounds with a small margin of 2 pixels - x_min_clip = max(x_min - 2, 0) - y_min_clip = max(y_min - 2, 0) - x_max_clip = min(x_max + 2, src.width - 1) - y_max_clip = min(y_max + 2, src.height - 1) - - # Define a rasterio Window object to read the required block from source image - read_window = Window(x_min_clip, y_min_clip, x_max_clip - x_min_clip + 1, y_max_clip - y_min_clip + 1) - - # Read the source block pixels with boundless=True to allow reading outside boundaries if needed, - # filling missing values with the nodata value. - src_block = src.read(1, window=read_window, boundless=True, fill_value=src_nodata) - - # Adjust source coordinates to be relative to the read window top-left corner - x_src_shifted = x_src - x_min_clip - y_src_shifted = y_src - y_min_clip - - # Warp (remap) the source block pixels to the destination coordinate grid - warped_ext = cv2.remap( - src_block, - x_src_shifted, - y_src_shifted, - interpolation=interpolation, - borderMode=cv2.BORDER_CONSTANT, # Use constant border mode to fill out-of-bounds with nodata - borderValue=src_nodata, # type: ignore[arg-type] - ) - - # Define core block area inside the extended block by removing overlap margins on edges - x_start = overlap if x_out - overlap >= 0 else 0 - y_start = overlap if y_out - overlap >= 0 else 0 - x_end = x_start + w - y_end = y_start + h - - # Extract the central (non-overlapping) region of the warped block to avoid duplication during writing - warped_core = warped_ext[y_start:y_end, x_start:x_end] - - # Read the existing pixels from destination at the current block location - existing = dst.read(1, window=Window(x_out, y_out, w, h)) - - # Create mask where newly warped pixels are valid (not nodata) - mask_new = warped_core != src_nodata - - # Combine the warped pixels with existing destination pixels, - # giving priority to valid new warped pixels to avoid overwriting with black/empty pixels - combined = np.where(mask_new, warped_core, existing) - - # Write the combined result back to the destination raster at the current block window - dst.write(combined.astype(src_dtype), 1, window=Window(x_out, y_out, w, h)) - - -def extract_global_matches_from_overlap( - image_a_path: str, - image_b_path: str, - overlap_width: int = 3000, - bloc_height: int = 1024, - nfeature_per_block: int = 500, -) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]: - """ - Extracts matched keypoints between the overlapping edge of two georeferenced images, - by processing them in horizontal blocks. Assumes that image A is on the left and image B is on the right. - - This function is useful to compute global tie points between adjacent raster strips (e.g., satellite or aerial images). - """ - points_a, points_b = [], [] - - with rasterio.open(image_a_path) as src_a, rasterio.open(image_b_path) as src_b: - width_a = src_a.width - height_a = src_a.height - height_b = src_b.height - - # Ensure both images have the same height for block-wise processing - assert height_a == height_b, "Both images must have the same height for block-wise matching." - - # Iterate over horizontal blocks - for i in range(0, src_a.height, bloc_height): - current_block_height = min(bloc_height, height_a - i) - - # Define overlapping windows: - # - image A: right edge - # - image B: left edge - window_a = Window( - col_off=width_a - overlap_width, row_off=i, width=overlap_width, height=current_block_height - ) - window_b = Window(col_off=0, row_off=i, width=overlap_width, height=current_block_height) - - # Read corresponding blocks - block_a = src_a.read(1, window=window_a) - block_b = src_b.read(1, window=window_b) - - # Match keypoints using ORB - pts_a, pts_b = match_orb_keypoints(block_a, block_b, nfeatures=nfeature_per_block) - - # Reproject local coordinates to global coordinates - pts_a_global = [(pt[0] + (width_a - overlap_width), pt[1] + i) for pt in pts_a] - pts_b_global = [(pt[0], pt[1] + i) for pt in pts_b] - - # Accumulate results - points_a.extend(pts_a_global) - points_b.extend(pts_b_global) - - return points_a, points_b - - -def match_orb_keypoints( - image_a: cv2.typing.MatLike, image_b: cv2.typing.MatLike, nfeatures: int = 500 -) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]: - """ - Detect ORB keypoints and return matched coordinates between two grayscale image. - Returns - ------- - pts_a : list of tuple[float, float] - Matched keypoint coordinates from image A. - pts_b : list of tuple[float, float] - Matched keypoint coordinates from image B. - """ - # Initialize ORB - orb = cv2.ORB_create(nfeatures=nfeatures) # type: ignore[attr-defined] - - # Detect and compute descriptors - kp_a, des_a = orb.detectAndCompute(image_a, None) - kp_b, des_b = orb.detectAndCompute(image_b, None) - - if des_a is None or des_b is None: - return [], [] - - # Matcher - bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) - matches = bf.match(des_a, des_b) - - # Sort by match distance - matches = sorted(matches, key=lambda x: x.distance) - - # Extract matched coordinates - pts_a = [kp_a[m.queryIdx].pt for m in matches] - pts_b = [kp_b[m.trainIdx].pt for m in matches] - - return pts_a, pts_b diff --git a/src/hipp/kh9pc/kh9_image_spec.py b/src/hipp/kh9pc/kh9_image_spec.py new file mode 100644 index 0000000..841821b --- /dev/null +++ b/src/hipp/kh9pc/kh9_image_spec.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from typing import Literal +from pathlib import Path +import re +import rasterio + + +from hipp.kh9pc.fiducial_patterns import PATTERNS + +IMAGE_WIDTHS_PX: list[int] = [114082, 228165, 342247, 456329] +IMAGE_HEIGHT_PX: int = 21771 + + +@dataclass +class KH9ImageSpec: + expected_size: tuple[int, int] + collimation_line: bool + fiducial_type: Literal["disk", "wagon_wheel"] + top_fiducial_patterns: tuple[PATTERNS, PATTERNS] + bottom_fiducial_patterns: tuple[PATTERNS, PATTERNS] + + @classmethod + def from_raster_filepath(cls, filepath: str | Path) -> "KH9ImageSpec": + mission = KH9ImageSpec.mission_from_filepath(filepath) + expected_size = KH9ImageSpec.expected_size_from_file(filepath) + collimation_line = KH9ImageSpec.collimation_from_mission(mission) + fiducial_type = KH9ImageSpec.fiducial_type_from_mission(mission) + top_fiducial_patterns = KH9ImageSpec.top_fiducial_patterns_from_mission(mission) + bottom_fiducial_patterns = KH9ImageSpec.bottom_fiducial_patterns_from_mission(mission) + + return cls(expected_size, collimation_line, fiducial_type, top_fiducial_patterns, bottom_fiducial_patterns) + + @staticmethod + def mission_from_filepath(filepath: str | Path) -> int: + pattern = re.compile(r"^(D3C)(\d{4})-(\d)(\d{5})([FA])(\d{3})$") + stem = Path(filepath).stem + m = pattern.match(stem) + if m is None: + raise ValueError( + f"Cannot parse KH-9 image ID from {filepath!r}. Expected D3C{{mission}}-{{n}}{{roll}}{{F|A}}{{frame}}." + ) + mission = int(m.group(2)) + return mission + + @staticmethod + def collimation_from_mission(mission: int) -> bool: + if mission < 1201 or mission > 1219: + raise ValueError("Unrecgnized mission") + collimation_line = mission >= 1206 + return collimation_line + + @staticmethod + def fiducial_type_from_mission(mission: int) -> Literal["disk", "wagon_wheel"]: + if mission < 1201 or mission > 1219: + raise ValueError("Unrecgnized mission") + fiducial_type: Literal["disk", "wagon_wheel"] = "disk" if mission <= 1213 else "wagon_wheel" + return fiducial_type + + @staticmethod + def top_fiducial_patterns_from_mission(mission: int) -> tuple[PATTERNS, PATTERNS]: + if mission < 1201 or mission > 1219: + raise ValueError("Unrecgnized mission") + top_fiducial_patterns: tuple[PATTERNS, PATTERNS] + if mission <= 1213: + top_fiducial_patterns = ("regulare_sparse", "serialized_time_word") + elif mission <= 1217: + top_fiducial_patterns = ("segmented_mid", "serialized_time_word") + else: + top_fiducial_patterns = ("segmented_mid", "segmented_dense") + return top_fiducial_patterns + + @staticmethod + def bottom_fiducial_patterns_from_mission(mission: int) -> tuple[PATTERNS, PATTERNS]: + if mission < 1201 or mission > 1219: + raise ValueError("Unrecgnized mission") + + bottom_fiducial_patterns: tuple[PATTERNS, PATTERNS] + if mission <= 1213: + bottom_fiducial_patterns = ("regulare_sparse", "regular_dense") + else: + bottom_fiducial_patterns = ("regulare_mid", "regular_dense") + + return bottom_fiducial_patterns + + @staticmethod + def expected_size_from_file(filepath: str | Path) -> tuple[int, int]: + with rasterio.open(filepath) as src: + width = src.width + + expected_widths_px = sorted(IMAGE_WIDTHS_PX) + candidates = [w for w in expected_widths_px if w <= width] + if not candidates: + raise ValueError(f"Image width {width} is smaller than all known expected widths.") + return (candidates[-1], IMAGE_HEIGHT_PX) diff --git a/src/hipp/kh9pc/mosaic.py b/src/hipp/kh9pc/mosaic.py new file mode 100644 index 0000000..dc55205 --- /dev/null +++ b/src/hipp/kh9pc/mosaic.py @@ -0,0 +1,390 @@ +""" +Copyright (c) 2025 HIPP developers +Description: Functions to recreate in python the image_mosaic function from ASP +""" + +import logging +import os +import subprocess +from collections.abc import Sequence +from dataclasses import dataclass +from glob import glob +from pathlib import Path + +import cv2 +import numpy as np +import rasterio +from rasterio.vrt import WarpedVRT +from rasterio.warp import Resampling +from rasterio.windows import Window +from skimage.measure import ransac +from skimage.transform import EuclideanTransform + +from hipp.image import LogProgressBar + + +@dataclass +class ImageAlignment: + """Alignment result for a single image in a sequential alignment chain. + + Attributes + ---------- + image_path : Path + Path to the image file. + relative_transform : np.ndarray + 3x3 homogeneous transformation matrix relative to the previous image + (identity for the first/reference image). + absolute_transform : np.ndarray + 3x3 homogeneous transformation matrix in the global/mosaic coordinate system, + accumulated from the reference image. + n_matches : int + Total number of ORB keypoint matches found before RANSAC filtering + (0 for the reference image). + n_inliers : int + Number of inlier matches kept after RANSAC filtering + (0 for the reference image). + """ + + image_path: Path + relative_transform: np.ndarray + absolute_transform: np.ndarray + n_matches: int + n_inliers: int + + +logger = logging.getLogger(__name__) + + +#################################################################################################################################### +# MAIN FUNCTIONS +#################################################################################################################################### +def image_mosaic( + image_paths: Sequence[str | Path], + output_tif: str | Path, + overwrite: bool = False, + resampling: int = Resampling.cubic, + overlap_width: int = 3000, + bloc_height: int = 512, + nfeature_per_block: int = 500, + ransac_max_trials: int = 1000, + ransac_residual_threshold: float = 3.0, +) -> None: + # standardize paths + output_tif = Path(output_tif) + + # manage overwrite + if output_tif.exists() and not overwrite: + logger.info("Skipping image_mosaic: %s (already exists, overwrite=False)", str(output_tif)) + return + + alignments = compute_sequential_alignments( + image_paths, + overlap_width=overlap_width, + bloc_height=bloc_height, + nfeature_per_block=nfeature_per_block, + ransac_max_trials=ransac_max_trials, + ransac_residual_threshold=ransac_residual_threshold, + ) + + write_mosaic(alignments, output_tif, resampling=resampling) + + +def compute_sequential_alignments( + image_paths: Sequence[str | Path], + overlap_width: int = 3000, + bloc_height: int = 512, + nfeature_per_block: int = 500, + ransac_max_trials: int = 1000, + ransac_residual_threshold: float = 3.0, +) -> list[ImageAlignment]: + """Compute sequential alignments between images. + + Detects ORB keypoints between consecutive images, estimates RANSAC Euclidean + transforms, and accumulates absolute transformations from the reference image. + """ + # standardize path + paths: list[Path] = [Path(f) for f in image_paths] + + identity = np.eye(3) + alignments: list[ImageAlignment] = [ + ImageAlignment( + image_path=paths[0], + relative_transform=identity, + absolute_transform=identity, + n_matches=0, + n_inliers=0, + ) + ] + + for i in range(len(paths) - 1): + logger.info("Matching '%s' with '%s'", str(paths[i]), str(paths[i + 1])) + + points_a, points_b = _extract_global_matches_from_overlap( + paths[i], + paths[i + 1], + overlap_width, + bloc_height, + nfeature_per_block, + ) + + model_robust, inliers = ransac( + (np.array(points_b, dtype=np.float32), np.array(points_a, dtype=np.float32)), + EuclideanTransform, + min_samples=3, + residual_threshold=ransac_residual_threshold, + max_trials=ransac_max_trials, + ) + + n_inliers = int(np.sum(inliers)) + logger.info("Inliers after RANSAC: %d/%d", n_inliers, len(points_a)) + + relative_transform: np.ndarray = model_robust.params + absolute_transform: np.ndarray = alignments[i].absolute_transform @ relative_transform + + alignments.append( + ImageAlignment( + image_path=Path(paths[i + 1]), + relative_transform=relative_transform, + absolute_transform=absolute_transform, + n_matches=len(points_a), + n_inliers=n_inliers, + ) + ) + + return alignments + + +def write_mosaic( + alignments: list[ImageAlignment], + output_tif: str | Path, + resampling: int = Resampling.cubic, +) -> None: + """Warp and merge all aligned images into a single output GeoTIFF. + + Images are warped into the output pixel space using WarpedVRT and merged + block-by-block. Valid pixels from later images do not overwrite valid pixels + already written from earlier images. + + If any image extends above or to the left of the first image (negative coordinates + after transformation), an offset is automatically applied to all transforms so that + the full mosaic fits within the canvas without clipping. + + """ + # normalize path + output_tif = Path(output_tif) + output_tif.parent.mkdir(exist_ok=True, parents=True) + + output_width, output_height, offset_x, offset_y = _compute_canvas(alignments) + + T_offset = np.array([[1, 0, -offset_x], [0, 1, -offset_y], [0, 0, 1]], dtype=float) + + fake_crs = rasterio.CRS.from_epsg(3857) + dst_transform = rasterio.Affine.identity() + + profile = { + "width": output_width, + "height": output_height, + "compress": "lzw", + "driver": "GTiff", + "BIGTIFF": "YES", + "count": 1, + "tiled": True, + "blockxsize": 256, + "blockysize": 256, + "dtype": "uint8", + } + + n_blocks = (output_width // 256 + 1) * (output_height // 256 + 1) + + logger.info("Mosaicing %d images → %s (%d×%d px)", len(alignments), str(output_tif), output_width, output_height) + + with rasterio.open(output_tif, "w+", **profile) as dst: + for i, alignment in enumerate(alignments): + logger.info("[%d/%d] %s", i + 1, len(alignments), alignment.image_path.name) + pbar = LogProgressBar(f"mosaicing {alignment.image_path.name}", n_blocks, logger) + + adjusted_transform = T_offset @ alignment.absolute_transform + + with rasterio.open(alignment.image_path) as src: + with WarpedVRT( + src, + src_transform=rasterio.Affine(*adjusted_transform.flatten()[:6]), + src_crs=fake_crs, + dst_crs=fake_crs, + resampling=resampling, + width=output_width, + height=output_height, + transform=dst_transform, + ) as vrt: + for block_idx, (_, window) in enumerate(dst.block_windows(1)): + pbar.update(block_idx) + warped = vrt.read(1, window=window) + mask = warped != 0 + if not mask.any(): + continue + existing = dst.read(1, window=window) + dst.write(np.where(mask, warped, existing), 1, window=window) + pbar.close() + + logger.info("Mosaic written to %s", str(output_tif)) + + +#################################################################################################################################### +# STANDALONE FUNCTIONS +#################################################################################################################################### + + +def image_mosaic_asp( + image_paths: list[str | Path], + output_image_path: str | Path, + threads: int = 0, + cleanup: bool = True, + dryrun: bool = False, +) -> None: + """ + Mosaics a list of images into a single output image using the external 'image_mosaic' command. + + Parameters + ---------- + image_paths : list[str | Path] + List of paths to input image tiles. + output_image_path : str | Path + Path to the output mosaic image. + threads : int, optional + Number of threads to use for processing. Default is 0 (let the tool decide). + cleanup : bool, optional + Whether to remove temporary log and auxiliary files after processing. Default is True. + dryrun : bool, optional + If True, builds the command but does not execute it. Default is False. + """ + str_image_paths = list(sorted([str(f) for f in image_paths])) + + cmd = [ + "image_mosaic", + *str_image_paths, + "--ot", + "byte", + "--overlap-width", + "3000", + "--threads", + str(threads), + "-o", + str(output_image_path), + ] + + logger.info("Running: %s", " ".join(cmd)) + + if not dryrun: + try: + subprocess.run(cmd, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + logger.error("image_mosaic_asp failed for %s: %s", output_image_path, e) + + if cleanup: + for f in glob(f"{output_image_path}-log-image_mosaic-*.txt") + glob(f"{output_image_path}.aux.xml"): + os.remove(f) + + +#################################################################################################################################### +# PRIVATE FUNCTIONS +#################################################################################################################################### + + +def _compute_canvas(alignments: list[ImageAlignment]) -> tuple[int, int, float, float]: + """Compute output canvas dimensions and the offset needed to shift all images into positive coordinates. + + Returns + ------- + width : int + height : int + offset_x : float + Horizontal shift to apply so the leftmost pixel lands at x=0. + offset_y : float + Vertical shift to apply so the topmost pixel lands at y=0. + """ + all_corners: list[np.ndarray] = [] + for alignment in alignments: + with rasterio.open(alignment.image_path) as src: + w, h = src.width, src.height + corners = np.array([[0, 0, 1], [w, 0, 1], [0, h, 1], [w, h, 1]], dtype=float).T + transformed = (alignment.absolute_transform @ corners)[:2] + all_corners.append(transformed) + + stacked = np.hstack(all_corners) + min_x, min_y = stacked[0].min(), stacked[1].min() + width = int(np.ceil(stacked[0].max() - min_x)) + height = int(np.ceil(stacked[1].max() - min_y)) + return width, height, min_x, min_y + + +def _extract_global_matches_from_overlap( + image_a_path: str | Path, + image_b_path: str | Path, + overlap_width: int = 3000, + bloc_height: int = 1024, + nfeature_per_block: int = 500, +) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]: + """ + Extract matched keypoints between the overlapping edge of two images, in horizontal blocks. + + Assumes image A is on the left and image B is on the right. + """ + points_a, points_b = [], [] + + with rasterio.open(image_a_path) as src_a, rasterio.open(image_b_path) as src_b: + width_a = src_a.width + height_a = src_a.height + height_b = src_b.height + + if height_a != height_b: + raise ValueError( + f"Both images must have the same height for block-wise matching ({height_a} != {height_b})." + ) + + for i in range(0, height_a, bloc_height): + current_block_height = min(bloc_height, height_a - i) + + window_a = Window( + col_off=width_a - overlap_width, row_off=i, width=overlap_width, height=current_block_height + ) + window_b = Window(col_off=0, row_off=i, width=overlap_width, height=current_block_height) + + block_a = src_a.read(1, window=window_a) + block_b = src_b.read(1, window=window_b) + + pts_a, pts_b = _match_orb_keypoints(block_a, block_b, nfeatures=nfeature_per_block) + + points_a.extend([(pt[0] + (width_a - overlap_width), pt[1] + i) for pt in pts_a]) + points_b.extend([(pt[0], pt[1] + i) for pt in pts_b]) + + return points_a, points_b + + +def _match_orb_keypoints( + image_a: cv2.typing.MatLike, image_b: cv2.typing.MatLike, nfeatures: int = 500 +) -> tuple[list[tuple[float, float]], list[tuple[float, float]]]: + """ + Detect ORB keypoints and return matched coordinates between two grayscale images. + + Returns + ------- + pts_a : list of tuple[float, float] + Matched keypoint coordinates from image A. + pts_b : list of tuple[float, float] + Matched keypoint coordinates from image B. + """ + orb = cv2.ORB_create(nfeatures=nfeatures) # type: ignore[attr-defined] + + kp_a, des_a = orb.detectAndCompute(image_a, None) + kp_b, des_b = orb.detectAndCompute(image_b, None) + + if des_a is None or des_b is None: + return [], [] + + bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) + matches = sorted(bf.match(des_a, des_b), key=lambda x: x.distance) + + pts_a = [kp_a[m.queryIdx].pt for m in matches] + pts_b = [kp_b[m.trainIdx].pt for m in matches] + + return pts_a, pts_b diff --git a/src/hipp/kh9pc/pipeline.py b/src/hipp/kh9pc/pipeline.py new file mode 100644 index 0000000..23e417a --- /dev/null +++ b/src/hipp/kh9pc/pipeline.py @@ -0,0 +1,152 @@ +import logging +from collections.abc import Sequence +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +import shutil + +import joblib + +from hipp.image import generate_quickview +from hipp.kh9pc.mosaic import image_mosaic +from hipp.kh9pc.quality_control import save_figures, save_metrics +from hipp.kh9pc.restitution.fiducial_strategy import FiducialStrategy +from hipp.tools import extract_archive + +logger = logging.getLogger(__name__) + + +def preprocess_kh9pc( + input: str | Path | Sequence[str | Path], output_dir: str | Path, overwrite: bool = False, keep_work: bool = False +) -> None: + # standardize path + input_paths: Path | list[Path] = Path(input) if isinstance(input, (str, Path)) else [Path(f) for f in input] + output_dir = Path(output_dir) + + # extract entity id from input + entity_id = input_paths.stem if isinstance(input_paths, Path) else input_paths[0].stem.split("_")[0] + + # create all path + output_path = output_dir / "images" / f"{entity_id}.tif" + qc_dir = output_dir / "qc" + work_dir = output_dir / "work" + + # overwrite checking + if output_path.exists() and not overwrite: + logger.info("Skipping preprocess_kh9pc: %s (already exists, overwrite=False)", str(output_path)) + return + + # START PREPROCESSING + logger.info("Start preprocessing of %s", entity_id) + + # STEP 1 : EXTRACTION (can be skipped if the input is a list) + if isinstance(input_paths, Path): + tiles = extract_archive(input_paths, work_dir / "extracted" / entity_id, overwrite=overwrite) + else: + tiles = input_paths + + # STEP 2 : JOIN_IMAGES + joined_image = work_dir / "joined_images" / f"{entity_id}.tif" + image_mosaic(tiles, joined_image, overwrite=overwrite) + + # QC STEP : QUICKVIEW + generate_quickview( + joined_image, + qc_dir / "mosaic_qv" / f"{entity_id}.jpg", + scale_factor=0.1, + jpeg_quality=70, + overwrite=overwrite, + ) + + # STEP 3 : RESTITUTION + strategy = FiducialStrategy().fit(joined_image) + (work_dir / "joblibs").mkdir(parents=True, exist_ok=True) + joblib.dump(strategy, work_dir / "joblibs" / f"{entity_id}.joblib") + + # QC STEP : RESTITUTION + save_figures(strategy, qc_dir / "restitution") + save_metrics(strategy, qc_dir) + + strategy.transform(output_path) + + # QC STEP : QUICKVIEW (skipped if no qc dir is provideed) + generate_quickview( + output_path, + qc_dir / "final_qv" / f"{entity_id}.jpg", + scale_factor=0.1, + jpeg_quality=70, + overwrite=overwrite, + ) + + # clean the work dir + if not keep_work: + shutil.rmtree(work_dir) + + logger.info("Finish preprocessing of %s", entity_id) + + +def search_input_dir(input_dir: str | Path) -> list[Path | list[Path]]: + """Scan a directory and return inputs ready for preprocess_kh9pc, one entry per image. + + - .tgz files at root → one Path per archive + - subdirectories with .tif → one list[Path] of tiles per subdir + - .tif files at root → grouped by entity_id prefix into list[Path] + Mixed directories are supported. + """ + from itertools import groupby + + input_dir = Path(input_dir) + result: list[Path | list[Path]] = [] + + result.extend(sorted(input_dir.glob("*.tgz"))) + + for subdir in sorted(d for d in input_dir.iterdir() if d.is_dir()): + tiles = sorted(subdir.glob("*.tif")) + if tiles: + result.append(tiles) + + loose = sorted(input_dir.glob("*.tif")) + if loose: + + def _entity_id(p: Path) -> str: + return p.stem.split("_")[0] + + for _, group in groupby(loose, key=_entity_id): + result.append(list(group)) + + return result + + +def batch_preprocess_kh9pc( + input_dir: str | Path, + output_dir: str | Path, + overwrite: bool = False, + keep_work: bool = False, + n_jobs: int = 1, + dry_run: bool = False, +) -> None: + """Run preprocess_kh9pc on all images found in input_dir, logging failures without stopping the batch.""" + output_dir = Path(output_dir) + + def entity_id(inp: Path | list[Path]) -> str: + return inp.stem if isinstance(inp, Path) else inp[0].stem.split("_")[0] + + inputs = search_input_dir(input_dir) + done = [inp for inp in inputs if (output_dir / "images" / f"{entity_id(inp)}.tif").exists()] + todo = [inp for inp in inputs if inp not in done] + + logger.info("Batch preprocess — %d images found in %s", len(inputs), input_dir) + logger.info(" output_dir : %s", output_dir) + logger.info(" n_jobs : %d | keep_work : %s | overwrite : %s", n_jobs, keep_work, overwrite) + logger.info(" done : %d %s", len(done), [entity_id(i) for i in done]) + logger.info(" remaining : %d %s", len(todo), [entity_id(i) for i in todo]) + + if dry_run: + return + + with ProcessPoolExecutor(max_workers=n_jobs) as executor: + futures = {executor.submit(preprocess_kh9pc, inp, output_dir, overwrite, keep_work): inp for inp in inputs} + for future in as_completed(futures): + try: + future.result() + except Exception: + logger.error("Failed to process %s", entity_id(futures[future]), exc_info=True) diff --git a/src/hipp/kh9pc/quality_control.py b/src/hipp/kh9pc/quality_control.py index d7aee8a..66bee68 100644 --- a/src/hipp/kh9pc/quality_control.py +++ b/src/hipp/kh9pc/quality_control.py @@ -1,216 +1,700 @@ -""" -Copyright (c) 2025 HIPP developers -Description: Functions to generate some quality control plots -""" - -import os -import re -from collections import defaultdict +import logging +from datetime import datetime from pathlib import Path -import cv2 +import pandas as pd +from typing import Any, Iterator + +import matplotlib.pyplot as plt import numpy as np -from matplotlib import pyplot as plt -from numpy.typing import NDArray -from sklearn.linear_model import RANSACRegressor - - -def process_image_mosaicing_qc( - qc_directory: str, vmax_percentile: int = 97, scale_factor: int = 8, keep: bool = True -) -> None: - scene_tiles = defaultdict(list) - - # Group image tiles by scene ID (assumed to be the prefix before the first underscore) - for filename in sorted(os.listdir(qc_directory)): - match = re.match(r"diff_[a-z]_[a-z]_(.+)\.tif", filename) - if match: - base_name = match.group(1) - scene_tiles[base_name].append(os.path.join(qc_directory, filename)) - - for base_name, paths in scene_tiles.items(): - fig, axes = plt.subplots(1, len(paths), figsize=(15, 10)) - axes = axes.flatten() - for i, path in enumerate(paths): - img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) - assert img is not None - img_resized = cv2.resize( - img, (img.shape[1] // scale_factor, img.shape[0] // scale_factor), interpolation=cv2.INTER_CUBIC - ).astype(np.uint8) - vmax = np.percentile(img_resized, vmax_percentile) - im = axes[i].imshow(img_resized, cmap="viridis", vmin=0, vmax=vmax) - axes[i].set_title(f"{chr(ord('a') + i)} - {chr(ord('a') + i + 1)}\nMAE={np.mean(img_resized):.2f}") - axes[i].axis("off") - fig.colorbar(im, ax=axes[i]) - - if not keep: - os.remove(path) - - plt.suptitle("Overlapping images absolute differences", fontsize=16) - plt.tight_layout() - plt.savefig(os.path.join(qc_directory, f"diff_{base_name}.png")) - plt.show() - - -def plot_src_and_dst_points( - src_points: NDArray[np.generic], - dst_points: NDArray[np.generic], - output_size: tuple[int, int], - plot: bool = True, - output_plot_path: str | Path | None = None, -) -> None: - """Plot source and destination TPS points with boundaries and legends.""" - - # scatter des points source (en bleu) et destination (en rouge) - plt.scatter(dst_points[:, 0], dst_points[:, 1], c="red", s=1, label="Destination points", alpha=0.5) - plt.scatter(src_points[:, 0], src_points[:, 1], c="blue", s=1, label="Source points") - - # Rectangle de l’output (dans l’espace dst) - rect_x = [0, output_size[0], output_size[0], 0, 0] - rect_y = [0, 0, output_size[1], output_size[1], 0] - plt.plot(rect_x, rect_y, color="green", linewidth=1, label="Output rectangle", linestyle="--") - - plt.gca().invert_yaxis() # cohérent avec les coordonnées image - plt.xlabel("x-coordinate [pixels]") - plt.ylabel("y-coordinate [pixels]") - plt.title("Source vs Destination points") - plt.legend() - plt.tight_layout() - - # Save or display - if output_plot_path is not None: - Path(output_plot_path).parent.mkdir(parents=True, exist_ok=True) - plt.savefig(output_plot_path) - if plot: - plt.show() - plt.close() - - -def plot_collimation_gradient( - collimation_lines: dict[str, RANSACRegressor], - tf_collimation_lines: dict[str, RANSACRegressor], - width: int, - nb_points: int = 100, - plot: bool = True, - output_plot_path: str | Path | None = None, -) -> None: +import rasterio +from matplotlib import patches +from matplotlib.figure import Figure +from matplotlib.lines import Line2D +from rasterio.warp import Resampling +from rasterio.windows import Window + +from hipp.image import SubImage +from hipp.kh9pc.fiducial_patterns import ( + compute_expected_fiducial_count, + compute_intra_segment_spacings, + theorical_spacing_from_pattern, +) +from hipp.kh9pc.kh9_image_spec import KH9ImageSpec +from hipp.kh9pc.restitution.base import FittingClass, Transformation +from hipp.kh9pc.restitution.collimation_strategy import CollimationStrategy +from hipp.kh9pc.restitution.fiducial_strategy import FiducialStrategy +from hipp.kh9pc.restitution.flat_strategy import FlatStrategy +from hipp.kh9pc.restitution.mixed_strategy import MixedStrategy +from hipp.kh9pc.restitution.poly_strategy import PolyStrategy +from hipp.kh9pc.restitution.vertical_detector import VerticalDetector + +logger = logging.getLogger(__name__) + + +# --- Vertical --- +def plot_vertical_ruptures(detector: VerticalDetector) -> Figure: + """Band profiles with detected rupture positions for left and right edges.""" + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + + for ax, side, result in zip(axes, ["left", "right"], [detector.left_, detector.right_]): + ax.plot(result.profile, color="gray") + ax.axvline(x=result.rupture_local, color="red", label=f"rupture (local={result.rupture_local})") + ax.set_title(f"{side} column-sum profile (global col={result.position})") + ax.set_xlabel("local column index") + ax.set_ylabel("column sum") + ax.legend() + + return fig + + +def plot_vertical_edges( + detector: VerticalDetector, + margin_fraction: float = 0.03, + plot_res: float = 0.05, +) -> Figure: + """Thumbnails around the left and right edge positions.""" + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + + with rasterio.open(detector.raster_filepath_) as src: + margin = int(src.width * margin_fraction) + + for ax, side, edge_col in zip(axes, ["left", "right"], detector.edges_): + col_off = max(0, edge_col - margin) + col_end = min(src.width, edge_col + margin) + window = Window(col_off, 0, col_end - col_off, src.height) + out_shape = (1, int(src.height * plot_res), int(window.width * plot_res)) + band = src.read(1, window=window, out_shape=out_shape, resampling=Resampling.average) + + ax.imshow(band, cmap="gray", aspect="auto") + ax.axvline(x=(edge_col - col_off) * plot_res, color="red") + ax.set_title(f"{side} edge (col={edge_col})") + ax.axis("off") + + return fig + + +# --- Flat --- +def plot_flat_ruptures(detector: FlatStrategy) -> Figure: + """Band profiles (collapsed horizontally) with detected rupture row for top and bottom.""" + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + + for ax, side, result in zip(axes, ["top", "bottom"], [detector.top_, detector.bottom_]): + profile = result.sub_image.band.flatten() + ax.plot(profile, color="steelblue", linewidth=1) + ax.axvline(result.rupture_local, color="red", linewidth=1.5, label=f"rupture={result.rupture_local}") + ax.set_title(f"{side} band profile") + ax.set_xlabel("row index (downsampled)") + ax.set_ylabel("intensity") + ax.legend(fontsize=8) + + return fig + + +def plot_flat_edges(detector: FlatStrategy, margin_fraction: float = 0.03) -> Figure: + """Thumbnails around the top and bottom edge positions with detected line overlaid.""" + left, _ = detector.vertical_detector.edges_ + roi_w = detector.vertical_detector.detected_width_ + + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + + with rasterio.open(detector.raster_filepath_) as src: + margin = int(margin_fraction * src.height) + + for ax, side, result in zip(axes, ["top", "bottom"], [detector.top_, detector.bottom_]): + row_off = max(0, result.position - margin) + row_end = min(src.height, result.position + margin) + win_h = row_end - row_off + thumb = src.read( + 1, + window=Window(left, row_off, roi_w, win_h), + out_shape=(512, 512), + resampling=Resampling.average, + ) + line_row = (result.position - row_off) / win_h * 512 + ax.imshow(thumb, cmap="gray", aspect="auto") + ax.axhline(line_row, color="yellow", linewidth=1.5) + ax.set_title(f"{side} edge — position={result.position} px") + ax.axis("off") + + return fig + + +# --- Poly --- + + +def plot_poly_edges(detector: PolyStrategy) -> Figure: + """Subimage thumbnails with RANSAC inliers/outliers and polynomial model for top and bottom edges.""" + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + + for ax, side, result in zip(axes, ["top", "bottom"], [detector.top_, detector.bottom_]): + ax.imshow(result.sub_image.band, cmap="gray", aspect="auto") + + inlier_mask = result.model.inlier_mask_ + pts = result.ruptures_local + ax.scatter(pts[~inlier_mask, 0], pts[~inlier_mask, 1], s=12, c="red", label="outliers") + ax.scatter(pts[inlier_mask, 0], pts[inlier_mask, 1], s=12, c="green", label="inliers") + + x_global = result.ruptures_global[:, 0].astype(float) + y_global_pred = result.model.predict(x_global.reshape(-1, 1)) + global_pred = np.column_stack([x_global, y_global_pred.ravel()]) + local_pred = result.sub_image.to_local(global_pred) + ax.plot(local_pred[:, 0], local_pred[:, 1], color="blue", linewidth=1, label="model") + + ax.set_title(f"{side} edge") + ax.legend(loc="best", fontsize=8) + ax.axis("off") + + return fig + + +def plot_poly_distortions(detector: PolyStrategy) -> Figure: + """Residual distortion curves (deviation from mean) for top and bottom polynomial fits.""" + fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) + + ax.plot(detector.top_.distortion[:, 0], detector.top_.distortion[:, 1], label="top") + ax.plot(detector.bottom_.distortion[:, 0], detector.bottom_.distortion[:, 1], label="bottom") + ax.axhline(0, color="gray", linewidth=0.8, linestyle="--") + ax.invert_yaxis() + ax.legend() + ax.set_title("global distortion (top & bottom)") + ax.set_xlabel("column (px)") + ax.set_ylabel("distortion (px)") + + return fig + + +# --- Collimation --- + + +def plot_collimation_edges(detector: CollimationStrategy) -> Figure: + """Subimage thumbnails with RANSAC inliers/outliers and polynomial model for top and bottom collimation lines.""" + fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True) + + for ax, side, result in zip(axes, ["top", "bottom"], [detector.top_, detector.bottom_]): + ax.imshow(result.sub_image.band, cmap="gray", aspect="auto") + + inliers = result.model.inlier_mask_ + peaks = result.peaks_local + ax.scatter(peaks[~inliers, 0], peaks[~inliers, 1], s=12, c="red", label="outliers") + ax.scatter(peaks[inliers, 0], peaks[inliers, 1], s=12, c="green", label="inliers") + + y_global_pred = result.model.predict(result.peaks_global[:, 0].reshape(-1, 1)) + global_pred = np.column_stack([result.peaks_global[:, 0], y_global_pred]) + local_pred = result.sub_image.to_local(global_pred) + ax.plot(local_pred[:, 0], local_pred[:, 1], color="blue", linewidth=1, label="model") + + ax.set_title(f"{side} collimation line") + ax.legend(loc="best", fontsize=8) + ax.axis("off") + + return fig + + +def plot_collimation_distortions(detector: CollimationStrategy) -> Figure: + """Residual distortion curves for top and bottom collimation line fits.""" + fig, ax = plt.subplots(figsize=(6, 4), constrained_layout=True) + + for side, result in zip(["top", "bottom"], [detector.top_, detector.bottom_]): + ax.plot(result.distortion[:, 0], result.distortion[:, 1], label=side) + + ax.invert_yaxis() + ax.legend() + ax.set_title("global distortion (top & bottom)") + ax.set_xlabel("column (px)") + ax.set_ylabel("distortion (px)") + + return fig + + +# --- Fiducial --- + + +_PATTERN_COLORS: dict[str, str] = { + "regulare_sparse": "red", + "regulare_mid": "orange", + "regular_dense": "gold", + "segmented_mid": "limegreen", + "segmented_dense": "cyan", + "serialized_time_word": "violet", +} + + +def _coord_index(centers_xy: np.ndarray) -> dict[tuple[float, float], int]: + return {(float(cx), float(cy)): i for i, (cx, cy) in enumerate(centers_xy)} + + +def plot_fiducial_filtering(detector: FiducialStrategy) -> Figure: + """Pattern detection diagnostics: spatial scatter and feature space for top and bottom sides. + + Each row corresponds to one side (top / bottom). The left column shows detections in + global image space (cx vs cy) with the fitted polynomial edge overlaid. The right column + shows the raw feature space (matching score vs residual to the edge model). + + Valid patterns are highlighted with distinct colours; unmatched detections appear in light gray. """ - Plot the gradient of collimation lines before and after a transformation. - - This function computes and plots the gradients (first derivatives) of the - top and bottom collimation lines both before and after a geometric transformation. - It can either display the plot or save it to a specified path. - - Args: - collimation_lines (dict[str, RANSACRegressor]): - Dictionary containing RANSAC models for the "top" and "bottom" collimation lines before transformation. - tf_collimation_lines (dict[str, RANSACRegressor]): - Dictionary containing RANSAC models for the "top" and "bottom" collimation lines after transformation. - width (int): - Image width used to define the x-range for prediction. - nb_points (int, optional): - Number of points to sample along the x-axis. Defaults to 100. - plot (bool, optional): - If True, the plot will be displayed. If False, the figure will be closed after saving. Defaults to True. - output_plot_path (str | Path | None, optional): - Path to save the plot as an image file. If None, the plot is not saved. Defaults to None. - - Returns: - None + fig, axes = plt.subplots(2, 2, figsize=(14, 8), constrained_layout=True) + fig.suptitle(f"Fiducial pattern detection ({detector.raster_filepath_.stem})", fontsize=12, fontweight="bold") + + sides = ("top", "bottom") + results = (detector.top_, detector.bottom_) + edge_models = [detector.poly_strategy.top_.model, detector.poly_strategy.bottom_.model] + cmap = plt.get_cmap("tab10") + _noise = (0.85, 0.85, 0.85, 1.0) + + for row, (side, result, edge_model) in enumerate(zip(sides, results, edge_models)): + ax_spatial, ax_feat = axes[row] + + centers_xy = result.centers_xy + features = result.features + coord_idx = _coord_index(centers_xy) + + ax_spatial.scatter(centers_xy[:, 0], centers_xy[:, 1], c=[_noise], s=10, linewidths=0) + ax_feat.scatter(features[:, 0], features[:, 1], c=[_noise], s=10, linewidths=0) + + legend_handles: list[Line2D] = [] + + for i, (name, pattern) in enumerate(result.patterns.items()): + if pattern.count == 0: + continue + color = cmap(i % 10) + score = pattern.score + star = " ★" if score > detector.min_score_threshold else "" + + indices = [coord_idx[k] for pt in pattern.points if (k := (float(pt[0]), float(pt[1]))) in coord_idx] + if indices: + idx = np.array(indices) + ax_spatial.scatter(centers_xy[idx, 0], centers_xy[idx, 1], c=[color], s=25, linewidths=0) + ax_feat.scatter(features[idx, 0], features[idx, 1], c=[color], s=25, linewidths=0) + + legend_handles.append( + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=color, + markersize=6, + label=f"{name}{star} score={score:.3f} n={pattern.count}", + ) + ) + + x_grid = np.linspace(0, float(centers_xy[:, 0].max()), 300) + y_pred = edge_model.predict(x_grid.reshape(-1, 1)).ravel() + ax_spatial.plot(x_grid, y_pred, color="steelblue", linewidth=1.0, linestyle="--") + + ax_spatial.invert_yaxis() + ax_spatial.set_title(f"{side} — spatial ({len(centers_xy)} detections)") + ax_spatial.set_xlabel("cx (px)") + ax_spatial.set_ylabel("cy (px)") + + ax_feat.legend( + handles=legend_handles, loc="upper left", bbox_to_anchor=(1.02, 1.0), fontsize=7, borderaxespad=0 + ) + ax_feat.set_title(f"{side} — feature space") + ax_feat.set_xlabel("score") + ax_feat.set_ylabel("residual (px)") + + return fig + + +def plot_fiducial_distortions(detector: FiducialStrategy) -> Figure: + """Fiducial center y-deviation from mean, per valid pattern, for top and bottom sides. + + In the ideal case all points lie at 0. Divergence reveals scan distortion. """ - x = np.linspace(0, width, nb_points) - X = x.reshape(-1, 1) - y_top = collimation_lines["top"].predict(X) - y_bottom = collimation_lines["bottom"].predict(X) - y_tf_top = tf_collimation_lines["top"].predict(X) - y_tf_bottom = tf_collimation_lines["bottom"].predict(X) - - plt.plot(x, np.gradient(y_top, x), label="Top before transform") - plt.plot(x, np.gradient(y_bottom, x), label="Bottom before tranform") - plt.plot(x, np.gradient(y_tf_top, x), label="Top after transform") - plt.plot(x, np.gradient(y_tf_bottom, x), label="Bottom after transform") - - # Add title and axis labels - plt.title("Collimation Line Gradients Before and After Transformation") - plt.xlabel("Horizontal position (pixels)") - plt.ylabel("Gradient value") - - plt.legend() - - if output_plot_path is not None: - Path(output_plot_path).parent.mkdir(exist_ok=True, parents=True) - plt.savefig(output_plot_path) - - if plot: - plt.show() - else: - plt.close() - - -def plot_distance_between_collimation_lines( - collimation_lines: dict[str, RANSACRegressor], - tf_collimation_lines: dict[str, RANSACRegressor], - width: int, - true_distance_between_collimation: int, - nb_points: int = 100, - plot: bool = True, - output_plot_path: str | Path | None = None, -) -> None: + fig, ax = plt.subplots(figsize=(14, 4), constrained_layout=True) + fig.suptitle(f"Fiducial distortion — {detector.raster_filepath_.stem}", fontsize=12, fontweight="bold") + + for side, result in zip(["top", "bottom"], [detector.top_, detector.bottom_]): + for name, pattern in result.patterns.items(): + if pattern.score <= detector.min_score_threshold or pattern.count < 8: + continue + x = pattern.points[:, 0].astype(np.float64) + y = pattern.points[:, 1].astype(np.float64) + ax.scatter( + x, + y - y.mean(), + s=8, + marker="x" if side == "bottom" else "o", + label=f"{side} · {name} (n={len(x)})", + ) + + ax.axhline(0.0, color="gray", linewidth=0.8, linestyle=":") + ax.invert_yaxis() + ax.legend(fontsize=8) + ax.set_xlabel("column (px)") + ax.set_ylabel("distortion (px)") + + return fig + + +def plot_fiducial_detected_profiles(detector: FiducialStrategy, window_height_fraction: float = 0.08) -> Figure: + """Detected fiducial centers overlaid on the top and bottom image strips, one scatter per valid pattern.""" + sides_results = [detector.top_, detector.bottom_] + n_insets = max( + max(sum(1 for p in r.patterns.values() if p.score > detector.min_score_threshold) for r in sides_results), + 1, + ) + + fig = plt.figure(figsize=(18, 5), constrained_layout=True) + fig.suptitle(f"Fiducial detected profiles — {detector.raster_filepath_.stem}", fontsize=18, fontweight="bold") + gs = fig.add_gridspec(2, 1 + n_insets, width_ratios=[14] + [1] * n_insets) + main_axes = [fig.add_subplot(gs[row, 0]) for row in range(2)] + inset_slots = [[fig.add_subplot(gs[row, col + 1]) for col in range(n_insets)] for row in range(2)] + + with rasterio.open(detector.raster_filepath_) as src: + window_height = int(src.height * window_height_fraction) + windows = [ + Window(0, 0, src.width, window_height), + Window(0, src.height - window_height, src.width, window_height), + ] + edge_models = [detector.poly_strategy.top_.model, detector.poly_strategy.bottom_.model] + + def _spacing_info(cx: np.ndarray) -> str: + if len(cx) >= 2: + dists = np.diff(np.sort(cx.astype(np.float64))) + return f"spacing mean={float(dists.mean()):.1f}px std={float(dists.std()):.1f}px" + return "spacing n/a" + + for row, (ax, side, window, result, edge_model) in enumerate( + zip(main_axes, ["top", "bottom"], windows, sides_results, edge_models) + ): + sub_img = SubImage(src, window, (1, 512, 4096)) + ax.imshow(sub_img.band, cmap="gray", aspect="auto") + + ax_handles: list[Line2D] = [] + inset_data: list[tuple[str, np.ndarray]] = [] + + for name, pattern in result.patterns.items(): + if pattern.score <= detector.min_score_threshold: + continue + color = _PATTERN_COLORS.get(name, "white") + score = pattern.score + centers = pattern.points.astype(np.float64) + + if len(centers) > 0: + centers_local = sub_img.to_local(centers) + ax.scatter(centers_local[:, 0], centers_local[:, 1], c=color, s=20, zorder=3) + ax_handles.append( + Line2D( + [0], + [0], + marker="o", + color="w", + markerfacecolor=color, + markersize=7, + label=f"{side} {name} | score={score:.3f} | fiducials={pattern.count} | {_spacing_info(centers[:, 0])}", + ) + ) + mp = mean_patch_from_centers(src, centers) + if mp is not None: + inset_data.append((color, mp)) + + x_edge = np.linspace(0, src.width, 500) + edge_local = sub_img.to_local(np.column_stack([x_edge, edge_model.predict(x_edge.reshape(-1, 1)).ravel()])) + ax.plot(edge_local[:, 0], edge_local[:, 1], color="steelblue", linewidth=1.0, linestyle="--") + + ax.legend(handles=ax_handles, loc="lower center", bbox_to_anchor=(0.5, 1.0), fontsize=15, frameon=True) + ax.axis("off") + + for col, inset_ax in enumerate(inset_slots[row]): + if col < len(inset_data): + color, patch = inset_data[col] + inset_ax.imshow(patch, cmap="gray") + for spine in inset_ax.spines.values(): + spine.set_edgecolor(color) + spine.set_linewidth(2) + inset_ax.set_xticks([]) + inset_ax.set_yticks([]) + else: + inset_ax.set_visible(False) + + return fig + + +def plot_fiducial_detected_boxes(detector: FiducialStrategy) -> tuple[Figure, Figure]: + """One figure per side showing every detected fiducial box as a cropped patch. + + Boxes are colour-coded by pattern; unmatched detections appear in gray. """ - Plot the distance between the top and bottom collimation lines before and after transformation. - - This function computes and visualizes the vertical distance between two collimation lines - (top and bottom) across the image width, both before and after a geometric transformation. - It also overlays the expected true distance as a reference line to assess rectification accuracy. - - Args: - collimation_lines (dict[str, RANSACRegressor]): - Dictionary containing RANSAC models for the "top" and "bottom" collimation lines before transformation. - tf_collimation_lines (dict[str, RANSACRegressor]): - Dictionary containing RANSAC models for the "top" and "bottom" collimation lines after transformation. - width (int): - Image width used to define the x-range for prediction. - true_distance_between_collimation (int): - Expected true distance (in pixels) between the top and bottom collimation lines. - nb_points (int, optional): - Number of x-samples used to evaluate the fitted lines. Defaults to 100. - plot (bool, optional): - If True, displays the plot. If False, closes it after saving. Defaults to True. - output_plot_path (str | Path | None, optional): - Path to save the plot image. If None, the plot is not saved. Defaults to None. - - Returns: - None + figures: list[Figure] = [] + cmap = plt.get_cmap("tab10") + + for side, side_result in zip(("top", "bottom"), (detector.top_, detector.bottom_)): + boxes = side_result.boxes + scores = side_result.scores + centers_xy = side_result.centers_xy + n = len(boxes) + + coord_to_pattern: dict[tuple[float, float], tuple[str, Any]] = {} + for i, (name, pattern) in enumerate(side_result.patterns.items()): + color = cmap(i % 10) + for pt in pattern.points: + coord_to_pattern[(float(pt[0]), float(pt[1]))] = (name, color) + + _noise_color = (0.85, 0.85, 0.85, 1.0) + + grid = max(1, int(np.ceil(np.sqrt(n)))) + fig, axes_2d = plt.subplots(grid, grid, figsize=(grid * 2, grid * 2), squeeze=False, constrained_layout=True) + fig.suptitle(f"Detected fiducial boxes — {side} ({n} boxes)", fontsize=11, fontweight="bold") + axes = axes_2d.flatten() + + with rasterio.open(detector.raster_filepath_) as src: + for ax, box, score, (cx, cy) in zip(axes, boxes, scores, centers_xy): + x, y, w, h = box + band = src.read(1, window=Window(x, y, w, h)) + ax.imshow(band, cmap="gray", interpolation="nearest") + + match = coord_to_pattern.get((float(cx), float(cy))) + if match is not None: + pattern_name, color = match + label_str = pattern_name + else: + color = _noise_color + label_str = "unmatched" + + ax.set_title(f"{label_str} {score:.3f}", fontsize=7, color=color) + ax.axis("off") + + for ax in axes[n:]: + ax.axis("off") + + figures.append(fig) + + return figures[0], figures[1] + + +def mean_patch_from_centers( + src: str | Path | rasterio.DatasetReader, + centers: np.ndarray, + half_size: int = 50, +) -> np.ndarray | None: + """Compute the mean image patch (band 1) around a set of pixel centers. + + Uses an incremental float64 accumulator so peak memory is O(patch_size²) + regardless of the number of centers. Out-of-bounds regions are zero-padded + before averaging. Centers that fall entirely outside the raster are silently skipped. """ - x = np.linspace(0, width, nb_points) - X = x.reshape(-1, 1) - y_top = collimation_lines["top"].predict(X) - y_bottom = collimation_lines["bottom"].predict(X) - y_tf_top = tf_collimation_lines["top"].predict(X) - y_tf_bottom = tf_collimation_lines["bottom"].predict(X) - - dist_before_transformation = np.abs(y_top - y_bottom) - dist_after_transformation = np.abs(y_tf_top - y_tf_bottom) - - plt.plot(x, dist_before_transformation, label="Before transformation") - plt.plot(x, dist_after_transformation, label="After transformation") - - plt.axhline( - y=true_distance_between_collimation, - color="red", - linestyle="--", - label=f"True distance : {true_distance_between_collimation}", + if not isinstance(src, rasterio.DatasetReader): + with rasterio.open(src) as opened: + return mean_patch_from_centers(opened, centers, half_size) + + size = 2 * half_size + accumulator = np.zeros((size, size), dtype=np.float64) + count = 0 + + x0s: np.ndarray = centers[:, 0].astype(np.intp) - half_size + y0s: np.ndarray = centers[:, 1].astype(np.intp) - half_size + + for x0, y0 in zip(x0s, y0s): + x0c = max(0, int(x0)) + y0c = max(0, int(y0)) + x1c = min(src.width, int(x0) + size) + y1c = min(src.height, int(y0) + size) + if x1c <= x0c or y1c <= y0c: + continue + patch = src.read(1, window=Window(x0c, y0c, x1c - x0c, y1c - y0c)) + accumulator[y0c - y0 : y0c - y0 + patch.shape[0], x0c - x0 : x0c - x0 + patch.shape[1]] += patch + count += 1 + + return (accumulator / count).astype(np.float32) if count > 0 else None + + +# --- Transform --- + + +def plot_deformation_grid( + transform: Transformation, + num: int = 20, + figsize: tuple[int, int] = (6, 6), +) -> Figure: + """Visualize the deformation field by plotting warped grid lines.""" + with rasterio.open(transform.raster_filepath) as src: + w, h = src.width, src.height + + xs = np.linspace(0, w - 1, num, dtype=np.float32) + ys = np.linspace(0, h - 1, num, dtype=np.float32) + + fig, ax = plt.subplots(figsize=figsize) + + for y in ys: + line = np.stack([xs, np.full_like(xs, y)], axis=-1) + warped_line = transform.deformation(line) + ax.plot(warped_line[:, 0], warped_line[:, 1], color="gray", lw=0.8, alpha=0.7) + + for x in xs: + line = np.stack([np.full_like(ys, x), ys], axis=-1) + warped_line = transform.deformation(line) + ax.plot(warped_line[:, 0], warped_line[:, 1], color="gray", lw=0.8, alpha=0.7) + + ax.set_title("Warped deformation grid") + ax.invert_yaxis() + + return fig + + +def plot_crop_area(transform: Transformation, figsize: tuple[int, int] = (6, 6)) -> Figure: + """Visualize the crop region within the original image frame.""" + fig, ax = plt.subplots(figsize=figsize) + + with rasterio.open(transform.raster_filepath) as src: + w, h = src.width, src.height + + crop_x, crop_y = transform.crop_offset + crop_w, crop_h = transform.output_size + + ax.add_patch(patches.Rectangle((0, 0), w, h, fill=False, edgecolor="black", linewidth=2, label="Original image")) + ax.add_patch( + patches.Rectangle((crop_x, crop_y), crop_w, crop_h, fill=True, alpha=0.3, color="orange", label="Crop region") ) - plt.title("Distance Between Collimation Lines Before and After Transformation") - plt.xlabel("Horizontal position (pixels)") - plt.ylabel("Distance between lines (pixels)") + ax.scatter(crop_x, crop_y, color="red", marker="+", s=10, label="Crop origin (0,0 in crop space)") + + ax.set_xlim(0, w) + ax.set_ylim(0, h) + ax.set_aspect("auto") + ax.set_box_aspect(h / (w / 2)) + ax.invert_yaxis() + ax.set_title(f"Crop visualization\ncrop_offset = ({crop_x}, {crop_y}), size = ({crop_w}, {crop_h})") + ax.legend(loc="upper left", bbox_to_anchor=(1.02, 1)) - plt.legend() + return fig - if output_plot_path is not None: - Path(output_plot_path).parent.mkdir(exist_ok=True, parents=True) - plt.savefig(output_plot_path) - if plot: - plt.show() +# --- Dispatch --- + + +def _vertical_metrics(detector: VerticalDetector) -> dict[str, Any]: + expected_width = KH9ImageSpec.from_raster_filepath(detector.raster_filepath_).expected_size[0] + return {"expected_width": expected_width, "detected_width": detector.detected_width_} + + +def _poly_metrics(strategy: PolyStrategy) -> dict[str, Any]: + expected_height = KH9ImageSpec.from_raster_filepath(strategy.raster_filepath_).expected_size[1] + x = np.linspace(*strategy.vertical_detector.edges_, strategy.grid_shape[0]).reshape(-1, 1) + heights = strategy.bottom_.model.predict(x) - strategy.top_.model.predict(x) + return { + **_vertical_metrics(strategy.vertical_detector), + "expected_height": expected_height, + "detected_height": float(np.mean(heights)), + "detected_height_std": float(np.std(heights)), + } + + +def _fiducial_metrics(strategy: FiducialStrategy) -> dict[str, Any]: + primary_top = strategy.kh9_image_spec_.top_fiducial_patterns[0] + primary_bottom = strategy.kh9_image_spec_.bottom_fiducial_patterns[0] + top_pattern = strategy.top_.patterns[primary_top] + bottom_pattern = strategy.bottom_.patterns[primary_bottom] + expected_width = strategy.kh9_image_spec_.expected_size[0] + + top_spacings = compute_intra_segment_spacings(top_pattern.points) if len(top_pattern.points) > 1 else np.array([]) + bot_spacings = ( + compute_intra_segment_spacings(bottom_pattern.points) if len(bottom_pattern.points) > 1 else np.array([]) + ) + + return { + **_poly_metrics(strategy.poly_strategy), + "primary_top_pattern": primary_top, + "primary_bottom_pattern": primary_bottom, + "top_expected_fiducial_count": compute_expected_fiducial_count(primary_top, expected_width), + "top_detected_fiducial_count": top_pattern.count, + "top_true_spacing": theorical_spacing_from_pattern(primary_top), + "top_detected_mean_spacing": float(np.mean(top_spacings)) if len(top_spacings) else float("nan"), + "top_detected_std_spacing": float(np.std(top_spacings)) if len(top_spacings) else float("nan"), + "bottom_expected_fiducial_count": compute_expected_fiducial_count(primary_bottom, expected_width), + "bottom_detected_fiducial_count": bottom_pattern.count, + "bottom_true_spacing": theorical_spacing_from_pattern(primary_bottom), + "bottom_detected_mean_spacing": float(np.mean(bot_spacings)) if len(bot_spacings) else float("nan"), + "bottom_detected_std_spacing": float(np.std(bot_spacings)) if len(bot_spacings) else float("nan"), + } + + +def get_metrics(fitting_class: FittingClass) -> dict[str, Any] | None: + """Return metrics for the effective strategy, or None if not supported.""" + if isinstance(fitting_class, MixedStrategy): + return None if fitting_class.is_failed else get_metrics(fitting_class.selected_strategy_) + if isinstance(fitting_class, FiducialStrategy): + return {"strategy": "FiducialStrategy", **_fiducial_metrics(fitting_class)} + if isinstance(fitting_class, PolyStrategy): + return {"strategy": "PolyStrategy", **_poly_metrics(fitting_class)} + if isinstance(fitting_class, VerticalDetector): + return {"strategy": "VerticalDetector", **_vertical_metrics(fitting_class)} + return None + + +def save_metrics(fitting_class: FittingClass, output_dir: str | Path) -> None: + """Write or update one row in metrics.csv, keyed by image name.""" + metrics = get_metrics(fitting_class) + if metrics is None: + return + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + csv_path = output_dir / "metrics.csv" + row = {"image": fitting_class.raster_filepath_.stem, "processed_at": datetime.now().isoformat(), **metrics} + df_new = pd.DataFrame([row]) + if csv_path.exists(): + df = pd.read_csv(csv_path) + df = df[df["image"] != row["image"]] + df = pd.concat([df, df_new], ignore_index=True) else: - plt.close() + df = df_new + df.to_csv(csv_path, index=False) + + +def save_figures(fitting_class: FittingClass, output_dir: str | Path) -> None: + output_dir = Path(output_dir) + gen = get_figures(fitting_class) + while True: + try: + name, fig = next(gen) + (output_dir / name).mkdir(parents=True, exist_ok=True) + fig.savefig(output_dir / name / f"{fitting_class.raster_filepath_.stem}.png") + plt.close(fig) + except StopIteration: + break + except Exception as e: + logger.warning("Skipping QC figure: %s", e) + + +def get_figures(fitting_class: FittingClass, plot_transformation: bool = True) -> Iterator[tuple[str, Figure]]: + """Yield (name, figure) pairs for all QC plots of a fitted FittingClass instance.""" + if isinstance(fitting_class, VerticalDetector): + yield "vertical_edges", plot_vertical_edges(fitting_class) + yield "vertical_ruptures", plot_vertical_ruptures(fitting_class) + return + if isinstance(fitting_class, FlatStrategy): + yield from get_figures(fitting_class.vertical_detector, plot_transformation=False) + yield "flat_edges", plot_flat_edges(fitting_class) + yield "flat_ruptures", plot_flat_ruptures(fitting_class) + if plot_transformation: + yield "crop_area", plot_crop_area(fitting_class.transformation_) + return + if isinstance(fitting_class, PolyStrategy): + yield from get_figures(fitting_class.vertical_detector, plot_transformation=False) + yield "poly_edges", plot_poly_edges(fitting_class) + yield "poly_distortions", plot_poly_distortions(fitting_class) + if plot_transformation: + yield "deformation_grid", plot_deformation_grid(fitting_class.transformation_) + yield "crop_area", plot_crop_area(fitting_class.transformation_) + return + if isinstance(fitting_class, CollimationStrategy): + yield from get_figures(fitting_class.poly_strategy, plot_transformation=False) + yield "collimation_edges", plot_collimation_edges(fitting_class) + yield "collimation_distortions", plot_collimation_distortions(fitting_class) + if plot_transformation: + yield "deformation_grid", plot_deformation_grid(fitting_class.transformation_) + yield "crop_area", plot_crop_area(fitting_class.transformation_) + return + if isinstance(fitting_class, FiducialStrategy): + yield from get_figures(fitting_class.poly_strategy, plot_transformation=False) + yield "fiducial_filtering", plot_fiducial_filtering(fitting_class) + yield "fiducial_distortions", plot_fiducial_distortions(fitting_class) + yield "fiducial_detected_profiles", plot_fiducial_detected_profiles(fitting_class) + # yield from zip(("fiducial_boxes_top", "fiducial_boxes_bottom"), plot_fiducial_detected_boxes(fitting_class)) + if plot_transformation: + yield "deformation_grid", plot_deformation_grid(fitting_class.transformation_) + yield "crop_area", plot_crop_area(fitting_class.transformation_) + return + if isinstance(fitting_class, MixedStrategy): + yield from get_figures(fitting_class.selected_strategy_, plot_transformation=plot_transformation) diff --git a/src/hipp/kh9pc/restitution/__init__.py b/src/hipp/kh9pc/restitution/__init__.py new file mode 100644 index 0000000..a12838e --- /dev/null +++ b/src/hipp/kh9pc/restitution/__init__.py @@ -0,0 +1,15 @@ +from .collimation_strategy import CollimationStrategy +from .fiducial_strategy import FiducialStrategy +from .flat_strategy import FlatStrategy +from .mixed_strategy import MixedStrategy +from .poly_strategy import PolyStrategy +from .vertical_detector import VerticalDetector + +__all__ = [ + "CollimationStrategy", + "FiducialStrategy", + "FlatStrategy", + "PolyStrategy", + "MixedStrategy", + "VerticalDetector", +] diff --git a/src/hipp/kh9pc/restitution/base.py b/src/hipp/kh9pc/restitution/base.py new file mode 100644 index 0000000..84a686c --- /dev/null +++ b/src/hipp/kh9pc/restitution/base.py @@ -0,0 +1,114 @@ +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Self + +import numpy as np +from numpy.typing import NDArray +from sklearn.linear_model import LinearRegression, RANSACRegressor +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import PolynomialFeatures, StandardScaler + +logger = logging.getLogger(__name__) + +DEFAULT_OUTPUT_HEIGHT: int = 22064 +"""Standard output height in pixels for restituted KH-9 PC images (22064 px at nominal scan resolution).""" + + +class DetectionError(Exception): + """Raised when no valid detections are found during fitting.""" + + +class FittingClass(ABC): + def __init__(self) -> None: + self.__raster_filepath_: Path | None = None + + @property + def raster_filepath_(self) -> Path: + if self.__raster_filepath_ is None: + raise RuntimeError("Call fit() before.") + return self.__raster_filepath_ + + @property + def is_fitted(self) -> bool: + return self.__raster_filepath_ is not None + + @property + @abstractmethod + def is_failed(self) -> bool: ... + + def fit(self, raster_filepath: str | Path) -> Self: + raster_filepath = Path(raster_filepath) + logger.info("[%s] %s - start fit...", self.__class__.__name__, raster_filepath.name) + fit_res = self._fit(raster_filepath) + self.__raster_filepath_ = raster_filepath + logger.info( + "[%s] %s - finish fit : [%s]", + self.__class__.__name__, + raster_filepath.name, + "FAILED" if self.is_failed else "SUCCESS", + ) + return fit_res + + @abstractmethod + def _fit(self, raster_filepath: Path) -> Self: ... + + +class RestitutionStrategy(FittingClass): + @abstractmethod + def transform(self, output_path: str | Path) -> None: ... + + @property + @abstractmethod + def transformation_(self) -> "Transformation": ... + + +@dataclass +class Transformation: + raster_filepath: Path + deformation: Callable[[NDArray[np.float32]], NDArray[np.float32]] + crop_offset: tuple[float, float] = (0, 0) + output_size: tuple[int, int] = (0, 0) + + def inverse_remap(self, coords: NDArray[np.float32]) -> NDArray[np.float32]: + coords = coords + np.array([self.crop_offset[0], self.crop_offset[1]], dtype=coords.dtype) + return self.deformation(coords) + + +def fit_ransac_poly( + x: NDArray[np.generic], + y: NDArray[np.generic], + degree: int = 3, + residual_threshold: float = 100, + max_trials: int = 100, +) -> RANSACRegressor: + """Fit a polynomial regression with RANSAC on 1D data. Returns the fitted RANSACRegressor.""" + poly_model = make_pipeline( + PolynomialFeatures(degree=degree), + StandardScaler(), + LinearRegression(), + ) + + min_samples = min(degree * 3, len(x)) + ransac = RANSACRegressor( + poly_model, residual_threshold=residual_threshold, min_samples=min_samples, max_trials=max_trials + ) + ransac.fit(x.reshape(-1, 1), y) + return ransac + + +def detect_ruptures(vec: NDArray[np.number], threshold: float, reverse_scan: bool = False) -> NDArray[np.integer]: + """Detect indices where the signal drops below a threshold (falling edges). + + If reverse_scan is True, scan from the end and return indices in original coordinates. + """ + if reverse_scan: + vec = vec[::-1] + + idx = np.where((vec[1:] <= threshold) & (vec[:-1] > threshold))[0] + 1 + + if reverse_scan: + idx = len(vec) - 1 - idx + + return idx diff --git a/src/hipp/kh9pc/restitution/collimation_strategy.py b/src/hipp/kh9pc/restitution/collimation_strategy.py new file mode 100644 index 0000000..dc579b1 --- /dev/null +++ b/src/hipp/kh9pc/restitution/collimation_strategy.py @@ -0,0 +1,198 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Self + +import numpy as np +import rasterio +from numpy.typing import NDArray +from rasterio.warp import Resampling +from rasterio.windows import Window +from scipy.ndimage import gaussian_filter1d +from skimage.transform import ThinPlateSplineTransform +from sklearn.linear_model import RANSACRegressor + +from hipp.image import SubImage, remap_tif_blockwise +from hipp.kh9pc.restitution.base import fit_ransac_poly +from hipp.kh9pc.restitution.base import DEFAULT_OUTPUT_HEIGHT, RestitutionStrategy, Transformation +from hipp.kh9pc.restitution.poly_strategy import PolyStrategy + + +@dataclass +class CollimationResult: + peaks_local: NDArray[np.integer] + peaks_global: NDArray[np.integer] + distortion: NDArray[np.floating] + inlier_ratio: float + model: RANSACRegressor + sub_image: SubImage + + +@dataclass +class CollimationStrategy(RestitutionStrategy): + poly_strategy: PolyStrategy = field(default_factory=PolyStrategy) + polynomial_degree: int = 5 + ransac_residual_threshold: float = 80.0 + ransac_max_trials: int = 1000 + grid_shape: tuple[int, int] = (100, 50) + stride: int = 10 + refinement_fraction: float = 0.03 + max_width_peak: int = 200 + collimation_line_dist: int = ( + 21770 # known physical distance between top/bottom collimation lines at nominal scan resolution + ) + min_inliers_threshold: float = 0.5 + output_width: int | None = None + output_height: int | None = DEFAULT_OUTPUT_HEIGHT + + def __post_init__(self) -> None: + super().__init__() + self._results: dict[str, CollimationResult] = {} + self.__transformation_: Transformation | None = None + + @property + def is_failed(self) -> bool: + return min(self.top_.inlier_ratio, self.bottom_.inlier_ratio) < self.min_inliers_threshold + + @property + def top_(self) -> CollimationResult: + if "top" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["top"] + + @property + def bottom_(self) -> CollimationResult: + if "bottom" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["bottom"] + + @property + def transformation_(self) -> Transformation: + if self.__transformation_ is None: + self.__transformation_ = self._compute_transformation() + return self.__transformation_ + + def _fit(self, raster_filepath: Path) -> Self: + if not self.poly_strategy.is_fitted or raster_filepath != self.poly_strategy.raster_filepath_: + self.poly_strategy.fit(raster_filepath) + + col_off, col_end = self.poly_strategy.vertical_detector.edges_ + window_width = self.poly_strategy.vertical_detector.detected_width_ + col_center = (col_off + col_end) // 2 + + with rasterio.open(raster_filepath) as src: + window_height = int(src.height * self.refinement_fraction) + out_shape = (1, window_height // self.stride, self.grid_shape[0]) + + top_edge = int(self.poly_strategy.top_.model.predict(np.array([[col_center]])).flat[0]) + bot_edge = int(self.poly_strategy.bottom_.model.predict(np.array([[col_center]])).flat[0]) + + for side, window in { + "top": Window(col_off, top_edge, window_width, window_height), + "bottom": Window(col_off, bot_edge - window_height, window_width, window_height), + }.items(): + sub_image = SubImage(src, window, out_shape, resampling=Resampling.average) + self._results[side] = self._process_side(sub_image, side) + + return self + + def _process_side(self, sub_image: SubImage, side: str) -> CollimationResult: + _, w = sub_image.band.shape + + peaks_local = np.zeros((w, 2), dtype=int) + for col in range(w): + vec = sub_image.band[:, col] + idx = detect_collimation_peak(vec, max_peak_width=self.max_width_peak // self.stride) + peaks_local[col, 0] = col + peaks_local[col, 1] = idx + + peaks_global = sub_image.to_global(peaks_local).astype(int) + + model = fit_ransac_poly( + peaks_global[:, 0], + peaks_global[:, 1], + degree=self.polynomial_degree, + residual_threshold=self.ransac_residual_threshold, + max_trials=self.ransac_max_trials, + ) + + inlier_ratio = float(model.inlier_mask_.mean()) + + y_global_pred = model.predict(peaks_global[:, 0].reshape(-1, 1)) + y_distortion = y_global_pred - y_global_pred.mean() + distortion = np.column_stack([peaks_global[:, 0], y_distortion]) + + return CollimationResult( + peaks_local=peaks_local, + peaks_global=peaks_global, + distortion=distortion, + inlier_ratio=inlier_ratio, + model=model, + sub_image=sub_image, + ) + + def _compute_transformation(self) -> Transformation: + left, right = self.poly_strategy.vertical_detector.edges_ + detected_width = self.poly_strategy.vertical_detector.detected_width_ + output_width = self.output_width or detected_width + + x = np.linspace(left, right, self.grid_shape[0]) + + y_top_src = self.top_.model.predict(x.reshape(-1, 1)) + y_bot_src = self.bottom_.model.predict(x.reshape(-1, 1)) + + top = int(np.median(y_top_src)) + bot = top + self.collimation_line_dist + detected_height = bot - top + output_height = self.output_height or detected_height + + y_top_dst = np.full_like(x, top) + y_bot_dst = np.full_like(x, bot) + + src = np.column_stack((np.concatenate((x, x)), np.concatenate((y_top_src, y_bot_src)))) + dst = np.column_stack((np.concatenate((x, x)), np.concatenate((y_top_dst, y_bot_dst)))) + + # inverse source destination (important) + deformation = ThinPlateSplineTransform().from_estimate(dst, src) + + # ---- CENTERING TO OUTPUT ---- + pad_x = (output_width - detected_width) / 2 + pad_y = (output_height - detected_height) / 2 + + crop_offset = (int(left - pad_x), int(top - pad_y)) + + return Transformation( + self.raster_filepath_, + deformation, + crop_offset=crop_offset, + output_size=(output_width, output_height), + ) + + def transform(self, output_path: str | Path) -> None: + tf = self.transformation_ + + remap_tif_blockwise( + tf.raster_filepath, + output_path, + tf.inverse_remap, + tf.output_size, + block_size=2**13, + lowres_step=100, + ) + + +def detect_collimation_peak(x: NDArray[np.number], max_peak_width: int, sigma: int = 2) -> int: + smooth = gaussian_filter1d(x, sigma=sigma) + + grad = np.gradient(smooth) + + idx_max = np.argmax(grad) + idx_min = np.argmin(grad) + + if abs(idx_max - idx_min) < max_peak_width and idx_max != idx_min: + w_start = min(idx_max, idx_min) + w_end = max(idx_max, idx_min) + idx = np.argmax(smooth[w_start:w_end]) + w_start + else: + idx = np.argmax(smooth) # fallback + + return int(idx) diff --git a/src/hipp/kh9pc/restitution/fiducial_strategy.py b/src/hipp/kh9pc/restitution/fiducial_strategy.py new file mode 100644 index 0000000..37da2a9 --- /dev/null +++ b/src/hipp/kh9pc/restitution/fiducial_strategy.py @@ -0,0 +1,351 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Literal, Self + +import cv2 +import numpy as np +import rasterio +from numpy.typing import NDArray +from rasterio.windows import Window +from skimage.transform import ThinPlateSplineTransform +from sklearn.cluster import DBSCAN +from sklearn.linear_model import RANSACRegressor +from sklearn.preprocessing import StandardScaler + +from hipp.image import SubImage, match_multiple_templates, remap_tif_blockwise +from hipp.kh9pc.fiducial_patterns import ( + PATTERNS, + DetectedPattern, + centers_xy_from_boxes, + compute_global_src_and_dst_points, + evaluate_pattern, +) +from hipp.kh9pc.kh9_image_spec import KH9ImageSpec +from hipp.kh9pc.restitution.base import DEFAULT_OUTPUT_HEIGHT, DetectionError, RestitutionStrategy, Transformation +from hipp.kh9pc.restitution.poly_strategy import PolyStrategy + + +@dataclass +class FiducialResult: + boxes: NDArray[np.int_] + scores: NDArray[np.floating] + patterns: dict[str, DetectedPattern] + features: NDArray[np.floating] # (N, 2): (matching_score, residual_to_edge) + + @property + def centers_xy(self) -> NDArray[np.floating]: + return centers_xy_from_boxes(self.boxes) + + +_TEMPLATE_DIR = Path(__file__).parent / "templates" +_BLOCK_MARGIN = 0.1 # overlap fraction between adjacent blocks + +_KIND_TEMPLATES: dict[str, list[Path]] = { + "disk": sorted(_TEMPLATE_DIR.glob("disk*.png")), + "wagon_wheel": sorted(_TEMPLATE_DIR.glob("wagon_wheel.png")), +} + + +def _load_kind(kind: str) -> list[cv2.typing.MatLike]: + paths = _KIND_TEMPLATES.get(kind, []) + templates = [img for p in paths if (img := cv2.imread(str(p), cv2.IMREAD_GRAYSCALE)) is not None] + if not templates: + raise DetectionError(f"No templates loaded for kind {kind!r} — check {_TEMPLATE_DIR}") + return templates + + +@dataclass +class FiducialStrategy(RestitutionStrategy): + """Detect fiducial markers above and below the image area using template matching. + + The strategy relies on a fitted ``PolyStrategy`` to locate the top/bottom image + edges, then slides overlapping blocks along the horizontal axis and runs + multi-template matching on the strip above the top edge and below the bottom + edge. A final global NMS pass deduplicates matches that span adjacent blocks. + + Parameters + ---------- + poly_strategy: + Fitted (or to-be-fitted) strategy that provides the horizontal edge models. + kh9_image_spec: + Image specification describing which fiducial template set to use + (``fiducial_type`` field: ``"disk"`` for missions D3C1201–D3C1213, + ``"wagon_wheel"`` for D3C1214–D3C1219). If ``None`` (default), the spec + is inferred from the image filename at fit time via + block_width: + Width in pixels of each scanning block. + threshold: + Minimum template-matching score to keep a detection. + nms_threshold: + IoU threshold for non-maximum suppression within a block and globally. + horizontal_margins: + ``(left, right)`` fractional margins relative to the detected image width. + The search window is inset by ``left * width`` on the left and + ``right * width`` on the right of the vertical edges. + """ + + poly_strategy: PolyStrategy = field(default_factory=PolyStrategy) + polynomial_degree: int = 7 + block_width: int = 512 + threshold: float = 0.5 + nms_threshold: float = 0.1 + output_width: int | None = None + output_height: int | None = DEFAULT_OUTPUT_HEIGHT + min_score_threshold: float = 0.8 + + def __post_init__(self) -> None: + super().__init__() + self._results: dict[str, FiducialResult] = {} + self.__transformation_: Transformation | None = None + + # ------------------------------------------------------------------ + # Public properties + # ------------------------------------------------------------------ + + @property + def top_(self) -> FiducialResult: + if "top" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["top"] + + @property + def bottom_(self) -> FiducialResult: + if "bottom" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["bottom"] + + @property + def is_failed(self) -> bool: + return not all( + any( + p.score > self.min_score_threshold + for name, p in result.patterns.items() + if "mid" in name or "sparse" in name + ) + for result in (self.top_, self.bottom_) + ) + + @property + def transformation_(self) -> Transformation: + if self.__transformation_ is None: + self.__transformation_ = self._compute_transformation() + return self.__transformation_ + + def transform(self, output_path: str | Path) -> None: + tf = self.transformation_ + remap_tif_blockwise( + tf.raster_filepath, + output_path, + tf.inverse_remap, + tf.output_size, + block_size=2**13, + lowres_step=100, + ) + + # ------------------------------------------------------------------ + # Fitting + # ------------------------------------------------------------------ + + def _fit(self, raster_filepath: Path) -> Self: + self.kh9_image_spec_ = KH9ImageSpec.from_raster_filepath(raster_filepath) + self.templates_ = _load_kind(self.kh9_image_spec_.fiducial_type) + + if not self.poly_strategy.is_fitted or raster_filepath != self.poly_strategy.raster_filepath_: + self.poly_strategy.fit(raster_filepath) + + _, col_end = self.poly_strategy.vertical_detector.edges_ + + with rasterio.open(raster_filepath) as src: + for side in ("top", "bottom"): + boxes, scores, ids = self._scan_side(src, 0, col_end, side) + + # apply NMS to remove duplicate detection + indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), self.threshold, self.nms_threshold) + keep = np.array(indices).reshape(-1).astype(int) + boxes, scores, ids = boxes[keep], scores[keep], ids[keep] + + if len(boxes) == 0: + raise DetectionError(f"no fiducials detected on {side} side of {raster_filepath.name}") + + patterns, features = self._search_patterns(boxes, scores, side) + + self._results[side] = FiducialResult(boxes, scores, patterns, features) + + return self + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _scan_side( + self, src: rasterio.DatasetReader, col_start: int, col_end: int, side: Literal["top", "bottom"] + ) -> tuple[NDArray[np.int_], NDArray[np.float64], NDArray[np.int_]]: + """Slide overlapping blocks across [col_start, col_end] for one side. + + Each block overlaps its neighbours by ``_BLOCK_MARGIN`` so that fiducials + near block boundaries are not missed. Detections from all blocks are merged + into a single array; global NMS is applied by the caller. + + Parameters + ---------- + src: + Open rasterio dataset to read image strips from. + col_start, col_end: + Column bounds of the active image area (from the vertical detector). + side: + ``"top"`` scans the strip above the top edge; ``"bottom"`` scans below. + + Returns + ------- + tuple of (boxes, scores, template_ids) in global raster coordinates. + """ + boxes: list[list[int]] = [] + scores: list[float] = [] + template_ids: list[int] = [] + + for cursor in range(col_start, col_end, self.block_width): + # left_boundary extends 2 % of the detected width to the left of col_start + w_start = max(col_start, int(cursor - self.block_width * _BLOCK_MARGIN)) + w_end = min(col_end, int(cursor + self.block_width * (1 + _BLOCK_MARGIN))) + w_width = w_end - w_start + w_center = w_start + w_width // 2 + + # strip above the top edge: rows 0 → predicted top row + if side == "top": + top_row = int(self.poly_strategy.top_.model.predict(np.array([[w_center]])).flat[0]) + window = Window(w_start, 0, w_width, top_row) + # strip below the bottom edge: predicted bottom row → end of raster + else: + bot_row = int(self.poly_strategy.bottom_.model.predict(np.array([[w_center]])).flat[0]) + window = Window(w_start, bot_row, w_width, src.height - bot_row) + + if window.height <= 0 or window.width <= 0: + continue + + sub_image = SubImage(src, window) + local_boxes, block_scores, block_ids = match_multiple_templates( + image=sub_image.band, + templates=self.templates_, + threshold=self.threshold, + nms_threshold=self.nms_threshold, + ) + + # convert local block coordinates to global raster coordinates + for x, y, w, h in local_boxes: + gx, gy = sub_image.to_global(np.array([x, y], dtype=np.float64)) + boxes.append([int(gx), int(gy), w, h]) + scores.extend(block_scores) + template_ids.extend(block_ids) + + np_boxes = np.array(boxes, dtype=np.int_).reshape(-1, 4) if boxes else np.empty((0, 4), dtype=np.int_) + return np_boxes, np.array(scores, dtype=np.float64), np.array(template_ids, dtype=np.int_) + + @staticmethod + def _compute_detection_features( + boxes: NDArray[np.int_], + scores: NDArray[np.float64], + model: RANSACRegressor, + ) -> tuple[NDArray[np.floating], NDArray[np.floating]]: + centers_xy = centers_xy_from_boxes(boxes) + residuals = np.abs(centers_xy[:, 1] - model.predict(centers_xy[:, 0].reshape(-1, 1)).ravel()) + features: NDArray[np.floating] = np.column_stack((scores, residuals)) + return centers_xy, features + + def _score_clusters( + self, + labels: NDArray[np.int_], + centers_xy: NDArray[np.floating], + fiducial_pattern: PATTERNS, + ) -> DetectedPattern: + expected_width = self.kh9_image_spec_.expected_size[0] + result = evaluate_pattern(fiducial_pattern, np.empty((0, 2), dtype=np.float64), expected_width) + + for label in np.unique(labels): + if label == -1: + continue + mask = labels == label + if mask.sum() < 5: + continue + + detected_pattern = evaluate_pattern(fiducial_pattern, centers_xy[mask], expected_width) + if result.score < detected_pattern.score: + result = detected_pattern + + return result + + def _grid_search_clustering( + self, + features: NDArray[np.floating], + centers_xy: NDArray[np.floating], + fiducial_patterns: tuple[PATTERNS, PATTERNS], + ) -> dict[str, DetectedPattern]: + X_scaled: NDArray[np.floating] = StandardScaler().fit_transform(features) + patterns: dict[str, DetectedPattern] = { + pt: evaluate_pattern(pt, np.empty((0, 2), dtype=np.float64), self.kh9_image_spec_.expected_size[0]) + for pt in fiducial_patterns + } + + for rw in np.linspace(0.5, 5, 20): + X_weighted = (X_scaled * np.array([1.0, rw])).astype(np.float64) + for eps in np.linspace(0.1, 5, 20): + labels: NDArray[np.int_] = DBSCAN(eps, min_samples=5).fit(X_weighted).labels_ + + for pt in fiducial_patterns: + pattern = self._score_clusters(labels, centers_xy, pt) + if patterns[pt].score < pattern.score: + patterns[pt] = pattern + + return patterns + + def _search_patterns( + self, + boxes: NDArray[np.int_], + scores: NDArray[np.float64], + side: Literal["top", "bottom"], + ) -> tuple[dict[str, DetectedPattern], NDArray[np.floating]]: + if side == "top": + model = self.poly_strategy.top_.model + fiducial_patterns = self.kh9_image_spec_.top_fiducial_patterns + else: + model = self.poly_strategy.bottom_.model + fiducial_patterns = self.kh9_image_spec_.bottom_fiducial_patterns + + centers_xy, features = self._compute_detection_features(boxes, scores, model) + patterns = self._grid_search_clustering(features, centers_xy, fiducial_patterns) + + return patterns, features + + def _compute_transformation(self) -> Transformation: + if self.is_failed: + raise DetectionError("Can't compute the transformation with a failed estimation") + + # use only primary patterns (sparse & mid) cause we know the theorical spacing + primary_top_pattern = self.kh9_image_spec_.top_fiducial_patterns[0] + primary_bottom_pattern = self.kh9_image_spec_.bottom_fiducial_patterns[0] + + top_pattern = self.top_.patterns[primary_top_pattern] + bottom_pattern = self.bottom_.patterns[primary_bottom_pattern] + + src_pts, dst_pts = compute_global_src_and_dst_points(top_pattern, bottom_pattern) + + y_center = (dst_pts[:, 1].min() + dst_pts[:, 1].max()) / 2 + + # map vertical edges from src space to dst space to get a correct x_center + forward_tps = ThinPlateSplineTransform().from_estimate(src_pts, dst_pts) + col_left, col_right = self.poly_strategy.vertical_detector.edges_ + edges_dst = forward_tps(np.array([[col_left, y_center], [col_right, y_center]], dtype=np.float32)) + x_center = float((edges_dst[0, 0] + edges_dst[1, 0]) / 2) + + final_width, final_height = self.kh9_image_spec_.expected_size + crop_offset = (int(x_center - final_width / 2), int(y_center - final_height / 2)) + + # inverse source destination (important) + deformation = ThinPlateSplineTransform().from_estimate(dst_pts, src_pts) + + # test for the moment without any crop to detect an other time for quality control and qc + return Transformation( + self.raster_filepath_, + deformation, + crop_offset=crop_offset, + output_size=(final_width, final_height), + ) diff --git a/src/hipp/kh9pc/restitution/flat_strategy.py b/src/hipp/kh9pc/restitution/flat_strategy.py new file mode 100644 index 0000000..7418d9a --- /dev/null +++ b/src/hipp/kh9pc/restitution/flat_strategy.py @@ -0,0 +1,110 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Self + +import numpy as np +import rasterio +from rasterio.windows import Window + +from hipp.image import SubImage, remap_tif_blockwise +from hipp.kh9pc.restitution.base import detect_ruptures +from hipp.kh9pc.restitution.base import DEFAULT_OUTPUT_HEIGHT, RestitutionStrategy, Transformation +from hipp.kh9pc.restitution.vertical_detector import VerticalDetector + + +@dataclass +class FlatResult: + position: int + rupture_local: int + sub_image: SubImage + + +@dataclass +class FlatStrategy(RestitutionStrategy): + vertical_detector: VerticalDetector = field(default_factory=VerticalDetector) + background_threshold: int = 20 + height_fraction: float = 0.15 + stride: int = 10 + output_width: int | None = None + output_height: int | None = DEFAULT_OUTPUT_HEIGHT + + def __post_init__(self) -> None: + super().__init__() + self._results: dict[str, FlatResult] = {} + self.__transformation_: Transformation | None = None + + @property + def is_failed(self) -> bool: + return False + + @property + def top_(self) -> FlatResult: + if "top" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["top"] + + @property + def bottom_(self) -> FlatResult: + if "bottom" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["bottom"] + + @property + def transformation_(self) -> Transformation: + if self.__transformation_ is None: + self.__transformation_ = self._compute_transformation() + return self.__transformation_ + + def _fit(self, raster_filepath: Path) -> Self: + if not self.vertical_detector.is_fitted or raster_filepath != self.vertical_detector.raster_filepath_: + self.vertical_detector.fit(raster_filepath) + + col_off, _ = self.vertical_detector.edges_ + window_width = self.vertical_detector.detected_width_ + + with rasterio.open(raster_filepath) as src: + window_height = int(src.height * self.height_fraction) + out_shape = (1, window_height // self.stride, 1) + + for side, window in { + "top": Window(col_off, 0, window_width, window_height), + "bottom": Window(col_off, src.height - window_height, window_width, window_height), + }.items(): + sub_image = SubImage(src, window, out_shape) + self._results[side] = self._process_side(sub_image, side) + + return self + + def _process_side(self, sub_image: SubImage, side: str) -> FlatResult: + ruptures = detect_ruptures(sub_image.band.flatten(), self.background_threshold, reverse_scan=(side == "top")) + if len(ruptures) == 0: + raise RuntimeError(f"No rupture detected on the {side} edge.") + rupture_local = int(ruptures[0]) + position = int(sub_image.to_global(np.array([0.0, rupture_local]))[1]) + return FlatResult(position=position, rupture_local=rupture_local, sub_image=sub_image) + + def _compute_transformation(self) -> Transformation: + left, right = self.vertical_detector.edges_ + detected_width = self.vertical_detector.detected_width_ + output_width = self.output_width or detected_width + + top = self.top_.position + bot = self.bottom_.position + detected_height = bot - top + output_height = self.output_height or detected_height + + pad_x = (output_width - detected_width) / 2 + pad_y = (output_height - detected_height) / 2 + + crop_offset = (int(left - pad_x), int(top - pad_y)) + + return Transformation( + self.raster_filepath_, + lambda coords: coords, + crop_offset=crop_offset, + output_size=(output_width, output_height), + ) + + def transform(self, output_path: str | Path) -> None: + tf = self.transformation_ + remap_tif_blockwise(tf.raster_filepath, output_path, tf.inverse_remap, tf.output_size) diff --git a/src/hipp/kh9pc/restitution/mixed_strategy.py b/src/hipp/kh9pc/restitution/mixed_strategy.py new file mode 100644 index 0000000..e88866b --- /dev/null +++ b/src/hipp/kh9pc/restitution/mixed_strategy.py @@ -0,0 +1,86 @@ +import logging +from dataclasses import dataclass, field +from pathlib import Path + +from hipp.kh9pc.restitution.collimation_strategy import CollimationStrategy +from hipp.kh9pc.restitution.fiducial_strategy import FiducialStrategy +from hipp.kh9pc.restitution.flat_strategy import FlatStrategy +from hipp.kh9pc.restitution.poly_strategy import PolyStrategy +from hipp.kh9pc.restitution.base import RestitutionStrategy, Transformation + +logger = logging.getLogger(__name__) + + +@dataclass +class MixedStrategy(RestitutionStrategy): + strategies: list[RestitutionStrategy] = field( + default_factory=lambda: [FiducialStrategy(), CollimationStrategy(), PolyStrategy(), FlatStrategy()] + ) + poly_strategy: PolyStrategy = field(default_factory=PolyStrategy) + + def __post_init__(self) -> None: + super().__init__() + self.__selected_strategy_: RestitutionStrategy | None = None + + for i, strat in enumerate(self.strategies): + if hasattr(strat, "vertical_detector"): + setattr(strat, "vertical_detector", self.poly_strategy.vertical_detector) + if hasattr(strat, "poly_strategy"): + setattr(strat, "poly_strategy", self.poly_strategy) + # replace any standalone PolyStrategy with the shared instance to avoid recomputation + if isinstance(strat, PolyStrategy) and strat is not self.poly_strategy: + self.strategies[i] = self.poly_strategy + + @property + def is_failed(self) -> bool: + if not self.is_fitted: + raise RuntimeError("call fit() before") + return self.__selected_strategy_ is None + + @property + def selected_strategy_(self) -> RestitutionStrategy: + if not self.is_fitted: + raise RuntimeError("call fit() before") + + if self.__selected_strategy_ is None: + raise RuntimeError("All strategies failed") + + return self.__selected_strategy_ + + @property + def failed_strategies(self) -> list[RestitutionStrategy]: + if not self.is_fitted: + raise RuntimeError("call fit() before") + + if self.__selected_strategy_ is None: + return self.strategies + + idx = self.strategies.index(self.__selected_strategy_) + return self.strategies[:idx] + + @property + def transformation_(self) -> Transformation: + return self.selected_strategy_.transformation_ + + def transform(self, output_path: str | Path) -> None: + self.selected_strategy_.transform(output_path) + + def _fit(self, raster_filepath: Path) -> "MixedStrategy": + self.__selected_strategy_ = None + + vd = self.poly_strategy.vertical_detector + if not vd.is_fitted or raster_filepath != vd.raster_filepath_: + vd.fit(raster_filepath) + + for strat in self.strategies: + try: + if not strat.is_fitted or raster_filepath != strat.raster_filepath_: + strat.fit(raster_filepath) + except Exception: + logger.warning("%s failed for %s", type(strat).__name__, raster_filepath.name, exc_info=True) + continue + if not strat.is_failed: + self.__selected_strategy_ = strat + break + + return self diff --git a/src/hipp/kh9pc/restitution/poly_strategy.py b/src/hipp/kh9pc/restitution/poly_strategy.py new file mode 100644 index 0000000..d84aeb3 --- /dev/null +++ b/src/hipp/kh9pc/restitution/poly_strategy.py @@ -0,0 +1,174 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Self + +import numpy as np +import rasterio +from numpy.typing import NDArray +from rasterio.windows import Window +from skimage.transform import ThinPlateSplineTransform +from sklearn.linear_model import RANSACRegressor + +from hipp.image import SubImage, remap_tif_blockwise +from hipp.kh9pc.restitution.base import detect_ruptures, fit_ransac_poly +from hipp.kh9pc.restitution.base import DEFAULT_OUTPUT_HEIGHT, RestitutionStrategy, Transformation +from hipp.kh9pc.restitution.vertical_detector import VerticalDetector + + +@dataclass +class PolyResult: + ruptures_local: NDArray[np.integer] + ruptures_global: NDArray[np.integer] + distortion: NDArray[np.floating] + inlier_ratio: float + model: RANSACRegressor + sub_image: SubImage + + +@dataclass +class PolyStrategy(RestitutionStrategy): + vertical_detector: VerticalDetector = field(default_factory=VerticalDetector) + background_threshold: int = 20 + height_fraction: float = 0.15 + stride: int = 10 + polynomial_degree: int = 2 + ransac_residual_threshold: float = 80.0 + ransac_max_trials: int = 1000 + grid_shape: tuple[int, int] = (100, 50) + min_inliers_threshold: float = 0.5 + output_width: int | None = None + output_height: int | None = DEFAULT_OUTPUT_HEIGHT + + def __post_init__(self) -> None: + super().__init__() + self._results: dict[str, PolyResult] = {} + self.__transformation_: Transformation | None = None + + @property + def is_failed(self) -> bool: + return min(self.top_.inlier_ratio, self.bottom_.inlier_ratio) < self.min_inliers_threshold + + @property + def top_(self) -> PolyResult: + if "top" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["top"] + + @property + def bottom_(self) -> PolyResult: + if "bottom" not in self._results: + raise RuntimeError("Call fit() before") + return self._results["bottom"] + + @property + def transformation_(self) -> Transformation: + if self.__transformation_ is None: + self.__transformation_ = self._compute_transformation() + return self.__transformation_ + + def _fit(self, raster_filepath: Path) -> Self: + if not self.vertical_detector.is_fitted or raster_filepath != self.vertical_detector.raster_filepath_: + self.vertical_detector.fit(raster_filepath) + + col_off, _ = self.vertical_detector.edges_ + window_width = self.vertical_detector.detected_width_ + + with rasterio.open(raster_filepath) as src: + window_height = int(src.height * self.height_fraction) + out_shape = (1, window_height // self.stride, self.grid_shape[0]) + + for side, window in { + "top": Window(col_off, 0, window_width, window_height), + "bottom": Window(col_off, src.height - window_height, window_width, window_height), + }.items(): + sub_image = SubImage(src, window, out_shape) + self._results[side] = self._process_side(sub_image, side) + + return self + + def _process_side(self, sub_image: SubImage, side: str) -> PolyResult: + res = [] + for i in range(sub_image.band.shape[1]): + ruptures = detect_ruptures(sub_image.band[:, i], self.background_threshold, reverse_scan=(side == "top")) + if len(ruptures) > 0: + res.append((i, ruptures[0])) + + if not res: + raise RuntimeError(f"No rupture detected on the {side} edge.") + + ruptures_local = np.array(res) + ruptures_global = sub_image.to_global(ruptures_local) + + model = fit_ransac_poly( + ruptures_global[:, 0], + ruptures_global[:, 1], + degree=self.polynomial_degree, + residual_threshold=self.ransac_residual_threshold, + max_trials=self.ransac_max_trials, + ) + + inlier_ratio = float(model.inlier_mask_.mean()) + + x_sample = np.linspace( + sub_image.window.col_off, sub_image.window.col_off + sub_image.window.width, self.grid_shape[0] + ) + y_global_pred = model.predict(x_sample.reshape(-1, 1)).ravel() + y_distortion = y_global_pred - y_global_pred.mean() + distortion = np.column_stack([x_sample, y_distortion]) + + return PolyResult( + ruptures_local=ruptures_local, + ruptures_global=ruptures_global.astype(int), + distortion=distortion, + inlier_ratio=inlier_ratio, + model=model, + sub_image=sub_image, + ) + + def _compute_transformation(self) -> Transformation: + left, right = self.vertical_detector.edges_ + detected_width = self.vertical_detector.detected_width_ + output_width = self.output_width or detected_width + + x = np.linspace(left, right, self.grid_shape[0]) + + y_top_src = self.top_.model.predict(x.reshape(-1, 1)) + y_bot_src = self.bottom_.model.predict(x.reshape(-1, 1)) + + top, bot = int(np.median(y_top_src)), int(np.median(y_bot_src)) + detected_height = bot - top + output_height = self.output_height or detected_height + + y_top_dst = np.full_like(x, top) + y_bot_dst = np.full_like(x, bot) + + src = np.column_stack((np.concatenate((x, x)), np.concatenate((y_top_src, y_bot_src)))) + dst = np.column_stack((np.concatenate((x, x)), np.concatenate((y_top_dst, y_bot_dst)))) + + # inverse source destination (important) + deformation = ThinPlateSplineTransform().from_estimate(dst, src) + + # ---- CENTERING TO OUTPUT ---- + pad_x = (output_width - detected_width) / 2 + pad_y = (output_height - detected_height) / 2 + + crop_offset = (int(left - pad_x), int(top - pad_y)) + + return Transformation( + self.raster_filepath_, + deformation, + crop_offset=crop_offset, + output_size=(output_width, output_height), + ) + + def transform(self, output_path: str | Path) -> None: + tf = self.transformation_ + + remap_tif_blockwise( + tf.raster_filepath, + output_path, + tf.inverse_remap, + tf.output_size, + block_size=2**13, + lowres_step=100, + ) diff --git a/src/hipp/kh9pc/restitution/templates/disk_15.png b/src/hipp/kh9pc/restitution/templates/disk_15.png new file mode 100644 index 0000000..6aab974 Binary files /dev/null and b/src/hipp/kh9pc/restitution/templates/disk_15.png differ diff --git a/src/hipp/kh9pc/restitution/templates/disk_20.png b/src/hipp/kh9pc/restitution/templates/disk_20.png new file mode 100644 index 0000000..858f56d Binary files /dev/null and b/src/hipp/kh9pc/restitution/templates/disk_20.png differ diff --git a/src/hipp/kh9pc/restitution/templates/disk_25.png b/src/hipp/kh9pc/restitution/templates/disk_25.png new file mode 100644 index 0000000..9ae6126 Binary files /dev/null and b/src/hipp/kh9pc/restitution/templates/disk_25.png differ diff --git a/src/hipp/kh9pc/restitution/templates/wagon_wheel.png b/src/hipp/kh9pc/restitution/templates/wagon_wheel.png new file mode 100644 index 0000000..54e46c8 Binary files /dev/null and b/src/hipp/kh9pc/restitution/templates/wagon_wheel.png differ diff --git a/src/hipp/kh9pc/restitution/templates/weird_disk.png b/src/hipp/kh9pc/restitution/templates/weird_disk.png new file mode 100644 index 0000000..c24c15a Binary files /dev/null and b/src/hipp/kh9pc/restitution/templates/weird_disk.png differ diff --git a/src/hipp/kh9pc/restitution/vertical_detector.py b/src/hipp/kh9pc/restitution/vertical_detector.py new file mode 100644 index 0000000..d37e4b8 --- /dev/null +++ b/src/hipp/kh9pc/restitution/vertical_detector.py @@ -0,0 +1,133 @@ +""" +Copyright (c) 2025 HIPP developers +Description: VerticalDetector — detects left/right film frame edges. +""" + +import logging +from dataclasses import dataclass +from pathlib import Path + +import cv2 +import numpy as np +from numpy.typing import NDArray +import rasterio +from rasterio.windows import Window + +from hipp.image import SubImage +from hipp.kh9pc.kh9_image_spec import KH9ImageSpec +from hipp.kh9pc.restitution.base import FittingClass, detect_ruptures + + +logger = logging.getLogger(__name__) + + +@dataclass +class VerticalEdgeResult: + """Detected edge: global position, local rupture index, sub-image, and column-sum profile.""" + + position: int + rupture_local: int + sub_image: SubImage + profile: NDArray[np.floating] + + +@dataclass +class VerticalDetector(FittingClass): + """Detects the left and right film frame edges from a KH-9 PC raster.""" + + background_threshold: int = 20 + vertical_padding: float = 0.25 + search_half_width: int = 5000 + scale: float = 0.1 + + def __post_init__(self) -> None: + """Initialise fitted-attribute slots.""" + super().__init__() + self._results: dict[str, VerticalEdgeResult] = {} + self._failed: bool = False + + @property + def is_failed(self) -> bool: + """True if the last fit() call failed to detect one or both edges.""" + return self._failed + + @property + def left_(self) -> VerticalEdgeResult: + """Detected left edge. Raises if fit() has not been called or failed.""" + if "left" not in self._results: + raise RuntimeError("left edge not available — call fit() first") + return self._results["left"] + + @property + def right_(self) -> VerticalEdgeResult: + """Detected right edge. Raises if fit() has not been called or failed.""" + if "right" not in self._results: + raise RuntimeError("right edge not available — call fit() first") + return self._results["right"] + + @property + def edges_(self) -> tuple[int, int]: + """(left_position, right_position) in full-raster pixel coordinates.""" + return self.left_.position, self.right_.position + + @property + def detected_width_(self) -> int: + """Width between detected edges in full-raster pixels.""" + return self.right_.position - self.left_.position + + def _fit(self, raster_filepath: Path) -> "VerticalDetector": + """Detect left then right edge and populate results.""" + self._failed = False + self._results = {} + image_spec = KH9ImageSpec.from_raster_filepath(raster_filepath) + expected_width = image_spec.expected_size[0] + + with rasterio.open(raster_filepath) as src: + left = self._detect_edge(src, col_off=0, reverse_scan=True, side="left") + if left is None: + self._failed = True + return self + self._results["left"] = left + + right_center = left.position + expected_width + right = self._detect_edge( + src, col_off=right_center - self.search_half_width, reverse_scan=False, side="right" + ) + if right is None: + self._failed = True + return self + self._results["right"] = right + + logger.info( + "VerticalDetector: left=%d, right=%d, detected width=%d (expected=%d, diff=%+d px)", + left.position, + right.position, + self.detected_width_, + expected_width, + self.detected_width_ - expected_width, + ) + return self + + def _detect_edge( + self, src: rasterio.DatasetReader, col_off: int, reverse_scan: bool, side: str + ) -> VerticalEdgeResult | None: + """Detect a single edge in a window; return None and warn if no rupture found.""" + sub = self._sub_image(src, col_off=col_off) + _, binary = cv2.threshold(sub.band, self.background_threshold, 1, cv2.THRESH_BINARY) + profile = np.sum(binary, axis=0) + ruptures = detect_ruptures(profile, 2, reverse_scan=reverse_scan) + if ruptures.size == 0: + logger.warning("VerticalDetector: no %s edge found", side) + return None + r_local = int(ruptures[0]) + position = int(sub.to_global(np.array([r_local, 0.0]))[0]) + return VerticalEdgeResult(position=position, rupture_local=r_local, sub_image=sub, profile=profile) + + def _sub_image(self, src: rasterio.DatasetReader, col_off: int) -> SubImage: + """Read a downsampled window of width 2*search_half_width starting at col_off.""" + padding_px = int(self.vertical_padding * src.height) + col_off = max(0, col_off) + width = min(src.width - col_off, 2 * self.search_half_width) + window = Window(col_off, padding_px, width, src.height - 2 * padding_px) + out_shape = (1, int(window.height * self.scale), int(window.width * self.scale)) + return SubImage(src, window, out_shape) diff --git a/src/hipp/tools.py b/src/hipp/tools.py index 9c94ba5..5e94c55 100644 --- a/src/hipp/tools.py +++ b/src/hipp/tools.py @@ -3,9 +3,13 @@ Description: Generic tools """ +import logging import os import subprocess +import tarfile +import zipfile from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from pathlib import Path from typing import Any import cv2 @@ -15,6 +19,8 @@ from hipp.image import apply_clahe, generate_quickview +logger = logging.getLogger(__name__) + def points_picker( image: cv2.typing.MatLike, point_count: int = 1, clahe_enhancement: bool = True @@ -226,6 +232,67 @@ def optimize_geotif_file(geotif_file: str, overwrite: bool = False) -> None: os.rename(tmp_tif, geotif_file) +def extract_archive(archive_path: str | Path, output_dir: str | Path, overwrite: bool = False) -> list[Path]: + """Extract an archive file to a directory and return all extracted file paths sorted. + + Supported formats: + - zip + - tar, tar.gz / tgz, tar.bz2, tar.xz, tar.zst + - 7z (requires ``py7zr``) + - rar (requires ``rarfile``) + + A sentinel file ``.extracted`` is written inside ``output_dir`` upon successful extraction. + If the sentinel exists and ``overwrite`` is False, extraction is skipped. + """ + archive_path = Path(archive_path) + output_dir = Path(output_dir) + sentinel = output_dir / ".extracted" + + if sentinel.exists() and not overwrite: + logger.info("Skipping extract_archive: %s (already exists, overwrite=False)", str(output_dir)) + return sorted(p for p in output_dir.rglob("*") if p.is_file() and p != sentinel) + + output_dir.mkdir(parents=True, exist_ok=True) + sentinel.unlink(missing_ok=True) + + name = archive_path.name.lower() + + logger.info("Start extracting %s in %s", str(archive_path), str(output_dir)) + if name.endswith(".zip"): + with zipfile.ZipFile(archive_path) as zf: + zf.extractall(output_dir) + + elif tarfile.is_tarfile(archive_path): + with tarfile.open(archive_path) as tf: + tf.extractall(output_dir, filter="data") + + elif name.endswith(".7z"): + try: + import py7zr + except ImportError as e: + raise ImportError("Install 'py7zr' to extract .7z archives: pip install py7zr") from e + with py7zr.SevenZipFile(archive_path, mode="r") as zf: + zf.extractall(output_dir) + + elif name.endswith(".rar"): + try: + import rarfile + except ImportError as e: + raise ImportError("Install 'rarfile' to extract .rar archives: pip install rarfile") from e + with rarfile.RarFile(archive_path) as rf: + rf.extractall(output_dir) + + else: + raise ValueError( + f"Unsupported archive format: '{archive_path.suffix}'. " + "Supported: .zip, .tar, .tar.gz, .tgz, .tar.bz2, .tar.xz, .tar.zst, .7z, .rar" + ) + + sentinel.touch() + logger.info("Extraction of %s finish !", str(archive_path)) + return sorted(p for p in output_dir.rglob("*") if p.is_file() and p != sentinel) + + def generate_quickviews( directory: str, factor: float = 0.2,