Skip to content
Open
15 changes: 15 additions & 0 deletions src/gasbench/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ async def run_benchmark(
holdouts_only: bool = False,
content_category: Optional[str] = None,
score_composition: Optional[Dict[str, float]] = None,
n_aug_per_dataset: int = 0,
aug_weight: float = 0.2,
aug_cache_dir: Optional[str] = None,
) -> Dict:
"""
Args:
Expand Down Expand Up @@ -122,6 +125,9 @@ async def run_benchmark(
holdouts_only,
content_category,
score_composition,
n_aug_per_dataset=n_aug_per_dataset,
aug_weight=aug_weight,
aug_cache_dir=aug_cache_dir,
)

benchmark_results["benchmark_score"] = benchmark_score
Expand Down Expand Up @@ -218,6 +224,9 @@ async def execute_benchmark(
holdouts_only: bool = False,
content_category: Optional[str] = None,
score_composition: Optional[Dict[str, float]] = None,
n_aug_per_dataset: int = 0,
aug_weight: float = 0.2,
aug_cache_dir: Optional[str] = None,
) -> float:
"""Execute the actual benchmark evaluation."""

Expand Down Expand Up @@ -245,6 +254,9 @@ async def execute_benchmark(
holdouts_only=holdouts_only,
content_category=content_category,
score_composition=score_composition,
n_aug_per_dataset=n_aug_per_dataset,
aug_weight=aug_weight,
aug_cache_dir=aug_cache_dir,
)
benchmark_score = benchmark_results.get("image_results", {}).get("benchmark_score", 0.0)
elif modality == "video":
Expand All @@ -268,6 +280,9 @@ async def execute_benchmark(
holdouts_only=holdouts_only,
content_category=content_category,
score_composition=score_composition,
n_aug_per_dataset=n_aug_per_dataset,
aug_weight=aug_weight,
aug_cache_dir=aug_cache_dir,
)
benchmark_score = benchmark_results.get("video_results", {}).get("benchmark_score", 0.0)
elif modality == "audio":
Expand Down
2 changes: 2 additions & 0 deletions src/gasbench/benchmarks/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class BenchmarkRunConfig:
holdouts_only: bool = False # If True, only run holdout datasets (requires holdout_config_path)
content_category: Optional[str] = None # Filter datasets by content_category (e.g. "faces")
score_composition: Optional[Dict[str, float]] = None # Target score weight share per provenance class, e.g. {"public": 0.5, "holdout": 0.3, "gasstation": 0.2}; weights all metrics incl. sn34_score
n_aug_per_dataset: int = 0 # Number of samples per dataset to re-evaluate with robustness augmentations (0 = disabled)
aug_weight: float = 0.2 # Weight of aug_sn34_score in blended final score (when n_aug_per_dataset > 0)


@dataclass
Expand Down
139 changes: 131 additions & 8 deletions src/gasbench/benchmarks/image_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..processing.media import process_image_sample
from ..processing.transforms import (
apply_random_augmentations,
apply_robustness_augmentations,
)
from ..config import (
DEFAULT_IMAGE_BATCH_SIZE,
Expand All @@ -19,8 +20,28 @@
from ..dataset.iterator import DatasetIterator

from .utils.inference import process_model_output
from .recording import BenchmarkRunRecorder, log_dataset_summary
from .recording import BenchmarkRunRecorder, log_dataset_summary, build_sample_id
from .common import BenchmarkRunConfig, build_plan, create_tracker, finalize_run

_IMG_AUG_VERSION = "img_v1"


def _aug_cache_path(cache_dir: str, sample_id: str) -> str:
return os.path.join(cache_dir, sample_id[:2], f"{sample_id}_{_IMG_AUG_VERSION}.npy")


def _write_aug_cache(path: str, array) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
tmp = f"{path}.tmp.{os.getpid()}"
try:
import numpy as _np
_np.save(tmp, array)
os.replace(tmp, path)
except Exception:
try:
os.unlink(tmp)
except OSError:
pass
import pandas as pd

logger = get_logger(__name__)
Expand All @@ -46,6 +67,8 @@ def __init__(
crop_prob,
num_workers=8,
max_queue_size=8,
robustness_pass=False,
aug_cache_dir=None,
):
self.dataset_iterator = dataset_iterator
self.target_size = target_size
Expand All @@ -55,6 +78,8 @@ def __init__(
self.crop_prob = crop_prob
self.num_workers = num_workers
self.max_queue_size = max_queue_size
self.robustness_pass = robustness_pass
self.aug_cache_dir = aug_cache_dir

self.batch_queue = Queue(maxsize=max_queue_size)
self.stop_event = threading.Event()
Expand All @@ -78,13 +103,31 @@ def _read_and_preprocess(self, sample, sample_index, dataset_name):
return None

sample_seed = None if self.seed is None else (self.seed + sample_index)
aug_hwc, _, _, _ = apply_random_augmentations(
image_array,
self.target_size,
seed=sample_seed,
level=self.augment_level,
crop_prob=self.crop_prob,
)
if self.robustness_pass:
if self.aug_cache_dir:
sid = build_sample_id(sample)
cache_path = _aug_cache_path(self.aug_cache_dir, sid)
if os.path.exists(cache_path):
aug_hwc = np.load(cache_path)
else:
aug_hwc, _, _, _ = apply_robustness_augmentations(
image_array, self.target_size, seed=sample_seed
)
_write_aug_cache(cache_path, aug_hwc)
else:
aug_hwc, _, _, _ = apply_robustness_augmentations(
image_array,
self.target_size,
seed=sample_seed,
)
else:
aug_hwc, _, _, _ = apply_random_augmentations(
image_array,
self.target_size,
seed=sample_seed,
level=self.augment_level,
crop_prob=self.crop_prob,
)
aug_chw = np.transpose(aug_hwc, (2, 0, 1))

sample_meta = {k: v for k, v in sample.items() if k not in _HEAVY_SAMPLE_KEYS}
Expand Down Expand Up @@ -187,6 +230,7 @@ def process_batch(
batch_metadata,
tracker: BenchmarkRunRecorder,
batch_id: int,
aug_pass: bool = False,
):
"""push a batch of images through the model and record rows in tracker."""
if not batch_images:
Expand Down Expand Up @@ -219,6 +263,7 @@ def process_batch(
batch_id=batch_id,
batch_size=len(batch_images),
sample_seed=sample_seed,
aug_pass=aug_pass,
)


Expand All @@ -244,6 +289,9 @@ async def run_image_benchmark(
holdouts_only: bool = False,
content_category: Optional[str] = None,
score_composition: dict = None,
n_aug_per_dataset: int = 0,
aug_weight: float = 0.2,
aug_cache_dir: Optional[str] = None,
) -> pd.DataFrame:
"""Test model on benchmark image datasets for AI-generated content detection."""

Expand Down Expand Up @@ -276,6 +324,8 @@ async def run_image_benchmark(
holdouts_only=holdouts_only,
content_category=content_category,
score_composition=score_composition,
n_aug_per_dataset=n_aug_per_dataset,
aug_weight=aug_weight,
)

plan = build_plan(logger, run_config, input_specs)
Expand Down Expand Up @@ -368,6 +418,79 @@ async def run_image_benchmark(
f"Dataset error for {dataset_config.name}: {str(e)[:100]}"
)

if n_aug_per_dataset > 0:
aug_seed = seed if seed is not None else 42
logger.info(
f"Starting augmentation robustness pass: {n_aug_per_dataset} samples/dataset"
+ (f" (aug_cache_dir={aug_cache_dir})" if aug_cache_dir else "")
)
for dataset_idx, dataset_config in enumerate(plan.available_datasets):
logger.info(
f"Robustness pass {dataset_idx + 1}/{len(plan.available_datasets)}: "
f"{dataset_config.name}"
)
try:
is_gasstation = "gasstation" in dataset_config.name.lower()
should_download = (
download_latest_gasstation_data if is_gasstation else True
) if not skip_missing else False

aug_iterator = DatasetIterator(
dataset_config,
max_samples=n_aug_per_dataset,
cache_dir=cache_dir,
download=should_download,
hf_token=hf_token,
seed=seed,
lazy_read=True,
)

if skip_missing and aug_iterator.get_total_cached_count() == 0:
continue

aug_pipeline = PrefetchPipeline(
dataset_iterator=aug_iterator,
target_size=plan.target_size,
batch_size=batch_size,
seed=aug_seed,
augment_level=augment_level,
crop_prob=crop_prob,
robustness_pass=True,
aug_cache_dir=aug_cache_dir,
)

aug_batch_id = 0
try:
for batch_data in aug_pipeline:
aug_batch_id += 1
batch_images = [item["image"] for item in batch_data]
batch_metadata = [
(
item["label"],
item["sample"],
item["sample_index"],
item["dataset_name"],
item["sample_seed"],
)
for item in batch_data
]
process_batch(
session,
input_specs,
batch_images,
batch_metadata,
tracker,
aug_batch_id,
aug_pass=True,
)
finally:
aug_pipeline.close()

except Exception as e:
logger.error(
f"Robustness pass failed for {dataset_config.name}: {e}"
)

df = finalize_run(
config=run_config,
plan=plan,
Expand Down
Loading