Skip to content

Commit d87f2f5

Browse files
drop cuda useless
1 parent b57d4d7 commit d87f2f5

3 files changed

Lines changed: 9 additions & 23 deletions

File tree

.github/workflows/train.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ jobs:
4949
aws-region: eu-west-3
5050

5151
- name: Install dependencies
52-
run: uv sync
52+
run: |
53+
uv sync
54+
uv pip uninstall onnxruntime onnxruntime-gpu 2>/dev/null || true
55+
uv pip install onnxruntime-gpu==1.22.1
5356
5457
- name: Run pipeline
5558
run: uv run dvc repro
@@ -92,11 +95,6 @@ jobs:
9295
data/processed/sequential_test/test --rev v2.0.0 \
9396
--out ./data/test/sequential_test
9497
95-
- name: Install onnxruntime-gpu for CUDA inference
96-
run: |
97-
uv pip uninstall onnxruntime onnxruntime-gpu 2>/dev/null || true
98-
uv pip install onnxruntime-gpu==1.22.1
99-
10098
- name: Run sequential evaluation
10199
run: |
102100
uv run python ./scripts/model/yolo/evaluate_sequential.py \

scripts/model/yolo/configs/best.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ epochs: 2
55
patience: 20
66
batch: 16
77
imgsz: 64
8-
lr0: 0.0002
8+
lr0: 0.0001
99
lrf: 0.1
1010
optimizer: AdamW
1111
mixup: 0.2

scripts/model/yolo/evaluate_sequential.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
):
4949
sys.modules.setdefault(_name, _mod)
5050

51-
import onnxruntime # noqa: E402
5251
from PIL import Image # noqa: E402
5352
from pyroengine.engine import Engine # noqa: E402
5453
from tqdm import tqdm # noqa: E402
@@ -78,10 +77,9 @@ def reset_state(engine: Engine, cam_key: str) -> None:
7877
engine.occlusion_masks[cam_key] = (None, {}, 0) # noqa: SLF001
7978

8079

81-
def evaluate_sequence(engine: Engine, images_dir: Path, cam_key: str, max_frames: int) -> bool:
80+
def evaluate_sequence(engine: Engine, frames: list[Path], cam_key: str, max_frames: int) -> bool:
8281
"""Return True if engine raised an alert on any frame in the sequence."""
8382
reset_state(engine, cam_key)
84-
frames = sorted(images_dir.glob("*.jpg")) + sorted(images_dir.glob("*.png"))
8583
for frame_path in frames[:max_frames]:
8684
conf = engine.predict(Image.open(frame_path), cam_id=cam_key)
8785
if conf > engine.conf_thresh:
@@ -98,8 +96,8 @@ def evaluate_category(engine: Engine, category_dir: Path, label: str, max_frames
9896
if not images_dir.exists():
9997
logger.warning(f"No images/ folder in {seq_dir}, skipping")
10098
continue
101-
frames = list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png"))
102-
alerted = evaluate_sequence(engine, images_dir, seq_dir.name, max_frames)
99+
frames = sorted(images_dir.glob("*.jpg")) + sorted(images_dir.glob("*.png"))
100+
alerted = evaluate_sequence(engine, frames, seq_dir.name, max_frames)
103101
records.append({"sequence": seq_dir.name, "alerted": alerted, "n_frames": min(len(frames), max_frames)})
104102
pbar.set_postfix({"last": seq_dir.name[:30], "alert": alerted})
105103
return records
@@ -182,17 +180,7 @@ def validate_parsed_args(args: dict) -> bool:
182180
nb_consecutive_frames=args["nb_consecutive_frames"],
183181
cache_folder=str(args["output_dir"]),
184182
)
185-
186-
# Upgrade to CUDA execution provider if available (requires onnxruntime-gpu).
187-
available_providers = onnxruntime.get_available_providers()
188-
if "CUDAExecutionProvider" in available_providers:
189-
engine.model.ort_session = onnxruntime.InferenceSession( # noqa: SLF001
190-
str(args["model_path"]),
191-
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
192-
)
193-
logger.info("onnxruntime: using CUDAExecutionProvider")
194-
else:
195-
logger.info(f"onnxruntime: CUDA not available, using {available_providers[0]}")
183+
# Provider selection (CUDA → CoreML → CPU) is now handled inside pyroengine's Classifier.
196184

197185
wf_records = evaluate_category(engine, args["data_dir"] / "wildfire", "wildfire", args["max_frames"])
198186
fp_records = evaluate_category(engine, args["data_dir"] / "fp", "fp ", args["max_frames"])

0 commit comments

Comments
 (0)