diff --git a/tests/models/conftest.py b/tests/models/conftest.py new file mode 100644 index 00000000..2c94dade --- /dev/null +++ b/tests/models/conftest.py @@ -0,0 +1,137 @@ +# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com) +# +# MIT License +"""Shared configuration and helpers for model correctness tests.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any + +BASE_DIR = Path(__file__).parent.parent.parent +DATA_DIR = BASE_DIR / "data" + +# --------------------------------------------------------------------------- +# Ground-truth expectations +# --------------------------------------------------------------------------- +DETECTOR_EXPECTED: list[dict[str, Any]] = [ + { + "image": "data/horse.jpg", + "expected_classes": [17], # COCO horse + "min_detections": 1, + "conf_thres": 0.3, + }, +] + +CLASSIFIER_EXPECTED: list[dict[str, Any]] = [ + { + "image": "data/horse.jpg", + "expected_top_k_classes": [339], # ImageNet "sorrel" + "top_k": 5, + }, +] + +# --------------------------------------------------------------------------- +# Model configurations +# --------------------------------------------------------------------------- +# Each entry: (model_class_name, model_name_for_download, imgsz_or_None) +# Class names are strings resolved at runtime to avoid import-time dependency +# on TensorRT. +DETECTOR_MODELS: dict[str, tuple[str, str, int | None]] = { + "yolov10": ("YOLOv10", "yolov10n", None), + "yolov8": ("YOLOv8", "yolov8n", None), + "yolov11": ("YOLOv11", "yolov11n", None), + "rtdetrv1": ("RTDETRv1", "rtdetrv1_r18", None), + "rtdetrv3": ("RTDETRv3", "rtdetrv3_r18", None), + "dfine": ("DFINE", "dfine_n", None), + "rfdetr": ("RFDETR", "rfdetr_n", 384), +} + +CLASSIFIER_MODELS: dict[str, tuple[str, str, int | None]] = { + "resnet18": ("ResNet", "resnet18", None), + "efficientnet_b0": ("EfficientNet", "efficientnet_b0", None), + "mobilenet_v3_small": ("MobileNetV3", "mobilenet_v3_small", None), +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _resolve_model_class(class_name: str) -> type: + """Resolve a model class name to the actual class from trtutils.models.""" + import trtutils.models as models + + cls = getattr(models, class_name, None) + if cls is None: + msg = f"Unknown model class: {class_name}" + raise ValueError(msg) + return cls + + +def build_model_engine( + model_class_name: str, + model_name: str, + imgsz: int | None, + cache_dir: Path | None = None, +) -> Path: + """ + Download ONNX (if needed) and build a TRT engine (if needed). + + Engines are cached under ``data/engines//``. + + Parameters + ---------- + model_class_name : str + Name of the model class in ``trtutils.models``. + model_name : str + The model variant to download (e.g. ``"yolov10n"``). + imgsz : int | None + Image size override; ``None`` uses the class default. + cache_dir : Path | None + Override for the cache root. Defaults to ``data/engines/``. + + Returns + ------- + Path + Path to the compiled TensorRT engine. + + """ + model_class = _resolve_model_class(model_class_name) + + if cache_dir is None: + cache_dir = DATA_DIR / "engines" + + model_dir = cache_dir / model_name + model_dir.mkdir(parents=True, exist_ok=True) + + onnx_dir = DATA_DIR / model_name + onnx_dir.mkdir(parents=True, exist_ok=True) + onnx_path = onnx_dir / f"{model_name}.onnx" + engine_path = model_dir / f"{model_name}.engine" + + # Return early if engine already exists + if engine_path.exists(): + return engine_path + + # Download ONNX if it doesn't exist + if not onnx_path.exists(): + download_kwargs: dict[str, Any] = { + "model": model_name, + "output": onnx_path, + } + if imgsz is not None: + download_kwargs["imgsz"] = imgsz + model_class.download(**download_kwargs) + + # Build engine + build_kwargs: dict[str, Any] = { + "onnx": onnx_path, + "output": engine_path, + "opt_level": 1, + "verbose": False, + } + if imgsz is not None: + build_kwargs["imgsz"] = imgsz + model_class.build(**build_kwargs) + + return engine_path diff --git a/tests/models/test_classifier_correctness.py b/tests/models/test_classifier_correctness.py new file mode 100644 index 00000000..ed2e594a --- /dev/null +++ b/tests/models/test_classifier_correctness.py @@ -0,0 +1,412 @@ +# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com) +# +# MIT License +"""Classifier output correctness tests (GPU required).""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +from .conftest import ( + CLASSIFIER_EXPECTED, + CLASSIFIER_MODELS, + _resolve_model_class, + build_model_engine, +) + +BASE_DIR = Path(__file__).parent.parent.parent +DATA_DIR = BASE_DIR / "data" +SIMPLE_ONNX = DATA_DIR / "simple.onnx" + + +# --------------------------------------------------------------------------- +# Helpers -- build a minimal classifier-shaped engine +# --------------------------------------------------------------------------- +def _find_classifier_onnx() -> Path: + """Find an available classifier ONNX model in the data directory.""" + # simple.onnx is always available for basic tests + if SIMPLE_ONNX.exists(): + return SIMPLE_ONNX + pytest.skip("No classifier ONNX model available") + # unreachable but keeps type checkers happy + return SIMPLE_ONNX + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(scope="module") +def classifier_engine(build_test_engine) -> Path: + """Build and cache a classifier engine for the test module.""" + onnx_path = _find_classifier_onnx() + return build_test_engine(onnx_path) + + +@pytest.fixture(scope="module") +def classifier_engine_batch(build_test_engine) -> Path: + """Build a batch-capable classifier engine for batch tests.""" + onnx_path = _find_classifier_onnx() + try: + return build_test_engine(onnx_path, batch_size=2) + except Exception as exc: + pytest.skip(f"Batch classifier engine unavailable: {exc}") + + +# --------------------------------------------------------------------------- +# Classifier base class tests (using Classifier directly) +# --------------------------------------------------------------------------- +class TestClassifierOutput: + """Tests for Classifier output format using a simple engine.""" + + def test_classifier_run_returns_list( + self, + classifier_engine, + images, + ) -> None: + """Classifier.run() should return a list of ndarrays.""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + outputs = cls.run(horse_image) + assert isinstance(outputs, list) + for arr in outputs: + assert isinstance(arr, np.ndarray) + + def test_classifier_run_raw( + self, + classifier_engine, + images, + ) -> None: + """run(postprocess=False) returns raw output ndarrays.""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + outputs = cls.run(horse_image, postprocess=False) + assert isinstance(outputs, list) + for arr in outputs: + assert isinstance(arr, np.ndarray) + + def test_classifier_postprocess_format( + self, + classifier_engine, + images, + ) -> None: + """postprocess() should return list[list[ndarray]].""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + raw = cls.run(horse_image, postprocess=False) + postprocessed = cls.postprocess(raw) + assert isinstance(postprocessed, list) + for per_image in postprocessed: + assert isinstance(per_image, list) + for arr in per_image: + assert isinstance(arr, np.ndarray) + + def test_get_classifications_format( + self, + classifier_engine, + images, + ) -> None: + """get_classifications returns list[tuple[int, float]].""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + outputs = cls.run(horse_image) + classifications = cls.get_classifications(outputs, top_k=5) + assert isinstance(classifications, list) + assert len(classifications) <= 5 + for entry in classifications: + assert isinstance(entry, tuple) + assert len(entry) == 2 + cls_id, score = entry + assert isinstance(cls_id, (int, np.integer)) + assert isinstance(score, (float, np.floating)) + + def test_get_classifications_top_k( + self, + classifier_engine, + images, + ) -> None: + """top_k parameter should limit number of results.""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + outputs = cls.run(horse_image) + top1 = cls.get_classifications(outputs, top_k=1) + top3 = cls.get_classifications(outputs, top_k=3) + assert len(top1) <= 1 + assert len(top3) <= 3 + + def test_get_classifications_scores_descending( + self, + classifier_engine, + images, + ) -> None: + """Classification scores should be in descending order.""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + outputs = cls.run(horse_image) + classifications = cls.get_classifications(outputs, top_k=10) + if len(classifications) > 1: + scores = [float(s) for _, s in classifications] + for i in range(len(scores) - 1): + assert scores[i] >= scores[i + 1] - 1e-6 + + def test_callable_matches_run( + self, + classifier_engine, + images, + ) -> None: + """__call__ should produce the same result as run().""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + out_run = cls.run(horse_image) + out_call = cls(horse_image) + assert len(out_run) == len(out_call) + + +# --------------------------------------------------------------------------- +# Batch tests +# --------------------------------------------------------------------------- +class TestClassifierBatch: + """Batch inference tests.""" + + def test_batch_run_returns_nested( + self, + classifier_engine_batch, + random_images, + ) -> None: + """Batch run with postprocess returns list[list[ndarray]].""" + from trtutils.image import Classifier + + cls = Classifier(classifier_engine_batch, warmup=False) + imgs = random_images(count=2, height=480, width=640) + outputs = cls.run(imgs) + assert isinstance(outputs, list) + assert len(outputs) == 2 + for per_image in outputs: + assert isinstance(per_image, list) + + def test_batch_get_classifications( + self, + classifier_engine_batch, + random_images, + ) -> None: + """Batch get_classifications returns list[list[tuple]].""" + from trtutils.image import Classifier + + cls = Classifier(classifier_engine_batch, warmup=False) + imgs = random_images(count=2, height=480, width=640) + outputs = cls.run(imgs) + results = cls.get_classifications(outputs, top_k=3) + assert isinstance(results, list) + assert len(results) == 2 + for per_image_cls in results: + assert isinstance(per_image_cls, list) + + +# --------------------------------------------------------------------------- +# End-to-end tests +# --------------------------------------------------------------------------- +class TestClassifierEnd2End: + """End-to-end inference tests.""" + + def test_end2end_single_image( + self, + classifier_engine, + images, + ) -> None: + """end2end() on single image returns list[tuple[int, float]].""" + horse_image = images["horse"].array + from trtutils.image import Classifier + + cls = Classifier(classifier_engine, warmup=False) + result = cls.end2end(horse_image, top_k=5) + assert isinstance(result, list) + for entry in result: + assert isinstance(entry, tuple) + assert len(entry) == 2 + + +# --------------------------------------------------------------------------- +# Data-driven multi-model classifier correctness tests +# --------------------------------------------------------------------------- +_CLASSIFIER_MODEL_IDS = list(CLASSIFIER_MODELS.keys()) + + +def _make_cls_expected_ids() -> list[str]: + """Create human-readable IDs for CLASSIFIER_EXPECTED entries.""" + return [Path(e["image"]).stem for e in CLASSIFIER_EXPECTED] + + +_CLS_EXPECTED_IDS = _make_cls_expected_ids() + + +@pytest.mark.correctness +@pytest.mark.download +class TestClassifierImageCorrectness: + """Data-driven correctness: classifiers must find expected classes.""" + + @pytest.mark.parametrize( + "expected_entry", + CLASSIFIER_EXPECTED, + ids=_CLS_EXPECTED_IDS, + ) + @pytest.mark.parametrize("model_id", _CLASSIFIER_MODEL_IDS) + def test_image_correctness( + self, + expected_entry: dict, + model_id: str, + ) -> None: + """At least one expected class appears in top-k predictions.""" + import cv2 + + cls_name, model_name, imgsz = CLASSIFIER_MODELS[model_id] + engine_path = build_model_engine(cls_name, model_name, imgsz) + + model_class = _resolve_model_class(cls_name) + classifier = model_class( + engine_path, + warmup=False, + no_warn=True, + ) + + image_path = str(BASE_DIR / expected_entry["image"]) + image = cv2.imread(image_path) + if image is None: + pytest.skip(f"Image not found: {image_path}") + + top_k = expected_entry["top_k"] + predictions = classifier.end2end(image, top_k=top_k) + + predicted_classes = [int(cls_id) for cls_id, _score in predictions] + has_match = any(c in expected_entry["expected_top_k_classes"] for c in predicted_classes) + assert has_match, ( + f"{model_id}: none of {predicted_classes} match " + f"expected {expected_entry['expected_top_k_classes']}" + ) + + del classifier + + +@pytest.mark.correctness +@pytest.mark.download +class TestClassifierOutputFormat: + """Validate classifier output format across all models.""" + + @pytest.mark.parametrize( + "expected_entry", + CLASSIFIER_EXPECTED, + ids=_CLS_EXPECTED_IDS, + ) + @pytest.mark.parametrize("model_id", _CLASSIFIER_MODEL_IDS) + def test_output_format( + self, + expected_entry: dict, + model_id: str, + ) -> None: + """Output is list[tuple[int, float]], scores in [0,1], descending.""" + import cv2 + + cls_name, model_name, imgsz = CLASSIFIER_MODELS[model_id] + engine_path = build_model_engine(cls_name, model_name, imgsz) + + model_class = _resolve_model_class(cls_name) + classifier = model_class( + engine_path, + warmup=False, + no_warn=True, + ) + + image_path = str(BASE_DIR / expected_entry["image"]) + image = cv2.imread(image_path) + if image is None: + pytest.skip(f"Image not found: {image_path}") + + top_k = expected_entry["top_k"] + predictions = classifier.end2end(image, top_k=top_k) + + # Must be a list + assert isinstance(predictions, list) + + # Each entry is (int, float) + for entry in predictions: + assert isinstance(entry, tuple), f"Expected tuple, got {type(entry)}" + assert len(entry) == 2 + cls_id, score = entry + assert isinstance(cls_id, (int, np.integer)) + assert isinstance(score, (float, np.floating)) + assert 0.0 <= float(score) <= 1.0, f"{model_id}: score {score} out of [0, 1]" + + # Scores in descending order + if len(predictions) > 1: + scores = [float(s) for _, s in predictions] + for i in range(len(scores) - 1): + assert scores[i] >= scores[i + 1] - 1e-6, ( + f"{model_id}: scores not descending at index {i}: {scores[i]} < {scores[i + 1]}" + ) + + del classifier + + +@pytest.mark.correctness +@pytest.mark.download +class TestClassifierPreprocessorConsistency: + """Verify cpu/cuda/trt preprocessors agree on top-1 class.""" + + @pytest.mark.parametrize( + "expected_entry", + CLASSIFIER_EXPECTED, + ids=_CLS_EXPECTED_IDS, + ) + @pytest.mark.parametrize("model_id", _CLASSIFIER_MODEL_IDS) + def test_preprocessor_consistency( + self, + expected_entry: dict, + model_id: str, + ) -> None: + """All preprocessors produce the same top-1 class.""" + import cv2 + + cls_name, model_name, imgsz = CLASSIFIER_MODELS[model_id] + engine_path = build_model_engine(cls_name, model_name, imgsz) + + model_class = _resolve_model_class(cls_name) + + image_path = str(BASE_DIR / expected_entry["image"]) + image = cv2.imread(image_path) + if image is None: + pytest.skip(f"Image not found: {image_path}") + + preprocessors = ["cpu", "cuda", "trt"] + top1_classes: dict[str, int] = {} + + for preproc in preprocessors: + classifier = model_class( + engine_path, + preprocessor=preproc, + warmup=False, + no_warn=True, + ) + predictions = classifier.end2end(image, top_k=1) + if predictions: + top1_classes[preproc] = int(predictions[0][0]) + del classifier + + unique_classes = set(top1_classes.values()) + assert len(unique_classes) == 1, ( + f"{model_id}: preprocessors disagree on top-1: {top1_classes}" + ) diff --git a/tests/models/test_depth_estimator_correctness.py b/tests/models/test_depth_estimator_correctness.py new file mode 100644 index 00000000..5248b003 --- /dev/null +++ b/tests/models/test_depth_estimator_correctness.py @@ -0,0 +1,225 @@ +# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com) +# +# MIT License +"""DepthEstimator output correctness tests (GPU required).""" + +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import pytest + +BASE_DIR = Path(__file__).parent.parent.parent +DATA_DIR = BASE_DIR / "data" +SIMPLE_ONNX = DATA_DIR / "simple.onnx" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _find_depth_onnx() -> Path: + """Find an available model ONNX for depth estimator testing.""" + # Use simple.onnx as a fallback -- the postprocessor will still run, + # though output semantics differ from a real depth model. + if SIMPLE_ONNX.exists(): + return SIMPLE_ONNX + pytest.skip("No ONNX model available for depth estimator test") + return SIMPLE_ONNX + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(scope="module") +def depth_engine(build_test_engine) -> Path: + """Build and cache a depth estimator engine for the test module.""" + onnx_path = _find_depth_onnx() + return build_test_engine(onnx_path) + + +@pytest.fixture(scope="module") +def depth_engine_batch(build_test_engine) -> Path: + """Build a batch-capable depth engine for batch tests.""" + onnx_path = _find_depth_onnx() + try: + return build_test_engine(onnx_path, batch_size=2) + except Exception as exc: + pytest.skip(f"Batch depth engine unavailable: {exc}") + + +# --------------------------------------------------------------------------- +# DepthEstimator base class tests +# --------------------------------------------------------------------------- +class TestDepthEstimatorOutput: + """Tests for DepthEstimator output format.""" + + def test_run_returns_list( + self, + depth_engine, + images, + ) -> None: + """DepthEstimator.run() should return a list of ndarrays.""" + horse_image = images["horse"].array + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine, warmup=False) + outputs = de.run(horse_image) + assert isinstance(outputs, list) + for arr in outputs: + assert isinstance(arr, np.ndarray) + + def test_run_raw_returns_ndarrays( + self, + depth_engine, + images, + ) -> None: + """run(postprocess=False) returns raw output ndarrays.""" + horse_image = images["horse"].array + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine, warmup=False) + outputs = de.run(horse_image, postprocess=False) + assert isinstance(outputs, list) + for arr in outputs: + assert isinstance(arr, np.ndarray) + + def test_postprocess_format( + self, + depth_engine, + images, + ) -> None: + """postprocess() should return list[list[ndarray]].""" + horse_image = images["horse"].array + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine, warmup=False) + raw = de.run(horse_image, postprocess=False) + postprocessed = de.postprocess(raw) + assert isinstance(postprocessed, list) + for per_image in postprocessed: + assert isinstance(per_image, list) + for arr in per_image: + assert isinstance(arr, np.ndarray) + + def test_get_depth_maps_single_image( + self, + depth_engine, + images, + ) -> None: + """get_depth_maps for single image returns an ndarray.""" + horse_image = images["horse"].array + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine, warmup=False) + outputs = de.run(horse_image) + depth_map = de.get_depth_maps(outputs) + assert isinstance(depth_map, np.ndarray) + # Depth maps are typically (1, H, W) or (H, W) + assert depth_map.ndim >= 2 + + def test_depth_map_values_finite( + self, + depth_engine, + images, + ) -> None: + """Depth map values should be finite (no NaN/Inf).""" + horse_image = images["horse"].array + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine, warmup=False) + outputs = de.run(horse_image) + depth_map = de.get_depth_maps(outputs) + assert np.all(np.isfinite(depth_map)) + + def test_callable_matches_run( + self, + depth_engine, + images, + ) -> None: + """__call__ should produce the same result as run().""" + horse_image = images["horse"].array + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine, warmup=False) + out_run = de.run(horse_image) + out_call = de(horse_image) + assert len(out_run) == len(out_call) + + +# --------------------------------------------------------------------------- +# Batch tests +# --------------------------------------------------------------------------- +class TestDepthEstimatorBatch: + """Batch inference tests.""" + + def test_batch_run_returns_nested( + self, + depth_engine_batch, + random_images, + ) -> None: + """Batch run with postprocess returns list[list[ndarray]].""" + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine_batch, warmup=False) + imgs = random_images(count=2, height=480, width=640) + outputs = de.run(imgs) + assert isinstance(outputs, list) + assert len(outputs) == 2 + for per_image in outputs: + assert isinstance(per_image, list) + + def test_batch_get_depth_maps( + self, + depth_engine_batch, + random_images, + ) -> None: + """Batch get_depth_maps returns list[ndarray].""" + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine_batch, warmup=False) + imgs = random_images(count=2, height=480, width=640) + outputs = de.run(imgs) + depth_maps = de.get_depth_maps(outputs) + assert isinstance(depth_maps, list) + assert len(depth_maps) == 2 + for dm in depth_maps: + assert isinstance(dm, np.ndarray) + assert dm.ndim >= 2 + + +# --------------------------------------------------------------------------- +# End-to-end tests +# --------------------------------------------------------------------------- +class TestDepthEstimatorEnd2End: + """End-to-end inference tests.""" + + def test_end2end_single_image( + self, + depth_engine, + images, + ) -> None: + """end2end() on single image returns ndarray depth map.""" + horse_image = images["horse"].array + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine, warmup=False) + result = de.end2end(horse_image) + assert isinstance(result, np.ndarray) + assert result.ndim >= 2 + + def test_end2end_batch( + self, + depth_engine_batch, + random_images, + ) -> None: + """end2end() on batch returns list[ndarray].""" + from trtutils.image import DepthEstimator + + de = DepthEstimator(depth_engine_batch, warmup=False) + imgs = random_images(count=2, height=480, width=640) + result = de.end2end(imgs) + assert isinstance(result, list) + assert len(result) == 2 + for dm in result: + assert isinstance(dm, np.ndarray) diff --git a/tests/models/test_detector_correctness.py b/tests/models/test_detector_correctness.py new file mode 100644 index 00000000..c4c4ef06 --- /dev/null +++ b/tests/models/test_detector_correctness.py @@ -0,0 +1,432 @@ +# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com) +# +# MIT License +"""Detector output correctness tests (GPU required).""" + +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +from .conftest import ( + DETECTOR_EXPECTED, + DETECTOR_MODELS, + _resolve_model_class, + build_model_engine, +) + +if TYPE_CHECKING: + from trtutils.models import YOLOv10 + +BASE_DIR = Path(__file__).parent.parent.parent +DATA_DIR = BASE_DIR / "data" +YOLOV10_ONNX = DATA_DIR / "yolov10" / "yolov10n_640.onnx" + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +@pytest.fixture(scope="module") +def yolov10_engine(build_test_engine) -> Path: + """Build and cache a YOLOv10n engine for the test module.""" + if not YOLOV10_ONNX.exists(): + pytest.skip("yolov10n_640.onnx not available") + return build_test_engine(YOLOV10_ONNX) + + +@pytest.fixture(scope="module") +def yolov10_engine_batch(build_test_engine) -> Path: + """Build and cache a batch-capable YOLOv10n engine.""" + if not YOLOV10_ONNX.exists(): + pytest.skip("yolov10n_640.onnx not available") + try: + return build_test_engine(YOLOV10_ONNX, batch_size=2) + except Exception as exc: + pytest.skip(f"Batch YOLOv10 engine unavailable: {exc}") + + +@pytest.fixture(scope="module") +def yolov10_detector(yolov10_engine) -> YOLOv10: + """Instantiate a YOLOv10 detector for the module.""" + from trtutils.compat._libs import cudart + from trtutils.models import YOLOv10 + + if not hasattr(cudart, "cudaStreamCreate"): + pytest.skip("No CUDA runtime available") + return YOLOv10(yolov10_engine, warmup=False) + + +@pytest.fixture(scope="module") +def yolov10_detector_batch(yolov10_engine_batch) -> YOLOv10: + """Instantiate a batch-capable YOLOv10 detector for batched tests.""" + from trtutils.compat._libs import cudart + from trtutils.models import YOLOv10 + + if not hasattr(cudart, "cudaStreamCreate"): + pytest.skip("No CUDA runtime available") + return YOLOv10(yolov10_engine_batch, warmup=False) + + +# --------------------------------------------------------------------------- +# Single-image tests +# --------------------------------------------------------------------------- +class TestDetectorSingleImage: + """Correctness tests for detector with a single image.""" + + def test_run_returns_list( + self, + yolov10_detector, + images, + ) -> None: + """run() with postprocess should return a list of ndarrays.""" + horse_image = images["horse"].array + outputs = yolov10_detector.run(horse_image) + assert isinstance(outputs, list) + for arr in outputs: + assert isinstance(arr, np.ndarray) + + def test_run_raw_returns_list_of_ndarrays( + self, + yolov10_detector, + images, + ) -> None: + """run() with postprocess=False returns raw output tensors.""" + horse_image = images["horse"].array + outputs = yolov10_detector.run( + horse_image, + postprocess=False, + ) + assert isinstance(outputs, list) + for arr in outputs: + assert isinstance(arr, np.ndarray) + + def test_get_detections_format( + self, + yolov10_detector, + images, + ) -> None: + """get_detections output should be list of (bbox, score, cls).""" + horse_image = images["horse"].array + outputs = yolov10_detector.run(horse_image) + dets = yolov10_detector.get_detections(outputs) + assert isinstance(dets, list) + for det in dets: + assert isinstance(det, tuple) + assert len(det) == 3 + bbox, score, cls_id = det + # bbox is (x1, y1, x2, y2) ints + assert isinstance(bbox, tuple) + assert len(bbox) == 4 + for coord in bbox: + assert isinstance(coord, (int, np.integer)) + # score is a float + assert isinstance(score, (float, np.floating)) + assert 0.0 <= float(score) <= 1.0 + # class id is an int + assert isinstance(cls_id, (int, np.integer)) + assert int(cls_id) >= 0 + + def test_end2end_format( + self, + yolov10_detector, + images, + ) -> None: + """end2end() should return detections in the same format.""" + horse_image = images["horse"].array + dets = yolov10_detector.end2end(horse_image) + assert isinstance(dets, list) + for det in dets: + assert isinstance(det, tuple) + assert len(det) == 3 + + def test_callable_matches_run( + self, + yolov10_detector, + images, + ) -> None: + """__call__ should produce the same result as run().""" + horse_image = images["horse"].array + out_run = yolov10_detector.run(horse_image) + out_call = yolov10_detector(horse_image) + assert len(out_run) == len(out_call) + + +# --------------------------------------------------------------------------- +# Batch tests +# --------------------------------------------------------------------------- +class TestDetectorBatch: + """Correctness tests for detector with batched images.""" + + def test_batch_run_returns_nested_lists( + self, + yolov10_detector_batch, + random_images, + ) -> None: + """Batch run with postprocess returns list[list[ndarray]].""" + imgs = random_images(count=2, height=480, width=640) + outputs = yolov10_detector_batch.run(imgs) + assert isinstance(outputs, list) + assert len(outputs) == 2 + for per_image in outputs: + assert isinstance(per_image, list) + for arr in per_image: + assert isinstance(arr, np.ndarray) + + def test_batch_get_detections( + self, + yolov10_detector_batch, + random_images, + ) -> None: + """Batch get_detections returns list[list[tuple]].""" + imgs = random_images(count=2, height=480, width=640) + outputs = yolov10_detector_batch.run(imgs) + dets = yolov10_detector_batch.get_detections(outputs) + assert isinstance(dets, list) + assert len(dets) == 2 + for per_image_dets in dets: + assert isinstance(per_image_dets, list) + + +# --------------------------------------------------------------------------- +# Confidence threshold tests +# --------------------------------------------------------------------------- +class TestDetectorConfThreshold: + """Test that confidence threshold filtering works.""" + + def test_high_threshold_fewer_detections( + self, + yolov10_detector, + images, + ) -> None: + """Higher conf threshold should yield fewer or equal detections.""" + horse_image = images["horse"].array + outputs = yolov10_detector.run(horse_image) + dets_low = yolov10_detector.get_detections( + outputs, + conf_thres=0.01, + ) + dets_high = yolov10_detector.get_detections( + outputs, + conf_thres=0.9, + ) + assert len(dets_high) <= len(dets_low) + + def test_all_scores_above_threshold( + self, + yolov10_detector, + images, + ) -> None: + """All returned detections should have score >= conf_thres.""" + horse_image = images["horse"].array + threshold = 0.5 + outputs = yolov10_detector.run(horse_image) + dets = yolov10_detector.get_detections( + outputs, + conf_thres=threshold, + ) + for _bbox, score, _cls_id in dets: + assert float(score) >= threshold - 1e-6 + + +# --------------------------------------------------------------------------- +# Bbox coordinate validity +# --------------------------------------------------------------------------- +class TestDetectorBboxValidity: + """Test that bounding box coordinates are valid.""" + + def test_bbox_x2_ge_x1( + self, + yolov10_detector, + images, + ) -> None: + """x2 should be >= x1 for all detections.""" + horse_image = images["horse"].array + outputs = yolov10_detector.run(horse_image) + dets = yolov10_detector.get_detections(outputs) + for (x1, y1, x2, y2), _score, _cls_id in dets: + assert int(x2) >= int(x1) + assert int(y2) >= int(y1) + + +# --------------------------------------------------------------------------- +# Data-driven multi-model correctness tests +# --------------------------------------------------------------------------- +_DETECTOR_MODEL_IDS = list(DETECTOR_MODELS.keys()) +_INFERENCE_MODES = ["end2end", "run"] + + +def _make_expected_ids() -> list[str]: + """Create human-readable IDs for DETECTOR_EXPECTED entries.""" + return [Path(e["image"]).stem for e in DETECTOR_EXPECTED] + + +_EXPECTED_IDS = _make_expected_ids() + + +def _run_detector_inference( + detector: object, + image: np.ndarray, + mode: str, +) -> list[tuple[tuple[int, int, int, int], float, int]]: + """Run inference with the given mode and return detections.""" + if mode == "end2end": + return detector.end2end(image) # type: ignore[union-attr] + if mode == "run": + outputs = detector.run(image) # type: ignore[union-attr] + return detector.get_detections(outputs) # type: ignore[union-attr] + err_msg = f"Unknown inference mode: {mode}" + raise ValueError(err_msg) + + +@pytest.mark.correctness +@pytest.mark.download +class TestDetectorImageCorrectness: + """Data-driven correctness: every detector must find expected objects.""" + + @pytest.mark.parametrize("expected_entry", DETECTOR_EXPECTED, ids=_EXPECTED_IDS) + @pytest.mark.parametrize("model_id", _DETECTOR_MODEL_IDS) + @pytest.mark.parametrize("mode", _INFERENCE_MODES) + def test_image_correctness( + self, + expected_entry: dict, + model_id: str, + mode: str, + ) -> None: + """Detector finds expected classes with sufficient detections.""" + import cv2 + + cls_name, model_name, imgsz = DETECTOR_MODELS[model_id] + engine_path = build_model_engine(cls_name, model_name, imgsz) + + model_class = _resolve_model_class(cls_name) + detector = model_class( + engine_path, + preprocessor="cpu", + warmup=False, + no_warn=True, + ) + + image_path = str(BASE_DIR / expected_entry["image"]) + image = cv2.imread(image_path) + if image is None: + pytest.skip(f"Image not found: {image_path}") + + dets = _run_detector_inference(detector, image, mode) + + # At least min_detections + assert len(dets) >= expected_entry["min_detections"], ( + f"{model_id}/{mode}: expected >= {expected_entry['min_detections']} " + f"detections, got {len(dets)}" + ) + + # At least one detection matches an expected class + detected_classes = [int(d[2]) for d in dets] + has_match = any(c in expected_entry["expected_classes"] for c in detected_classes) + assert has_match, ( + f"{model_id}/{mode}: none of {detected_classes} match " + f"expected {expected_entry['expected_classes']}" + ) + + # All scores above threshold + for _bbox, score, _cls_id in dets: + assert float(score) >= expected_entry["conf_thres"] - 1e-6, ( + f"{model_id}/{mode}: score {score} < {expected_entry['conf_thres']}" + ) + + del detector + + +@pytest.mark.correctness +@pytest.mark.download +class TestDetectorOutputValidityMultiModel: + """Validate bbox/score/class constraints across all detector models.""" + + @pytest.mark.parametrize("expected_entry", DETECTOR_EXPECTED, ids=_EXPECTED_IDS) + @pytest.mark.parametrize("model_id", _DETECTOR_MODEL_IDS) + def test_output_validity( + self, + expected_entry: dict, + model_id: str, + ) -> None: + """All detections have valid bbox, score, and class_id.""" + import cv2 + + cls_name, model_name, imgsz = DETECTOR_MODELS[model_id] + engine_path = build_model_engine(cls_name, model_name, imgsz) + + model_class = _resolve_model_class(cls_name) + detector = model_class( + engine_path, + preprocessor="cpu", + warmup=False, + no_warn=True, + ) + + image_path = str(BASE_DIR / expected_entry["image"]) + image = cv2.imread(image_path) + if image is None: + pytest.skip(f"Image not found: {image_path}") + + dets = _run_detector_inference(detector, image, "end2end") + + for bbox, score, class_id in dets: + x1, y1, x2, y2 = bbox + assert int(x2) > int(x1), f"{model_id}: invalid bbox width x1={x1}, x2={x2}" + assert int(y2) > int(y1), f"{model_id}: invalid bbox height y1={y1}, y2={y2}" + assert 0.0 <= float(score) <= 1.0, f"{model_id}: invalid score {score}" + assert int(class_id) >= 0, f"{model_id}: invalid class_id {class_id}" + + del detector + + +@pytest.mark.correctness +@pytest.mark.download +class TestPreprocessorConsistencyMultiModel: + """Verify cpu/cuda/trt preprocessors give consistent results.""" + + @pytest.mark.parametrize("expected_entry", DETECTOR_EXPECTED, ids=_EXPECTED_IDS) + @pytest.mark.parametrize("model_id", _DETECTOR_MODEL_IDS) + def test_preprocessor_consistency( + self, + expected_entry: dict, + model_id: str, + ) -> None: + """All preprocessors produce at least 1 detection, count variance <= 2.""" + import cv2 + + cls_name, model_name, imgsz = DETECTOR_MODELS[model_id] + engine_path = build_model_engine(cls_name, model_name, imgsz) + + model_class = _resolve_model_class(cls_name) + + image_path = str(BASE_DIR / expected_entry["image"]) + image = cv2.imread(image_path) + if image is None: + pytest.skip(f"Image not found: {image_path}") + + preprocessors = ["cpu", "cuda", "trt"] + counts: dict[str, int] = {} + + for preproc in preprocessors: + detector = model_class( + engine_path, + preprocessor=preproc, + warmup=False, + no_warn=True, + ) + dets = detector.end2end(image) + counts[preproc] = len(dets) + del detector + + min_count = min(counts.values()) + max_count = max(counts.values()) + + # Every preprocessor should produce at least 1 detection + for preproc, count in counts.items(): + assert count >= 1, f"{model_id}/{preproc}: expected >= 1 detection, got {count}" + + # Detection count variance across preprocessors should be small + assert max_count - min_count <= 2, f"{model_id}: detection count variance too high: {counts}" diff --git a/tests/models/test_models.py b/tests/models/test_models.py new file mode 100644 index 00000000..9d26dab7 --- /dev/null +++ b/tests/models/test_models.py @@ -0,0 +1,439 @@ +# Copyright (c) 2026 Justin Davis (davisjustin302@gmail.com) +# +# MIT License +"""Parametrized model instantiation tests for all model classes.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +import pytest + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_mock_engine( + input_names: list | None = None, + output_names: list | None = None, + input_spec: list | None = None, + output_spec: list | None = None, +) -> MagicMock: + """Create a mock TRTEngine with configurable tensor names/specs.""" + engine = MagicMock() + engine.input_names = input_names or ["images"] + engine.output_names = output_names or [ + "num_dets", + "det_boxes", + "det_scores", + "det_classes", + ] + engine.input_spec = input_spec or [("images", (1, 3, 640, 640))] + engine.output_spec = output_spec or [ + ("num_dets", (1, 1)), + ("det_boxes", (1, 100, 4)), + ("det_scores", (1, 100)), + ("det_classes", (1, 100)), + ] + return engine + + +# --------------------------------------------------------------------------- +# Model mixin class-attribute tests (no GPU needed) +# --------------------------------------------------------------------------- +DETECTOR_CLASSES = [ + "YOLOv3", + "YOLOv5", + "YOLOv7", + "YOLOv8", + "YOLOv9", + "YOLOv10", + "YOLOv11", + "YOLOv12", + "YOLOv13", + "YOLOv26", + "YOLOX", + "RTDETRv1", + "RTDETRv2", + "RTDETRv3", + "DFINE", + "DEIM", + "DEIMv2", + "RFDETR", +] + +CLASSIFIER_CLASSES = [ + "AlexNet", + "ConvNeXt", + "DenseNet", + "EfficientNet", + "EfficientNetV2", + "GoogLeNet", + "Inception", + "MaxViT", + "MNASNet", + "MobileNetV2", + "MobileNetV3", + "RegNet", + "ResNet", + "ResNeXt", + "ShuffleNetV2", + "SqueezeNet", + "SwinTransformer", + "SwinTransformerV2", + "VGG", + "ViT", + "WideResNet", +] + +DEPTH_CLASSES = ["DepthAnythingV2"] + +ALL_MODEL_NAMES = DETECTOR_CLASSES + CLASSIFIER_CLASSES + DEPTH_CLASSES + + +def _get_model_class(name: str) -> type: + """Import and return a model class by name from trtutils.models.""" + import trtutils.models as models_mod + + return getattr(models_mod, name) + + +class TestModelClassAttributes: + """Verify required class attributes are present on all Model subclasses.""" + + @pytest.mark.cpu + @pytest.mark.parametrize("cls_name", ALL_MODEL_NAMES) + def test_has_model_type(self, cls_name: str) -> None: + cls = _get_model_class(cls_name) + assert hasattr(cls, "_model_type") + assert isinstance(cls._model_type, str) + assert len(cls._model_type) > 0 + + @pytest.mark.cpu + @pytest.mark.parametrize("cls_name", ALL_MODEL_NAMES) + def test_has_friendly_name(self, cls_name: str) -> None: + cls = _get_model_class(cls_name) + assert hasattr(cls, "_friendly_name") + assert isinstance(cls._friendly_name, str) + + @pytest.mark.cpu + @pytest.mark.parametrize("cls_name", ALL_MODEL_NAMES) + def test_has_default_imgsz(self, cls_name: str) -> None: + cls = _get_model_class(cls_name) + assert hasattr(cls, "_default_imgsz") + assert isinstance(cls._default_imgsz, int) + assert cls._default_imgsz > 0 + + @pytest.mark.cpu + @pytest.mark.parametrize("cls_name", ALL_MODEL_NAMES) + def test_has_input_tensors(self, cls_name: str) -> None: + cls = _get_model_class(cls_name) + assert hasattr(cls, "_input_tensors") + assert isinstance(cls._input_tensors, list) + assert len(cls._input_tensors) > 0 + for name, kind in cls._input_tensors: + assert isinstance(name, str) + assert kind in ("image", "size") + + +class TestModelMakeShapes: + """Test the _make_shapes classmethod.""" + + @pytest.mark.cpu + @pytest.mark.parametrize("cls_name", ALL_MODEL_NAMES) + def test_make_shapes_returns_list(self, cls_name: str) -> None: + cls = _get_model_class(cls_name) + shapes = cls._make_shapes(1, cls._default_imgsz) + assert isinstance(shapes, list) + assert len(shapes) == len(cls._input_tensors) + + @pytest.mark.cpu + @pytest.mark.parametrize("cls_name", ALL_MODEL_NAMES) + def test_make_shapes_batch_size(self, cls_name: str) -> None: + cls = _get_model_class(cls_name) + batch = 4 + shapes = cls._make_shapes(batch, cls._default_imgsz) + for _name, shape in shapes: + assert shape[0] == batch + + @pytest.mark.cpu + @pytest.mark.parametrize("cls_name", ALL_MODEL_NAMES) + def test_make_shapes_image_tensor_dims(self, cls_name: str) -> None: + """Image tensors should have shape (B, 3, imgsz, imgsz).""" + cls = _get_model_class(cls_name) + imgsz = cls._default_imgsz + shapes = cls._make_shapes(1, imgsz) + for i, (_name, kind) in enumerate(cls._input_tensors): + _, shape = shapes[i] + if kind == "image": + assert len(shape) == 4 + assert shape[1] == 3 + assert shape[2] == imgsz + assert shape[3] == imgsz + elif kind == "size": + assert len(shape) == 2 + assert shape[1] == 2 + + +class TestModelValidateImgsz: + """Test image size validation.""" + + @pytest.mark.cpu + def test_valid_imgszs_reject_invalid(self) -> None: + """RTDETRv1 only allows 640; other sizes should raise.""" + cls = _get_model_class("RTDETRv1") + with pytest.raises(ValueError, match="supports only imgsz"): + cls._validate_imgsz(320) + + @pytest.mark.cpu + def test_valid_imgszs_accept_valid(self) -> None: + """RTDETRv1 should accept 640.""" + cls = _get_model_class("RTDETRv1") + cls._validate_imgsz(640) # Should not raise + + @pytest.mark.cpu + def test_divisor_reject_invalid(self) -> None: + """RFDETR requires imgsz divisible by 32.""" + cls = _get_model_class("RFDETR") + with pytest.raises(ValueError, match="divisible by"): + cls._validate_imgsz(577) + + @pytest.mark.cpu + def test_divisor_accept_valid(self) -> None: + """RFDETR should accept 576 (divisible by 32).""" + cls = _get_model_class("RFDETR") + cls._validate_imgsz(576) # Should not raise + + @pytest.mark.cpu + def test_no_restrictions_accept_any(self) -> None: + """YOLOv10 has no imgsz restrictions.""" + cls = _get_model_class("YOLOv10") + cls._validate_imgsz(123) # Should not raise + + +class TestModelDownloadValidation: + """Test download() class method validation (no actual downloads).""" + + @pytest.mark.cpu + def test_invalid_model_name_raises(self, tmp_path: Path) -> None: + """download() should raise for an invalid model name.""" + cls = _get_model_class("YOLOv10") + out_path = tmp_path / "out.onnx" + with pytest.raises(ValueError, match="Model fake_model_xyz not supported"): + cls.download("fake_model_xyz", out_path) + + @pytest.mark.cpu + def test_deimv2_wrong_imgsz_raises(self, tmp_path: Path) -> None: + """DEIMv2 with model-specific imgsz mismatch should raise.""" + cls = _get_model_class("DEIMv2") + out_path = tmp_path / "out.onnx" + with pytest.raises(ValueError, match="deimv2_atto requires imgsz of 320"): + cls.download( + "deimv2_atto", + out_path, + imgsz=640, + ) + + +class TestModelBuildValidation: + """Test build() class method validation (no actual builds).""" + + @pytest.mark.cpu + def test_unknown_kwargs_raises(self, tmp_path: Path) -> None: + """build() should raise TypeError for unknown keyword args.""" + cls = _get_model_class("YOLOv10") + fake_model_path = tmp_path / "fake.onnx" + engine_path = tmp_path / "out.engine" + with pytest.raises(TypeError, match="unexpected keyword arguments"): + cls.build( + fake_model_path, + engine_path, + totally_fake_kwarg=True, + ) + + +# --------------------------------------------------------------------------- +# Base arch classes (YOLO, DETR) -- no Model mixin +# --------------------------------------------------------------------------- +class TestBaseArchClasses: + """Test that YOLO and DETR base classes exist and are importable.""" + + @pytest.mark.cpu + def test_yolo_importable(self) -> None: + from trtutils.models import YOLO + + assert YOLO is not None + + @pytest.mark.cpu + def test_detr_importable(self) -> None: + from trtutils.models import DETR + + assert DETR is not None + + @pytest.mark.cpu + def test_yolo_has_default_imgsz(self) -> None: + from trtutils.models import YOLO + + assert hasattr(YOLO, "_default_imgsz") + assert YOLO._default_imgsz == 640 + + +# --------------------------------------------------------------------------- +# _make_shapes edge cases +# --------------------------------------------------------------------------- +class TestMakeShapesEdgeCases: + """Test _make_shapes error handling.""" + + @pytest.mark.cpu + def test_unknown_tensor_kind_raises(self) -> None: + """An unknown tensor kind should raise ValueError.""" + from trtutils.models._model import Model + + class FakeModel(Model): + _model_type = "fake" + _friendly_name = "Fake" + _default_imgsz = 640 + _input_tensors = (("input", "unknown_kind"),) + + with pytest.raises(ValueError, match="Unknown input tensor kind"): + FakeModel._make_shapes(1, 640) + + +# --------------------------------------------------------------------------- +# Model.download() -- _model_imgszs variant-specific size logic +# --------------------------------------------------------------------------- +class TestModelImgszVariants: + """Test the _model_imgszs logic in Model.download().""" + + @pytest.mark.cpu + def test_deimv2_atto_default_imgsz_is_320(self, tmp_path: Path) -> None: + """DEIMv2.download('deimv2_atto', ...) with imgsz=None should use 320.""" + from unittest.mock import patch as _patch + + cls = _get_model_class("DEIMv2") + with _patch( + "trtutils.models._model.download_model_internal", + ) as mock_download: + out_path = tmp_path / "out.onnx" + cls.download("deimv2_atto", out_path) + call_kwargs = mock_download.call_args[1] + assert call_kwargs["imgsz"] == 320 + + @pytest.mark.cpu + def test_deimv2_femto_default_imgsz_is_416(self, tmp_path: Path) -> None: + """DEIMv2.download('deimv2_femto', ...) with imgsz=None should use 416.""" + from unittest.mock import patch as _patch + + cls = _get_model_class("DEIMv2") + with _patch( + "trtutils.models._model.download_model_internal", + ) as mock_download: + out_path = tmp_path / "out.onnx" + cls.download("deimv2_femto", out_path) + call_kwargs = mock_download.call_args[1] + assert call_kwargs["imgsz"] == 416 + + @pytest.mark.cpu + def test_deimv2_non_variant_default_imgsz_is_640(self, tmp_path: Path) -> None: + """DEIMv2.download('deimv2_small', ...) with imgsz=None should use 640.""" + from unittest.mock import patch as _patch + + cls = _get_model_class("DEIMv2") + with _patch( + "trtutils.models._model.download_model_internal", + ) as mock_download: + out_path = tmp_path / "out.onnx" + cls.download("deimv2_small", out_path) + call_kwargs = mock_download.call_args[1] + assert call_kwargs["imgsz"] == 640 + + @pytest.mark.cpu + def test_deimv2_atto_explicit_correct_imgsz_accepted(self, tmp_path: Path) -> None: + """DEIMv2.download('deimv2_atto', ..., imgsz=320) should work.""" + from unittest.mock import patch as _patch + + cls = _get_model_class("DEIMv2") + with _patch( + "trtutils.models._model.download_model_internal", + ) as mock_download: + out_path = tmp_path / "out.onnx" + cls.download("deimv2_atto", out_path, imgsz=320) + call_kwargs = mock_download.call_args[1] + assert call_kwargs["imgsz"] == 320 + + @pytest.mark.cpu + def test_deimv2_atto_wrong_imgsz_raises(self, tmp_path: Path) -> None: + """DEIMv2.download('deimv2_atto', ..., imgsz=640) should raise.""" + cls = _get_model_class("DEIMv2") + out_path = tmp_path / "out.onnx" + with pytest.raises(ValueError, match="requires imgsz of 320"): + cls.download("deimv2_atto", out_path, imgsz=640) + + +# --------------------------------------------------------------------------- +# Model.download() -- general imgsz=None path +# --------------------------------------------------------------------------- +class TestModelDownloadImgszDefault: + """Test that download() uses _default_imgsz when imgsz=None.""" + + @pytest.mark.cpu + def test_yolov10_default_imgsz_640(self, tmp_path: Path) -> None: + """YOLOv10.download with imgsz=None should use 640.""" + from unittest.mock import patch as _patch + + cls = _get_model_class("YOLOv10") + with _patch( + "trtutils.models._model.download_model_internal", + ) as mock_download: + out_path = tmp_path / "out.onnx" + cls.download("yolov10n", out_path) + call_kwargs = mock_download.call_args[1] + assert call_kwargs["imgsz"] == 640 + + @pytest.mark.cpu + def test_rfdetr_default_imgsz_576(self, tmp_path: Path) -> None: + """RFDETR.download with imgsz=None should use 576.""" + from unittest.mock import patch as _patch + + cls = _get_model_class("RFDETR") + with _patch( + "trtutils.models._model.download_model_internal", + ) as mock_download: + out_path = tmp_path / "out.onnx" + cls.download("rfdetr_n", out_path) + call_kwargs = mock_download.call_args[1] + assert call_kwargs["imgsz"] == 576 + + +# --------------------------------------------------------------------------- +# nms_build_hook +# --------------------------------------------------------------------------- +class TestNmsBuildHook: + """Test the nms_build_hook function.""" + + @pytest.mark.cpu + def test_returns_hooks_list(self) -> None: + """nms_build_hook should return a dict with 'hooks' key.""" + from trtutils.models._model import nms_build_hook + + result = nms_build_hook() + assert "hooks" in result + assert isinstance(result["hooks"], list) + assert len(result["hooks"]) == 1 + + @pytest.mark.cpu + def test_custom_params_accepted(self) -> None: + """nms_build_hook should accept custom NMS parameters.""" + from trtutils.models._model import nms_build_hook + + result = nms_build_hook( + num_classes=10, + conf_threshold=0.5, + iou_threshold=0.7, + top_k=50, + ) + assert "hooks" in result