refactor(tools): generalize infer.py to accept any registered model#10
Draft
idchacon28 wants to merge 4 commits into
Draft
refactor(tools): generalize infer.py to accept any registered model#10idchacon28 wants to merge 4 commits into
idchacon28 wants to merge 4 commits into
Conversation
…ults
Introduces a single source of truth for per-model deployment defaults
(stitcher, evaluator, model_kwargs, normalization stats, down_ratio,
multi_class flag) that tools like `tools/infer.py` need but should not
hard-code.
Placed under `animaloc/registry/` rather than `animaloc/models/` so the
model classes themselves do not pick up an accidental dependency on
eval components.
Includes entries for the seven registered models in this repo:
HerdNet, OWLC, OWLT, OWLD_S, OWLD_B, OWLD_L, OWLD_H
The `resolve_family(name, *, checkpoint_meta, overrides)` helper
returns the effective config with this resolution order:
family defaults -> checkpoint metadata (mean/std/classes saved by
tools/train.py) -> explicit CLI overrides (with model_kwargs
merged, not replaced).
Notable design choices:
* All OWL families set `pretrained=False` (HerdNet, OWLD_*) and
`pretrained_cnn=False` (OWLT) so inference does not re-fetch
backbone weights -- the checkpoint state_dict supersedes them.
* Normalization defaults to ImageNet stats. Verified against every
OWL training config in this repo (incl. all DINOv3 ViT runs); the
user trains with ImageNet stats throughout.
Smoke-tested:
- All 7 family names resolve.
- resolve_family() with checkpoint_meta + overrides correctly merges
model_kwargs and overrides scalar fields.
- Unknown family name raises KeyError with an actionable message
listing known families and the file to edit.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
`tools/infer.py` was hardcoded to the legacy multi-class HerdNet model:
hardcoded `from animaloc.models import HerdNet`, hardcoded HerdNet
stitcher/evaluator, a non-recoverable `assert num_classes == 7`, and a
hardcoded `classes={1:..6:..}` dict. Single-class OWL models (OWLC,
OWLT, OWLD_S/B/L/H) could not run inference through this script -- the
docs routed users to `tools/test.py` (which needs ground-truth
annotations) as a workaround.
This commit rewrites infer.py to look up the model class, stitcher,
evaluator, default kwargs, and normalization stats from the
`animaloc.registry.families.FAMILIES` table added in the previous
commit. Any registered model that has a family entry is usable.
## New CLI surface (all flags optional except positional root + pth)
--model NAME from FAMILIES keys; default HerdNet (back-compat)
--model-kwarg KEY=VAL override a constructor kwarg (repeatable)
--stitcher NAME override family-default stitcher class
--evaluator NAME override family-default evaluator class
--num-classes N explicit override (HerdNet without metadata only)
--mean R,G,B override normalization mean
--std R,G,B override normalization std
--down-ratio N override down_ratio
--lmds-kernel-size H,W LMDS kernel (default 3,3)
--lmds-adapt-ts FLOAT LMDS adaptive threshold (default 0.2)
--lmds-neg-ts FLOAT LMDS negative threshold (HerdNet family only)
--output-dir PATH default <root>/<date>_<model>_results
-size -over -device -pf -rot --skip-model-inference (unchanged)
## Resolution order
For each setting: FAMILIES[name] defaults -> checkpoint metadata
(`classes`, `mean`, `std`) -> explicit CLI override. `model_kwargs`
is merged rather than replaced so users override one kwarg without
listing every default.
## Behavior changes that matter
* Output dir is configurable (`--output-dir`) and the folder name now
includes the model (e.g. `20260605_OWLC_results`), not the hardcoded
`_HerdNet_results`. Default location is unchanged for HerdNet.
* `assert num_classes == 7` is gone. For HerdNet without `classes`
metadata, the layer-shape probe (`model.cls_head.2.weight`) is kept
as a last-resort fallback. For OWL families, num_classes is not
passed (the constructor doesn't accept it).
* `state_dict` loading is now `strict=False` with explicit
missing/unexpected key warnings. Catches partial-load checkpoints
without crashing immediately, but tells the user what happened.
* `.map(classes) + .dropna()` chain is gone. Detection rows whose
label is unmapped now keep the raw label (as string) in `species`
and emit a single warning listing unmapped labels.
* `pretrained=False` is set in every family's `model_kwargs` so the
constructor never re-fetches DINOv3 or DLA-34 weights at inference
time (the checkpoint's state_dict supersedes them).
## Sanity checks added at startup
* One-shot dummy forward on `torch.zeros(1, 3, size, size)` to detect
model/stitcher shape mismatches early with a clear error instead of
a deep tuple-unpack failure inside LMDS.
* Unwrap `LossWrapper`'s `(output, output_dict)` and ignore `None`
entries in tuple outputs (e.g. OWLD_S returns `(heatmap, None)`)
before counting outputs.
## Smoke validation
* `tools/infer.py /tmp/owl-smoketest/val/ <OWLC_ckpt> --model OWLC
-device cpu` produces a 1856-row detections.csv with columns
`images, labels, dscores, x, y, count_1, species`.
* `tools/infer.py /tmp/owl-smoketest/val/ <OWLD_S_ckpt>
--model OWLD_S -device cpu` produces 2270 rows via the DINOv3
ViT-S/16 frozen backbone. End-to-end DINOv3 + animaloc inference
works.
* `tools/infer.py /tmp/owl-smoketest/val/ <OWLC_ckpt>` (no --model,
defaults to HerdNet against an OWL checkpoint) fails cleanly with
"4 missing key(s) in state_dict" warning + LMDS shape error.
## Deferred to follow-up
* The Evaluator path is still used as a wrapper because it already
implements stitching + LMDS. Ground-truth values are dummy (x=0,
y=0, label=1) and metrics are discarded. A future PR can factor out
a pure inference function that does not go through Evaluator.
* Adding `model_name` / `stitcher_name` / `evaluator_name` to the
checkpoint metadata in `tools/train.py` so `--model` becomes fully
auto-detected. Today the user still has to pass `--model` for
non-HerdNet checkpoints.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The OWL-C training and evaluation smoke runs were already in
tests/README.md. This commit adds the inference smoke run that
exercises the newly generalized tools/infer.py.
`tests/smoke_infer.sh`:
1. Generates the synthetic dataset at /tmp/owl-smoketest/ if missing
2. Runs the OWL-C training smoke if no checkpoint exists under outputs/
3. Runs `tools/infer.py /tmp/owl-smoketest/val/ <ckpt> --model OWLC
-device cpu --output-dir /tmp/owl-smoketest-infer/`
4. Verifies the detections CSV exists, has > 0 rows, and contains
the columns `images, x, y, labels`
Exit code 0 on pass, non-zero on any failure. Runs in ~30 seconds on
CPU when the training smoke has already produced a checkpoint, or
~90 seconds if it has to train one first.
`tests/README.md`:
- Adds step 6 (./tests/smoke_infer.sh) to the smoke-test sequence
- Documents what the inference smoke script does and what it verifies
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Replaces the previous "tools/infer.py runs the original HerdNet model
end-to-end ... For OWL-C / OWL-D / OWL-T inference, use tools/test.py"
note with the real story: tools/infer.py now accepts --model <name>
for any registered model, including the OWL family.
Adds:
- Per-model invocation examples (HerdNet default, OWLC, OWLD_L,
OWLT with --model-kwarg override)
- List of supported --model values
- Pointer to `uv run python tools/infer.py --help` for the full
flag set
- Pointer to `animaloc/registry/families.py` for adding new model
families
Also updates the "Verifying the install (smoke tests)" section to
include `./tests/smoke_infer.sh` as step 5.
Verified `uv run mkdocs build --strict` still succeeds.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
|
@idchacon28 please read the following Contributor License Agreement(CLA). If you agree with the CLA, please reply with the following information.
Contributor License AgreementContribution License AgreementThis Contribution License Agreement (“Agreement”) is agreed to by the party signing below (“You”),
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Generalize
tools/infer.pyto accept any registered modelWhy
tools/infer.pywas hardcoded to the legacy multi-class HerdNet model:from animaloc.models import LossWrapper, HerdNetfrom animaloc.eval import HerdNetStitcher, HerdNetEvaluatorassert num_classes == 7(literal)classes = {1:'class_1', ..., 6:'class_6'}Result: single-class OWL models (OWLC, OWLT, OWLD_S/B/L/H) could not run inference through this script. The docs routed users to
tools/test.py, which requires ground-truth annotations — useless for actual inference on new imagery.What
Two commits do the work, plus tests and docs:
feat(registry): add animaloc/registry/families.pyFAMILIEStable mapping each registered model to its stitcher, evaluator, defaultmodel_kwargs, normalization stats,down_ratio, andmulti_classflag. Plus aresolve_family(name, *, checkpoint_meta, overrides)helper. Lives underanimaloc/registry/so any future tool can import it without dragging eval deps intoanimaloc.models.refactor(tools): generalize infer.py--model HerdNet).test(smoke): add tests/smoke_infer.shtools/infer.py --model OWLC, and verifies the CSV schema.docs(infer): update training.mdNew CLI surface
Full flag set:
--model,--model-kwarg KEY=VAL(repeatable),--stitcher,--evaluator,--num-classes,--mean,--std,--down-ratio,--lmds-kernel-size,--lmds-adapt-ts,--lmds-neg-ts,--output-dir. Plus the existing-size -over -device -pf -rot --skip-model-inference.Resolution order (per setting)
FAMILIES[name]defaultsclasses,mean,stdsaved bytools/train.py)model_kwargsis merged rather than replaced, so users override one kwarg without listing every default.Behavior changes that matter
--output-dir) and the default folder name now mentions the model (20260605_OWLC_results, not the hardcoded_HerdNet_results). HerdNet default location is unchanged.assert num_classes == 7is gone. For HerdNet withoutclassesmetadata, the layer-shape probe (model.cls_head.2.weight) is kept as a last-resort fallback. OWL families don't takenum_classes.state_dictloading is nowstrict=Falsewith explicit missing/unexpected key warnings (catches partial-load checkpoints; tells the user what happened)..map(classes) + .dropna()chain is gone. Detection rows whose label is unmapped now keep the raw label (as string) inspeciesand emit one warning listing unmapped labels.pretrained=Falseis set in every family'smodel_kwargsso the constructor never re-fetches DINOv3 or DLA-34 weights at inference time. The checkpoint'sstate_dictsupersedes them.Sanity checks added at startup
torch.zeros(1, 3, size, size)to detect model/stitcher shape mismatches early. UnwrapsLossWrapper.forward's(output, output_dict)tuple and ignoresNoneentries in tuple outputs (e.g. OWLD_S returns(heatmap, None)).Smoke validation (CPU, Python 3.11)
tools/infer.py val/ <OWLC_ckpt> --model OWLC -device cputools/infer.py val/ <OWLD_S_ckpt> --model OWLD_S -device cputools/infer.py val/ <OWLC_ckpt>(default HerdNet against OWL ckpt)./tests/smoke_infer.shuv run mkdocs build --strictDeferred to follow-up PRs
animaloc/eval/evaluators.py(the whole file is pasted twice; verified the first-half classes are operationally bound, but it's a mess).model_name/stitcher_name/evaluator_nametotools/train.py's checkpoint metadata so--modelcould become fully auto-detected.Evaluatorwith dummy ground truth.Backwards compatibility
tools/infer.py images/ ckpt.pth--model HerdNet.-size 512 -over 160 -device cuda(etc.)_HerdNet_resultsoutput folder for HerdNet_HerdNet_since--modeldefaults to HerdNet).The only intentional break is the
assert num_classes == 7. Any legacy HerdNet checkpoint with a different num_classes will now actually load.Branch
feat/generalize-infer— 4 commits, all conventional-commits style:How to test locally