From 0a614fd3c567331c3930eb588ab7d9bbcacf2f0a Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 17:43:43 +0000 Subject: [PATCH 1/3] feat(extraction): add TrapPruner, MissingnessRecognizer, TargetLeakageAuditor - TrapPruner: statistical profiling + LLM verification of Gaussian noise columns - MissingnessRecognizer: MNAR pattern detection with sandbox-compiled encoders - TargetLeakageAuditor: semantic timeline evaluation for target leakage --- src/loclean/extraction/leakage_auditor.py | 240 ++++++++++++ .../extraction/missingness_recognizer.py | 307 ++++++++++++++++ src/loclean/extraction/trap_pruner.py | 343 ++++++++++++++++++ 3 files changed, 890 insertions(+) create mode 100644 src/loclean/extraction/leakage_auditor.py create mode 100644 src/loclean/extraction/missingness_recognizer.py create mode 100644 src/loclean/extraction/trap_pruner.py diff --git a/src/loclean/extraction/leakage_auditor.py b/src/loclean/extraction/leakage_auditor.py new file mode 100644 index 0000000..67bdb3d --- /dev/null +++ b/src/loclean/extraction/leakage_auditor.py @@ -0,0 +1,240 @@ +"""Semantic target leakage detection via LLM-driven timeline evaluation. + +Identifies features that mathematically or logically imply the target +variable — i.e. columns containing information generated *after* the +target event occurs. In a deterministic leakage scenario: + + P(Y | X_i) ≈ 1 + +The generative engine acts as a semantic auditor, catching logical +leakage that basic statistical tests miss by evaluating the causal +timeline of each feature relative to the target outcome. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +import narwhals as nw + +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class TargetLeakageAuditor: + """Detect and remove features that leak the target variable. + + For each feature column the auditor prompts the LLM with the + dataset domain description and a representative sample, asking + it to evaluate whether the feature could only be known *after* + the target outcome is determined. + + Args: + inference_engine: Ollama (or compatible) engine. + max_retries: LLM generation retry budget. + sample_n: Number of sample rows to include in the prompt. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + *, + max_retries: int = 2, + sample_n: int = 10, + ) -> None: + self.inference_engine = inference_engine + self.max_retries = max_retries + self.sample_n = sample_n + + def audit( + self, + df: IntoFrameT, + target_col: str, + domain: str = "", + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Audit features for target leakage and drop offenders. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + domain: Brief text description of the dataset domain + (e.g. ``"hospital readmission prediction"``). + + Returns: + Tuple of ``(pruned_df, summary)`` where *summary* + contains ``dropped_columns`` and per-column ``verdicts``. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + + if target_col not in df_nw.columns: + raise ValueError(f"Target column '{target_col}' not found") + + feature_cols = [c for c in df_nw.columns if c != target_col] + if not feature_cols: + logger.info("No feature columns to audit.") + return df, {"dropped_columns": [], "verdicts": []} + + state = self._extract_state(df_nw, target_col, feature_cols) + prompt = self._build_prompt(state, domain) + verdicts = self._evaluate_with_llm(prompt) + + leaked = [v["column"] for v in verdicts if v.get("is_leakage")] + valid_leaked = [c for c in leaked if c in feature_cols] + + summary: dict[str, Any] = { + "dropped_columns": valid_leaked, + "verdicts": verdicts, + } + + if valid_leaked: + logger.info( + "Dropping %d leaked feature(s): %s", + len(valid_leaked), + valid_leaked, + ) + try: + pruned_nw = df_nw.drop(valid_leaked) + return nw.to_native(pruned_nw), summary # type: ignore[type-var] + except Exception as exc: + logger.warning("Failed to drop columns: %s", exc) + return df, summary + + logger.info("No target leakage detected.") + return df, summary + + # ------------------------------------------------------------------ + # State extraction + # ------------------------------------------------------------------ + + @staticmethod + def _extract_state( + df_nw: nw.DataFrame[Any], + target_col: str, + feature_cols: list[str], + sample_n: int = 10, + ) -> dict[str, Any]: + """Build structural metadata for the LLM prompt. + + Args: + df_nw: Narwhals DataFrame. + target_col: Target variable column. + feature_cols: Feature column names. + sample_n: Number of sample rows. + + Returns: + Dict with ``target_col``, ``features``, ``dtypes``, + and ``sample_rows``. + """ + n = min(df_nw.shape[0], sample_n) + sampled = df_nw.head(n) + sample_rows = sampled.rows(named=True) + + dtypes = {col: str(df_nw[col].dtype) for col in feature_cols} + + return { + "target_col": target_col, + "features": feature_cols, + "dtypes": dtypes, + "sample_rows": sample_rows, # type: ignore[dict-item] + } + + # ------------------------------------------------------------------ + # Prompt construction + # ------------------------------------------------------------------ + + @staticmethod + def _build_prompt( + state: dict[str, Any], + domain: str, + ) -> str: + """Build the LLM prompt for timeline evaluation.""" + target = state["target_col"] + features = state["features"] + dtypes = state["dtypes"] + sample_str = json.dumps(state["sample_rows"][:10], indent=2, default=str) + + domain_line = f"Dataset domain: {domain}\n" if domain else "" + + feature_info = "\n".join( + f" - {f} (dtype: {dtypes.get(f, 'unknown')})" for f in features + ) + + return ( + "You are a machine learning auditor specialising in data " + "leakage detection.\n\n" + f"{domain_line}" + f"Target variable: '{target}'\n\n" + f"Feature columns:\n{feature_info}\n\n" + f"Sample rows:\n{sample_str}\n\n" + "Task: For each feature column, evaluate whether it could " + "constitute **target leakage** — meaning the feature contains " + "information that would only be available AFTER the target " + "outcome is determined.\n\n" + "Consider:\n" + "- Temporal ordering: was this feature generated before or " + "after the target event?\n" + "- Semantic meaning: does the feature directly encode or " + "trivially derive from the target?\n" + "- Statistical signal: extremely high correlation may " + "indicate leakage, not just a good predictor.\n\n" + "Output ONLY a JSON array. For each feature, output an " + "object with exactly three keys:\n" + '- "column": the feature name\n' + '- "is_leakage": boolean\n' + '- "reason": brief explanation\n\n' + "Output ONLY the JSON array, no other text." + ) + + # ------------------------------------------------------------------ + # LLM evaluation + # ------------------------------------------------------------------ + + def _evaluate_with_llm( + self, + prompt: str, + ) -> list[dict[str, Any]]: + """Send the prompt and parse the leakage verdicts.""" + for attempt in range(1, self.max_retries + 1): + try: + raw = str(self.inference_engine.generate(prompt)).strip() + return self._parse_verdict(raw) + except (json.JSONDecodeError, ValueError, KeyError) as exc: + logger.warning( + "LLM verdict parsing failed (attempt %d/%d): %s", + attempt, + self.max_retries, + exc, + ) + logger.warning("Could not parse LLM verdicts — keeping all columns.") + return [] + + @staticmethod + def _parse_verdict(response: str) -> list[dict[str, Any]]: + """Parse the JSON verdict from the LLM response.""" + text = response.strip() + start = text.find("[") + end = text.rfind("]") + if start == -1 or end == -1: + raise ValueError("No JSON array found in LLM response") + + items: list[dict[str, Any]] = json.loads(text[start : end + 1]) + verdicts: list[dict[str, Any]] = [] + + for item in items: + verdicts.append( + { + "column": str(item["column"]), + "is_leakage": bool(item.get("is_leakage", False)), + "reason": str(item.get("reason", "")), + } + ) + + return verdicts diff --git a/src/loclean/extraction/missingness_recognizer.py b/src/loclean/extraction/missingness_recognizer.py new file mode 100644 index 0000000..e6ad695 --- /dev/null +++ b/src/loclean/extraction/missingness_recognizer.py @@ -0,0 +1,307 @@ +"""Missingness pattern recognition via LLM-driven MNAR detection. + +Identifies Missing Not At Random (MNAR) patterns where the probability +of a value being missing in feature X depends on the value of feature Y: + + P(X_missing | Y) ≠ P(X_missing) + +Uses Narwhals for backend-agnostic null analysis and an InferenceEngine +to infer structural correlations from data samples. Detected patterns +are encoded as new boolean feature columns. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any, Callable + +import narwhals as nw + +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + +_SUFFIX = "_mnar" + + +class MissingnessRecognizer: + """Detect MNAR patterns and encode them as boolean feature flags. + + For each column containing nulls the recognizer: + + 1. Samples rows where the column is null alongside other features. + 2. Prompts the LLM to identify structural correlations. + 3. Compiles the LLM-generated ``encode_missingness`` function in a + sandbox. + 4. Applies the function across the DataFrame to create a boolean + ``{col}_mnar`` column. + + Args: + inference_engine: Ollama (or compatible) engine. + sample_size: Maximum null rows to sample per column. + max_retries: LLM code-generation retry budget. + timeout_s: Per-row execution timeout in seconds. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + *, + sample_size: int = 50, + max_retries: int = 3, + timeout_s: float = 2.0, + ) -> None: + self.inference_engine = inference_engine + self.sample_size = sample_size + self.max_retries = max_retries + self.timeout_s = timeout_s + + def recognize( + self, + df: IntoFrameT, + target_cols: list[str] | None = None, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect MNAR patterns and add boolean feature columns. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_cols: Columns to analyse for missingness. If + ``None``, all columns containing nulls are evaluated. + + Returns: + Tuple of ``(augmented_df, summary)`` where *summary* + maps each analysed column to its pattern description + or ``None`` if no pattern was found. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + + null_cols = self._find_null_columns(df_nw) + + if target_cols is not None: + null_cols = [c for c in target_cols if c in null_cols] + + if not null_cols: + logger.info("No columns with null values to analyse.") + return df, {"patterns": {}} + + all_cols = df_nw.columns + patterns: dict[str, Any] = {} + new_columns: dict[str, list[bool]] = {} + + for col in null_cols: + context_cols = [c for c in all_cols if c != col] + sample = self._sample_null_context(df_nw, col, context_cols) + + if not sample: + patterns[col] = None + continue + + prompt = self._build_prompt(col, context_cols, sample) + fn = self._generate_and_compile(prompt) + + if fn is None: + patterns[col] = None + continue + + ok, error = self._verify_encoder(fn, sample) + if not ok: + logger.warning("Encoder for '%s' failed verification: %s", col, error) + patterns[col] = None + continue + + flags = self._apply_encoder(df_nw, fn) + col_name = f"{col}{_SUFFIX}" + new_columns[col_name] = flags + patterns[col] = { + "encoded_as": col_name, + "null_count": sum(1 for v in df_nw[col].to_list() if v is None), + "pattern_flags_true": sum(flags), + } + logger.info( + "Encoded MNAR pattern for '%s' → '%s' (%d flagged)", + col, + col_name, + sum(flags), + ) + + if new_columns: + native_ns = nw.get_native_namespace(df_nw) + rows_data: dict[str, list[Any]] = { + c: df_nw[c].to_list() for c in df_nw.columns + } + rows_data.update(new_columns) + result_nw = nw.from_dict(rows_data, backend=native_ns) + return nw.to_native(result_nw), {"patterns": patterns} + + return df, {"patterns": patterns} + + # ------------------------------------------------------------------ + # Null detection + # ------------------------------------------------------------------ + + @staticmethod + def _find_null_columns(df_nw: nw.DataFrame[Any]) -> list[str]: + """Return column names that contain at least one null.""" + return [col for col in df_nw.columns if df_nw[col].null_count() > 0] + + # ------------------------------------------------------------------ + # Sampling + # ------------------------------------------------------------------ + + @staticmethod + def _sample_null_context( + df_nw: nw.DataFrame[Any], + null_col: str, + context_cols: list[str], + max_rows: int = 50, + ) -> list[dict[str, Any]]: + """Extract rows where *null_col* is null with context values. + + Returns a list of dicts, each containing the context column + values for a row where *null_col* is missing. + """ + null_mask = df_nw[null_col].is_null() + null_rows = df_nw.filter(null_mask) + + if null_rows.shape[0] == 0: + return [] + + n = min(null_rows.shape[0], max_rows) + sampled = null_rows.head(n) + + select_cols = [c for c in context_cols if c in sampled.columns] + if not select_cols: + return [] + + return sampled.select(select_cols).rows(named=True) # type: ignore[return-value] + + # ------------------------------------------------------------------ + # Prompt construction + # ------------------------------------------------------------------ + + @staticmethod + def _build_prompt( + null_col: str, + context_cols: list[str], + sample_rows: list[dict[str, Any]], + ) -> str: + """Build the LLM prompt for pattern inference.""" + sample_str = json.dumps(sample_rows[:20], indent=2, default=str) + + return ( + "You are a data scientist analysing missing data patterns.\n\n" + f"Column '{null_col}' has missing values. Below are sample rows " + "where this column IS NULL, showing the values of the other " + "columns:\n\n" + f"Context columns: {context_cols}\n\n" + f"Sample rows (where '{null_col}' is null):\n{sample_str}\n\n" + "Task: Identify if there is a structural pattern that predicts " + f"when '{null_col}' is missing based on other column values.\n\n" + "Write a pure Python function with this exact signature:\n\n" + "def encode_missingness(row: dict) -> bool:\n" + " ...\n\n" + "The function receives a dict of ALL column values for a row " + "(including the target column) and returns True if the " + "missingness pattern is detected.\n\n" + "Rules:\n" + "- Use ONLY standard library modules (math, statistics, operator)\n" + "- Wrap logic in try/except returning False on failure\n" + "- Return a single boolean value\n" + "- Do NOT use markdown fences, comments, or prose\n" + "- Output ONLY the function code, nothing else\n\n" + "Example:\n" + "def encode_missingness(row: dict) -> bool:\n" + " try:\n" + " return row.get('category') == 'electronics' " + "and row.get('price', 0) > 500\n" + " except Exception:\n" + " return False\n" + ) + + # ------------------------------------------------------------------ + # Code generation + compilation + # ------------------------------------------------------------------ + + def _generate_and_compile( + self, + prompt: str, + ) -> Callable[[dict[str, Any]], bool] | None: + """Generate, sanitize, and compile the encoder function.""" + import re + + from loclean.utils.sandbox import compile_sandboxed + + for attempt in range(1, self.max_retries + 1): + try: + raw = str(self.inference_engine.generate(prompt)).strip() + source = re.sub(r"```(?:python)?\s*\n?", "", raw).strip() + fn = compile_sandboxed( + source, + "encode_missingness", + ["math", "statistics", "operator"], + ) + return fn # type: ignore[return-value] + except (ValueError, SyntaxError) as exc: + logger.warning( + "⚠ Code generation failed (attempt %d/%d): %s", + attempt, + self.max_retries, + exc, + ) + + logger.warning( + "Could not compile encoder after %d retries.", + self.max_retries, + ) + return None + + # ------------------------------------------------------------------ + # Verification + # ------------------------------------------------------------------ + + @staticmethod + def _verify_encoder( + fn: Callable[[dict[str, Any]], bool], + sample_rows: list[dict[str, Any]], + ) -> tuple[bool, str]: + """Test the encoder on sample rows.""" + from loclean.utils.sandbox import run_with_timeout + + for row in sample_rows[:5]: + result, error = run_with_timeout(fn, (row,), 2.0) + if error: + return False, f"Execution error: {error}" + if not isinstance(result, bool): + return False, f"Expected bool, got {type(result).__name__}" + + return True, "" + + # ------------------------------------------------------------------ + # Application + # ------------------------------------------------------------------ + + def _apply_encoder( + self, + df_nw: nw.DataFrame[Any], + fn: Callable[[dict[str, Any]], bool], + ) -> list[bool]: + """Apply the encoder across all rows.""" + from loclean.utils.sandbox import run_with_timeout + + rows: list[dict[str, Any]] = df_nw.rows(named=True) # type: ignore[assignment] + flags: list[bool] = [] + + for row in rows: + result, error = run_with_timeout(fn, (row,), self.timeout_s) + if error or not isinstance(result, bool): + flags.append(False) + else: + flags.append(result) + + return flags diff --git a/src/loclean/extraction/trap_pruner.py b/src/loclean/extraction/trap_pruner.py new file mode 100644 index 0000000..376b933 --- /dev/null +++ b/src/loclean/extraction/trap_pruner.py @@ -0,0 +1,343 @@ +"""Automated trap feature pruning via statistical profiling and LLM verification. + +Identifies columns that look like valid signals but are actually +uncorrelated Gaussian noise (trap features). Uses Narwhals for +backend-agnostic statistical profiling and an ``InferenceEngine`` +for generative verification. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +import narwhals as nw + +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class _ColumnProfile: + """Statistical profile for a single numeric column.""" + + __slots__ = ( + "name", + "mean", + "std", + "variance", + "skewness", + "kurtosis", + "min_val", + "max_val", + "corr_with_target", + ) + + def __init__( + self, + name: str, + mean: float, + std: float, + variance: float, + skewness: float, + kurtosis: float, + min_val: float, + max_val: float, + corr_with_target: float, + ) -> None: + self.name = name + self.mean = mean + self.std = std + self.variance = variance + self.skewness = skewness + self.kurtosis = kurtosis + self.min_val = min_val + self.max_val = max_val + self.corr_with_target = corr_with_target + + def to_dict(self) -> dict[str, Any]: + """Serialise profile to a plain dictionary.""" + return { + "name": self.name, + "mean": self.mean, + "std": self.std, + "variance": self.variance, + "skewness": self.skewness, + "kurtosis": self.kurtosis, + "min": self.min_val, + "max": self.max_val, + "corr_with_target": self.corr_with_target, + } + + +class TrapPruner: + """Identify and remove trap features from a DataFrame. + + Trap features are columns of uncorrelated Gaussian noise that + masquerade as valid signals. Detection relies entirely on + statistical distributions and target correlations — column names + are deliberately ignored. + + Args: + inference_engine: Ollama (or compatible) engine for verification. + correlation_threshold: Absolute correlation below which a + column is considered uncorrelated. Default ``0.05``. + max_retries: LLM generation retry budget. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + *, + correlation_threshold: float = 0.05, + max_retries: int = 2, + ) -> None: + self.inference_engine = inference_engine + self.correlation_threshold = correlation_threshold + self.max_retries = max_retries + + def prune( + self, + df: IntoFrameT, + target_col: str, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Profile, verify, and drop trap features. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + + Returns: + Tuple of ``(pruned_df, summary)`` where *summary* contains + ``dropped_columns`` (list of removed names) and + ``verdicts`` (per-column LLM reasoning). + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + + if target_col not in df_nw.columns: + raise ValueError(f"Target column '{target_col}' not found") + + numeric_cols = [ + c for c in df_nw.columns if c != target_col and df_nw[c].dtype.is_numeric() + ] + + if not numeric_cols: + logger.info("No numeric feature columns to evaluate.") + return df, {"dropped_columns": [], "verdicts": []} + + profiles = self._profile_columns(df_nw, target_col, numeric_cols) + + col_map, prompt = self._build_prompt(profiles) + + verdicts = self._verify_with_llm(prompt, col_map) + + trap_cols = [v["column"] for v in verdicts if v.get("is_trap")] + + summary: dict[str, Any] = { + "dropped_columns": trap_cols, + "verdicts": verdicts, + } + + if trap_cols: + logger.info( + "Dropping %d trap feature(s): %s", + len(trap_cols), + trap_cols, + ) + pruned_nw = df_nw.drop(trap_cols) + return nw.to_native(pruned_nw), summary # type: ignore[type-var] + + logger.info("No trap features detected.") + return df, summary + + # ------------------------------------------------------------------ + # Statistical profiling + # ------------------------------------------------------------------ + + @staticmethod + def _profile_columns( + df_nw: nw.DataFrame[Any], + target_col: str, + numeric_cols: list[str], + ) -> list[_ColumnProfile]: + """Compute distribution statistics for each numeric column. + + All operations use the Narwhals interface. Division-by-zero + and other math errors are caught per-column. + """ + n = df_nw.shape[0] + if n < 2: + return [] + + target_series = df_nw[target_col].cast(nw.Float64) + target_mean = target_series.mean() + target_std = target_series.std() + + profiles: list[_ColumnProfile] = [] + + for col in numeric_cols: + try: + series = df_nw[col].cast(nw.Float64) + col_mean = series.mean() + col_std = series.std() + + diffs = series - col_mean + variance = (diffs * diffs).mean() + + if col_std and col_std > 0 and target_std and target_std > 0: + corr = float( + ((series - col_mean) * (target_series - target_mean)).mean() + / (col_std * target_std) + ) + else: + corr = 0.0 + + if col_std and col_std > 0: + skewness = float((diffs**3).mean() / (col_std**3)) + kurtosis = float((diffs**4).mean() / (col_std**4)) - 3.0 + else: + skewness = 0.0 + kurtosis = 0.0 + + profiles.append( + _ColumnProfile( + name=col, + mean=float(col_mean) if col_mean is not None else 0.0, + std=float(col_std) if col_std is not None else 0.0, + variance=float(variance) if variance is not None else 0.0, + skewness=skewness, + kurtosis=kurtosis, + min_val=float(series.min()), + max_val=float(series.max()), + corr_with_target=corr, + ) + ) + except (ZeroDivisionError, ValueError, OverflowError): + profiles.append( + _ColumnProfile( + name=col, + mean=0.0, + std=0.0, + variance=0.0, + skewness=0.0, + kurtosis=0.0, + min_val=0.0, + max_val=0.0, + corr_with_target=0.0, + ) + ) + + return profiles + + # ------------------------------------------------------------------ + # Prompt construction (anonymised) + # ------------------------------------------------------------------ + + @staticmethod + def _build_prompt( + profiles: list[_ColumnProfile], + ) -> tuple[dict[str, str], str]: + """Build the LLM verification prompt with anonymised column IDs. + + Returns: + Tuple of ``(col_map, prompt_text)`` where *col_map* maps + ``"col_0"`` → real column name. + """ + col_map: dict[str, str] = {} + lines: list[str] = [] + + for i, p in enumerate(profiles): + anon = f"col_{i}" + col_map[anon] = p.name + + lines.append( + f"Column {anon}: " + f"mean={p.mean:.4f}, std={p.std:.4f}, " + f"variance={p.variance:.4f}, " + f"skewness={p.skewness:.4f}, kurtosis={p.kurtosis:.4f}, " + f"min={p.min_val:.4f}, max={p.max_val:.4f}, " + f"corr_with_target={p.corr_with_target:.4f}" + ) + + profile_block = "\n".join(lines) + + prompt = ( + "You are a statistical analyst. Below are the statistical profiles " + "of several numeric columns from a dataset. Each column is " + "identified only by an anonymous ID (column names are hidden).\n\n" + f"{profile_block}\n\n" + "A **trap feature** is a column that:\n" + "1. Exhibits a distribution close to standard Gaussian " + "(skewness ≈ 0, kurtosis ≈ 0, i.e. excess kurtosis near zero).\n" + "2. Has a correlation with the target variable very close to " + "zero (|corr| < 0.05).\n\n" + "Analyse each column and output ONLY a JSON array. " + "For each column output an object with exactly three keys:\n" + '- "column": the anonymous ID (e.g. "col_0")\n' + '- "is_trap": boolean\n' + '- "reason": brief explanation\n\n' + "Output ONLY the JSON array, no other text." + ) + + return col_map, prompt + + # ------------------------------------------------------------------ + # LLM verification + # ------------------------------------------------------------------ + + def _verify_with_llm( + self, + prompt: str, + col_map: dict[str, str], + ) -> list[dict[str, Any]]: + """Send the prompt to the LLM and parse the verdict.""" + for attempt in range(1, self.max_retries + 1): + try: + raw = self.inference_engine.generate(prompt) + return self._parse_verdict(str(raw).strip(), col_map) + except (json.JSONDecodeError, ValueError, KeyError) as exc: + logger.warning( + "LLM verdict parsing failed (attempt %d/%d): %s", + attempt, + self.max_retries, + exc, + ) + logger.warning("Could not parse LLM verdicts — keeping all columns.") + return [ + {"column": real, "is_trap": False, "reason": "LLM parse failure"} + for real in col_map.values() + ] + + @staticmethod + def _parse_verdict( + response: str, + col_map: dict[str, str], + ) -> list[dict[str, Any]]: + """Parse the JSON verdict and map anonymous IDs back to real names.""" + text = response.strip() + start = text.find("[") + end = text.rfind("]") + if start == -1 or end == -1: + raise ValueError("No JSON array found in LLM response") + + items: list[dict[str, Any]] = json.loads(text[start : end + 1]) + verdicts: list[dict[str, Any]] = [] + + for item in items: + anon_id = item["column"] + real_name = col_map.get(anon_id, anon_id) + verdicts.append( + { + "column": real_name, + "is_trap": bool(item.get("is_trap", False)), + "reason": str(item.get("reason", "")), + } + ) + + return verdicts From f5bd6f9a459b52640cfef62320af6dce68db5f30 Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 17:43:53 +0000 Subject: [PATCH 2/3] feat(api): wire prune_traps, recognize_missingness, audit_leakage into public API - Add all three to extraction/__init__.py lazy imports - Add Loclean class methods + module-level convenience functions - Update __all__ in loclean/__init__.py --- src/loclean/__init__.py | 223 +++++++++++++++++++++++++++++ src/loclean/extraction/__init__.py | 9 ++ 2 files changed, 232 insertions(+) diff --git a/src/loclean/__init__.py b/src/loclean/__init__.py index 96d2222..e7d0515 100644 --- a/src/loclean/__init__.py +++ b/src/loclean/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "__version__", "Loclean", + "audit_leakage", "clean", "discover_features", "extract", @@ -17,6 +18,8 @@ "get_engine", "optimize_instruction", "oversample", + "prune_traps", + "recognize_missingness", "resolve_entities", "scrub", "shred_to_relations", @@ -321,6 +324,96 @@ def discover_features( ) return discoverer.discover(df, target_col) + def prune_traps( + self, + df: IntoFrameT, + target_col: str, + *, + correlation_threshold: float = 0.05, + max_retries: int = 2, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Identify and remove trap features. + + Trap features are columns of uncorrelated Gaussian noise + that masquerade as valid signals. + + Args: + df: Input DataFrame. + target_col: Target variable column. + correlation_threshold: Absolute correlation below which + a column is considered uncorrelated. + max_retries: LLM retry budget. + + Returns: + Tuple of (pruned DataFrame, summary dict). + """ + from loclean.extraction.trap_pruner import TrapPruner + + pruner = TrapPruner( + inference_engine=self.engine, + correlation_threshold=correlation_threshold, + max_retries=max_retries, + ) + return pruner.prune(df, target_col) + + def recognize_missingness( + self, + df: IntoFrameT, + target_cols: list[str] | None = None, + *, + sample_size: int = 50, + max_retries: int = 3, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect MNAR patterns and encode as boolean features. + + Args: + df: Input DataFrame. + target_cols: Columns to analyse (default: all with nulls). + sample_size: Max null rows to sample per column. + max_retries: LLM retry budget. + + Returns: + Tuple of (augmented DataFrame, summary dict). + """ + from loclean.extraction.missingness_recognizer import MissingnessRecognizer + + recognizer = MissingnessRecognizer( + inference_engine=self.engine, + sample_size=sample_size, + max_retries=max_retries, + ) + return recognizer.recognize(df, target_cols) + + def audit_leakage( + self, + df: IntoFrameT, + target_col: str, + domain: str = "", + *, + max_retries: int = 2, + sample_n: int = 10, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect and remove target-leaking features. + + Args: + df: Input DataFrame. + target_col: Target variable column. + domain: Dataset domain description. + max_retries: LLM retry budget. + sample_n: Sample rows for the prompt. + + Returns: + Tuple of (pruned DataFrame, summary dict). + """ + from loclean.extraction.leakage_auditor import TargetLeakageAuditor + + auditor = TargetLeakageAuditor( + inference_engine=self.engine, + max_retries=max_retries, + sample_n=sample_n, + ) + return auditor.audit(df, target_col, domain) + def validate_quality( self, df: IntoFrameT, @@ -853,3 +946,133 @@ def discover_features( max_retries=max_retries, ) return discoverer.discover(df, target_col) + + +def prune_traps( + df: IntoFrameT, + target_col: str, + *, + correlation_threshold: float = 0.05, + max_retries: int = 2, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> tuple[IntoFrameT, dict[str, Any]]: + """Identify and remove trap features from a DataFrame. + + Trap features are columns of uncorrelated Gaussian noise that + masquerade as valid signals. Detection relies on statistical + distributions and target correlations — column names are ignored. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + correlation_threshold: Absolute correlation threshold. + max_retries: LLM retry budget. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Tuple of ``(pruned_df, summary)`` where *summary* contains + ``dropped_columns`` and ``verdicts``. + """ + from loclean.extraction.trap_pruner import TrapPruner + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + pruner = TrapPruner( + inference_engine=inference_engine, + correlation_threshold=correlation_threshold, + max_retries=max_retries, + ) + return pruner.prune(df, target_col) + + +def recognize_missingness( + df: IntoFrameT, + target_cols: list[str] | None = None, + *, + sample_size: int = 50, + max_retries: int = 3, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect MNAR patterns and encode as boolean feature flags. + + Identifies Missing Not At Random patterns where the probability + of a value being missing depends on other feature values. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_cols: Columns to analyse (default: all with nulls). + sample_size: Max null rows to sample per column. + max_retries: LLM retry budget. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Tuple of ``(augmented_df, summary)`` where *summary* maps + each analysed column to its pattern description. + """ + from loclean.extraction.missingness_recognizer import MissingnessRecognizer + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + recognizer = MissingnessRecognizer( + inference_engine=inference_engine, + sample_size=sample_size, + max_retries=max_retries, + ) + return recognizer.recognize(df, target_cols) + + +def audit_leakage( + df: IntoFrameT, + target_col: str, + domain: str = "", + *, + max_retries: int = 2, + sample_n: int = 10, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect and remove target-leaking features. + + Identifies features that contain information generated after the + target event, where P(Y | X_i) ≈ 1. Uses semantic timeline + evaluation via the LLM. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + domain: Brief dataset domain description. + max_retries: LLM retry budget. + sample_n: Sample rows for the prompt. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Tuple of ``(pruned_df, summary)`` with ``dropped_columns`` + and ``verdicts``. + """ + from loclean.extraction.leakage_auditor import TargetLeakageAuditor + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + auditor = TargetLeakageAuditor( + inference_engine=inference_engine, + max_retries=max_retries, + sample_n=sample_n, + ) + return auditor.audit(df, target_col, domain) diff --git a/src/loclean/extraction/__init__.py b/src/loclean/extraction/__init__.py index 64e9d66..460b0b8 100644 --- a/src/loclean/extraction/__init__.py +++ b/src/loclean/extraction/__init__.py @@ -9,18 +9,24 @@ if TYPE_CHECKING: from .feature_discovery import FeatureDiscovery + from .leakage_auditor import TargetLeakageAuditor + from .missingness_recognizer import MissingnessRecognizer from .optimizer import InstructionOptimizer from .oversampler import SemanticOversampler from .resolver import EntityResolver from .shredder import RelationalShredder + from .trap_pruner import TrapPruner __all__ = [ "EntityResolver", "Extractor", "FeatureDiscovery", "InstructionOptimizer", + "MissingnessRecognizer", "RelationalShredder", "SemanticOversampler", + "TargetLeakageAuditor", + "TrapPruner", "extract_dataframe_compiled", ] @@ -28,8 +34,11 @@ "EntityResolver": ".resolver", "FeatureDiscovery": ".feature_discovery", "InstructionOptimizer": ".optimizer", + "MissingnessRecognizer": ".missingness_recognizer", "RelationalShredder": ".shredder", "SemanticOversampler": ".oversampler", + "TargetLeakageAuditor": ".leakage_auditor", + "TrapPruner": ".trap_pruner", } From 67c5e494eb058d6306e53e2e251ce42b7c56c03d Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 17:44:03 +0000 Subject: [PATCH 3/3] test(extraction): add unit tests for TrapPruner, MissingnessRecognizer, TargetLeakageAuditor - 13 tests each (39 total) covering profiling, prompt construction, verdict parsing, verification, and mock-LLM integration --- tests/unit/extraction/test_leakage_auditor.py | 228 ++++++++++++++++ .../extraction/test_missingness_recognizer.py | 195 ++++++++++++++ tests/unit/extraction/test_trap_pruner.py | 247 ++++++++++++++++++ 3 files changed, 670 insertions(+) create mode 100644 tests/unit/extraction/test_leakage_auditor.py create mode 100644 tests/unit/extraction/test_missingness_recognizer.py create mode 100644 tests/unit/extraction/test_trap_pruner.py diff --git a/tests/unit/extraction/test_leakage_auditor.py b/tests/unit/extraction/test_leakage_auditor.py new file mode 100644 index 0000000..5e56578 --- /dev/null +++ b/tests/unit/extraction/test_leakage_auditor.py @@ -0,0 +1,228 @@ +"""Unit tests for the TargetLeakageAuditor module.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import narwhals as nw +import polars as pl +import pytest + +from loclean.extraction.leakage_auditor import TargetLeakageAuditor + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_engine(response: str) -> MagicMock: + engine = MagicMock() + engine.generate.return_value = response + return engine + + +def _sample_df() -> pl.DataFrame: + return pl.DataFrame( + { + "age": [25, 30, 45, 50, 35], + "income": [50000, 60000, 80000, 90000, 55000], + "approved_date": [ + "2024-01-15", + "2024-01-20", + "2024-02-01", + "2024-02-10", + "2024-01-25", + ], + "feedback_score": [4, 5, 3, 5, 4], + "approved": [True, True, False, True, True], + } + ) + + +# ------------------------------------------------------------------ +# _extract_state +# ------------------------------------------------------------------ + + +class TestExtractState: + def test_extracts_features_and_samples(self) -> None: + df = _sample_df() + df_nw = nw.from_native(df) + features = ["age", "income", "approved_date", "feedback_score"] + state = TargetLeakageAuditor._extract_state(df_nw, "approved", features) + + assert state["target_col"] == "approved" + assert state["features"] == features + assert len(state["sample_rows"]) <= 10 + assert "age" in state["dtypes"] + + def test_respects_sample_n(self) -> None: + df = _sample_df() + df_nw = nw.from_native(df) + state = TargetLeakageAuditor._extract_state( + df_nw, "approved", ["age"], sample_n=2 + ) + assert len(state["sample_rows"]) == 2 + + +# ------------------------------------------------------------------ +# _build_prompt +# ------------------------------------------------------------------ + + +class TestBuildPrompt: + def test_includes_domain_and_target(self) -> None: + state = { + "target_col": "approved", + "features": ["age", "income"], + "dtypes": {"age": "Int64", "income": "Int64"}, + "sample_rows": [{"age": 25, "income": 50000, "approved": True}], + } + prompt = TargetLeakageAuditor._build_prompt(state, "loan approval prediction") + assert "loan approval prediction" in prompt + assert "approved" in prompt + assert "age" in prompt + assert "is_leakage" in prompt + + def test_no_domain(self) -> None: + state = { + "target_col": "y", + "features": ["x"], + "dtypes": {"x": "Float64"}, + "sample_rows": [{"x": 1.0, "y": 0}], + } + prompt = TargetLeakageAuditor._build_prompt(state, "") + assert "Dataset domain:" not in prompt + + +# ------------------------------------------------------------------ +# _parse_verdict +# ------------------------------------------------------------------ + + +class TestParseVerdict: + def test_parses_valid_json(self) -> None: + response = json.dumps( + [ + {"column": "approved_date", "is_leakage": True, "reason": "Post-event"}, + {"column": "age", "is_leakage": False, "reason": "Pre-event"}, + ] + ) + verdicts = TargetLeakageAuditor._parse_verdict(response) + assert len(verdicts) == 2 + assert verdicts[0]["column"] == "approved_date" + assert verdicts[0]["is_leakage"] is True + assert verdicts[1]["is_leakage"] is False + + def test_handles_extra_text(self) -> None: + response = ( + 'Analysis:\n[{"column": "x", "is_leakage": false, "reason": "ok"}]\nEnd.' + ) + verdicts = TargetLeakageAuditor._parse_verdict(response) + assert len(verdicts) == 1 + + def test_raises_on_no_json(self) -> None: + with pytest.raises(ValueError, match="No JSON array"): + TargetLeakageAuditor._parse_verdict("no json here") + + +# ------------------------------------------------------------------ +# audit (integration with mock LLM) +# ------------------------------------------------------------------ + + +class TestAudit: + def test_drops_leaked_columns(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + {"column": "income", "is_leakage": False, "reason": "ok"}, + {"column": "approved_date", "is_leakage": True, "reason": "Post-event"}, + { + "column": "feedback_score", + "is_leakage": True, + "reason": "Post-event", + }, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + pruned, summary = auditor.audit(df, "approved", "loan approval") + + assert "approved_date" not in pruned.columns + assert "feedback_score" not in pruned.columns + assert "age" in pruned.columns + assert "income" in pruned.columns + assert "approved" in pruned.columns + assert "approved_date" in summary["dropped_columns"] + assert "feedback_score" in summary["dropped_columns"] + + def test_keeps_all_if_no_leakage(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + {"column": "income", "is_leakage": False, "reason": "ok"}, + {"column": "approved_date", "is_leakage": False, "reason": "ok"}, + {"column": "feedback_score", "is_leakage": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + pruned, summary = auditor.audit(df, "approved") + + assert set(pruned.columns) == set(df.columns) + assert summary["dropped_columns"] == [] + + def test_missing_target_raises(self) -> None: + df = _sample_df() + engine = _make_engine("[]") + auditor = TargetLeakageAuditor(inference_engine=engine) + + with pytest.raises(ValueError, match="not found"): + auditor.audit(df, "nonexistent") + + def test_no_feature_columns(self) -> None: + df = pl.DataFrame({"target": [1, 2, 3]}) + engine = _make_engine("[]") + auditor = TargetLeakageAuditor(inference_engine=engine) + + pruned, summary = auditor.audit(df, "target") + + assert pruned.columns == ["target"] + assert summary["dropped_columns"] == [] + engine.generate.assert_not_called() + + def test_summary_contains_verdicts(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + _, summary = auditor.audit(df, "approved") + + assert "verdicts" in summary + assert isinstance(summary["verdicts"], list) + + def test_domain_passed_to_prompt(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + auditor.audit(df, "approved", domain="healthcare readmission") + + call_args = engine.generate.call_args[0][0] + assert "healthcare readmission" in call_args diff --git a/tests/unit/extraction/test_missingness_recognizer.py b/tests/unit/extraction/test_missingness_recognizer.py new file mode 100644 index 0000000..13e44c1 --- /dev/null +++ b/tests/unit/extraction/test_missingness_recognizer.py @@ -0,0 +1,195 @@ +"""Unit tests for the MissingnessRecognizer module.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import narwhals as nw +import polars as pl + +from loclean.extraction.missingness_recognizer import MissingnessRecognizer + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_engine(response: str) -> MagicMock: + engine = MagicMock() + engine.generate.return_value = response + return engine + + +_ENCODER_SRC = ( + "def encode_missingness(row: dict) -> bool:\n" + " try:\n" + " return row.get('category') == 'electronics'\n" + " except Exception:\n" + " return False\n" +) + + +def _df_with_nulls() -> pl.DataFrame: + return pl.DataFrame( + { + "price": [100.0, None, 300.0, None, 500.0, None], + "category": [ + "clothing", + "electronics", + "clothing", + "electronics", + "clothing", + "electronics", + ], + "quantity": [10, 5, 20, 3, 15, 1], + } + ) + + +# ------------------------------------------------------------------ +# _find_null_columns +# ------------------------------------------------------------------ + + +class TestFindNullColumns: + def test_detects_columns_with_nulls(self) -> None: + df = _df_with_nulls() + df_nw = nw.from_native(df) + result = MissingnessRecognizer._find_null_columns(df_nw) + assert result == ["price"] + + def test_no_nulls(self) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + df_nw = nw.from_native(df) + result = MissingnessRecognizer._find_null_columns(df_nw) + assert result == [] + + +# ------------------------------------------------------------------ +# _sample_null_context +# ------------------------------------------------------------------ + + +class TestSampleNullContext: + def test_samples_rows_where_target_is_null(self) -> None: + df = _df_with_nulls() + df_nw = nw.from_native(df) + sample = MissingnessRecognizer._sample_null_context( + df_nw, "price", ["category", "quantity"] + ) + assert len(sample) == 3 + for row in sample: + assert "category" in row + assert "quantity" in row + assert "price" not in row + + def test_respects_max_rows(self) -> None: + df = _df_with_nulls() + df_nw = nw.from_native(df) + sample = MissingnessRecognizer._sample_null_context( + df_nw, "price", ["category"], max_rows=2 + ) + assert len(sample) == 2 + + def test_empty_when_no_nulls(self) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + df_nw = nw.from_native(df) + sample = MissingnessRecognizer._sample_null_context(df_nw, "a", ["b"]) + assert sample == [] + + +# ------------------------------------------------------------------ +# _build_prompt +# ------------------------------------------------------------------ + + +class TestBuildPrompt: + def test_includes_column_name_and_sample(self) -> None: + prompt = MissingnessRecognizer._build_prompt( + "price", + ["category", "quantity"], + [{"category": "electronics", "quantity": 5}], + ) + assert "price" in prompt + assert "encode_missingness" in prompt + assert "electronics" in prompt + + def test_includes_rules(self) -> None: + prompt = MissingnessRecognizer._build_prompt("x", ["y"], [{"y": 1}]) + assert "try/except" in prompt + assert "boolean" in prompt.lower() or "bool" in prompt.lower() + + +# ------------------------------------------------------------------ +# _verify_encoder +# ------------------------------------------------------------------ + + +class TestVerifyEncoder: + def test_valid_encoder_passes(self) -> None: + def good_fn(row: dict[str, Any]) -> bool: + return True + + ok, err = MissingnessRecognizer._verify_encoder(good_fn, [{"a": 1}, {"a": 2}]) + assert ok is True + assert err == "" + + def test_non_bool_return_fails(self) -> None: + def bad_fn(row: dict[str, Any]) -> Any: + return "not a bool" + + ok, err = MissingnessRecognizer._verify_encoder(bad_fn, [{"a": 1}]) + assert ok is False + assert "bool" in err.lower() + + +# ------------------------------------------------------------------ +# recognize (integration with mock LLM) +# ------------------------------------------------------------------ + + +class TestRecognize: + def test_adds_mnar_column(self) -> None: + engine = _make_engine(_ENCODER_SRC) + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + df = _df_with_nulls() + result, summary = recognizer.recognize(df) + + assert "price_mnar" in result.columns + assert "price" in summary["patterns"] + assert summary["patterns"]["price"]["encoded_as"] == "price_mnar" + + def test_no_nulls_skips(self) -> None: + engine = _make_engine("") + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result, summary = recognizer.recognize(df) + + assert set(result.columns) == {"a", "b"} + assert summary["patterns"] == {} + engine.generate.assert_not_called() + + def test_target_cols_filter(self) -> None: + df = pl.DataFrame( + { + "a": [1, None, 3], + "b": [None, 2, None], + "c": [10, 20, 30], + } + ) + engine = _make_engine(_ENCODER_SRC) + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + _, summary = recognizer.recognize(df, target_cols=["a"]) + + assert "a" in summary["patterns"] + assert "b" not in summary["patterns"] + + def test_compile_failure_returns_none_pattern(self) -> None: + engine = _make_engine("this is not valid python at all!!!") + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + df = _df_with_nulls() + result, summary = recognizer.recognize(df) + + assert "price_mnar" not in result.columns + assert summary["patterns"]["price"] is None diff --git a/tests/unit/extraction/test_trap_pruner.py b/tests/unit/extraction/test_trap_pruner.py new file mode 100644 index 0000000..03ddcd0 --- /dev/null +++ b/tests/unit/extraction/test_trap_pruner.py @@ -0,0 +1,247 @@ +"""Unit tests for the TrapPruner module.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import narwhals as nw +import polars as pl +import pytest + +from loclean.extraction.trap_pruner import TrapPruner, _ColumnProfile + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_engine(response: str) -> MagicMock: + engine = MagicMock() + engine.generate.return_value = response + return engine + + +def _sample_df() -> pl.DataFrame: + """DataFrame with one real feature and one Gaussian noise column.""" + import random + + random.seed(42) + n = 100 + prices = [200_000 + i * 5000 for i in range(n)] + sqft = [150 + i * 10 + random.randint(-5, 5) for i in range(n)] + noise = [random.gauss(0, 1) for _ in range(n)] + + return pl.DataFrame( + { + "sqft": sqft, + "noise_feat": noise, + "price": prices, + } + ) + + +# ------------------------------------------------------------------ +# _profile_columns +# ------------------------------------------------------------------ + + +class TestProfileColumns: + def test_basic_stats(self) -> None: + df = _sample_df() + df_nw = nw.from_native(df) + profiles = TrapPruner._profile_columns(df_nw, "price", ["sqft", "noise_feat"]) + assert len(profiles) == 2 + + sqft_p = next(p for p in profiles if p.name == "sqft") + noise_p = next(p for p in profiles if p.name == "noise_feat") + + assert abs(sqft_p.corr_with_target) > 0.5 + assert abs(noise_p.corr_with_target) < 0.2 + + def test_zero_variance_column(self) -> None: + df = pl.DataFrame( + { + "constant": [5] * 10, + "target": list(range(10)), + } + ) + df_nw = nw.from_native(df) + profiles = TrapPruner._profile_columns(df_nw, "target", ["constant"]) + assert len(profiles) == 1 + assert profiles[0].variance == 0.0 + assert profiles[0].corr_with_target == 0.0 + + def test_single_row_returns_empty(self) -> None: + df = pl.DataFrame({"a": [1], "target": [2]}) + df_nw = nw.from_native(df) + profiles = TrapPruner._profile_columns(df_nw, "target", ["a"]) + assert profiles == [] + + +# ------------------------------------------------------------------ +# _build_prompt +# ------------------------------------------------------------------ + + +class TestBuildPrompt: + def test_anonymises_column_names(self) -> None: + profiles = [ + _ColumnProfile( + name="secret_column", + mean=0.0, + std=1.0, + variance=1.0, + skewness=0.0, + kurtosis=0.0, + min_val=-3.0, + max_val=3.0, + corr_with_target=0.01, + ), + ] + col_map, prompt = TrapPruner._build_prompt(profiles) + + assert "secret_column" not in prompt + assert "col_0" in prompt + assert col_map["col_0"] == "secret_column" + + def test_multiple_columns_indexed(self) -> None: + profiles = [ + _ColumnProfile( + name=f"feat_{i}", + mean=float(i), + std=1.0, + variance=1.0, + skewness=0.0, + kurtosis=0.0, + min_val=0.0, + max_val=10.0, + corr_with_target=0.5, + ) + for i in range(3) + ] + col_map, prompt = TrapPruner._build_prompt(profiles) + assert len(col_map) == 3 + assert "col_0" in prompt + assert "col_1" in prompt + assert "col_2" in prompt + + +# ------------------------------------------------------------------ +# _parse_verdict +# ------------------------------------------------------------------ + + +class TestParseVerdict: + def test_maps_anonymous_to_real(self) -> None: + col_map = {"col_0": "noise_feat", "col_1": "real_feat"} + response = json.dumps( + [ + {"column": "col_0", "is_trap": True, "reason": "Gaussian noise"}, + {"column": "col_1", "is_trap": False, "reason": "Correlated"}, + ] + ) + + verdicts = TrapPruner._parse_verdict(response, col_map) + assert len(verdicts) == 2 + assert verdicts[0]["column"] == "noise_feat" + assert verdicts[0]["is_trap"] is True + assert verdicts[1]["column"] == "real_feat" + assert verdicts[1]["is_trap"] is False + + def test_handles_extra_text_around_json(self) -> None: + col_map = {"col_0": "feat_a"} + response = ( + "Here is the analysis:\n" + '[{"column": "col_0", "is_trap": false, "reason": "ok"}]\nDone.' + ) + + verdicts = TrapPruner._parse_verdict(response, col_map) + assert len(verdicts) == 1 + assert verdicts[0]["column"] == "feat_a" + + def test_raises_on_no_json(self) -> None: + with pytest.raises(ValueError, match="No JSON array"): + TrapPruner._parse_verdict("no json here", {}) + + +# ------------------------------------------------------------------ +# prune (integration with mock LLM) +# ------------------------------------------------------------------ + + +class TestPrune: + def test_removes_trap_columns(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "col_0", "is_trap": False, "reason": "Correlated"}, + {"column": "col_1", "is_trap": True, "reason": "Gaussian noise"}, + ] + ) + engine = _make_engine(response) + pruner = TrapPruner(inference_engine=engine) + + pruned, summary = pruner.prune(df, "price") + + assert "noise_feat" not in pruned.columns + assert "sqft" in pruned.columns + assert "price" in pruned.columns + assert "noise_feat" in summary["dropped_columns"] + + def test_keeps_all_if_no_traps(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "col_0", "is_trap": False, "reason": "ok"}, + {"column": "col_1", "is_trap": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + pruner = TrapPruner(inference_engine=engine) + + pruned, summary = pruner.prune(df, "price") + + assert set(pruned.columns) == set(df.columns) + assert summary["dropped_columns"] == [] + + def test_returns_summary_with_verdicts(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "col_0", "is_trap": False, "reason": "real"}, + {"column": "col_1", "is_trap": True, "reason": "noise"}, + ] + ) + engine = _make_engine(response) + pruner = TrapPruner(inference_engine=engine) + + _, summary = pruner.prune(df, "price") + + assert "dropped_columns" in summary + assert "verdicts" in summary + assert len(summary["verdicts"]) == 2 + + def test_missing_target_raises(self) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + engine = _make_engine("[]") + pruner = TrapPruner(inference_engine=engine) + + with pytest.raises(ValueError, match="not found"): + pruner.prune(df, "nonexistent") + + def test_no_numeric_columns(self) -> None: + df = pl.DataFrame( + { + "name": ["alice", "bob"], + "target": [1, 2], + } + ) + engine = _make_engine("[]") + pruner = TrapPruner(inference_engine=engine) + + pruned, summary = pruner.prune(df, "target") + + assert set(pruned.columns) == set(df.columns) + assert summary["dropped_columns"] == [] + engine.generate.assert_not_called()