From 0a614fd3c567331c3930eb588ab7d9bbcacf2f0a Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 17:43:43 +0000 Subject: [PATCH 1/8] feat(extraction): add TrapPruner, MissingnessRecognizer, TargetLeakageAuditor - TrapPruner: statistical profiling + LLM verification of Gaussian noise columns - MissingnessRecognizer: MNAR pattern detection with sandbox-compiled encoders - TargetLeakageAuditor: semantic timeline evaluation for target leakage --- src/loclean/extraction/leakage_auditor.py | 240 ++++++++++++ .../extraction/missingness_recognizer.py | 307 ++++++++++++++++ src/loclean/extraction/trap_pruner.py | 343 ++++++++++++++++++ 3 files changed, 890 insertions(+) create mode 100644 src/loclean/extraction/leakage_auditor.py create mode 100644 src/loclean/extraction/missingness_recognizer.py create mode 100644 src/loclean/extraction/trap_pruner.py diff --git a/src/loclean/extraction/leakage_auditor.py b/src/loclean/extraction/leakage_auditor.py new file mode 100644 index 0000000..67bdb3d --- /dev/null +++ b/src/loclean/extraction/leakage_auditor.py @@ -0,0 +1,240 @@ +"""Semantic target leakage detection via LLM-driven timeline evaluation. + +Identifies features that mathematically or logically imply the target +variable — i.e. columns containing information generated *after* the +target event occurs. In a deterministic leakage scenario: + + P(Y | X_i) ≈ 1 + +The generative engine acts as a semantic auditor, catching logical +leakage that basic statistical tests miss by evaluating the causal +timeline of each feature relative to the target outcome. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +import narwhals as nw + +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class TargetLeakageAuditor: + """Detect and remove features that leak the target variable. + + For each feature column the auditor prompts the LLM with the + dataset domain description and a representative sample, asking + it to evaluate whether the feature could only be known *after* + the target outcome is determined. + + Args: + inference_engine: Ollama (or compatible) engine. + max_retries: LLM generation retry budget. + sample_n: Number of sample rows to include in the prompt. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + *, + max_retries: int = 2, + sample_n: int = 10, + ) -> None: + self.inference_engine = inference_engine + self.max_retries = max_retries + self.sample_n = sample_n + + def audit( + self, + df: IntoFrameT, + target_col: str, + domain: str = "", + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Audit features for target leakage and drop offenders. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + domain: Brief text description of the dataset domain + (e.g. ``"hospital readmission prediction"``). + + Returns: + Tuple of ``(pruned_df, summary)`` where *summary* + contains ``dropped_columns`` and per-column ``verdicts``. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + + if target_col not in df_nw.columns: + raise ValueError(f"Target column '{target_col}' not found") + + feature_cols = [c for c in df_nw.columns if c != target_col] + if not feature_cols: + logger.info("No feature columns to audit.") + return df, {"dropped_columns": [], "verdicts": []} + + state = self._extract_state(df_nw, target_col, feature_cols) + prompt = self._build_prompt(state, domain) + verdicts = self._evaluate_with_llm(prompt) + + leaked = [v["column"] for v in verdicts if v.get("is_leakage")] + valid_leaked = [c for c in leaked if c in feature_cols] + + summary: dict[str, Any] = { + "dropped_columns": valid_leaked, + "verdicts": verdicts, + } + + if valid_leaked: + logger.info( + "Dropping %d leaked feature(s): %s", + len(valid_leaked), + valid_leaked, + ) + try: + pruned_nw = df_nw.drop(valid_leaked) + return nw.to_native(pruned_nw), summary # type: ignore[type-var] + except Exception as exc: + logger.warning("Failed to drop columns: %s", exc) + return df, summary + + logger.info("No target leakage detected.") + return df, summary + + # ------------------------------------------------------------------ + # State extraction + # ------------------------------------------------------------------ + + @staticmethod + def _extract_state( + df_nw: nw.DataFrame[Any], + target_col: str, + feature_cols: list[str], + sample_n: int = 10, + ) -> dict[str, Any]: + """Build structural metadata for the LLM prompt. + + Args: + df_nw: Narwhals DataFrame. + target_col: Target variable column. + feature_cols: Feature column names. + sample_n: Number of sample rows. + + Returns: + Dict with ``target_col``, ``features``, ``dtypes``, + and ``sample_rows``. + """ + n = min(df_nw.shape[0], sample_n) + sampled = df_nw.head(n) + sample_rows = sampled.rows(named=True) + + dtypes = {col: str(df_nw[col].dtype) for col in feature_cols} + + return { + "target_col": target_col, + "features": feature_cols, + "dtypes": dtypes, + "sample_rows": sample_rows, # type: ignore[dict-item] + } + + # ------------------------------------------------------------------ + # Prompt construction + # ------------------------------------------------------------------ + + @staticmethod + def _build_prompt( + state: dict[str, Any], + domain: str, + ) -> str: + """Build the LLM prompt for timeline evaluation.""" + target = state["target_col"] + features = state["features"] + dtypes = state["dtypes"] + sample_str = json.dumps(state["sample_rows"][:10], indent=2, default=str) + + domain_line = f"Dataset domain: {domain}\n" if domain else "" + + feature_info = "\n".join( + f" - {f} (dtype: {dtypes.get(f, 'unknown')})" for f in features + ) + + return ( + "You are a machine learning auditor specialising in data " + "leakage detection.\n\n" + f"{domain_line}" + f"Target variable: '{target}'\n\n" + f"Feature columns:\n{feature_info}\n\n" + f"Sample rows:\n{sample_str}\n\n" + "Task: For each feature column, evaluate whether it could " + "constitute **target leakage** — meaning the feature contains " + "information that would only be available AFTER the target " + "outcome is determined.\n\n" + "Consider:\n" + "- Temporal ordering: was this feature generated before or " + "after the target event?\n" + "- Semantic meaning: does the feature directly encode or " + "trivially derive from the target?\n" + "- Statistical signal: extremely high correlation may " + "indicate leakage, not just a good predictor.\n\n" + "Output ONLY a JSON array. For each feature, output an " + "object with exactly three keys:\n" + '- "column": the feature name\n' + '- "is_leakage": boolean\n' + '- "reason": brief explanation\n\n' + "Output ONLY the JSON array, no other text." + ) + + # ------------------------------------------------------------------ + # LLM evaluation + # ------------------------------------------------------------------ + + def _evaluate_with_llm( + self, + prompt: str, + ) -> list[dict[str, Any]]: + """Send the prompt and parse the leakage verdicts.""" + for attempt in range(1, self.max_retries + 1): + try: + raw = str(self.inference_engine.generate(prompt)).strip() + return self._parse_verdict(raw) + except (json.JSONDecodeError, ValueError, KeyError) as exc: + logger.warning( + "LLM verdict parsing failed (attempt %d/%d): %s", + attempt, + self.max_retries, + exc, + ) + logger.warning("Could not parse LLM verdicts — keeping all columns.") + return [] + + @staticmethod + def _parse_verdict(response: str) -> list[dict[str, Any]]: + """Parse the JSON verdict from the LLM response.""" + text = response.strip() + start = text.find("[") + end = text.rfind("]") + if start == -1 or end == -1: + raise ValueError("No JSON array found in LLM response") + + items: list[dict[str, Any]] = json.loads(text[start : end + 1]) + verdicts: list[dict[str, Any]] = [] + + for item in items: + verdicts.append( + { + "column": str(item["column"]), + "is_leakage": bool(item.get("is_leakage", False)), + "reason": str(item.get("reason", "")), + } + ) + + return verdicts diff --git a/src/loclean/extraction/missingness_recognizer.py b/src/loclean/extraction/missingness_recognizer.py new file mode 100644 index 0000000..e6ad695 --- /dev/null +++ b/src/loclean/extraction/missingness_recognizer.py @@ -0,0 +1,307 @@ +"""Missingness pattern recognition via LLM-driven MNAR detection. + +Identifies Missing Not At Random (MNAR) patterns where the probability +of a value being missing in feature X depends on the value of feature Y: + + P(X_missing | Y) ≠ P(X_missing) + +Uses Narwhals for backend-agnostic null analysis and an InferenceEngine +to infer structural correlations from data samples. Detected patterns +are encoded as new boolean feature columns. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any, Callable + +import narwhals as nw + +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + +_SUFFIX = "_mnar" + + +class MissingnessRecognizer: + """Detect MNAR patterns and encode them as boolean feature flags. + + For each column containing nulls the recognizer: + + 1. Samples rows where the column is null alongside other features. + 2. Prompts the LLM to identify structural correlations. + 3. Compiles the LLM-generated ``encode_missingness`` function in a + sandbox. + 4. Applies the function across the DataFrame to create a boolean + ``{col}_mnar`` column. + + Args: + inference_engine: Ollama (or compatible) engine. + sample_size: Maximum null rows to sample per column. + max_retries: LLM code-generation retry budget. + timeout_s: Per-row execution timeout in seconds. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + *, + sample_size: int = 50, + max_retries: int = 3, + timeout_s: float = 2.0, + ) -> None: + self.inference_engine = inference_engine + self.sample_size = sample_size + self.max_retries = max_retries + self.timeout_s = timeout_s + + def recognize( + self, + df: IntoFrameT, + target_cols: list[str] | None = None, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect MNAR patterns and add boolean feature columns. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_cols: Columns to analyse for missingness. If + ``None``, all columns containing nulls are evaluated. + + Returns: + Tuple of ``(augmented_df, summary)`` where *summary* + maps each analysed column to its pattern description + or ``None`` if no pattern was found. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + + null_cols = self._find_null_columns(df_nw) + + if target_cols is not None: + null_cols = [c for c in target_cols if c in null_cols] + + if not null_cols: + logger.info("No columns with null values to analyse.") + return df, {"patterns": {}} + + all_cols = df_nw.columns + patterns: dict[str, Any] = {} + new_columns: dict[str, list[bool]] = {} + + for col in null_cols: + context_cols = [c for c in all_cols if c != col] + sample = self._sample_null_context(df_nw, col, context_cols) + + if not sample: + patterns[col] = None + continue + + prompt = self._build_prompt(col, context_cols, sample) + fn = self._generate_and_compile(prompt) + + if fn is None: + patterns[col] = None + continue + + ok, error = self._verify_encoder(fn, sample) + if not ok: + logger.warning("Encoder for '%s' failed verification: %s", col, error) + patterns[col] = None + continue + + flags = self._apply_encoder(df_nw, fn) + col_name = f"{col}{_SUFFIX}" + new_columns[col_name] = flags + patterns[col] = { + "encoded_as": col_name, + "null_count": sum(1 for v in df_nw[col].to_list() if v is None), + "pattern_flags_true": sum(flags), + } + logger.info( + "Encoded MNAR pattern for '%s' → '%s' (%d flagged)", + col, + col_name, + sum(flags), + ) + + if new_columns: + native_ns = nw.get_native_namespace(df_nw) + rows_data: dict[str, list[Any]] = { + c: df_nw[c].to_list() for c in df_nw.columns + } + rows_data.update(new_columns) + result_nw = nw.from_dict(rows_data, backend=native_ns) + return nw.to_native(result_nw), {"patterns": patterns} + + return df, {"patterns": patterns} + + # ------------------------------------------------------------------ + # Null detection + # ------------------------------------------------------------------ + + @staticmethod + def _find_null_columns(df_nw: nw.DataFrame[Any]) -> list[str]: + """Return column names that contain at least one null.""" + return [col for col in df_nw.columns if df_nw[col].null_count() > 0] + + # ------------------------------------------------------------------ + # Sampling + # ------------------------------------------------------------------ + + @staticmethod + def _sample_null_context( + df_nw: nw.DataFrame[Any], + null_col: str, + context_cols: list[str], + max_rows: int = 50, + ) -> list[dict[str, Any]]: + """Extract rows where *null_col* is null with context values. + + Returns a list of dicts, each containing the context column + values for a row where *null_col* is missing. + """ + null_mask = df_nw[null_col].is_null() + null_rows = df_nw.filter(null_mask) + + if null_rows.shape[0] == 0: + return [] + + n = min(null_rows.shape[0], max_rows) + sampled = null_rows.head(n) + + select_cols = [c for c in context_cols if c in sampled.columns] + if not select_cols: + return [] + + return sampled.select(select_cols).rows(named=True) # type: ignore[return-value] + + # ------------------------------------------------------------------ + # Prompt construction + # ------------------------------------------------------------------ + + @staticmethod + def _build_prompt( + null_col: str, + context_cols: list[str], + sample_rows: list[dict[str, Any]], + ) -> str: + """Build the LLM prompt for pattern inference.""" + sample_str = json.dumps(sample_rows[:20], indent=2, default=str) + + return ( + "You are a data scientist analysing missing data patterns.\n\n" + f"Column '{null_col}' has missing values. Below are sample rows " + "where this column IS NULL, showing the values of the other " + "columns:\n\n" + f"Context columns: {context_cols}\n\n" + f"Sample rows (where '{null_col}' is null):\n{sample_str}\n\n" + "Task: Identify if there is a structural pattern that predicts " + f"when '{null_col}' is missing based on other column values.\n\n" + "Write a pure Python function with this exact signature:\n\n" + "def encode_missingness(row: dict) -> bool:\n" + " ...\n\n" + "The function receives a dict of ALL column values for a row " + "(including the target column) and returns True if the " + "missingness pattern is detected.\n\n" + "Rules:\n" + "- Use ONLY standard library modules (math, statistics, operator)\n" + "- Wrap logic in try/except returning False on failure\n" + "- Return a single boolean value\n" + "- Do NOT use markdown fences, comments, or prose\n" + "- Output ONLY the function code, nothing else\n\n" + "Example:\n" + "def encode_missingness(row: dict) -> bool:\n" + " try:\n" + " return row.get('category') == 'electronics' " + "and row.get('price', 0) > 500\n" + " except Exception:\n" + " return False\n" + ) + + # ------------------------------------------------------------------ + # Code generation + compilation + # ------------------------------------------------------------------ + + def _generate_and_compile( + self, + prompt: str, + ) -> Callable[[dict[str, Any]], bool] | None: + """Generate, sanitize, and compile the encoder function.""" + import re + + from loclean.utils.sandbox import compile_sandboxed + + for attempt in range(1, self.max_retries + 1): + try: + raw = str(self.inference_engine.generate(prompt)).strip() + source = re.sub(r"```(?:python)?\s*\n?", "", raw).strip() + fn = compile_sandboxed( + source, + "encode_missingness", + ["math", "statistics", "operator"], + ) + return fn # type: ignore[return-value] + except (ValueError, SyntaxError) as exc: + logger.warning( + "⚠ Code generation failed (attempt %d/%d): %s", + attempt, + self.max_retries, + exc, + ) + + logger.warning( + "Could not compile encoder after %d retries.", + self.max_retries, + ) + return None + + # ------------------------------------------------------------------ + # Verification + # ------------------------------------------------------------------ + + @staticmethod + def _verify_encoder( + fn: Callable[[dict[str, Any]], bool], + sample_rows: list[dict[str, Any]], + ) -> tuple[bool, str]: + """Test the encoder on sample rows.""" + from loclean.utils.sandbox import run_with_timeout + + for row in sample_rows[:5]: + result, error = run_with_timeout(fn, (row,), 2.0) + if error: + return False, f"Execution error: {error}" + if not isinstance(result, bool): + return False, f"Expected bool, got {type(result).__name__}" + + return True, "" + + # ------------------------------------------------------------------ + # Application + # ------------------------------------------------------------------ + + def _apply_encoder( + self, + df_nw: nw.DataFrame[Any], + fn: Callable[[dict[str, Any]], bool], + ) -> list[bool]: + """Apply the encoder across all rows.""" + from loclean.utils.sandbox import run_with_timeout + + rows: list[dict[str, Any]] = df_nw.rows(named=True) # type: ignore[assignment] + flags: list[bool] = [] + + for row in rows: + result, error = run_with_timeout(fn, (row,), self.timeout_s) + if error or not isinstance(result, bool): + flags.append(False) + else: + flags.append(result) + + return flags diff --git a/src/loclean/extraction/trap_pruner.py b/src/loclean/extraction/trap_pruner.py new file mode 100644 index 0000000..376b933 --- /dev/null +++ b/src/loclean/extraction/trap_pruner.py @@ -0,0 +1,343 @@ +"""Automated trap feature pruning via statistical profiling and LLM verification. + +Identifies columns that look like valid signals but are actually +uncorrelated Gaussian noise (trap features). Uses Narwhals for +backend-agnostic statistical profiling and an ``InferenceEngine`` +for generative verification. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +import narwhals as nw + +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class _ColumnProfile: + """Statistical profile for a single numeric column.""" + + __slots__ = ( + "name", + "mean", + "std", + "variance", + "skewness", + "kurtosis", + "min_val", + "max_val", + "corr_with_target", + ) + + def __init__( + self, + name: str, + mean: float, + std: float, + variance: float, + skewness: float, + kurtosis: float, + min_val: float, + max_val: float, + corr_with_target: float, + ) -> None: + self.name = name + self.mean = mean + self.std = std + self.variance = variance + self.skewness = skewness + self.kurtosis = kurtosis + self.min_val = min_val + self.max_val = max_val + self.corr_with_target = corr_with_target + + def to_dict(self) -> dict[str, Any]: + """Serialise profile to a plain dictionary.""" + return { + "name": self.name, + "mean": self.mean, + "std": self.std, + "variance": self.variance, + "skewness": self.skewness, + "kurtosis": self.kurtosis, + "min": self.min_val, + "max": self.max_val, + "corr_with_target": self.corr_with_target, + } + + +class TrapPruner: + """Identify and remove trap features from a DataFrame. + + Trap features are columns of uncorrelated Gaussian noise that + masquerade as valid signals. Detection relies entirely on + statistical distributions and target correlations — column names + are deliberately ignored. + + Args: + inference_engine: Ollama (or compatible) engine for verification. + correlation_threshold: Absolute correlation below which a + column is considered uncorrelated. Default ``0.05``. + max_retries: LLM generation retry budget. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + *, + correlation_threshold: float = 0.05, + max_retries: int = 2, + ) -> None: + self.inference_engine = inference_engine + self.correlation_threshold = correlation_threshold + self.max_retries = max_retries + + def prune( + self, + df: IntoFrameT, + target_col: str, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Profile, verify, and drop trap features. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + + Returns: + Tuple of ``(pruned_df, summary)`` where *summary* contains + ``dropped_columns`` (list of removed names) and + ``verdicts`` (per-column LLM reasoning). + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + + if target_col not in df_nw.columns: + raise ValueError(f"Target column '{target_col}' not found") + + numeric_cols = [ + c for c in df_nw.columns if c != target_col and df_nw[c].dtype.is_numeric() + ] + + if not numeric_cols: + logger.info("No numeric feature columns to evaluate.") + return df, {"dropped_columns": [], "verdicts": []} + + profiles = self._profile_columns(df_nw, target_col, numeric_cols) + + col_map, prompt = self._build_prompt(profiles) + + verdicts = self._verify_with_llm(prompt, col_map) + + trap_cols = [v["column"] for v in verdicts if v.get("is_trap")] + + summary: dict[str, Any] = { + "dropped_columns": trap_cols, + "verdicts": verdicts, + } + + if trap_cols: + logger.info( + "Dropping %d trap feature(s): %s", + len(trap_cols), + trap_cols, + ) + pruned_nw = df_nw.drop(trap_cols) + return nw.to_native(pruned_nw), summary # type: ignore[type-var] + + logger.info("No trap features detected.") + return df, summary + + # ------------------------------------------------------------------ + # Statistical profiling + # ------------------------------------------------------------------ + + @staticmethod + def _profile_columns( + df_nw: nw.DataFrame[Any], + target_col: str, + numeric_cols: list[str], + ) -> list[_ColumnProfile]: + """Compute distribution statistics for each numeric column. + + All operations use the Narwhals interface. Division-by-zero + and other math errors are caught per-column. + """ + n = df_nw.shape[0] + if n < 2: + return [] + + target_series = df_nw[target_col].cast(nw.Float64) + target_mean = target_series.mean() + target_std = target_series.std() + + profiles: list[_ColumnProfile] = [] + + for col in numeric_cols: + try: + series = df_nw[col].cast(nw.Float64) + col_mean = series.mean() + col_std = series.std() + + diffs = series - col_mean + variance = (diffs * diffs).mean() + + if col_std and col_std > 0 and target_std and target_std > 0: + corr = float( + ((series - col_mean) * (target_series - target_mean)).mean() + / (col_std * target_std) + ) + else: + corr = 0.0 + + if col_std and col_std > 0: + skewness = float((diffs**3).mean() / (col_std**3)) + kurtosis = float((diffs**4).mean() / (col_std**4)) - 3.0 + else: + skewness = 0.0 + kurtosis = 0.0 + + profiles.append( + _ColumnProfile( + name=col, + mean=float(col_mean) if col_mean is not None else 0.0, + std=float(col_std) if col_std is not None else 0.0, + variance=float(variance) if variance is not None else 0.0, + skewness=skewness, + kurtosis=kurtosis, + min_val=float(series.min()), + max_val=float(series.max()), + corr_with_target=corr, + ) + ) + except (ZeroDivisionError, ValueError, OverflowError): + profiles.append( + _ColumnProfile( + name=col, + mean=0.0, + std=0.0, + variance=0.0, + skewness=0.0, + kurtosis=0.0, + min_val=0.0, + max_val=0.0, + corr_with_target=0.0, + ) + ) + + return profiles + + # ------------------------------------------------------------------ + # Prompt construction (anonymised) + # ------------------------------------------------------------------ + + @staticmethod + def _build_prompt( + profiles: list[_ColumnProfile], + ) -> tuple[dict[str, str], str]: + """Build the LLM verification prompt with anonymised column IDs. + + Returns: + Tuple of ``(col_map, prompt_text)`` where *col_map* maps + ``"col_0"`` → real column name. + """ + col_map: dict[str, str] = {} + lines: list[str] = [] + + for i, p in enumerate(profiles): + anon = f"col_{i}" + col_map[anon] = p.name + + lines.append( + f"Column {anon}: " + f"mean={p.mean:.4f}, std={p.std:.4f}, " + f"variance={p.variance:.4f}, " + f"skewness={p.skewness:.4f}, kurtosis={p.kurtosis:.4f}, " + f"min={p.min_val:.4f}, max={p.max_val:.4f}, " + f"corr_with_target={p.corr_with_target:.4f}" + ) + + profile_block = "\n".join(lines) + + prompt = ( + "You are a statistical analyst. Below are the statistical profiles " + "of several numeric columns from a dataset. Each column is " + "identified only by an anonymous ID (column names are hidden).\n\n" + f"{profile_block}\n\n" + "A **trap feature** is a column that:\n" + "1. Exhibits a distribution close to standard Gaussian " + "(skewness ≈ 0, kurtosis ≈ 0, i.e. excess kurtosis near zero).\n" + "2. Has a correlation with the target variable very close to " + "zero (|corr| < 0.05).\n\n" + "Analyse each column and output ONLY a JSON array. " + "For each column output an object with exactly three keys:\n" + '- "column": the anonymous ID (e.g. "col_0")\n' + '- "is_trap": boolean\n' + '- "reason": brief explanation\n\n' + "Output ONLY the JSON array, no other text." + ) + + return col_map, prompt + + # ------------------------------------------------------------------ + # LLM verification + # ------------------------------------------------------------------ + + def _verify_with_llm( + self, + prompt: str, + col_map: dict[str, str], + ) -> list[dict[str, Any]]: + """Send the prompt to the LLM and parse the verdict.""" + for attempt in range(1, self.max_retries + 1): + try: + raw = self.inference_engine.generate(prompt) + return self._parse_verdict(str(raw).strip(), col_map) + except (json.JSONDecodeError, ValueError, KeyError) as exc: + logger.warning( + "LLM verdict parsing failed (attempt %d/%d): %s", + attempt, + self.max_retries, + exc, + ) + logger.warning("Could not parse LLM verdicts — keeping all columns.") + return [ + {"column": real, "is_trap": False, "reason": "LLM parse failure"} + for real in col_map.values() + ] + + @staticmethod + def _parse_verdict( + response: str, + col_map: dict[str, str], + ) -> list[dict[str, Any]]: + """Parse the JSON verdict and map anonymous IDs back to real names.""" + text = response.strip() + start = text.find("[") + end = text.rfind("]") + if start == -1 or end == -1: + raise ValueError("No JSON array found in LLM response") + + items: list[dict[str, Any]] = json.loads(text[start : end + 1]) + verdicts: list[dict[str, Any]] = [] + + for item in items: + anon_id = item["column"] + real_name = col_map.get(anon_id, anon_id) + verdicts.append( + { + "column": real_name, + "is_trap": bool(item.get("is_trap", False)), + "reason": str(item.get("reason", "")), + } + ) + + return verdicts From f5bd6f9a459b52640cfef62320af6dce68db5f30 Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 17:43:53 +0000 Subject: [PATCH 2/8] feat(api): wire prune_traps, recognize_missingness, audit_leakage into public API - Add all three to extraction/__init__.py lazy imports - Add Loclean class methods + module-level convenience functions - Update __all__ in loclean/__init__.py --- src/loclean/__init__.py | 223 +++++++++++++++++++++++++++++ src/loclean/extraction/__init__.py | 9 ++ 2 files changed, 232 insertions(+) diff --git a/src/loclean/__init__.py b/src/loclean/__init__.py index 96d2222..e7d0515 100644 --- a/src/loclean/__init__.py +++ b/src/loclean/__init__.py @@ -10,6 +10,7 @@ __all__ = [ "__version__", "Loclean", + "audit_leakage", "clean", "discover_features", "extract", @@ -17,6 +18,8 @@ "get_engine", "optimize_instruction", "oversample", + "prune_traps", + "recognize_missingness", "resolve_entities", "scrub", "shred_to_relations", @@ -321,6 +324,96 @@ def discover_features( ) return discoverer.discover(df, target_col) + def prune_traps( + self, + df: IntoFrameT, + target_col: str, + *, + correlation_threshold: float = 0.05, + max_retries: int = 2, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Identify and remove trap features. + + Trap features are columns of uncorrelated Gaussian noise + that masquerade as valid signals. + + Args: + df: Input DataFrame. + target_col: Target variable column. + correlation_threshold: Absolute correlation below which + a column is considered uncorrelated. + max_retries: LLM retry budget. + + Returns: + Tuple of (pruned DataFrame, summary dict). + """ + from loclean.extraction.trap_pruner import TrapPruner + + pruner = TrapPruner( + inference_engine=self.engine, + correlation_threshold=correlation_threshold, + max_retries=max_retries, + ) + return pruner.prune(df, target_col) + + def recognize_missingness( + self, + df: IntoFrameT, + target_cols: list[str] | None = None, + *, + sample_size: int = 50, + max_retries: int = 3, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect MNAR patterns and encode as boolean features. + + Args: + df: Input DataFrame. + target_cols: Columns to analyse (default: all with nulls). + sample_size: Max null rows to sample per column. + max_retries: LLM retry budget. + + Returns: + Tuple of (augmented DataFrame, summary dict). + """ + from loclean.extraction.missingness_recognizer import MissingnessRecognizer + + recognizer = MissingnessRecognizer( + inference_engine=self.engine, + sample_size=sample_size, + max_retries=max_retries, + ) + return recognizer.recognize(df, target_cols) + + def audit_leakage( + self, + df: IntoFrameT, + target_col: str, + domain: str = "", + *, + max_retries: int = 2, + sample_n: int = 10, + ) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect and remove target-leaking features. + + Args: + df: Input DataFrame. + target_col: Target variable column. + domain: Dataset domain description. + max_retries: LLM retry budget. + sample_n: Sample rows for the prompt. + + Returns: + Tuple of (pruned DataFrame, summary dict). + """ + from loclean.extraction.leakage_auditor import TargetLeakageAuditor + + auditor = TargetLeakageAuditor( + inference_engine=self.engine, + max_retries=max_retries, + sample_n=sample_n, + ) + return auditor.audit(df, target_col, domain) + def validate_quality( self, df: IntoFrameT, @@ -853,3 +946,133 @@ def discover_features( max_retries=max_retries, ) return discoverer.discover(df, target_col) + + +def prune_traps( + df: IntoFrameT, + target_col: str, + *, + correlation_threshold: float = 0.05, + max_retries: int = 2, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> tuple[IntoFrameT, dict[str, Any]]: + """Identify and remove trap features from a DataFrame. + + Trap features are columns of uncorrelated Gaussian noise that + masquerade as valid signals. Detection relies on statistical + distributions and target correlations — column names are ignored. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + correlation_threshold: Absolute correlation threshold. + max_retries: LLM retry budget. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Tuple of ``(pruned_df, summary)`` where *summary* contains + ``dropped_columns`` and ``verdicts``. + """ + from loclean.extraction.trap_pruner import TrapPruner + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + pruner = TrapPruner( + inference_engine=inference_engine, + correlation_threshold=correlation_threshold, + max_retries=max_retries, + ) + return pruner.prune(df, target_col) + + +def recognize_missingness( + df: IntoFrameT, + target_cols: list[str] | None = None, + *, + sample_size: int = 50, + max_retries: int = 3, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect MNAR patterns and encode as boolean feature flags. + + Identifies Missing Not At Random patterns where the probability + of a value being missing depends on other feature values. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_cols: Columns to analyse (default: all with nulls). + sample_size: Max null rows to sample per column. + max_retries: LLM retry budget. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Tuple of ``(augmented_df, summary)`` where *summary* maps + each analysed column to its pattern description. + """ + from loclean.extraction.missingness_recognizer import MissingnessRecognizer + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + recognizer = MissingnessRecognizer( + inference_engine=inference_engine, + sample_size=sample_size, + max_retries=max_retries, + ) + return recognizer.recognize(df, target_cols) + + +def audit_leakage( + df: IntoFrameT, + target_col: str, + domain: str = "", + *, + max_retries: int = 2, + sample_n: int = 10, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> tuple[IntoFrameT, dict[str, Any]]: + """Detect and remove target-leaking features. + + Identifies features that contain information generated after the + target event, where P(Y | X_i) ≈ 1. Uses semantic timeline + evaluation via the LLM. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the prediction target. + domain: Brief dataset domain description. + max_retries: LLM retry budget. + sample_n: Sample rows for the prompt. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Tuple of ``(pruned_df, summary)`` with ``dropped_columns`` + and ``verdicts``. + """ + from loclean.extraction.leakage_auditor import TargetLeakageAuditor + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + auditor = TargetLeakageAuditor( + inference_engine=inference_engine, + max_retries=max_retries, + sample_n=sample_n, + ) + return auditor.audit(df, target_col, domain) diff --git a/src/loclean/extraction/__init__.py b/src/loclean/extraction/__init__.py index 64e9d66..460b0b8 100644 --- a/src/loclean/extraction/__init__.py +++ b/src/loclean/extraction/__init__.py @@ -9,18 +9,24 @@ if TYPE_CHECKING: from .feature_discovery import FeatureDiscovery + from .leakage_auditor import TargetLeakageAuditor + from .missingness_recognizer import MissingnessRecognizer from .optimizer import InstructionOptimizer from .oversampler import SemanticOversampler from .resolver import EntityResolver from .shredder import RelationalShredder + from .trap_pruner import TrapPruner __all__ = [ "EntityResolver", "Extractor", "FeatureDiscovery", "InstructionOptimizer", + "MissingnessRecognizer", "RelationalShredder", "SemanticOversampler", + "TargetLeakageAuditor", + "TrapPruner", "extract_dataframe_compiled", ] @@ -28,8 +34,11 @@ "EntityResolver": ".resolver", "FeatureDiscovery": ".feature_discovery", "InstructionOptimizer": ".optimizer", + "MissingnessRecognizer": ".missingness_recognizer", "RelationalShredder": ".shredder", "SemanticOversampler": ".oversampler", + "TargetLeakageAuditor": ".leakage_auditor", + "TrapPruner": ".trap_pruner", } From 67c5e494eb058d6306e53e2e251ce42b7c56c03d Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 17:44:03 +0000 Subject: [PATCH 3/8] test(extraction): add unit tests for TrapPruner, MissingnessRecognizer, TargetLeakageAuditor - 13 tests each (39 total) covering profiling, prompt construction, verdict parsing, verification, and mock-LLM integration --- tests/unit/extraction/test_leakage_auditor.py | 228 ++++++++++++++++ .../extraction/test_missingness_recognizer.py | 195 ++++++++++++++ tests/unit/extraction/test_trap_pruner.py | 247 ++++++++++++++++++ 3 files changed, 670 insertions(+) create mode 100644 tests/unit/extraction/test_leakage_auditor.py create mode 100644 tests/unit/extraction/test_missingness_recognizer.py create mode 100644 tests/unit/extraction/test_trap_pruner.py diff --git a/tests/unit/extraction/test_leakage_auditor.py b/tests/unit/extraction/test_leakage_auditor.py new file mode 100644 index 0000000..5e56578 --- /dev/null +++ b/tests/unit/extraction/test_leakage_auditor.py @@ -0,0 +1,228 @@ +"""Unit tests for the TargetLeakageAuditor module.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import narwhals as nw +import polars as pl +import pytest + +from loclean.extraction.leakage_auditor import TargetLeakageAuditor + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_engine(response: str) -> MagicMock: + engine = MagicMock() + engine.generate.return_value = response + return engine + + +def _sample_df() -> pl.DataFrame: + return pl.DataFrame( + { + "age": [25, 30, 45, 50, 35], + "income": [50000, 60000, 80000, 90000, 55000], + "approved_date": [ + "2024-01-15", + "2024-01-20", + "2024-02-01", + "2024-02-10", + "2024-01-25", + ], + "feedback_score": [4, 5, 3, 5, 4], + "approved": [True, True, False, True, True], + } + ) + + +# ------------------------------------------------------------------ +# _extract_state +# ------------------------------------------------------------------ + + +class TestExtractState: + def test_extracts_features_and_samples(self) -> None: + df = _sample_df() + df_nw = nw.from_native(df) + features = ["age", "income", "approved_date", "feedback_score"] + state = TargetLeakageAuditor._extract_state(df_nw, "approved", features) + + assert state["target_col"] == "approved" + assert state["features"] == features + assert len(state["sample_rows"]) <= 10 + assert "age" in state["dtypes"] + + def test_respects_sample_n(self) -> None: + df = _sample_df() + df_nw = nw.from_native(df) + state = TargetLeakageAuditor._extract_state( + df_nw, "approved", ["age"], sample_n=2 + ) + assert len(state["sample_rows"]) == 2 + + +# ------------------------------------------------------------------ +# _build_prompt +# ------------------------------------------------------------------ + + +class TestBuildPrompt: + def test_includes_domain_and_target(self) -> None: + state = { + "target_col": "approved", + "features": ["age", "income"], + "dtypes": {"age": "Int64", "income": "Int64"}, + "sample_rows": [{"age": 25, "income": 50000, "approved": True}], + } + prompt = TargetLeakageAuditor._build_prompt(state, "loan approval prediction") + assert "loan approval prediction" in prompt + assert "approved" in prompt + assert "age" in prompt + assert "is_leakage" in prompt + + def test_no_domain(self) -> None: + state = { + "target_col": "y", + "features": ["x"], + "dtypes": {"x": "Float64"}, + "sample_rows": [{"x": 1.0, "y": 0}], + } + prompt = TargetLeakageAuditor._build_prompt(state, "") + assert "Dataset domain:" not in prompt + + +# ------------------------------------------------------------------ +# _parse_verdict +# ------------------------------------------------------------------ + + +class TestParseVerdict: + def test_parses_valid_json(self) -> None: + response = json.dumps( + [ + {"column": "approved_date", "is_leakage": True, "reason": "Post-event"}, + {"column": "age", "is_leakage": False, "reason": "Pre-event"}, + ] + ) + verdicts = TargetLeakageAuditor._parse_verdict(response) + assert len(verdicts) == 2 + assert verdicts[0]["column"] == "approved_date" + assert verdicts[0]["is_leakage"] is True + assert verdicts[1]["is_leakage"] is False + + def test_handles_extra_text(self) -> None: + response = ( + 'Analysis:\n[{"column": "x", "is_leakage": false, "reason": "ok"}]\nEnd.' + ) + verdicts = TargetLeakageAuditor._parse_verdict(response) + assert len(verdicts) == 1 + + def test_raises_on_no_json(self) -> None: + with pytest.raises(ValueError, match="No JSON array"): + TargetLeakageAuditor._parse_verdict("no json here") + + +# ------------------------------------------------------------------ +# audit (integration with mock LLM) +# ------------------------------------------------------------------ + + +class TestAudit: + def test_drops_leaked_columns(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + {"column": "income", "is_leakage": False, "reason": "ok"}, + {"column": "approved_date", "is_leakage": True, "reason": "Post-event"}, + { + "column": "feedback_score", + "is_leakage": True, + "reason": "Post-event", + }, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + pruned, summary = auditor.audit(df, "approved", "loan approval") + + assert "approved_date" not in pruned.columns + assert "feedback_score" not in pruned.columns + assert "age" in pruned.columns + assert "income" in pruned.columns + assert "approved" in pruned.columns + assert "approved_date" in summary["dropped_columns"] + assert "feedback_score" in summary["dropped_columns"] + + def test_keeps_all_if_no_leakage(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + {"column": "income", "is_leakage": False, "reason": "ok"}, + {"column": "approved_date", "is_leakage": False, "reason": "ok"}, + {"column": "feedback_score", "is_leakage": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + pruned, summary = auditor.audit(df, "approved") + + assert set(pruned.columns) == set(df.columns) + assert summary["dropped_columns"] == [] + + def test_missing_target_raises(self) -> None: + df = _sample_df() + engine = _make_engine("[]") + auditor = TargetLeakageAuditor(inference_engine=engine) + + with pytest.raises(ValueError, match="not found"): + auditor.audit(df, "nonexistent") + + def test_no_feature_columns(self) -> None: + df = pl.DataFrame({"target": [1, 2, 3]}) + engine = _make_engine("[]") + auditor = TargetLeakageAuditor(inference_engine=engine) + + pruned, summary = auditor.audit(df, "target") + + assert pruned.columns == ["target"] + assert summary["dropped_columns"] == [] + engine.generate.assert_not_called() + + def test_summary_contains_verdicts(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + _, summary = auditor.audit(df, "approved") + + assert "verdicts" in summary + assert isinstance(summary["verdicts"], list) + + def test_domain_passed_to_prompt(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "age", "is_leakage": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + auditor = TargetLeakageAuditor(inference_engine=engine) + + auditor.audit(df, "approved", domain="healthcare readmission") + + call_args = engine.generate.call_args[0][0] + assert "healthcare readmission" in call_args diff --git a/tests/unit/extraction/test_missingness_recognizer.py b/tests/unit/extraction/test_missingness_recognizer.py new file mode 100644 index 0000000..13e44c1 --- /dev/null +++ b/tests/unit/extraction/test_missingness_recognizer.py @@ -0,0 +1,195 @@ +"""Unit tests for the MissingnessRecognizer module.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import narwhals as nw +import polars as pl + +from loclean.extraction.missingness_recognizer import MissingnessRecognizer + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_engine(response: str) -> MagicMock: + engine = MagicMock() + engine.generate.return_value = response + return engine + + +_ENCODER_SRC = ( + "def encode_missingness(row: dict) -> bool:\n" + " try:\n" + " return row.get('category') == 'electronics'\n" + " except Exception:\n" + " return False\n" +) + + +def _df_with_nulls() -> pl.DataFrame: + return pl.DataFrame( + { + "price": [100.0, None, 300.0, None, 500.0, None], + "category": [ + "clothing", + "electronics", + "clothing", + "electronics", + "clothing", + "electronics", + ], + "quantity": [10, 5, 20, 3, 15, 1], + } + ) + + +# ------------------------------------------------------------------ +# _find_null_columns +# ------------------------------------------------------------------ + + +class TestFindNullColumns: + def test_detects_columns_with_nulls(self) -> None: + df = _df_with_nulls() + df_nw = nw.from_native(df) + result = MissingnessRecognizer._find_null_columns(df_nw) + assert result == ["price"] + + def test_no_nulls(self) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + df_nw = nw.from_native(df) + result = MissingnessRecognizer._find_null_columns(df_nw) + assert result == [] + + +# ------------------------------------------------------------------ +# _sample_null_context +# ------------------------------------------------------------------ + + +class TestSampleNullContext: + def test_samples_rows_where_target_is_null(self) -> None: + df = _df_with_nulls() + df_nw = nw.from_native(df) + sample = MissingnessRecognizer._sample_null_context( + df_nw, "price", ["category", "quantity"] + ) + assert len(sample) == 3 + for row in sample: + assert "category" in row + assert "quantity" in row + assert "price" not in row + + def test_respects_max_rows(self) -> None: + df = _df_with_nulls() + df_nw = nw.from_native(df) + sample = MissingnessRecognizer._sample_null_context( + df_nw, "price", ["category"], max_rows=2 + ) + assert len(sample) == 2 + + def test_empty_when_no_nulls(self) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + df_nw = nw.from_native(df) + sample = MissingnessRecognizer._sample_null_context(df_nw, "a", ["b"]) + assert sample == [] + + +# ------------------------------------------------------------------ +# _build_prompt +# ------------------------------------------------------------------ + + +class TestBuildPrompt: + def test_includes_column_name_and_sample(self) -> None: + prompt = MissingnessRecognizer._build_prompt( + "price", + ["category", "quantity"], + [{"category": "electronics", "quantity": 5}], + ) + assert "price" in prompt + assert "encode_missingness" in prompt + assert "electronics" in prompt + + def test_includes_rules(self) -> None: + prompt = MissingnessRecognizer._build_prompt("x", ["y"], [{"y": 1}]) + assert "try/except" in prompt + assert "boolean" in prompt.lower() or "bool" in prompt.lower() + + +# ------------------------------------------------------------------ +# _verify_encoder +# ------------------------------------------------------------------ + + +class TestVerifyEncoder: + def test_valid_encoder_passes(self) -> None: + def good_fn(row: dict[str, Any]) -> bool: + return True + + ok, err = MissingnessRecognizer._verify_encoder(good_fn, [{"a": 1}, {"a": 2}]) + assert ok is True + assert err == "" + + def test_non_bool_return_fails(self) -> None: + def bad_fn(row: dict[str, Any]) -> Any: + return "not a bool" + + ok, err = MissingnessRecognizer._verify_encoder(bad_fn, [{"a": 1}]) + assert ok is False + assert "bool" in err.lower() + + +# ------------------------------------------------------------------ +# recognize (integration with mock LLM) +# ------------------------------------------------------------------ + + +class TestRecognize: + def test_adds_mnar_column(self) -> None: + engine = _make_engine(_ENCODER_SRC) + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + df = _df_with_nulls() + result, summary = recognizer.recognize(df) + + assert "price_mnar" in result.columns + assert "price" in summary["patterns"] + assert summary["patterns"]["price"]["encoded_as"] == "price_mnar" + + def test_no_nulls_skips(self) -> None: + engine = _make_engine("") + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result, summary = recognizer.recognize(df) + + assert set(result.columns) == {"a", "b"} + assert summary["patterns"] == {} + engine.generate.assert_not_called() + + def test_target_cols_filter(self) -> None: + df = pl.DataFrame( + { + "a": [1, None, 3], + "b": [None, 2, None], + "c": [10, 20, 30], + } + ) + engine = _make_engine(_ENCODER_SRC) + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + _, summary = recognizer.recognize(df, target_cols=["a"]) + + assert "a" in summary["patterns"] + assert "b" not in summary["patterns"] + + def test_compile_failure_returns_none_pattern(self) -> None: + engine = _make_engine("this is not valid python at all!!!") + recognizer = MissingnessRecognizer(inference_engine=engine, max_retries=1) + df = _df_with_nulls() + result, summary = recognizer.recognize(df) + + assert "price_mnar" not in result.columns + assert summary["patterns"]["price"] is None diff --git a/tests/unit/extraction/test_trap_pruner.py b/tests/unit/extraction/test_trap_pruner.py new file mode 100644 index 0000000..03ddcd0 --- /dev/null +++ b/tests/unit/extraction/test_trap_pruner.py @@ -0,0 +1,247 @@ +"""Unit tests for the TrapPruner module.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import narwhals as nw +import polars as pl +import pytest + +from loclean.extraction.trap_pruner import TrapPruner, _ColumnProfile + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + + +def _make_engine(response: str) -> MagicMock: + engine = MagicMock() + engine.generate.return_value = response + return engine + + +def _sample_df() -> pl.DataFrame: + """DataFrame with one real feature and one Gaussian noise column.""" + import random + + random.seed(42) + n = 100 + prices = [200_000 + i * 5000 for i in range(n)] + sqft = [150 + i * 10 + random.randint(-5, 5) for i in range(n)] + noise = [random.gauss(0, 1) for _ in range(n)] + + return pl.DataFrame( + { + "sqft": sqft, + "noise_feat": noise, + "price": prices, + } + ) + + +# ------------------------------------------------------------------ +# _profile_columns +# ------------------------------------------------------------------ + + +class TestProfileColumns: + def test_basic_stats(self) -> None: + df = _sample_df() + df_nw = nw.from_native(df) + profiles = TrapPruner._profile_columns(df_nw, "price", ["sqft", "noise_feat"]) + assert len(profiles) == 2 + + sqft_p = next(p for p in profiles if p.name == "sqft") + noise_p = next(p for p in profiles if p.name == "noise_feat") + + assert abs(sqft_p.corr_with_target) > 0.5 + assert abs(noise_p.corr_with_target) < 0.2 + + def test_zero_variance_column(self) -> None: + df = pl.DataFrame( + { + "constant": [5] * 10, + "target": list(range(10)), + } + ) + df_nw = nw.from_native(df) + profiles = TrapPruner._profile_columns(df_nw, "target", ["constant"]) + assert len(profiles) == 1 + assert profiles[0].variance == 0.0 + assert profiles[0].corr_with_target == 0.0 + + def test_single_row_returns_empty(self) -> None: + df = pl.DataFrame({"a": [1], "target": [2]}) + df_nw = nw.from_native(df) + profiles = TrapPruner._profile_columns(df_nw, "target", ["a"]) + assert profiles == [] + + +# ------------------------------------------------------------------ +# _build_prompt +# ------------------------------------------------------------------ + + +class TestBuildPrompt: + def test_anonymises_column_names(self) -> None: + profiles = [ + _ColumnProfile( + name="secret_column", + mean=0.0, + std=1.0, + variance=1.0, + skewness=0.0, + kurtosis=0.0, + min_val=-3.0, + max_val=3.0, + corr_with_target=0.01, + ), + ] + col_map, prompt = TrapPruner._build_prompt(profiles) + + assert "secret_column" not in prompt + assert "col_0" in prompt + assert col_map["col_0"] == "secret_column" + + def test_multiple_columns_indexed(self) -> None: + profiles = [ + _ColumnProfile( + name=f"feat_{i}", + mean=float(i), + std=1.0, + variance=1.0, + skewness=0.0, + kurtosis=0.0, + min_val=0.0, + max_val=10.0, + corr_with_target=0.5, + ) + for i in range(3) + ] + col_map, prompt = TrapPruner._build_prompt(profiles) + assert len(col_map) == 3 + assert "col_0" in prompt + assert "col_1" in prompt + assert "col_2" in prompt + + +# ------------------------------------------------------------------ +# _parse_verdict +# ------------------------------------------------------------------ + + +class TestParseVerdict: + def test_maps_anonymous_to_real(self) -> None: + col_map = {"col_0": "noise_feat", "col_1": "real_feat"} + response = json.dumps( + [ + {"column": "col_0", "is_trap": True, "reason": "Gaussian noise"}, + {"column": "col_1", "is_trap": False, "reason": "Correlated"}, + ] + ) + + verdicts = TrapPruner._parse_verdict(response, col_map) + assert len(verdicts) == 2 + assert verdicts[0]["column"] == "noise_feat" + assert verdicts[0]["is_trap"] is True + assert verdicts[1]["column"] == "real_feat" + assert verdicts[1]["is_trap"] is False + + def test_handles_extra_text_around_json(self) -> None: + col_map = {"col_0": "feat_a"} + response = ( + "Here is the analysis:\n" + '[{"column": "col_0", "is_trap": false, "reason": "ok"}]\nDone.' + ) + + verdicts = TrapPruner._parse_verdict(response, col_map) + assert len(verdicts) == 1 + assert verdicts[0]["column"] == "feat_a" + + def test_raises_on_no_json(self) -> None: + with pytest.raises(ValueError, match="No JSON array"): + TrapPruner._parse_verdict("no json here", {}) + + +# ------------------------------------------------------------------ +# prune (integration with mock LLM) +# ------------------------------------------------------------------ + + +class TestPrune: + def test_removes_trap_columns(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "col_0", "is_trap": False, "reason": "Correlated"}, + {"column": "col_1", "is_trap": True, "reason": "Gaussian noise"}, + ] + ) + engine = _make_engine(response) + pruner = TrapPruner(inference_engine=engine) + + pruned, summary = pruner.prune(df, "price") + + assert "noise_feat" not in pruned.columns + assert "sqft" in pruned.columns + assert "price" in pruned.columns + assert "noise_feat" in summary["dropped_columns"] + + def test_keeps_all_if_no_traps(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "col_0", "is_trap": False, "reason": "ok"}, + {"column": "col_1", "is_trap": False, "reason": "ok"}, + ] + ) + engine = _make_engine(response) + pruner = TrapPruner(inference_engine=engine) + + pruned, summary = pruner.prune(df, "price") + + assert set(pruned.columns) == set(df.columns) + assert summary["dropped_columns"] == [] + + def test_returns_summary_with_verdicts(self) -> None: + df = _sample_df() + response = json.dumps( + [ + {"column": "col_0", "is_trap": False, "reason": "real"}, + {"column": "col_1", "is_trap": True, "reason": "noise"}, + ] + ) + engine = _make_engine(response) + pruner = TrapPruner(inference_engine=engine) + + _, summary = pruner.prune(df, "price") + + assert "dropped_columns" in summary + assert "verdicts" in summary + assert len(summary["verdicts"]) == 2 + + def test_missing_target_raises(self) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + engine = _make_engine("[]") + pruner = TrapPruner(inference_engine=engine) + + with pytest.raises(ValueError, match="not found"): + pruner.prune(df, "nonexistent") + + def test_no_numeric_columns(self) -> None: + df = pl.DataFrame( + { + "name": ["alice", "bob"], + "target": [1, 2], + } + ) + engine = _make_engine("[]") + pruner = TrapPruner(inference_engine=engine) + + pruned, summary = pruner.prune(df, "target") + + assert set(pruned.columns) == set(df.columns) + assert summary["dropped_columns"] == [] + engine.generate.assert_not_called() From 8c235f359899ecdc877239abcfdea67a768266ba Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 18:03:02 +0000 Subject: [PATCH 4/8] feat(utils): add source_sanitizer for LLM output cleanup - Strip markdown fences, prose, and backticks - Fix unicode operators and invalid numeric literals - 17 unit tests covering all transformation stages --- src/loclean/utils/source_sanitizer.py | 137 ++++++++++++++++++++ tests/unit/utils/test_source_sanitizer.py | 147 ++++++++++++++++++++++ 2 files changed, 284 insertions(+) create mode 100644 src/loclean/utils/source_sanitizer.py create mode 100644 tests/unit/utils/test_source_sanitizer.py diff --git a/src/loclean/utils/source_sanitizer.py b/src/loclean/utils/source_sanitizer.py new file mode 100644 index 0000000..20d4940 --- /dev/null +++ b/src/loclean/utils/source_sanitizer.py @@ -0,0 +1,137 @@ +"""Deterministic source-code sanitizer for LLM-generated Python. + +Small models (phi3, etc.) frequently produce output with markdown +fences, prose preambles, non-ASCII operators, and invalid numeric +literals. This module fixes those issues mechanically — no LLM calls +required — before the code reaches ``compile_sandboxed``. +""" + +from __future__ import annotations + +import re + + +def sanitize_source(source: str) -> str: + """Clean up common LLM output artifacts from Python source code. + + Applies a sequence of deterministic transformations: + + 1. Strip markdown code fences (````python`` / `````) + 2. Remove prose before the first ``import`` / ``def`` / ``from`` + 3. Remove trailing prose after the last function body + 4. Replace non-ASCII mathematical operators + 5. Fix invalid numeric literals + 6. Strip stray inline backticks + + Args: + source: Raw LLM-generated Python source. + + Returns: + Cleaned source code ready for ``compile_sandboxed``. + """ + source = _strip_markdown_fences(source) + source = _strip_prose(source) + source = _fix_unicode_operators(source) + source = _fix_numeric_literals(source) + source = _strip_backticks(source) + return source + + +def _strip_markdown_fences(source: str) -> str: + """Remove markdown code fences wrapping the code block.""" + lines = source.split("\n") + cleaned: list[str] = [] + for line in lines: + stripped = line.strip() + if stripped.startswith("```"): + continue + cleaned.append(line) + return "\n".join(cleaned) + + +def _strip_prose(source: str) -> str: + """Remove explanatory text before/after the actual code. + + Keeps lines starting from the first ``import``, ``from``, or + ``def`` statement through the end of the last indented block. + """ + lines = source.split("\n") + + start_idx = 0 + for i, line in enumerate(lines): + stripped = line.strip() + if stripped.startswith(("import ", "from ", "def ", "class ")): + start_idx = i + break + + end_idx = len(lines) + for i in range(len(lines) - 1, start_idx - 1, -1): + stripped = lines[i].strip() + if stripped and not _is_prose_line(stripped): + end_idx = i + 1 + break + + return "\n".join(lines[start_idx:end_idx]) + + +def _is_prose_line(line: str) -> bool: + """Heuristic: a line is 'prose' if it looks like natural language.""" + if not line: + return False + if line.startswith(("#", "import ", "from ", "def ", "class ", "return ")): + return False + if line[0] in (" ", "\t", "@"): + return False + words = line.split() + if len(words) >= 4 and not any(c in line for c in ("=", "(", ")", "[", "]", ":")): + return True + return False + + +_UNICODE_MAP: dict[str, str] = { + "\u00d7": "*", # × + "\u00f7": "/", # ÷ + "\u2212": "-", # − (minus sign) + "\u2013": "-", # – (en dash) + "\u2014": "-", # — (em dash) + "\u2018": "'", # ' + "\u2019": "'", # ' + "\u201c": '"', # " + "\u201d": '"', # " + "\u2264": "<=", # ≤ + "\u2265": ">=", # ≥ + "\u2260": "!=", # ≠ +} + + +def _fix_unicode_operators(source: str) -> str: + """Replace non-ASCII mathematical and typographic characters.""" + for char, replacement in _UNICODE_MAP.items(): + source = source.replace(char, replacement) + return source + + +def _fix_numeric_literals(source: str) -> str: + """Fix invalid numeric literals commonly produced by small models. + + Patterns handled: + - ``0b2``, ``0b3`` etc. (invalid binary digits) → decimal + - Trailing currency/unit symbols on numbers (``100$``, ``50€``) + """ + source = re.sub( + r"\b0b([2-9]\d*)\b", + lambda m: m.group(1), + source, + ) + source = re.sub( + r"(\d+\.?\d*)[€$£%]+", + r"\1", + source, + ) + return source + + +def _strip_backticks(source: str) -> str: + """Remove stray inline backticks wrapping expressions.""" + source = re.sub(r"`([^`\n]+)`", r"\1", source) + return source diff --git a/tests/unit/utils/test_source_sanitizer.py b/tests/unit/utils/test_source_sanitizer.py new file mode 100644 index 0000000..edde1c4 --- /dev/null +++ b/tests/unit/utils/test_source_sanitizer.py @@ -0,0 +1,147 @@ +"""Tests for loclean.utils.source_sanitizer module.""" + +from loclean.utils.source_sanitizer import sanitize_source + + +class TestStripMarkdownFences: + """Markdown code fences should be removed.""" + + def test_python_fences(self) -> None: + source = "```python\ndef f():\n return 1\n```" + result = sanitize_source(source) + assert "```" not in result + assert "def f():" in result + + def test_triple_backtick_only(self) -> None: + source = "```\ndef f():\n return 1\n```" + result = sanitize_source(source) + assert "```" not in result + assert "def f():" in result + + def test_no_fences_passthrough(self) -> None: + source = "def f():\n return 1" + assert sanitize_source(source) == source + + +class TestStripProse: + """Leading/trailing prose should be removed.""" + + def test_leading_explanation(self) -> None: + source = ( + "Here is the corrected function:\n\n" + "def generate_features(row):\n" + " return {'a': row['x'] * 2}\n" + ) + result = sanitize_source(source) + assert result.startswith("def generate_features") + + def test_trailing_explanation(self) -> None: + source = ( + "def f(row):\n" + " return {'a': 1}\n\n" + "This function computes a simple feature by multiplying the value." + ) + result = sanitize_source(source) + assert "This function computes" not in result + assert "def f(row):" in result + + def test_import_preserved(self) -> None: + source = "import math\n\ndef f(row):\n return {'a': math.log(1)}" + result = sanitize_source(source) + assert result.startswith("import math") + + +class TestFixUnicodeOperators: + """Non-ASCII math operators should be replaced.""" + + def test_multiplication(self) -> None: + source = "def f(row):\n return row['a'] \u00d7 row['b']" + result = sanitize_source(source) + assert "\u00d7" not in result + assert "row['a'] * row['b']" in result + + def test_division(self) -> None: + source = "def f(row):\n return row['a'] \u00f7 row['b']" + result = sanitize_source(source) + assert "/" in result + + def test_minus_sign(self) -> None: + source = "def f(row):\n return row['a'] \u2212 row['b']" + result = sanitize_source(source) + assert "\u2212" not in result + assert "-" in result + + def test_smart_quotes(self) -> None: + source = "def f(row):\n return row[\u2018name\u2019]" + result = sanitize_source(source) + assert "\u2018" not in result + assert "\u2019" not in result + + def test_comparison_operators(self) -> None: + source = "def f(x):\n return x \u2264 10" + result = sanitize_source(source) + assert "<=" in result + + +class TestFixNumericLiterals: + """Invalid numeric literals should be fixed.""" + + def test_invalid_binary_digit(self) -> None: + source = "def f():\n x = 0b2\n return x" + result = sanitize_source(source) + assert "0b2" not in result + assert "2" in result + + def test_trailing_currency(self) -> None: + source = "def f():\n return 100$" + result = sanitize_source(source) + assert "$" not in result + assert "100" in result + + def test_valid_binary_untouched(self) -> None: + source = "def f():\n return 0b101" + result = sanitize_source(source) + assert "0b101" in result + + +class TestStripBackticks: + """Stray inline backticks should be removed.""" + + def test_wrapped_expression(self) -> None: + source = "def f(row):\n return `math.log(row['x'])`" + result = sanitize_source(source) + assert "`" not in result + assert "math.log(row['x'])" in result + + def test_no_backticks_passthrough(self) -> None: + source = "def f():\n return 42" + assert sanitize_source(source) == source + + +class TestEndToEnd: + """Full pipeline integration tests.""" + + def test_complete_cleanup(self) -> None: + source = ( + "Here is your function:\n\n" + "```python\n" + "import math\n\n" + "def generate_features(row):\n" + " result = {}\n" + " result[\u2018log_price\u2019] = math.log(row[\u2018price\u2019])\n" + " result['ratio'] = row['a'] \u00d7 row['b']\n" + " return result\n" + "```\n\n" + "This function generates two features." + ) + result = sanitize_source(source) + + assert "```" not in result + assert "Here is your function" not in result + assert "This function generates" not in result + assert "\u2018" not in result + assert "\u00d7" not in result + assert "import math" in result + assert "def generate_features(row):" in result + assert "math.log" in result + assert "* row['b']" in result From 6cad8a7e935ecfd130aa0d511e206387db3603e9 Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 18:03:10 +0000 Subject: [PATCH 5/8] refactor(sandbox): add restricted __import__ to compile_sandboxed - LLM-generated import statements now only work for explicitly allowed modules; all others raise ImportError - Preload modules into safe_globals for direct namespace access - Updated docstring to document the restriction --- src/loclean/utils/sandbox.py | 38 +++++++++++++++++++++++++++----- tests/unit/utils/test_sandbox.py | 7 +++--- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/loclean/utils/sandbox.py b/src/loclean/utils/sandbox.py index 632020c..fcf1025 100644 --- a/src/loclean/utils/sandbox.py +++ b/src/loclean/utils/sandbox.py @@ -93,9 +93,11 @@ def compile_sandboxed( The execution environment has: * ``__builtins__`` replaced by a curated safe subset (no ``open``, - ``exec``, ``eval``, ``__import__``, ``compile``, ``exit``, - ``quit``, ``input``, ``breakpoint``, ``globals``, ``locals``, - ``vars``, ``dir``). + ``exec``, ``eval``, ``compile``, ``exit``, ``quit``, ``input``, + ``breakpoint``, ``globals``, ``locals``, ``vars``, ``dir``). + * A restricted ``__import__`` that only permits explicitly listed + modules — LLM-generated ``import`` statements work for allowed + modules but raise ``ImportError`` for anything else. * Only explicitly listed standard-library modules injected. Args: @@ -110,16 +112,40 @@ def compile_sandboxed( Raises: ValueError: If compilation fails or *fn_name* is not defined. """ - safe_globals: dict[str, Any] = {"__builtins__": _SAFE_BUILTINS.copy()} + allowed = set(allowed_modules or []) + preloaded: dict[str, Any] = {} - for mod_name in allowed_modules or []: + for mod_name in allowed: try: - safe_globals[mod_name] = importlib.import_module(mod_name) + preloaded[mod_name] = importlib.import_module(mod_name) except ImportError: logger.warning( f"[yellow]⚠[/yellow] Module '{mod_name}' not available, skipping" ) + def _restricted_import( + name: str, + globals: Any = None, + locals: Any = None, + fromlist: Any = (), + level: int = 0, + ) -> Any: + root = name.split(".")[0] + if root not in allowed: + raise ImportError( + f"Import of '{name}' is not allowed in the sandbox. " + f"Permitted modules: {sorted(allowed)}" + ) + if root in preloaded: + return preloaded[root] + return importlib.import_module(name) + + builtins = _SAFE_BUILTINS.copy() + builtins["__import__"] = _restricted_import + + safe_globals: dict[str, Any] = {"__builtins__": builtins} + safe_globals.update(preloaded) + try: exec(source, safe_globals) # noqa: S102 except Exception as exc: diff --git a/tests/unit/utils/test_sandbox.py b/tests/unit/utils/test_sandbox.py index bb03e4b..9e028fd 100644 --- a/tests/unit/utils/test_sandbox.py +++ b/tests/unit/utils/test_sandbox.py @@ -41,10 +41,9 @@ def test_eval_blocked(self) -> None: fn() def test_import_blocked(self) -> None: - source = "def f():\n return __import__('os')\n" - fn = compile_sandboxed(source, "f") - with pytest.raises(NameError): - fn() + source = "import os\ndef f():\n return os.getcwd()\n" + with pytest.raises(ValueError, match="not allowed in the sandbox"): + compile_sandboxed(source, "f") def test_import_statement_blocked(self) -> None: source = "import os\ndef f():\n return os.listdir('.')\n" From b4354023d9eb0305488fc52b94b8daae6e8e6920 Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 18:03:17 +0000 Subject: [PATCH 6/8] refactor(extraction): harden retry loop and improve error messages - Wrap initial compile in try/except to catch ValueErrors - Add per-retry logging with attempt counter - Replace vague failure messages with actionable guidance (model suggestions, max_retries hint) - Add concrete code examples to LLM prompts for better output --- src/loclean/extraction/feature_discovery.py | 78 +++++++++++++++------ src/loclean/extraction/shredder.py | 73 +++++++++++++------ 2 files changed, 107 insertions(+), 44 deletions(-) diff --git a/src/loclean/extraction/feature_discovery.py b/src/loclean/extraction/feature_discovery.py index 6aaa939..5ddf5b6 100644 --- a/src/loclean/extraction/feature_discovery.py +++ b/src/loclean/extraction/feature_discovery.py @@ -103,21 +103,37 @@ def discover( return result.to_native() # type: ignore[no-any-return,return-value] source = self._propose_features(state) - fn = self._compile_function(source) - sample_rows = state["sample_rows"] - ok, error = self._verify_function(fn, sample_rows, self.timeout_s) - retries = 0 - while not ok and retries < self.max_retries: - source = self._repair_function(source, error, state) + + try: fn = self._compile_function(source) ok, error = self._verify_function(fn, sample_rows, self.timeout_s) + except ValueError as exc: + ok, error = False, str(exc) + + retries = 0 + while not ok and retries < self.max_retries: retries += 1 + logger.warning( + f"[yellow]⚠[/yellow] Retrying code generation " + f"({retries}/{self.max_retries}): {error}" + ) + source = self._repair_function(source, error, state) + try: + fn = self._compile_function(source) + ok, error = self._verify_function(fn, sample_rows, self.timeout_s) + except ValueError as exc: + ok, error = False, str(exc) if not ok: logger.warning( - f"[yellow]⚠[/yellow] Feature generation failed after " - f"{self.max_retries} retries: {error} — returning original DataFrame" + f"[yellow]⚠[/yellow] The model could not generate valid Python " + f"code after {self.max_retries} retries. This is not a library " + f"bug — smaller models (e.g. phi3) sometimes produce syntax " + f"errors or invalid logic. Returning the original DataFrame.\n" + f" [dim]Last error: {error}[/dim]\n" + f" [dim]Tip: try a larger model " + f"(model='qwen2.5-coder:7b') or increase max_retries.[/dim]" ) return df @@ -199,11 +215,28 @@ def _propose_features(self, state: dict[str, Any]) -> str: "maximise mutual information I(X_new; Y) with the target.\n\n" "Write a pure Python function with this exact signature:\n\n" "def generate_features(row: dict) -> dict:\n\n" - "The function must:\n" + "EXAMPLE (for a different dataset with columns " + "'age', 'income', 'debt'):\n\n" + "import math\n\n" + "def generate_features(row: dict) -> dict:\n" + " result = {}\n" + " try:\n" + " result['debt_to_income'] = " + "row['debt'] / row['income'] if row['income'] else None\n" + " except Exception:\n" + " result['debt_to_income'] = None\n" + " try:\n" + " result['log_income'] = " + "math.log(row['income']) if row['income'] and " + "row['income'] > 0 else None\n" + " except Exception:\n" + " result['log_income'] = None\n" + " return result\n\n" + "Now write yours for the dataset above. The function must:\n" "- Accept a dict of column_name: value pairs\n" f"- Return a dict with exactly {self.n_features} new " "key-value pairs (the new feature names and values)\n" - "- Use ONLY standard library modules (math, etc.)\n" + "- Use ONLY standard library modules (math, statistics, operator)\n" "- Wrap each calculation in try/except, defaulting to " "None on failure\n" "- Use descriptive feature names like 'ratio_a_b' or " @@ -213,12 +246,7 @@ def _propose_features(self, state: dict[str, Any]) -> str: ) raw = self.inference_engine.generate(prompt) - source = str(raw).strip() - if source.startswith("```"): - lines = source.split("\n") - lines = [line for line in lines if not line.strip().startswith("```")] - source = "\n".join(lines) - return source + return str(raw).strip() # ------------------------------------------------------------------ # Compilation @@ -230,6 +258,10 @@ def _compile_function( ) -> Callable[[dict[str, Any]], dict[str, Any]]: """Compile source code in a restricted sandbox. + Applies deterministic sanitization before compilation to fix + common LLM output artifacts (markdown fences, non-ASCII + operators, invalid literals, etc.). + Args: source: Python source containing ``generate_features``. @@ -240,8 +272,13 @@ def _compile_function( ValueError: If compilation fails or function not found. """ from loclean.utils.sandbox import compile_sandboxed + from loclean.utils.source_sanitizer import sanitize_source - return compile_sandboxed(source, "generate_features", ["math"]) + return compile_sandboxed( + sanitize_source(source), + "generate_features", + ["math", "statistics", "operator"], + ) # ------------------------------------------------------------------ # Verification @@ -312,12 +349,7 @@ def _repair_function( ) raw = self.inference_engine.generate(prompt) - repaired = str(raw).strip() - if repaired.startswith("```"): - lines = repaired.split("\n") - lines = [line for line in lines if not line.strip().startswith("```")] - repaired = "\n".join(lines) - return repaired + return str(raw).strip() # ------------------------------------------------------------------ # Application diff --git a/src/loclean/extraction/shredder.py b/src/loclean/extraction/shredder.py index eabdcd1..432e7d8 100644 --- a/src/loclean/extraction/shredder.py +++ b/src/loclean/extraction/shredder.py @@ -146,22 +146,40 @@ def shred( return self._separate_tables(results, schema, native_ns) source = self._generate_extractor(schema, samples) - extract_fn = self._compile_function(source) - ok, error = self._verify_function(extract_fn, samples, schema, self.timeout_s) - retries = 0 - while not ok and retries < self.max_retries: - source = self._repair_function(source, error, samples) + try: extract_fn = self._compile_function(source) ok, error = self._verify_function( extract_fn, samples, schema, self.timeout_s ) + except ValueError as exc: + ok, error = False, str(exc) + + retries = 0 + while not ok and retries < self.max_retries: retries += 1 + logger.warning( + f"[yellow]⚠[/yellow] Retrying code generation " + f"({retries}/{self.max_retries}): {error}" + ) + source = self._repair_function(source, error, samples) + try: + extract_fn = self._compile_function(source) + ok, error = self._verify_function( + extract_fn, samples, schema, self.timeout_s + ) + except ValueError as exc: + ok, error = False, str(exc) if not ok: logger.warning( - f"[yellow]⚠[/yellow] Code generation failed after " - f"{self.max_retries} retries: {error} — returning empty result" + f"[yellow]⚠[/yellow] The model could not generate valid Python " + f"code after {self.max_retries} retries. This is not a library " + f"bug — smaller models (e.g. phi3) sometimes produce syntax " + f"errors or invalid logic. Returning empty result.\n" + f" [dim]Last error: {error}[/dim]\n" + f" [dim]Tip: try a larger model " + f"(model='qwen2.5-coder:7b') or increase max_retries.[/dim]" ) return {} @@ -310,8 +328,26 @@ def _generate_extractor( f"Target tables:\n{table_specs}\n\n" "Sample log entries:\n" f"{json.dumps(samples[:5], ensure_ascii=False)}\n\n" + "EXAMPLE (for a different log format):\n\n" + "import re\n\n" + "def extract_relations(log: str) -> dict[str, dict]:\n" + " result = {}\n" + " try:\n" + " m = re.match(" + "r'(\\S+) (\\S+) \\[(.*?)\\] \"(\\S+)\"', log)\n" + " if m:\n" + " result['requests'] = {\n" + " 'ip': m.group(1),\n" + " 'method': m.group(4),\n" + " }\n" + " except Exception:\n" + " result['requests'] = " + "{'ip': '', 'method': ''}\n" + " return result\n\n" + "Now write yours for the log format above.\n\n" "Rules:\n" - "- Use ONLY standard library modules (re, string, etc.)\n" + "- Use ONLY standard library modules (re, json, " + "datetime, collections)\n" "- Wrap parsing logic in try/except blocks\n" "- Return empty strings for fields that cannot be parsed\n" "- Do NOT import any third-party libraries\n\n" @@ -319,12 +355,7 @@ def _generate_extractor( ) raw = self.inference_engine.generate(prompt) - source = str(raw).strip() - if source.startswith("```"): - lines = source.split("\n") - lines = [line for line in lines if not line.strip().startswith("```")] - source = "\n".join(lines) - return source + return str(raw).strip() @staticmethod def _compile_function( @@ -332,6 +363,10 @@ def _compile_function( ) -> Callable[[str], dict[str, dict[str, Any]]]: """Compile source code in a restricted sandbox. + Applies deterministic sanitization before compilation to fix + common LLM output artifacts (markdown fences, non-ASCII + operators, invalid literals, etc.). + Args: source: Python source containing ``extract_relations``. @@ -342,9 +377,10 @@ def _compile_function( ValueError: If compilation fails or function not found. """ from loclean.utils.sandbox import compile_sandboxed + from loclean.utils.source_sanitizer import sanitize_source return compile_sandboxed( - source, + sanitize_source(source), "extract_relations", ["re", "json", "datetime", "collections"], ) @@ -414,12 +450,7 @@ def _repair_function( ) raw = self.inference_engine.generate(prompt) - repaired = str(raw).strip() - if repaired.startswith("```"): - lines = repaired.split("\n") - lines = [line for line in lines if not line.strip().startswith("```")] - repaired = "\n".join(lines) - return repaired + return str(raw).strip() # ------------------------------------------------------------------ # Phase 3: Full execution + separation From d44cf11a0e98fa58c1e151dcee896f31cdb64e1d Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 18:03:28 +0000 Subject: [PATCH 7/8] perf(inference): cache verified models to skip redundant Ollama checks - Add module-level _verified_models set for deduplication - Fix model_exists to handle both dict and object API responses - Use model attribute (not name) for correct Ollama registry matching --- src/loclean/inference/model_manager.py | 16 ++++++++++++++-- tests/unit/inference/test_model_manager.py | 5 +++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/loclean/inference/model_manager.py b/src/loclean/inference/model_manager.py index eb09f9b..559fa51 100644 --- a/src/loclean/inference/model_manager.py +++ b/src/loclean/inference/model_manager.py @@ -22,6 +22,8 @@ logger = configure_module_logger(__name__, level=logging.INFO) +_verified_models: set[str] = set() + def model_exists(client: Any, model: str) -> bool: """Check whether *model* is already available in the local Ollama registry. @@ -38,9 +40,14 @@ def model_exists(client: Any, model: str) -> bool: except Exception: return False - models = response.get("models", []) + models = getattr(response, "models", None) + if models is None: + models = response.get("models", []) if isinstance(response, dict) else [] + for entry in models: - name: str = entry.get("name", "") + name: str = getattr(entry, "model", None) or ( + entry.get("name", "") if isinstance(entry, dict) else "" + ) if name == model or name.startswith(f"{model}:"): return True return False @@ -64,11 +71,15 @@ def ensure_model( Raises: RuntimeError: If the pull fails or encounters an error status. """ + if model in _verified_models: + return + if model_exists(client, model): logger.info( f"[green]✓[/green] Model [bold cyan]{model}[/bold cyan] " "is already available." ) + _verified_models.add(model) return if console is None: @@ -114,3 +125,4 @@ def ensure_model( console.print( f"[green]✓[/green] Model [bold cyan]{model}[/bold cyan] pulled successfully." ) + _verified_models.add(model) diff --git a/tests/unit/inference/test_model_manager.py b/tests/unit/inference/test_model_manager.py index 248203c..100ca6a 100644 --- a/tests/unit/inference/test_model_manager.py +++ b/tests/unit/inference/test_model_manager.py @@ -6,6 +6,7 @@ import pytest from rich.console import Console +from loclean.inference import model_manager from loclean.inference.model_manager import ensure_model, model_exists @@ -48,6 +49,10 @@ def _make_test_console() -> Console: class TestEnsureModel: """Tests for ensure_model.""" + @pytest.fixture(autouse=True) + def _clear_cache(self) -> None: + model_manager._verified_models.clear() + @patch("loclean.inference.model_manager.model_exists", return_value=True) def test_model_already_exists_skips_pull(self, _mock_exists: MagicMock) -> None: client = MagicMock() From 888298a2237eb547412c1fb29316b934a31307c1 Mon Sep 17 00:00:00 2001 From: nxank4 Date: Fri, 27 Feb 2026 18:03:36 +0000 Subject: [PATCH 8/8] chore: remove LIBRARY_SUMMARY.md Superseded by examples/README.md and module docstrings. --- LIBRARY_SUMMARY.md | 462 --------------------------------------------- 1 file changed, 462 deletions(-) delete mode 100644 LIBRARY_SUMMARY.md diff --git a/LIBRARY_SUMMARY.md b/LIBRARY_SUMMARY.md deleted file mode 100644 index b656d2e..0000000 --- a/LIBRARY_SUMMARY.md +++ /dev/null @@ -1,462 +0,0 @@ -# Loclean — Library Technical Summary - -> **Version**: 0.2.2 | **Python**: ≥3.10 | **License**: Apache-2.0 -> **One-liner**: High-performance, local-first semantic data cleaning library powered by Ollama LLMs. - -## Purpose - -Loclean is an **AI-powered data cleaning and PII scrubbing library** that uses a locally-running [Ollama](https://ollama.com) instance for inference. It provides three core capabilities: - -1. **`clean()`** — Semantic column cleaning on DataFrames (extract numeric values + units from messy text) -2. **`scrub()`** — PII detection and masking/replacement in text or DataFrames -3. **`extract()`** — Structured data extraction from text using user-defined Pydantic schemas - -All DataFrame operations are **backend-agnostic** via [Narwhals](https://narwhals-dev.github.io/narwhals/), supporting pandas, Polars, PyArrow, cuDF, and Modin interchangeably. - ---- - -## Architecture Overview - -```mermaid -graph TB - subgraph "Public API (loclean/__init__.py)" - Loclean["Loclean class"] - clean["clean()"] - scrub["scrub()"] - extract["extract()"] - get_engine["get_engine()"] - end - - subgraph "Inference Layer" - ABC["InferenceEngine (ABC)"] - Ollama["OllamaEngine"] - Config["EngineConfig"] - Factory["create_engine()"] - end - - subgraph "Extraction Layer" - Extractor["Extractor"] - ExtractDF["extract_dataframe()"] - JsonRepair["json_repair"] - end - - subgraph "Privacy Layer" - PIIDetector["PIIDetector"] - RegexDet["RegexDetector"] - LLMDet["LLMDetector"] - Scrub["scrub_string() / scrub_dataframe()"] - FakeGen["FakeDataGenerator"] - end - - subgraph "Engine Layer" - NarwhalsEng["NarwhalsEngine"] - end - - subgraph "Shared" - Cache["LocleanCache (SQLite)"] - Schemas["Pydantic Schemas"] - end - - Loclean --> Ollama - Loclean --> Extractor - clean --> get_engine --> Ollama - clean --> NarwhalsEng - scrub --> Scrub --> PIIDetector - extract --> Extractor - extract --> ExtractDF - Extractor --> Ollama - Extractor --> JsonRepair - Extractor --> Cache - PIIDetector --> RegexDet - PIIDetector --> LLMDet --> Ollama - LLMDet --> Cache - Scrub --> FakeGen - NarwhalsEng --> Ollama - Factory --> Config - Factory --> Ollama - Ollama --> ABC -``` - ---- - -## File Hierarchy - -``` -loclean/ -├── pyproject.toml # Build config, deps, tool settings -├── src/loclean/ -│ ├── __init__.py # PUBLIC API: Loclean class, clean(), scrub(), extract(), get_engine() -│ ├── _version.py # __version__ = "0.2.2" -│ ├── cache.py # LocleanCache — SQLite3 persistent cache (WAL mode) -│ │ -│ ├── inference/ # ── Inference Engine Layer ── -│ │ ├── __init__.py # Re-exports: InferenceEngine, OllamaEngine, EngineConfig -│ │ ├── base.py # InferenceEngine ABC (generate, clean_batch) -│ │ ├── ollama_engine.py # OllamaEngine — Ollama HTTP client wrapper -│ │ ├── config.py # EngineConfig (Pydantic) + hierarchical config loader -│ │ ├── factory.py # create_engine() factory function -│ │ ├── schemas.py # ExtractionResult schema (reasoning/value/unit) -│ │ └── local/ # Reserved for future local engines -│ │ └── __init__.py -│ │ -│ ├── extraction/ # ── Structured Extraction Layer ── -│ │ ├── __init__.py # Re-exports: Extractor, extract_dataframe -│ │ ├── extractor.py # Extractor — prompt → generate → parse → validate → retry -│ │ ├── extract_dataframe.py # DataFrame column extraction (pandas/Polars) -│ │ └── json_repair.py # Heuristic JSON repair for malformed LLM output -│ │ -│ ├── privacy/ # ── PII Detection & Scrubbing Layer ── -│ │ ├── __init__.py # Re-exports: scrub_string, scrub_dataframe -│ │ ├── schemas.py # PIIEntity, PIIDetectionResult (Pydantic) -│ │ ├── detector.py # PIIDetector — hybrid router (regex + LLM) -│ │ ├── regex_detector.py # RegexDetector — email, phone, credit_card, ip_address -│ │ ├── llm_detector.py # LLMDetector — person, address (via engine.generate) -│ │ ├── scrub.py # scrub_string(), scrub_dataframe() + replace_entities() -│ │ └── generator.py # FakeDataGenerator (Faker) for "fake" mode -│ │ -│ ├── engine/ # ── DataFrame Processing Engine ── -│ │ └── narwhals_ops.py # NarwhalsEngine — batch processing, parallel, progress -│ │ -│ ├── cli/ # ── Command-Line Interface ── -│ │ ├── __init__.py # Typer app with "model" subgroup -│ │ ├── model.py # "model status" command -│ │ └── model_commands.py # check_connection() — Ollama connectivity check -│ │ -│ ├── utils/ # ── Utilities ── -│ │ ├── __init__.py -│ │ ├── logging.py # Rich-compatible module logger -│ │ ├── rich_output.py # Progress bars, tables, cache stats -│ │ └── resources.py # (Stub — grammar/template loaders removed) -│ │ -│ └── resources/ # ── Static Resources ── -│ └── __init__.py # (Empty — grammars/templates removed in migration) -│ -├── tests/ -│ ├── conftest.py # Shared fixtures -│ ├── unit/ # 318 tests — fast, isolated, mocked -│ │ ├── test_public_api.py # Loclean class + clean/scrub/extract functions -│ │ ├── test_cache.py # LocleanCache -│ │ ├── cli/ # CLI tests -│ │ │ ├── test_cli_init.py # App structure + routing -│ │ │ ├── test_model.py # Status command -│ │ │ └── test_model_commands.py # check_connection() -│ │ ├── inference/ # Inference tests -│ │ │ ├── test_base.py # ABC contract -│ │ │ ├── test_config.py # Config loading (env, pyproject, defaults) -│ │ │ ├── test_factory.py # Engine creation -│ │ │ └── test_schemas.py # ExtractionResult -│ │ ├── extraction/ # Extraction tests -│ │ │ ├── test_extractor.py # Extractor (37 tests) -│ │ │ ├── test_extract_dataframe.py # DataFrame extraction -│ │ │ └── test_json_repair.py # JSON repair -│ │ ├── privacy/ # Privacy tests -│ │ │ ├── test_detector.py # PIIDetector hybrid -│ │ │ ├── test_detector_functions.py # find_all_positions, resolve_overlaps -│ │ │ ├── test_llm_detector.py # LLMDetector (19 tests) -│ │ │ ├── test_regex_detector.py # RegexDetector -│ │ │ ├── test_schemas.py # PIIEntity, PIIDetectionResult -│ │ │ ├── test_scrub.py # scrub_string, scrub_dataframe -│ │ │ └── test_generator.py # FakeDataGenerator -│ │ ├── engine/ -│ │ │ └── test_narwhals_ops.py # NarwhalsEngine -│ │ └── utils/ -│ │ ├── test_logging.py -│ │ ├── test_rich_output.py -│ │ └── test_resources.py # (Stub) -│ ├── integration/ # Require live Ollama instance -│ │ ├── test_core.py -│ │ └── test_reasoning.py -│ └── scenarios/ # E2E + UX tests -│ ├── test_e2e_flows.py -│ ├── test_error_experience.py -│ └── test_ux_interface.py -│ -├── examples/ # Usage examples -├── docs-web/ # Documentation website -├── scripts/ # Build/CI scripts -├── .github/ # CI/CD workflows -└── .agent/workflows/ # Agent workflow definitions -``` - ---- - -## Core Components — Detailed Reference - -### 1. Public API (`__init__.py`) - -The module-level API is the primary entry point. All functions use a **singleton `OllamaEngine`** by default, or accept `model`/`host`/`verbose` overrides to create dedicated instances. - -| Symbol | Type | Purpose | -|--------|------|---------| -| `Loclean` | Class | OOP interface wrapping `OllamaEngine` + `Extractor` | -| `clean(df, col, instruction)` | Function | Semantic column cleaning → adds `clean_value`, `clean_unit`, `clean_reasoning` | -| `scrub(input, strategies, mode)` | Function | PII detection + masking/faking on text or DataFrame | -| `extract(input, schema)` | Function | Structured extraction via Pydantic schema | -| `get_engine()` | Function | Singleton `OllamaEngine` manager | - -**Key design**: `Loclean` class does lazy local imports (`Extractor`, `Scrub`, `BaseModel`) inside methods to keep import time fast. - ---- - -### 2. Inference Layer (`inference/`) - -#### `InferenceEngine` (ABC in `base.py`) -Two abstract methods every engine must implement: - -```python -class InferenceEngine(ABC): - @abstractmethod - def generate(self, prompt: str, schema: type[BaseModel] | None = None) -> str: ... - - @abstractmethod - def clean_batch(self, items: List[str], instruction: str) -> Dict[str, Optional[Dict[str, Any]]]: ... -``` - -#### `OllamaEngine` (`ollama_engine.py`) -- Connects to Ollama HTTP API via `ollama.Client(host=...)` -- Validates connection in `__init__` by calling `client.list()` -- `generate()` passes `schema.model_json_schema()` as the `format` kwarg to Ollama's `generate()` endpoint → Ollama constrains output to valid JSON -- `clean_batch()` iterates items, calls `generate()` with `ExtractionResult` schema, parses JSON - -#### `EngineConfig` (`config.py`) -Pydantic model with hierarchical config loading: - -``` -Priority: Runtime params > Env vars (LOCLEAN_*) > pyproject.toml [tool.loclean] > Defaults -``` - -| Field | Default | Env Var | -|-------|---------|---------| -| `engine` | `"ollama"` | `LOCLEAN_ENGINE` | -| `model` | `"phi3"` | `LOCLEAN_MODEL` | -| `host` | `"http://localhost:11434"` | `LOCLEAN_HOST` | -| `api_key` | `None` | `LOCLEAN_API_KEY` | -| `verbose` | `False` | `LOCLEAN_VERBOSE` | - -#### `create_engine()` (`factory.py`) -Factory that reads `EngineConfig.engine` and instantiates the correct backend. Only `"ollama"` is implemented; `"openai"`, `"anthropic"`, `"gemini"` raise `NotImplementedError`. - ---- - -### 3. Extraction Layer (`extraction/`) - -#### `Extractor` (`extractor.py`) -Core extraction class. Flow: - -``` -extract(text, schema, instruction?) - → _build_instruction(schema, instruction) - → check cache - → _extract_with_retry(text, schema, instruction, retry_count=0) - → build prompt: f"{instruction}\n\nInput: {text}" - → engine.generate(prompt, schema=schema) - → _parse_and_validate(raw_output, schema, ...) - → json.loads() or json_repair - → schema(**data) # Pydantic validation - → on failure → _retry_extraction (up to max_retries) - → cache result - → return validated BaseModel instance -``` - -Also has `extract_batch()` for processing lists with dedup + caching. - -#### `extract_dataframe()` (`extract_dataframe.py`) -Wraps `Extractor.extract_batch()` for DataFrame columns. Handles: -- Unique value deduplication -- Polars Struct columns vs pandas dicts -- `output_type="dict"` or `"pydantic"` - -#### `json_repair.py` -Heuristic JSON repair for truncated/malformed LLM output (bracket balancing, trailing comma removal). - ---- - -### 4. Privacy Layer (`privacy/`) - -#### Detection Architecture - -```mermaid -graph LR - PIIDetector --> RegexDetector - PIIDetector --> LLMDetector - RegexDetector -- "email, phone, credit_card, ip_address" --> PIIEntity - LLMDetector -- "person, address" --> PIIDetectionResult --> PIIEntity -``` - -**`PIIDetector`** (`detector.py`) is a hybrid router: -- **Regex strategies** (fast): `email`, `phone`, `credit_card`, `ip_address` → `RegexDetector` -- **LLM strategies** (accurate): `person`, `address` → `LLMDetector` → `engine.generate(prompt, PIIDetectionResult)` -- Merges results → `resolve_overlaps()` (longer match wins) - -#### Scrubbing - -**`scrub_string()`** and **`scrub_dataframe()`** in `scrub.py`: -- `mode="mask"` → replaces PII with `[TYPE]` (e.g., `[PERSON]`, `[EMAIL]`) -- `mode="fake"` → replaces with realistic fake data via `FakeDataGenerator` (requires `faker`, optional dep) - -#### Pydantic Schemas (`schemas.py`) - -```python -PIIType = Literal["person", "phone", "email", "credit_card", "address", "ip_address"] - -class PIIEntity(BaseModel): - type: PIIType - value: str - start: int - end: int - -class PIIDetectionResult(BaseModel): - entities: list[PIIEntity] - reasoning: str | None = None -``` - ---- - -### 5. DataFrame Engine (`engine/narwhals_ops.py`) - -**`NarwhalsEngine`** — static class for backend-agnostic batch processing: -- `process_column(df, col, engine, instruction, batch_size, parallel, max_workers)` -- Deduplicates unique values, chunks into batches -- Calls `engine.clean_batch()` per chunk -- Supports `ThreadPoolExecutor` parallel mode -- Rich progress bars via `utils/rich_output.py` -- Joins results back to original DataFrame via Narwhals - ---- - -### 6. Caching (`cache.py`) - -**`LocleanCache`** — SQLite3 persistent cache: -- Location: `~/.cache/loclean/cache.db` -- WAL mode for concurrent access -- Hash key: `SHA256("v3::{instruction}::{text}")` -- Used by both `Extractor` and `LLMDetector` -- Context manager support (`with LocleanCache() as cache`) - ---- - -### 7. CLI (`cli/`) - -Entry point: `loclean` (registered in `pyproject.toml` as script). - -``` -loclean -└── model - └── status [--host URL] # Check Ollama connection, list available models -``` - -`check_connection()` in `model_commands.py`: -- Connects via `ollama.Client(host=...)` (local import) -- Lists models in a Rich table -- Shows install instructions on failure - ---- - -## Dependencies - -### Core (required) -| Package | Purpose | -|---------|---------| -| `narwhals≥2.14.0` | Backend-agnostic DataFrame operations | -| `pydantic≥2.12.5` | Schema validation + JSON schema generation | -| `ollama≥0.4.0` | Ollama Python client | -| `json-repair≥0.27.0` | JSON repair for malformed LLM output | -| `typer≥0.12.0` | CLI framework | -| `rich≥14.0.0` | Terminal output formatting | - -### Optional extras -| Extra | Packages | Purpose | -|-------|----------|---------| -| `data` | pandas, polars, pyarrow | DataFrame backends | -| `cloud` | openai, anthropic, google-genai, instructor | Future cloud engines | -| `privacy` | faker | Fake data generation for PII replacement | -| `all` | All of the above | Everything | - ---- - -## Configuration - -### Hierarchical Priority -``` -1. Runtime params (model=, host=, verbose=) -2. Environment variables (LOCLEAN_ENGINE, LOCLEAN_MODEL, LOCLEAN_HOST, etc.) -3. pyproject.toml [tool.loclean] section -4. Hardcoded defaults (engine=ollama, model=phi3, host=localhost:11434) -``` - -### pyproject.toml example -```toml -[tool.loclean] -engine = "ollama" -model = "llama3" -host = "http://remote-server:11434" -verbose = true -``` - ---- - -## Development & Tooling - -### Setup -```bash -uv sync --all-extras --dev # Install all deps -ollama serve # Start Ollama externally -ollama pull phi3 # Pull default model -``` - -### PR Readiness Checklist -```bash -uv run ruff format . # Format -uv run ruff check . --fix # Lint -uv run mypy . # Type check -uv run python -m pytest # Test (318 unit tests) -``` - -### Test Configuration -- **Framework**: pytest + pytest-cov + pytest-mock -- **Config**: `pyproject.toml` `[tool.pytest.ini_options]` -- **Coverage**: Branch coverage, fail-under 50%, XML report -- **Markers**: `slow`, `cloud` - -### Key Design Rules (from user guidelines) -1. **Use `uv`** for all Python/pip operations -2. **Use Narwhals** for all DataFrame ops — never import pandas/polars in core logic -3. **Optional deps** wrapped in `try/except ImportError` -4. **Never commit to main** — use feature branches + PRs -5. **Atomic commits** — small, logical chunks - ---- - -## Data Flow Examples - -### `clean()` flow -``` -DataFrame → Narwhals wraps → deduplicate unique values → chunk into batches -→ OllamaEngine.clean_batch(items, instruction) per batch - → for each item: generate(prompt, schema=ExtractionResult) - → Ollama returns JSON: {"reasoning": "...", "value": 5.5, "unit": "kg"} -→ join results back to DataFrame → return native DataFrame -``` - -### `extract()` flow -``` -text + Pydantic schema → Extractor -→ build instruction from schema fields -→ check LocleanCache -→ engine.generate(prompt, schema=UserSchema) -→ Ollama returns constrained JSON -→ json.loads() → schema(**data) → Pydantic validates -→ on failure: json_repair → retry with adjusted prompt (up to 3x) -→ cache result → return validated BaseModel instance -``` - -### `scrub()` flow -``` -text + strategies=["person", "email", "phone"] -→ PIIDetector.detect(text, strategies) - → RegexDetector: email patterns, phone patterns - → LLMDetector: engine.generate(prompt, schema=PIIDetectionResult) - → merge + resolve_overlaps() -→ replace_entities(text, entities, mode="mask") -→ "Contact [PERSON] at [EMAIL] or [PHONE]" -```