Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions animaloc/registry/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Deployment-defaults registries for animaloc.

Reusable lookup tables that tell client code (CLI tools, notebooks)
which Stitcher, Evaluator, model_kwargs, normalization stats, etc. to
use for each registered model. The model classes themselves do NOT read
this — it's strictly a deployment / tooling concern. Keeping it out of
animaloc.models prevents accidental coupling between model code and
eval components.

Consumers:
from animaloc.registry.families import FAMILIES, resolve_family
"""

from .families import FAMILIES, ModelFamily, resolve_family

__all__ = ["FAMILIES", "ModelFamily", "resolve_family"]
197 changes: 197 additions & 0 deletions animaloc/registry/families.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Per-model deployment defaults: which Stitcher, Evaluator, model
constructor kwargs, image normalization stats, and downsample ratio to
use for each registered model in `animaloc.models.MODELS`.

This is a *deployment* concern (used by tools/infer.py and any future
prediction tool), NOT a property of the model itself. Models never
import from here.

## Design

A `FAMILIES[name]` entry has the shape of `ModelFamily`:

stitcher: str # name of class in animaloc.eval.stitchers
evaluator: str # name of class in animaloc.eval.evaluators
model_kwargs: dict[str, Any] # constructor kwargs for the model class
down_ratio: int # output stride; threaded into transforms + stitcher
mean: list[float] # image normalization mean (RGB)
std: list[float] # image normalization std (RGB)
multi_class: bool # True if model outputs (heatmap, classmap), False if heatmap-only

## How tools should use it

The `resolve_family(name, *, checkpoint_meta=None, overrides=None)` helper
returns the effective config for a given model name. Resolution order
(later wins): family defaults -> checkpoint metadata -> explicit CLI
overrides.

## Extending

To register a new model family, add an entry to `FAMILIES` here, NOT in
`tools/infer.py`. The model itself only needs to be registered with
`@MODELS.register()` (in its own file under `animaloc.models`).
"""

from __future__ import annotations

import copy
from dataclasses import dataclass, field
from typing import Any, Optional


# Normalization stats used by every config in this repo (HerdNet + all
# OWL variants). DINOv3 backbones happen to use these too in the OWLD_*
# training configs (verified against exp_dpt_vits_proj_r12_frozen.yaml,
# exp_dpt_vith_dinov3_overhead_generalized.yaml, etc.).
_IMAGENET_MEAN = [0.485, 0.456, 0.406]
_IMAGENET_STD = [0.229, 0.224, 0.225]


@dataclass(frozen=True)
class ModelFamily:
"""Deployment defaults for one model family."""

stitcher: str
evaluator: str
model_kwargs: dict[str, Any] = field(default_factory=dict)
down_ratio: int = 2
mean: list[float] = field(default_factory=lambda: list(_IMAGENET_MEAN))
std: list[float] = field(default_factory=lambda: list(_IMAGENET_STD))
multi_class: bool = False

def as_dict(self) -> dict[str, Any]:
return {
"stitcher": self.stitcher,
"evaluator": self.evaluator,
"model_kwargs": copy.deepcopy(self.model_kwargs),
"down_ratio": self.down_ratio,
"mean": list(self.mean),
"std": list(self.std),
"multi_class": self.multi_class,
}


FAMILIES: dict[str, ModelFamily] = {
# Legacy HerdNet -- multi-class, outputs (heatmap, classmap).
"HerdNet": ModelFamily(
stitcher="HerdNetStitcher",
evaluator="HerdNetEvaluator",
model_kwargs=dict(
num_layers=34,
pretrained=False, # inference loads from the .pth checkpoint
down_ratio=2,
head_conv=64,
),
down_ratio=2,
multi_class=True,
),
# OWL-C: HerdNet detection branch, single-class FIDT heatmap, DLA-34.
"OWLC": ModelFamily(
stitcher="HerdNet_Detection_Branch_Stitcher",
evaluator="HerdNet_Detection_Branch_Evaluator",
model_kwargs=dict(
num_layers=34,
pretrained=False,
down_ratio=2,
head_conv=64,
),
down_ratio=2,
multi_class=False,
),
# OWL-T: DLA-34 + Swin multiscale residual. Note kwarg `pretrained_cnn`,
# not `pretrained`, on the DLA base.
"OWLT": ModelFamily(
stitcher="HerdNet_Detection_Branch_Stitcher",
evaluator="HerdNet_Detection_Branch_Evaluator",
model_kwargs=dict(
num_layers=34,
pretrained_cnn=False,
down_ratio=2,
head_conv=64,
),
down_ratio=2,
multi_class=False,
),
}

# OWL-D family: DINOv3 ViT (S/B/L/H) + DPT decoder. All four variants
# share the same stitcher / evaluator / kwargs (the variant is selected
# by the class name itself). pretrained=False to make sure the
# constructor does not try to fetch DINOv3 hub weights at inference --
# the checkpoint's state_dict supersedes them anyway.
_OWLD_DEFAULT_KWARGS = dict(down_ratio=2, freeze_backbone=True, pretrained=False)

for _owld_name in ("OWLD_S", "OWLD_B", "OWLD_L", "OWLD_H"):
FAMILIES[_owld_name] = ModelFamily(
stitcher="HerdNet_Detection_Branch_Stitcher",
evaluator="HerdNet_Detection_Branch_Evaluator",
model_kwargs=dict(_OWLD_DEFAULT_KWARGS),
down_ratio=2,
multi_class=False,
)


def resolve_family(
name: str,
*,
checkpoint_meta: Optional[dict[str, Any]] = None,
overrides: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""Return the effective deployment config for the named model.

Resolution order (later wins):
1. `FAMILIES[name]` defaults
2. Values pulled from the checkpoint metadata (`mean`, `std`,
`classes`, anything else stored by `tools/train.py`)
3. Explicit CLI overrides

Args:
name: Registered model name (must be a key of `FAMILIES`).
checkpoint_meta: Optional dict pulled from `torch.load(pth_file)`.
Recognized keys: `mean`, `std`, `classes` (passes through),
and any other key that matches a `ModelFamily` field.
overrides: Optional dict of CLI-driven overrides. Same recognized
keys as `checkpoint_meta`, plus `model_kwargs` (merged into
family defaults, not replaced).

Returns:
Plain dict with the resolved config. Always has the keys:
`stitcher`, `evaluator`, `model_kwargs`, `down_ratio`, `mean`,
`std`, `multi_class`. Plus passthrough keys like `classes` when
present in metadata.

Raises:
KeyError: if `name` is not in `FAMILIES`. The caller should
catch this and report the available families to the user.
"""
if name not in FAMILIES:
raise KeyError(
f"Unknown model family {name!r}. Known families: {sorted(FAMILIES.keys())}. "
"Add an entry to animaloc/registry/families.py for new model classes."
)

resolved = FAMILIES[name].as_dict()

# Pull supported keys from checkpoint metadata (mean, std, classes,
# plus any direct field overrides).
if checkpoint_meta:
for key in ("mean", "std", "down_ratio"):
if key in checkpoint_meta and checkpoint_meta[key] is not None:
resolved[key] = checkpoint_meta[key]
if "classes" in checkpoint_meta:
resolved["classes"] = checkpoint_meta["classes"]

# CLI overrides. `model_kwargs` is MERGED (not replaced) so users
# can override one kwarg without listing every default.
if overrides:
for key, value in overrides.items():
if value is None:
continue
if key == "model_kwargs" and isinstance(value, dict):
merged = dict(resolved["model_kwargs"])
merged.update(value)
resolved["model_kwargs"] = merged
else:
resolved[key] = value

return resolved
39 changes: 34 additions & 5 deletions docs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,41 @@ and reports F1 / precision / recall / MAE / RMSE per class.

## Inference on new imagery

`tools/infer.py` runs the original `HerdNet` model end-to-end:
`tools/infer.py` runs any registered model on a folder of images and
writes detections to a CSV. The model class, stitcher, evaluator, and
default constructor kwargs are looked up from
`animaloc/registry/families.py` based on `--model`.

```bash
uv run python tools/infer.py <images_dir> <model.pth>
# Legacy HerdNet (default, backwards-compatible)
uv run python tools/infer.py <images_dir> <herdnet.pth>

# OWL-C
uv run python tools/infer.py <images_dir> <owlc.pth> --model OWLC

# OWL-D-L on CPU, writing results outside the input dir
uv run python tools/infer.py <images_dir> <owld_l.pth> \
--model OWLD_L -device cpu \
--output-dir /tmp/owl_l_results

# Override a constructor kwarg
uv run python tools/infer.py <images_dir> <owlt.pth> --model OWLT \
--model-kwarg down_ratio=4
```

Outputs land in `<images_dir>/<date>_HerdNet_results/<date>_detections.csv`.
For OWL-C / OWL-D / OWL-T inference, use `tools/test.py` with the
corresponding test config.
Supported `--model` values: `HerdNet`, `OWLC`, `OWLT`, `OWLD_S`,
`OWLD_B`, `OWLD_L`, `OWLD_H`. Run `uv run python tools/infer.py --help`
for the full flag set (including `--stitcher`, `--evaluator`,
`--num-classes`, `--mean`, `--std`, `--down-ratio`, `--lmds-*`,
`--output-dir`).

Outputs land in `<output_dir>/<date>_detections.csv` with columns
`images, x, y, labels, scores/dscores, ...`. A `species` column is added
when the checkpoint stores a `classes` mapping (saved automatically by
`tools/train.py`).

To register a new model family with `infer.py`, add an entry to
`animaloc/registry/families.py` rather than editing `tools/infer.py`.

## Tiling large images

Expand Down Expand Up @@ -171,6 +197,9 @@ WANDB_MODE=disabled uv run python tools/train.py train=owlc_smoketest
CKPT=$(ls -t outputs/*/*/best_model.pth | head -1 | xargs realpath)
WANDB_MODE=disabled uv run python tools/test.py test=owlc_smoketest \
"++test.model.pth_file=$CKPT"

# 5. Run inference on the same checkpoint
./tests/smoke_infer.sh
```

OWL-D variants additionally need DINOv3 weights under `weights/`; if
Expand Down
18 changes: 18 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ WANDB_MODE=disabled uv run python tools/train.py train=owld_s_smoketest
CKPT=$(ls -t outputs/*/*/best_model.pth | head -1 | xargs realpath)
WANDB_MODE=disabled uv run python tools/test.py test=owld_s_smoketest \
"++test.model.pth_file=$CKPT"

# 6. Inference smoke (auto-runs steps 2 and 3 if needed)
./tests/smoke_infer.sh
```

Expected runtime on CPU: ~1 min for forward-pass + dataset, ~30 s for
Expand Down Expand Up @@ -83,3 +86,18 @@ Training complete | Best f1_score: ... at epoch 1
The evaluation smoke run writes `metrics_results.csv`,
`confusion_matrix.csv`, `detections.csv`, and `plots/precision_recall_curve.png`
under `outputs/<date>/<time>/`.

## Inference smoke (`tests/smoke_infer.sh`)

Runs `tools/infer.py --model OWLC` against the synthetic val/ split,
using whichever OWL-C checkpoint is most recent under `outputs/`. If no
checkpoint exists, it runs the OWL-C training smoke first. Verifies
that:

- The detections CSV is created at `/tmp/owl-smoketest-infer/<date>_detections.csv`
- The CSV has > 0 rows
- The CSV header contains `images`, `x`, `y`, `labels` columns

Exit code is 0 on pass, non-zero on any failure. The script is bash, not
Python — it composes the existing Python entry points and is meant to
be the one-shot end-to-end "does inference work" check.
76 changes: 76 additions & 0 deletions tests/smoke_infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env bash
# Smoke test for tools/infer.py.
#
# 1. Ensures /tmp/owl-smoketest/ has the synthetic dataset (regenerates if not).
# 2. Ensures an OWL-C checkpoint is available (runs the OWL-C training smoke
# if none is found under outputs/).
# 3. Runs tools/infer.py against the val/ split with --model OWLC.
# 4. Verifies the detections CSV exists, has > 0 rows, and contains the
# expected schema columns.
#
# Designed to run on CPU, no GPU required, ~30 seconds end-to-end after
# the training smoke has already produced a checkpoint.
#
# Exit codes:
# 0 = smoke pass
# non-zero = something is broken; the failing step prints details

set -euo pipefail

REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
cd "$REPO_ROOT"

SMOKETEST_DATA="/tmp/owl-smoketest"
OUTPUT_DIR="/tmp/owl-smoketest-infer"

echo "==> Step 1/4: ensure synthetic dataset at ${SMOKETEST_DATA}"
if [[ ! -d "${SMOKETEST_DATA}/val" ]]; then
uv run python tests/make_synthetic_dataset.py
fi

echo "==> Step 2/4: ensure an OWL-C checkpoint exists"
CKPT="$(ls -t outputs/*/*/best_model.pth 2>/dev/null | head -1 || true)"
if [[ -z "${CKPT}" ]]; then
echo " no checkpoint found; running owlc_smoketest training (one epoch on CPU)"
WANDB_MODE=disabled uv run python tools/train.py train=owlc_smoketest
CKPT="$(ls -t outputs/*/*/best_model.pth | head -1)"
fi
CKPT="$(realpath "${CKPT}")"
echo " using checkpoint: ${CKPT}"

echo "==> Step 3/4: run tools/infer.py --model OWLC"
mkdir -p "${OUTPUT_DIR}"
WANDB_MODE=disabled uv run python tools/infer.py \
"${SMOKETEST_DATA}/val/" \
"${CKPT}" \
--model OWLC \
-device cpu \
--output-dir "${OUTPUT_DIR}"

echo "==> Step 4/4: verify output"
CSV="$(ls -t "${OUTPUT_DIR}"/*_detections.csv | head -1)"
if [[ -z "${CSV}" ]]; then
echo " FAIL: no detections CSV under ${OUTPUT_DIR}"
exit 1
fi
ROW_COUNT="$(($(wc -l < "${CSV}") - 1))"
echo " found: ${CSV} (${ROW_COUNT} detections)"

if [[ "${ROW_COUNT}" -le 0 ]]; then
echo " FAIL: empty detections CSV"
exit 1
fi

# The exact column set depends on the family (HerdNet vs Detection_Branch);
# at minimum, every detection row must have images, x, y, labels.
HEADER="$(head -1 "${CSV}")"
for col in images x y labels; do
if [[ "${HEADER}" != *"${col}"* ]]; then
echo " FAIL: column ${col!r} missing from CSV header: ${HEADER}"
exit 1
fi
done

echo
echo "OK: tools/infer.py smoke test passed (${ROW_COUNT} detections, schema OK)"
echo " CSV: ${CSV}"
Loading