diff --git a/src/gasbench/benchmark.py b/src/gasbench/benchmark.py index 8f27d32..c250670 100644 --- a/src/gasbench/benchmark.py +++ b/src/gasbench/benchmark.py @@ -37,6 +37,9 @@ async def run_benchmark( holdouts_only: bool = False, content_category: Optional[str] = None, score_composition: Optional[Dict[str, float]] = None, + n_aug_per_dataset: int = 0, + aug_weight: float = 0.2, + aug_cache_dir: Optional[str] = None, ) -> Dict: """ Args: @@ -122,6 +125,9 @@ async def run_benchmark( holdouts_only, content_category, score_composition, + n_aug_per_dataset=n_aug_per_dataset, + aug_weight=aug_weight, + aug_cache_dir=aug_cache_dir, ) benchmark_results["benchmark_score"] = benchmark_score @@ -218,6 +224,9 @@ async def execute_benchmark( holdouts_only: bool = False, content_category: Optional[str] = None, score_composition: Optional[Dict[str, float]] = None, + n_aug_per_dataset: int = 0, + aug_weight: float = 0.2, + aug_cache_dir: Optional[str] = None, ) -> float: """Execute the actual benchmark evaluation.""" @@ -245,6 +254,9 @@ async def execute_benchmark( holdouts_only=holdouts_only, content_category=content_category, score_composition=score_composition, + n_aug_per_dataset=n_aug_per_dataset, + aug_weight=aug_weight, + aug_cache_dir=aug_cache_dir, ) benchmark_score = benchmark_results.get("image_results", {}).get("benchmark_score", 0.0) elif modality == "video": @@ -268,6 +280,9 @@ async def execute_benchmark( holdouts_only=holdouts_only, content_category=content_category, score_composition=score_composition, + n_aug_per_dataset=n_aug_per_dataset, + aug_weight=aug_weight, + aug_cache_dir=aug_cache_dir, ) benchmark_score = benchmark_results.get("video_results", {}).get("benchmark_score", 0.0) elif modality == "audio": diff --git a/src/gasbench/benchmarks/common.py b/src/gasbench/benchmarks/common.py index 5b46c09..d3ad65d 100644 --- a/src/gasbench/benchmarks/common.py +++ b/src/gasbench/benchmarks/common.py @@ -41,6 +41,8 @@ class BenchmarkRunConfig: holdouts_only: bool = False # If True, only run holdout datasets (requires holdout_config_path) content_category: Optional[str] = None # Filter datasets by content_category (e.g. "faces") score_composition: Optional[Dict[str, float]] = None # Target score weight share per provenance class, e.g. {"public": 0.5, "holdout": 0.3, "gasstation": 0.2}; weights all metrics incl. sn34_score + n_aug_per_dataset: int = 0 # Number of samples per dataset to re-evaluate with robustness augmentations (0 = disabled) + aug_weight: float = 0.2 # Weight of aug_sn34_score in blended final score (when n_aug_per_dataset > 0) @dataclass diff --git a/src/gasbench/benchmarks/image_bench.py b/src/gasbench/benchmarks/image_bench.py index bb32e89..11b2f2b 100644 --- a/src/gasbench/benchmarks/image_bench.py +++ b/src/gasbench/benchmarks/image_bench.py @@ -11,6 +11,7 @@ from ..processing.media import process_image_sample from ..processing.transforms import ( apply_random_augmentations, + apply_robustness_augmentations, ) from ..config import ( DEFAULT_IMAGE_BATCH_SIZE, @@ -19,8 +20,28 @@ from ..dataset.iterator import DatasetIterator from .utils.inference import process_model_output -from .recording import BenchmarkRunRecorder, log_dataset_summary +from .recording import BenchmarkRunRecorder, log_dataset_summary, build_sample_id from .common import BenchmarkRunConfig, build_plan, create_tracker, finalize_run + +_IMG_AUG_VERSION = "img_v1" + + +def _aug_cache_path(cache_dir: str, sample_id: str) -> str: + return os.path.join(cache_dir, sample_id[:2], f"{sample_id}_{_IMG_AUG_VERSION}.npy") + + +def _write_aug_cache(path: str, array) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp = f"{path}.tmp.{os.getpid()}" + try: + import numpy as _np + _np.save(tmp, array) + os.replace(tmp, path) + except Exception: + try: + os.unlink(tmp) + except OSError: + pass import pandas as pd logger = get_logger(__name__) @@ -46,6 +67,8 @@ def __init__( crop_prob, num_workers=8, max_queue_size=8, + robustness_pass=False, + aug_cache_dir=None, ): self.dataset_iterator = dataset_iterator self.target_size = target_size @@ -55,6 +78,8 @@ def __init__( self.crop_prob = crop_prob self.num_workers = num_workers self.max_queue_size = max_queue_size + self.robustness_pass = robustness_pass + self.aug_cache_dir = aug_cache_dir self.batch_queue = Queue(maxsize=max_queue_size) self.stop_event = threading.Event() @@ -78,13 +103,31 @@ def _read_and_preprocess(self, sample, sample_index, dataset_name): return None sample_seed = None if self.seed is None else (self.seed + sample_index) - aug_hwc, _, _, _ = apply_random_augmentations( - image_array, - self.target_size, - seed=sample_seed, - level=self.augment_level, - crop_prob=self.crop_prob, - ) + if self.robustness_pass: + if self.aug_cache_dir: + sid = build_sample_id(sample) + cache_path = _aug_cache_path(self.aug_cache_dir, sid) + if os.path.exists(cache_path): + aug_hwc = np.load(cache_path) + else: + aug_hwc, _, _, _ = apply_robustness_augmentations( + image_array, self.target_size, seed=sample_seed + ) + _write_aug_cache(cache_path, aug_hwc) + else: + aug_hwc, _, _, _ = apply_robustness_augmentations( + image_array, + self.target_size, + seed=sample_seed, + ) + else: + aug_hwc, _, _, _ = apply_random_augmentations( + image_array, + self.target_size, + seed=sample_seed, + level=self.augment_level, + crop_prob=self.crop_prob, + ) aug_chw = np.transpose(aug_hwc, (2, 0, 1)) sample_meta = {k: v for k, v in sample.items() if k not in _HEAVY_SAMPLE_KEYS} @@ -187,6 +230,7 @@ def process_batch( batch_metadata, tracker: BenchmarkRunRecorder, batch_id: int, + aug_pass: bool = False, ): """push a batch of images through the model and record rows in tracker.""" if not batch_images: @@ -219,6 +263,7 @@ def process_batch( batch_id=batch_id, batch_size=len(batch_images), sample_seed=sample_seed, + aug_pass=aug_pass, ) @@ -244,6 +289,9 @@ async def run_image_benchmark( holdouts_only: bool = False, content_category: Optional[str] = None, score_composition: dict = None, + n_aug_per_dataset: int = 0, + aug_weight: float = 0.2, + aug_cache_dir: Optional[str] = None, ) -> pd.DataFrame: """Test model on benchmark image datasets for AI-generated content detection.""" @@ -276,6 +324,8 @@ async def run_image_benchmark( holdouts_only=holdouts_only, content_category=content_category, score_composition=score_composition, + n_aug_per_dataset=n_aug_per_dataset, + aug_weight=aug_weight, ) plan = build_plan(logger, run_config, input_specs) @@ -368,6 +418,79 @@ async def run_image_benchmark( f"Dataset error for {dataset_config.name}: {str(e)[:100]}" ) + if n_aug_per_dataset > 0: + aug_seed = seed if seed is not None else 42 + logger.info( + f"Starting augmentation robustness pass: {n_aug_per_dataset} samples/dataset" + + (f" (aug_cache_dir={aug_cache_dir})" if aug_cache_dir else "") + ) + for dataset_idx, dataset_config in enumerate(plan.available_datasets): + logger.info( + f"Robustness pass {dataset_idx + 1}/{len(plan.available_datasets)}: " + f"{dataset_config.name}" + ) + try: + is_gasstation = "gasstation" in dataset_config.name.lower() + should_download = ( + download_latest_gasstation_data if is_gasstation else True + ) if not skip_missing else False + + aug_iterator = DatasetIterator( + dataset_config, + max_samples=n_aug_per_dataset, + cache_dir=cache_dir, + download=should_download, + hf_token=hf_token, + seed=seed, + lazy_read=True, + ) + + if skip_missing and aug_iterator.get_total_cached_count() == 0: + continue + + aug_pipeline = PrefetchPipeline( + dataset_iterator=aug_iterator, + target_size=plan.target_size, + batch_size=batch_size, + seed=aug_seed, + augment_level=augment_level, + crop_prob=crop_prob, + robustness_pass=True, + aug_cache_dir=aug_cache_dir, + ) + + aug_batch_id = 0 + try: + for batch_data in aug_pipeline: + aug_batch_id += 1 + batch_images = [item["image"] for item in batch_data] + batch_metadata = [ + ( + item["label"], + item["sample"], + item["sample_index"], + item["dataset_name"], + item["sample_seed"], + ) + for item in batch_data + ] + process_batch( + session, + input_specs, + batch_images, + batch_metadata, + tracker, + aug_batch_id, + aug_pass=True, + ) + finally: + aug_pipeline.close() + + except Exception as e: + logger.error( + f"Robustness pass failed for {dataset_config.name}: {e}" + ) + df = finalize_run( config=run_config, plan=plan, diff --git a/src/gasbench/benchmarks/recording.py b/src/gasbench/benchmarks/recording.py index 6167193..8edb605 100644 --- a/src/gasbench/benchmarks/recording.py +++ b/src/gasbench/benchmarks/recording.py @@ -57,6 +57,7 @@ def add_ok( batch_id: int, batch_size: int, sample_seed: Optional[int], + aug_pass: bool = False, ): # Normalize probabilities to a list of floats for parquet friendliness try: @@ -77,6 +78,7 @@ def add_ok( "target_width": self.target_width, "augment_level": self.augment_level, "crop_prob": self.crop_prob, + "aug_pass": bool(aug_pass), "dataset_name": dataset_name, "iteration_index": int(sample_index), "media_type": sample.get("media_type"), @@ -119,10 +121,12 @@ def add_ok( self.rows.append(row) # Maintain incremental counters so per-dataset logging is O(1). - ds = self._dataset_counts.setdefault(dataset_name, {"ok": 0, "correct": 0, "skipped": 0}) - ds["ok"] += 1 - if bool(predicted == label): - ds["correct"] += 1 + # Aug-pass rows are not counted here — they are reported separately. + if not aug_pass: + ds = self._dataset_counts.setdefault(dataset_name, {"ok": 0, "correct": 0, "skipped": 0}) + ds["ok"] += 1 + if bool(predicted == label): + ds["correct"] += 1 def add_skip( self, @@ -440,7 +444,26 @@ def compute_metrics_from_df( "sn34_score": 0.0, } - ok_df = df[df["status"] == "ok"].copy() + all_ok_df = df[df["status"] == "ok"].copy() + if all_ok_df.empty: + return { + "benchmark_score": 0.0, + "avg_inference_time_ms": 0.0, + "p95_inference_time_ms": 0.0, + "binary_mcc": 0.0, + "binary_cross_entropy": 0.0, + "binary_brier": 0.25, # random baseline + "sn34_score": 0.0, + } + + # Split base pass from augmentation robustness pass + if "aug_pass" in all_ok_df.columns: + ok_df = all_ok_df[~all_ok_df["aug_pass"].fillna(False)].copy() + aug_ok_df = all_ok_df[all_ok_df["aug_pass"].fillna(False)].copy() + else: + ok_df = all_ok_df + aug_ok_df = pd.DataFrame() + if ok_df.empty: return { "benchmark_score": 0.0, @@ -458,6 +481,7 @@ def compute_metrics_from_df( # ALL metrics. Legacy path: holdout_weight scales holdout samples in the # accuracy (benchmark_score) only, preserving historical behavior. composition_fields: Dict[str, Any] = {} + class_weights: Dict[str, float] = {} sample_weights = np.ones(len(ok_df), dtype=float) if score_composition and "dataset_name" in ok_df.columns: provenance = ok_df["dataset_name"].map(classify_sample_provenance) @@ -518,6 +542,7 @@ def compute_metrics_from_df( weight=float(weight), ) + base_sn34 = metrics.compute_sn34_score() result.update( { "benchmark_score": accuracy, @@ -526,17 +551,95 @@ def compute_metrics_from_df( "binary_mcc": metrics.calculate_binary_mcc(), "binary_cross_entropy": metrics.calculate_binary_cross_entropy(), "binary_brier": metrics.calculate_brier(), - "sn34_score": metrics.compute_sn34_score(), + "sn34_score": base_sn34, } ) result.update(composition_fields) + + if not aug_ok_df.empty: + result.update( + _compute_aug_metrics(aug_ok_df, base_sn34, ok_df, class_weights if score_composition else {}) + ) + return result +def _compute_aug_metrics( + aug_df: "pd.DataFrame", + base_sn34: float, + base_df: "pd.DataFrame", + class_weights: Dict[str, float], +) -> Dict[str, Any]: + """Compute augmentation robustness metrics from the aug-pass rows. + + Returns a dict with aug_sn34_score, augmentation_robustness, + aug_weighted_sn34_score, and per-sample degradation stats. + All fields are only present when aug rows exist. + """ + aug_metrics = Metrics() + aug_sample_weights = np.ones(len(aug_df), dtype=float) + if class_weights and "dataset_name" in aug_df.columns: + provenance = aug_df["dataset_name"].map(classify_sample_provenance) + aug_sample_weights = provenance.map(class_weights).fillna(1.0).astype(float).values + + for (_, r), weight in zip(aug_df.iterrows(), aug_sample_weights): + try: + probs = [ + float(x) + for x in ( + r["probs"].tolist() if hasattr(r["probs"], "tolist") else list(r["probs"]) + ) + ] + except Exception: + probs = [] + aug_metrics.update(int(r["label"]), int(r["predicted"]), probs, weight=float(weight)) + + aug_sn34 = aug_metrics.compute_sn34_score() + robustness_ratio = (aug_sn34 / base_sn34) if base_sn34 > 0 else 0.0 + + out: Dict[str, Any] = { + "aug_sn34_score": aug_sn34, + "augmentation_robustness": robustness_ratio, + "aug_total_samples": int(len(aug_df)), + } + + # Per-sample degradation: join augmented rows to their base counterpart via sample_id + if "sample_id" in base_df.columns and "sample_id" in aug_df.columns: + def _prob_correct(r): + try: + probs = list(r["probs"].tolist() if hasattr(r["probs"], "tolist") else r["probs"]) + label = int(r["label"]) + return float(probs[label]) if label < len(probs) else None + except Exception: + return None + + base_pc = base_df[["sample_id", "label", "probs", "correct"]].copy() + base_pc["prob_correct"] = base_pc.apply(_prob_correct, axis=1) + + aug_pc = aug_df[["sample_id", "label", "probs", "correct"]].copy() + aug_pc["prob_correct"] = aug_pc.apply(_prob_correct, axis=1) + + paired = base_pc[["sample_id", "prob_correct", "correct"]].merge( + aug_pc[["sample_id", "prob_correct", "correct"]], + on="sample_id", + suffixes=("_base", "_aug"), + ).dropna(subset=["prob_correct_base", "prob_correct_aug"]) + + if not paired.empty: + paired["prob_degradation"] = paired["prob_correct_base"] - paired["prob_correct_aug"] + out["aug_paired_samples"] = int(len(paired)) + out["aug_mean_prob_degradation"] = float(paired["prob_degradation"].mean()) + out["aug_p95_prob_degradation"] = float(np.percentile(paired["prob_degradation"], 95)) + + return out + + def compute_per_dataset_from_df(df: pd.DataFrame) -> Dict[str, Any]: if df is None or df.empty: return {} ok_df = df[df["status"] == "ok"] + if "aug_pass" in ok_df.columns: + ok_df = ok_df[~ok_df["aug_pass"].fillna(False)] if ok_df.empty: return {} diff --git a/src/gasbench/benchmarks/video_bench.py b/src/gasbench/benchmarks/video_bench.py index 94315cb..a7c886b 100644 --- a/src/gasbench/benchmarks/video_bench.py +++ b/src/gasbench/benchmarks/video_bench.py @@ -13,6 +13,7 @@ from ..processing.media import process_video_bytes_sample, process_video_frames_sample from ..processing.transforms import ( apply_random_augmentations, + apply_video_robustness_augmentations, extract_num_frames_from_input_specs, ) from ..config import DEFAULT_VIDEO_BATCH_SIZE @@ -20,8 +21,28 @@ from ..dataset.iterator import DatasetIterator from .utils.inference import process_model_output -from .recording import BenchmarkRunRecorder, log_dataset_summary +from .recording import BenchmarkRunRecorder, log_dataset_summary, build_sample_id from .common import BenchmarkRunConfig, build_plan, create_tracker, finalize_run + +_VID_AUG_VERSION = "vid_v1" + + +def _aug_cache_path(cache_dir: str, sample_id: str) -> str: + return os.path.join(cache_dir, sample_id[:2], f"{sample_id}_{_VID_AUG_VERSION}.npy") + + +def _write_aug_cache(path: str, array) -> None: + os.makedirs(os.path.dirname(path), exist_ok=True) + tmp = f"{path}.tmp.{os.getpid()}" + try: + import numpy as _np + _np.save(tmp, array) + os.replace(tmp, path) + except Exception: + try: + os.unlink(tmp) + except OSError: + pass import pandas as pd logger = get_logger(__name__) @@ -49,6 +70,8 @@ def __init__( max_queue_size=6, num_frames=16, frame_rate=None, + robustness_pass=False, + aug_cache_dir=None, ): self.dataset_iterator = dataset_iterator self.target_size = target_size @@ -60,6 +83,8 @@ def __init__( self.max_queue_size = max_queue_size self.num_frames = num_frames self.frame_rate = frame_rate + self.robustness_pass = robustness_pass + self.aug_cache_dir = aug_cache_dir self.batch_queue = Queue(maxsize=max_queue_size) self.stop_event = threading.Event() @@ -92,13 +117,31 @@ def _read_and_preprocess(self, sample, sample_index, dataset_name): sample_seed = None if self.seed is None else (self.seed + sample_index) try: - aug_thwc, _, _, _ = apply_random_augmentations( - video_array, - self.target_size, - seed=sample_seed, - level=self.augment_level, - crop_prob=self.crop_prob, - ) + if self.robustness_pass: + if self.aug_cache_dir: + sid = build_sample_id(sample) + cache_path = _aug_cache_path(self.aug_cache_dir, sid) + if os.path.exists(cache_path): + aug_thwc = np.load(cache_path) + else: + aug_thwc, _, _, _ = apply_video_robustness_augmentations( + video_array, self.target_size, seed=sample_seed + ) + _write_aug_cache(cache_path, aug_thwc) + else: + aug_thwc, _, _, _ = apply_video_robustness_augmentations( + video_array, + self.target_size, + seed=sample_seed, + ) + else: + aug_thwc, _, _, _ = apply_random_augmentations( + video_array, + self.target_size, + seed=sample_seed, + level=self.augment_level, + crop_prob=self.crop_prob, + ) except Exception as e: logger.error(f"Video augmentation failed: {e}") return None @@ -207,6 +250,7 @@ def process_video_batch( batch_metadata, tracker: BenchmarkRunRecorder, batch_id: int, + aug_pass: bool = False, ): """push a batch of videos through the model and record rows in tracker.""" if not batch_videos: @@ -251,6 +295,7 @@ def process_video_batch( batch_id=batch_id, batch_size=len(batch_videos), sample_seed=sample_seed, + aug_pass=aug_pass, ) @@ -276,6 +321,9 @@ async def run_video_benchmark( holdouts_only: bool = False, content_category: Optional[str] = None, score_composition: dict = None, + n_aug_per_dataset: int = 0, + aug_weight: float = 0.2, + aug_cache_dir: Optional[str] = None, ) -> pd.DataFrame: """Test model on benchmark video datasets for AI-generated content detection.""" @@ -327,6 +375,8 @@ async def run_video_benchmark( holdouts_only=holdouts_only, content_category=content_category, score_composition=score_composition, + n_aug_per_dataset=n_aug_per_dataset, + aug_weight=aug_weight, ) plan = build_plan(logger, run_config, input_specs) if not plan: @@ -430,6 +480,82 @@ async def run_video_benchmark( f"Dataset error for {dataset_config.name}: {str(e)[:100]}" ) + if n_aug_per_dataset > 0: + aug_seed = seed if seed is not None else 42 + logger.info( + f"Starting video augmentation robustness pass: " + f"{n_aug_per_dataset} samples/dataset" + + (f" (aug_cache_dir={aug_cache_dir})" if aug_cache_dir else "") + ) + for dataset_idx, dataset_config in enumerate(plan.available_datasets): + logger.info( + f"Robustness pass {dataset_idx + 1}/{len(plan.available_datasets)}: " + f"{dataset_config.name}" + ) + try: + is_gasstation = "gasstation" in dataset_config.name.lower() + should_download = ( + download_latest_gasstation_data if is_gasstation else True + ) if not skip_missing else False + + aug_iterator = DatasetIterator( + dataset_config, + max_samples=n_aug_per_dataset, + cache_dir=cache_dir, + download=should_download, + hf_token=hf_token, + seed=seed, + lazy_read=True, + ) + + if skip_missing and aug_iterator.get_total_cached_count() == 0: + continue + + aug_pipeline = VideoPrefetchPipeline( + dataset_iterator=aug_iterator, + target_size=plan.target_size, + batch_size=batch_size, + seed=aug_seed, + augment_level=augment_level, + crop_prob=crop_prob, + num_frames=num_frames, + frame_rate=frame_rate, + robustness_pass=True, + aug_cache_dir=aug_cache_dir, + ) + + aug_batch_id = 0 + try: + for batch_data in aug_pipeline: + aug_batch_id += 1 + batch_videos = [item["video"] for item in batch_data] + batch_metadata = [ + ( + item["label"], + item["sample"], + item["sample_index"], + item["dataset_name"], + item["sample_seed"], + ) + for item in batch_data + ] + process_video_batch( + session, + input_specs, + batch_videos, + batch_metadata, + tracker, + aug_batch_id, + aug_pass=True, + ) + finally: + aug_pipeline.close() + + except Exception as e: + logger.error( + f"Video robustness pass failed for {dataset_config.name}: {e}" + ) + cache_info = video_cache.get_cache_info() df = finalize_run( config=run_config, diff --git a/src/gasbench/cli.py b/src/gasbench/cli.py index aa3940a..1da9b93 100644 --- a/src/gasbench/cli.py +++ b/src/gasbench/cli.py @@ -163,6 +163,9 @@ def command_run(args): score_composition=score_composition, dataset_filters=getattr(args, "datasets", None), content_category=args.content_category, + n_aug_per_dataset=getattr(args, "n_aug_per_dataset", 0), + aug_weight=getattr(args, "aug_weight", 0.2), + aug_cache_dir=getattr(args, "aug_cache_dir", None), ) ) @@ -574,6 +577,34 @@ def main(): metavar="CATEGORY", help="Only run datasets matching a content_category (e.g. faces, documents)", ) + run_parser.add_argument( + "--n-aug-per-dataset", + type=int, + default=0, + metavar="N", + help="Number of samples per dataset to re-evaluate with a fixed robustness " + "augmentation suite (JPEG compression + downscale + blur). When > 0, " + "adds aug_sn34_score, augmentation_robustness, and per-sample degradation " + "stats to the results. Default: 0 (disabled).", + ) + run_parser.add_argument( + "--aug-weight", + type=float, + default=0.2, + metavar="W", + help="Weight of aug_sn34_score in the blended augmentation score " + "(only used when --n-aug-per-dataset > 0). Default: 0.2.", + ) + run_parser.add_argument( + "--aug-cache-dir", + type=str, + default=None, + metavar="PATH", + help="Directory to cache pre-augmented arrays for the robustness pass. " + "On first run, augmented arrays are saved; subsequent runs load from cache " + "instead of re-augmenting. Recommended for repeated bmcore runs. " + "Cache is versioned — a suite change auto-invalidates.", + ) run_parser.set_defaults(func=command_run) diff --git a/src/gasbench/processing/transforms.py b/src/gasbench/processing/transforms.py index 48d5338..f9824ce 100644 --- a/src/gasbench/processing/transforms.py +++ b/src/gasbench/processing/transforms.py @@ -76,6 +76,250 @@ def ensure_mask_3d(mask: np.ndarray) -> np.ndarray: return mask +def apply_robustness_augmentations( + image_array, + target_size, + seed=None, + jpeg_quality=55, + scale_factor=0.5, + webp_quality=75, +): + """Fixed augmentation suite for image augmentation robustness evaluation. + + Simulates the dominant real-world internet distribution pipeline: + 1. Downscale + upscale — thumbnail/CDN resize chain + 2. JPEG roundtrip at jpeg_quality — first platform upload (e.g. WhatsApp ~55) + 3. WebP roundtrip at webp_quality — CDN/platform re-host (Facebook, Google) + 4. Second JPEG roundtrip at 80 — re-share / re-host recompression + + Step 3 exercises the cross-codec re-hosting case: many platforms serve + WebP, whose VP8 intra coding leaves a different artifact family than JPEG + DCT, so a detector that survives repeated JPEG can still collapse on it. + Pass webp_quality=None to skip it and recover the JPEG-only chain. + + Returns the same 4-tuple as apply_random_augmentations for drop-in use + in PrefetchPipeline when robustness_pass=True. Deterministic given seed + so paired base/augmented samples share sample_id for degradation join. + """ + if seed is not None: + np.random.seed(seed) + + img = image_array.copy() + if img.dtype != np.uint8: + img = np.clip(img, 0, 255).astype(np.uint8) + + h, w = img.shape[:2] + + # Downscale then upscale — simulates share/thumbnail pipeline artifacts + if scale_factor < 1.0: + small_h = max(1, int(round(h * scale_factor))) + small_w = max(1, int(round(w * scale_factor))) + img = cv2.resize(img, (small_w, small_h), interpolation=cv2.INTER_AREA) + img = cv2.resize(img, (w, h), interpolation=cv2.INTER_LINEAR) + + # First JPEG pass — heavy platform compression (WhatsApp/Telegram ~q55) + img = compress_image_jpeg_pil(img, quality=jpeg_quality) + + # Cross-codec re-host — CDN/platform WebP transcode (Facebook, Google) + if webp_quality is not None: + img = compress_image_webp_pil(img, quality=webp_quality) + + # Second JPEG pass — lighter re-share recompression (Twitter/Instagram ~q80) + img = compress_image_jpeg_pil(img, quality=80) + + # Resize to model input size (same crop+resize as base pipeline) + tforms = get_base_transforms(target_size, (1.0, 1.0)) + aug_hwc, _ = tforms(img, None, reuse_params=False) + + params = { + "jpeg_quality": jpeg_quality, + "scale_factor": scale_factor, + "webp_quality": webp_quality, + "jpeg_quality_2": 80, + } + return aug_hwc, None, "robustness", params + + +def _decode_video_rgb(tmp_path, num_frames): + """Decode up to num_frames RGB frames from a video file, padding the last + frame if the decoder returns fewer. Returns a (T, H, W, 3) uint8 array or + None on failure.""" + cap = cv2.VideoCapture(tmp_path) + decoded = [] + while len(decoded) < num_frames: + ret, frame = cap.read() + if not ret: + break + decoded.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + if not decoded: + return None + while len(decoded) < num_frames: + decoded.append(decoded[-1]) + return np.stack(decoded[:num_frames], axis=0) + + +def _h264_roundtrip_ffmpeg(video_array, crf, fps): + """Faithful H.264 roundtrip via the ffmpeg CLI using a real ``-crf`` value. + + This is the only path that reproduces the FaceForensics++ CRF protocol + exactly; cv2's VideoWriter quality knob does not map to CRF and is ignored + on many OpenCV builds. Returns the decoded (T, H, W, 3) uint8 array, or + None if ffmpeg is unavailable or the roundtrip fails (caller falls back). + """ + import shutil + import subprocess + import tempfile + import os + + if shutil.which("ffmpeg") is None: + return None + + T, H, W, C = video_array.shape + tmp_path = None + try: + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + tmp_path = f.name + cmd = [ + "ffmpeg", "-y", "-loglevel", "error", + "-f", "rawvideo", "-pix_fmt", "rgb24", + "-s", f"{W}x{H}", "-r", str(int(fps)), "-i", "-", + "-c:v", "libx264", "-crf", str(int(crf)), + "-pix_fmt", "yuv420p", tmp_path, + ] + proc = subprocess.run( + cmd, input=np.ascontiguousarray(video_array).tobytes(), + capture_output=True, + ) + if proc.returncode != 0 or not os.path.exists(tmp_path) or os.path.getsize(tmp_path) == 0: + return None + return _decode_video_rgb(tmp_path, T) + except Exception: + return None + finally: + if tmp_path: + try: + os.unlink(tmp_path) + except Exception: + pass + + +def _h264_roundtrip_cv2(video_array, crf, fps): + """Best-effort H.264 roundtrip via cv2's avc1 writer. CRF cannot be set + directly, so it is approximated through VIDEOWRITER_PROP_QUALITY (a perceptual + 0-100 knob that some builds ignore). Returns (T, H, W, 3) uint8 or None.""" + import tempfile + import os + + T, H, W, C = video_array.shape + # Map CRF [18, 51] → cv2 quality [100, 0] linearly (approximate only) + cv2_quality = max(0, min(100, round((51 - crf) / 33.0 * 100))) + + tmp_path = None + try: + with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as f: + tmp_path = f.name + + fourcc = cv2.VideoWriter_fourcc(*"avc1") + writer = cv2.VideoWriter(tmp_path, fourcc, float(fps), (W, H)) + if not writer.isOpened(): + writer.release() + return None + + writer.set(cv2.VIDEOWRITER_PROP_QUALITY, cv2_quality) + for t in range(T): + writer.write(cv2.cvtColor(video_array[t], cv2.COLOR_RGB2BGR)) + writer.release() + return _decode_video_rgb(tmp_path, T) + except Exception: + return None + finally: + if tmp_path: + try: + os.unlink(tmp_path) + except Exception: + pass + + +def apply_video_robustness_augmentations( + video_array, + target_size, + seed=None, + crf=23, + fps=25, + scale_factor=1.0, +): + """H.264 compression roundtrip for video augmentation robustness evaluation. + + Mirrors the FaceForensics++ c23/c40 evaluation protocol — encode to H.264 + at a given CRF then decode back, simulating platform re-encoding pipelines. + CRF 23 = light (FF++ c23, YouTube-tier), CRF 40 = heavy (FF++ c40, + WhatsApp/Messenger-tier). + + Encoding tries, in order: + 1. ffmpeg CLI with a real ``-crf`` value — faithful FF++ reproduction. + 2. cv2 avc1 writer with an approximate quality mapping — used only if + ffmpeg is not on PATH. + 3. per-frame JPEG at an equivalent severity — last-resort fallback that + still preserves chroma-subsampling artifacts. + + scale_factor < 1.0 first downscales every frame (resolution ladder) before + encoding, mirroring platform transcodes that drop 1080p → 720p → 480p. Left + at 1.0 by default so the CRF-only pass stays faithful to the FF++ protocol; + set it (e.g. 0.5) to additionally exercise resolution degradation. + + Returns the same 4-tuple as apply_random_augmentations for drop-in use in + VideoPrefetchPipeline when robustness_pass=True. + """ + if seed is not None: + np.random.seed(seed) + + if video_array.dtype != np.uint8: + video_array = np.clip(video_array, 0, 255).astype(np.uint8) + + # Resolution ladder — downscale frames before encoding (platform transcode) + if scale_factor < 1.0: + T, H, W, C = video_array.shape + sh = max(2, int(round(H * scale_factor))) + sw = max(2, int(round(W * scale_factor))) + video_array = np.stack( + [cv2.resize(video_array[t], (sw, sh), interpolation=cv2.INTER_AREA) + for t in range(T)], + axis=0, + ) + + T, H, W, C = video_array.shape + + # H.264 with yuv420p (both the ffmpeg -crf and cv2 avc1 encoders) requires + # even width and height; libx264 rejects odd dimensions outright. Trim a + # trailing row/column when needed so encoding succeeds instead of silently + # erroring into the fallback path (which would also mis-set method). At most + # one pixel per axis is dropped, and the frame is resized to target anyway. + even_h, even_w = H - (H % 2), W - (W % 2) + if (even_h, even_w) != (H, W) and even_h >= 2 and even_w >= 2: + video_array = video_array[:, :even_h, :even_w, :] + T, H, W, C = video_array.shape + + method = "ffmpeg_crf" + compressed = _h264_roundtrip_ffmpeg(video_array, crf, fps) + if compressed is None: + method = "cv2_avc1" + compressed = _h264_roundtrip_cv2(video_array, crf, fps) + if compressed is None: + # Fallback: per-frame JPEG at quality approximating the requested CRF severity + method = "jpeg_fallback" + fallback_q = max(20, min(95, round(100 - (crf - 18) * 2.3))) + compressed = compress_video_frames_jpeg_torchvision(video_array, quality=fallback_q) + + # Resize to target via base transforms (no random crop/flip) + aug_thwc, _, _, _ = apply_random_augmentations( + compressed, target_size, seed=seed, level=0, crop_prob=0.0 + ) + + params = {"crf": crf, "fps": fps, "scale_factor": scale_factor, "method": method} + return aug_thwc, None, "robustness_video", params + + def apply_random_augmentations( inputs, target_size, @@ -709,12 +953,51 @@ def compress_image_jpeg_pil(image_hwc: np.ndarray, quality: int = 75) -> np.ndar pil_img = Image.fromarray(image_hwc, mode="RGB") buffer = BytesIO() - pil_img.save(buffer, format="JPEG", quality=int(quality)) + # subsampling=2 forces 4:2:0 chroma subsampling regardless of quality/Pillow + # version. This is the operation the social-media compression literature + # identifies as destroying high-frequency DCT fingerprints, so we pin it + # rather than letting Pillow pick subsampling per quality level. + pil_img.save(buffer, format="JPEG", quality=int(quality), subsampling=2) buffer.seek(0) decoded_pil = Image.open(buffer).convert("RGB") return np.array(decoded_pil) +def compress_image_webp_pil(image_hwc: np.ndarray, quality: int = 75) -> np.ndarray: + """ + Compress a single image using a PIL WebP (lossy) round-trip at fixed quality. + + Facebook, Google, and many CDNs re-host uploads as WebP, whose VP8 intra + coding leaves a different artifact family than JPEG's DCT blocks. Including + a WebP pass alongside the JPEG passes exercises detectors against the + cross-codec re-hosting that real distribution chains produce. + + Args: + image_hwc: numpy array (H, W, C), dtype uint8, RGB + quality: WebP quality (default 75) + + Returns: + numpy array (H, W, C), dtype uint8, RGB + """ + if image_hwc is None: + return image_hwc + if image_hwc.dtype != np.uint8: + image_hwc = np.clip(image_hwc, 0, 255).astype(np.uint8) + if image_hwc.ndim != 3 or image_hwc.shape[2] != 3: + return image_hwc + + try: + pil_img = Image.fromarray(image_hwc, mode="RGB") + buffer = BytesIO() + pil_img.save(buffer, format="WEBP", quality=int(quality), method=4) + buffer.seek(0) + decoded_pil = Image.open(buffer).convert("RGB") + return np.array(decoded_pil) + except Exception: + # WebP support is missing in some Pillow builds; fall back to original. + return image_hwc + + def compress_video_frames_jpeg_torchvision(video_thwc: np.ndarray, quality: int = 75) -> np.ndarray: """ Compress each frame of a video using torchvision's encode_jpeg/decode_jpeg at fixed quality.