diff --git a/src/loclean/__init__.py b/src/loclean/__init__.py index a684429..96d2222 100644 --- a/src/loclean/__init__.py +++ b/src/loclean/__init__.py @@ -11,11 +11,16 @@ "__version__", "Loclean", "clean", + "discover_features", "extract", "extract_compiled", "get_engine", "optimize_instruction", + "oversample", + "resolve_entities", "scrub", + "shred_to_relations", + "validate_quality", ] _ENGINE_INSTANCE: Optional[OllamaEngine] = None @@ -56,11 +61,48 @@ def get_engine( return OllamaEngine(**kwargs) +def _resolve_engine( + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> OllamaEngine: + """Return a shared or custom OllamaEngine instance. + + When all arguments are ``None`` and there are no extra kwargs, + returns the process-wide singleton. Otherwise builds a fresh + client with the supplied overrides. + + Args: + model: Ollama model tag. + host: Ollama server URL. + verbose: Enable detailed logging. + **engine_kwargs: Forwarded to ``OllamaEngine``. + + Returns: + OllamaEngine instance. + """ + if model is None and host is None and verbose is None and not engine_kwargs: + return get_engine() + + kwargs: dict[str, Any] = {} + if model is not None: + kwargs["model"] = model + if host is not None: + kwargs["host"] = host + if verbose is not None: + kwargs["verbose"] = verbose + kwargs.update(engine_kwargs) + return OllamaEngine(**kwargs) + + class Loclean: """Primary user-facing API for structured data extraction via Ollama. Connects to a running Ollama instance and uses Pydantic schemas to - enforce structured JSON output from LLMs. + enforce structured JSON output from LLMs. A single ``OllamaEngine`` + instance is shared across every wrapper method, preventing redundant + network sockets and reducing memory overhead. Example:: @@ -97,6 +139,10 @@ def __init__( """ self.engine = OllamaEngine(model=model, host=host, verbose=verbose) + # ------------------------------------------------------------------ + # Core extraction + # ------------------------------------------------------------------ + def extract( self, text: str, @@ -130,6 +176,179 @@ def extract( extractor = Extractor(inference_engine=self.engine, max_retries=max_retries) return extractor.extract(text, schema, instruction) + # ------------------------------------------------------------------ + # Advanced capabilities + # ------------------------------------------------------------------ + + def clean( + self, + df: IntoFrameT, + target_col: str, + instruction: str = "Extract the numeric value and unit as-is.", + *, + batch_size: int = 50, + parallel: bool = False, + max_workers: Optional[int] = None, + ) -> IntoFrameT: + """Clean a column in a DataFrame using semantic extraction. + + Args: + df: Input DataFrame. + target_col: Column to clean. + instruction: Instruction to guide the LLM. + batch_size: Unique values per batch. + parallel: Enable parallel processing. + max_workers: Worker threads for parallel processing. + + Returns: + DataFrame with added cleaning columns. + """ + return NarwhalsEngine.process_column( + df, + target_col, + self.engine, + instruction, + batch_size=batch_size, + parallel=parallel, + max_workers=max_workers, + ) + + def resolve_entities( + self, + df: IntoFrameT, + target_col: str, + *, + threshold: float = 0.8, + ) -> IntoFrameT: + """Canonicalize a messy string column via entity resolution. + + Args: + df: Input DataFrame. + target_col: Column with messy string values. + threshold: Semantic-distance threshold ε in ``(0, 1]``. + + Returns: + DataFrame with an added ``{target_col}_canonical`` column. + """ + from loclean.extraction.resolver import EntityResolver + + resolver = EntityResolver(inference_engine=self.engine, threshold=threshold) + return resolver.resolve(df, target_col) + + def oversample( + self, + df: IntoFrameT, + target_col: str, + target_value: Any, + n: int, + schema: type, + *, + batch_size: int = 10, + ) -> IntoFrameT: + """Generate synthetic minority-class records. + + Args: + df: Input DataFrame. + target_col: Column identifying the class label. + target_value: Minority class value. + n: Number of synthetic records. + schema: Pydantic model defining record structure. + batch_size: Records per LLM batch. + + Returns: + DataFrame with synthetic records appended. + """ + from loclean.extraction.oversampler import SemanticOversampler + + sampler = SemanticOversampler( + inference_engine=self.engine, batch_size=batch_size + ) + return sampler.oversample(df, target_col, target_value, n, schema) + + def shred_to_relations( + self, + df: IntoFrameT, + target_col: str, + *, + sample_size: int = 30, + max_retries: int = 3, + ) -> dict[str, Any]: + """Shred a log column into relational DataFrames. + + Args: + df: Input DataFrame. + target_col: Column with unstructured log text. + sample_size: Entries to sample for inference. + max_retries: Repair budget for code generation. + + Returns: + Dict mapping table names to native DataFrames. + """ + from loclean.extraction.shredder import RelationalShredder + + shredder = RelationalShredder( + inference_engine=self.engine, + sample_size=sample_size, + max_retries=max_retries, + ) + return shredder.shred(df, target_col) + + def discover_features( + self, + df: IntoFrameT, + target_col: str, + *, + n_features: int = 5, + max_retries: int = 3, + ) -> IntoFrameT: + """Discover and apply feature crosses. + + Args: + df: Input DataFrame. + target_col: Target variable column. + n_features: Number of features to propose. + max_retries: Repair budget for code generation. + + Returns: + DataFrame augmented with new feature columns. + """ + from loclean.extraction.feature_discovery import FeatureDiscovery + + discoverer = FeatureDiscovery( + inference_engine=self.engine, + n_features=n_features, + max_retries=max_retries, + ) + return discoverer.discover(df, target_col) + + def validate_quality( + self, + df: IntoFrameT, + rules: list[str], + *, + batch_size: int = 20, + sample_size: int = 100, + ) -> dict[str, Any]: + """Evaluate data quality against natural-language rules. + + Args: + df: Input DataFrame. + rules: Natural-language constraint strings. + batch_size: Rows per processing batch. + sample_size: Maximum rows to evaluate. + + Returns: + Dict with compliance report. + """ + from loclean.validation.quality_gate import QualityGate + + gate = QualityGate( + inference_engine=self.engine, + batch_size=batch_size, + sample_size=sample_size, + ) + return gate.evaluate(df, rules) + def clean( df: IntoFrameT, @@ -170,18 +389,7 @@ def clean( if target_col not in df_nw.columns: raise ValueError(f"Column '{target_col}' not found in DataFrame") - if model is None and host is None and verbose is None and not engine_kwargs: - engine = get_engine() - else: - kwargs: dict[str, Any] = {} - if model is not None: - kwargs["model"] = model - if host is not None: - kwargs["host"] = host - if verbose is not None: - kwargs["verbose"] = verbose - kwargs.update(engine_kwargs) - engine = OllamaEngine(**kwargs) + engine = _resolve_engine(model, host, verbose, **engine_kwargs) return NarwhalsEngine.process_column( df, @@ -234,18 +442,7 @@ def scrub( inference_engine = None if needs_llm: - if model is None and host is None and verbose is None and not engine_kwargs: - inference_engine = get_engine() - else: - kwargs_filtered: dict[str, Any] = {} - if model is not None: - kwargs_filtered["model"] = model - if host is not None: - kwargs_filtered["host"] = host - if verbose is not None: - kwargs_filtered["verbose"] = verbose - kwargs_filtered.update(engine_kwargs) - inference_engine = OllamaEngine(**kwargs_filtered) + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) if isinstance(input_data, str): return scrub_string( @@ -313,18 +510,7 @@ def extract( from loclean.extraction.extract_dataframe import extract_dataframe from loclean.extraction.extractor import Extractor - if model is None and host is None and verbose is None and not engine_kwargs: - inference_engine = get_engine() - else: - kwargs_filtered: dict[str, Any] = {} - if model is not None: - kwargs_filtered["model"] = model - if host is not None: - kwargs_filtered["host"] = host - if verbose is not None: - kwargs_filtered["verbose"] = verbose - kwargs_filtered.update(engine_kwargs) - inference_engine = OllamaEngine(**kwargs_filtered) + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) cache = LocleanCache() @@ -393,18 +579,7 @@ def extract_compiled( from loclean.extraction.extract_dataframe import extract_dataframe_compiled - if model is None and host is None and verbose is None and not engine_kwargs: - inference_engine = get_engine() - else: - kwargs_filtered: dict[str, Any] = {} - if model is not None: - kwargs_filtered["model"] = model - if host is not None: - kwargs_filtered["host"] = host - if verbose is not None: - kwargs_filtered["verbose"] = verbose - kwargs_filtered.update(engine_kwargs) - inference_engine = OllamaEngine(**kwargs_filtered) + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) return extract_dataframe_compiled( df, @@ -461,18 +636,7 @@ def optimize_instruction( from loclean.extraction.optimizer import InstructionOptimizer - if model is None and host is None and verbose is None and not engine_kwargs: - inference_engine = get_engine() - else: - kwargs_filtered: dict[str, Any] = {} - if model is not None: - kwargs_filtered["model"] = model - if host is not None: - kwargs_filtered["host"] = host - if verbose is not None: - kwargs_filtered["verbose"] = verbose - kwargs_filtered.update(engine_kwargs) - inference_engine = OllamaEngine(**kwargs_filtered) + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) optimizer = InstructionOptimizer( inference_engine=inference_engine, @@ -485,3 +649,207 @@ def optimize_instruction( baseline_instruction=baseline_instruction, sample_size=sample_size, ) + + +def resolve_entities( + df: IntoFrameT, + target_col: str, + *, + threshold: float = 0.8, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> IntoFrameT: + """Canonicalize a messy string column via semantic entity resolution. + + Groups similar string variations under a single authoritative label + using the local Ollama engine. A new ``{target_col}_canonical`` + column is appended to the returned DataFrame. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column containing messy string values. + threshold: Semantic-distance threshold ε in ``(0, 1]``. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + DataFrame with an added ``{target_col}_canonical`` column. + """ + from loclean.extraction.resolver import EntityResolver + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + resolver = EntityResolver( + inference_engine=inference_engine, + threshold=threshold, + ) + return resolver.resolve(df, target_col) + + +def validate_quality( + df: IntoFrameT, + rules: list[str], + *, + batch_size: int = 20, + sample_size: int = 100, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> dict[str, Any]: + """Evaluate data quality against natural-language rules. + + Checks sampled rows for compliance and returns a structured + report with compliance rate and per-failure reasoning. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + rules: Natural-language constraint strings. + batch_size: Rows per processing batch. + sample_size: Maximum rows to evaluate. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Dictionary with ``total_rows``, ``passed_rows``, + ``compliance_rate``, and ``failures``. + """ + from loclean.validation.quality_gate import QualityGate + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + gate = QualityGate( + inference_engine=inference_engine, + batch_size=batch_size, + sample_size=sample_size, + ) + return gate.evaluate(df, rules) + + +def oversample( + df: IntoFrameT, + target_col: str, + target_value: Any, + n: int, + schema: type, + *, + batch_size: int = 10, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> IntoFrameT: + """Generate synthetic minority-class records and append them. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column identifying the class label. + target_value: Value of the minority class to oversample. + n: Number of synthetic records to generate. + schema: Pydantic model defining the record structure. + batch_size: Records per LLM generation batch. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + DataFrame with synthetic records appended. + """ + from loclean.extraction.oversampler import SemanticOversampler + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + sampler = SemanticOversampler( + inference_engine=inference_engine, + batch_size=batch_size, + ) + return sampler.oversample(df, target_col, target_value, n, schema) + + +def shred_to_relations( + df: IntoFrameT, + target_col: str, + *, + sample_size: int = 30, + max_retries: int = 3, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> dict[str, Any]: + """Shred an unstructured log column into relational DataFrames. + + Uses the Ollama engine to infer a relational schema, generate + a parsing function, and separate the column into multiple tables. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column containing unstructured log text. + sample_size: Number of entries to sample for inference. + max_retries: Repair budget for code generation. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + Dictionary mapping table names to native DataFrames. + """ + from loclean.extraction.shredder import RelationalShredder + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + shredder = RelationalShredder( + inference_engine=inference_engine, + sample_size=sample_size, + max_retries=max_retries, + ) + return shredder.shred(df, target_col) + + +def discover_features( + df: IntoFrameT, + target_col: str, + *, + n_features: int = 5, + max_retries: int = 3, + model: Optional[str] = None, + host: Optional[str] = None, + verbose: Optional[bool] = None, + **engine_kwargs: Any, +) -> IntoFrameT: + """Discover and apply feature crosses to a DataFrame. + + Uses the Ollama engine to propose mathematical transformations + that maximise mutual information with the target variable. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the target variable. + n_features: Number of new features to propose. + max_retries: Repair budget for code generation. + model: Optional Ollama model tag override. + host: Optional Ollama server URL override. + verbose: Enable detailed logging. + **engine_kwargs: Additional arguments forwarded to OllamaEngine. + + Returns: + DataFrame augmented with new feature columns. + """ + from loclean.extraction.feature_discovery import FeatureDiscovery + + inference_engine = _resolve_engine(model, host, verbose, **engine_kwargs) + + discoverer = FeatureDiscovery( + inference_engine=inference_engine, + n_features=n_features, + max_retries=max_retries, + ) + return discoverer.discover(df, target_col) diff --git a/src/loclean/cache.py b/src/loclean/cache.py index fd5a7b9..1d4df57 100644 --- a/src/loclean/cache.py +++ b/src/loclean/cache.py @@ -51,6 +51,13 @@ def _init_db(self) -> None: last_access TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); """) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS code_cache ( + hash_key TEXT PRIMARY KEY, + source_code TEXT NOT NULL, + last_access TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ); + """) self.conn.commit() def _hash(self, text: str, instruction: str) -> str: @@ -149,6 +156,52 @@ def set_batch( except Exception as e: logger.error(f"Error writing to cache: {e}") + def get_code(self, key: str) -> Optional[str]: + """Retrieve cached source code by hash key. + + Args: + key: SHA256 hash key. + + Returns: + Source code string if found, ``None`` on miss. + """ + cursor = self.conn.cursor() + try: + cursor.execute( + "SELECT source_code FROM code_cache WHERE hash_key = ?", + (key,), + ) + row = cursor.fetchone() + if row is None: + return None + cursor.execute( + "UPDATE code_cache SET last_access = CURRENT_TIMESTAMP " + "WHERE hash_key = ?", + (key,), + ) + return row[0] # type: ignore[no-any-return] + except Exception as e: + logger.error(f"Error reading code cache: {e}") + return None + + def set_code(self, key: str, source: str) -> None: + """Store source code in the cache. + + Args: + key: SHA256 hash key. + source: Python source code string. + """ + cursor = self.conn.cursor() + try: + cursor.execute( + "INSERT OR REPLACE INTO code_cache " + "(hash_key, source_code) VALUES (?, ?)", + (key, source), + ) + self.conn.commit() + except Exception as e: + logger.error(f"Error writing code cache: {e}") + def close(self) -> None: """Close the database connection.""" self.conn.close() diff --git a/src/loclean/extraction/__init__.py b/src/loclean/extraction/__init__.py index 3cf9e87..64e9d66 100644 --- a/src/loclean/extraction/__init__.py +++ b/src/loclean/extraction/__init__.py @@ -1,7 +1,43 @@ """Extraction module for structured data extraction using Pydantic schemas.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + from .extract_dataframe import extract_dataframe_compiled from .extractor import Extractor -from .optimizer import InstructionOptimizer -__all__ = ["Extractor", "InstructionOptimizer", "extract_dataframe_compiled"] +if TYPE_CHECKING: + from .feature_discovery import FeatureDiscovery + from .optimizer import InstructionOptimizer + from .oversampler import SemanticOversampler + from .resolver import EntityResolver + from .shredder import RelationalShredder + +__all__ = [ + "EntityResolver", + "Extractor", + "FeatureDiscovery", + "InstructionOptimizer", + "RelationalShredder", + "SemanticOversampler", + "extract_dataframe_compiled", +] + +_LAZY_IMPORTS: dict[str, str] = { + "EntityResolver": ".resolver", + "FeatureDiscovery": ".feature_discovery", + "InstructionOptimizer": ".optimizer", + "RelationalShredder": ".shredder", + "SemanticOversampler": ".oversampler", +} + + +def __getattr__(name: str) -> object: + module_path = _LAZY_IMPORTS.get(name) + if module_path is not None: + import importlib + + module = importlib.import_module(module_path, __name__) + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/loclean/extraction/feature_discovery.py b/src/loclean/extraction/feature_discovery.py new file mode 100644 index 0000000..6aaa939 --- /dev/null +++ b/src/loclean/extraction/feature_discovery.py @@ -0,0 +1,375 @@ +"""Generative feature cross discovery via LLM-driven transformation proposals. + +Automates feature engineering by prompting the Ollama engine to propose +mathematical transformations between existing columns that maximise +mutual information :math:`I(X_{new}; Y)` with the target variable. +The proposed function is compiled via ``exec`` and applied natively +across the Narwhals DataFrame. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any, Callable + +import narwhals as nw + +from loclean.utils.cache_keys import compute_code_key +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.cache import LocleanCache + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class FeatureDiscovery: + """Propose and compile feature crosses using an LLM. + + Extracts structural metadata (column names, dtypes, sample rows) + and prompts the engine to write a ``generate_features`` function + that produces *n_features* new columns from existing ones. + + Args: + inference_engine: Engine for generative requests. + n_features: Number of new features to propose. + max_retries: Repair budget for the compilation loop. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + n_features: int = 5, + max_retries: int = 3, + timeout_s: float = 2.0, + cache: LocleanCache | None = None, + ) -> None: + if n_features < 1: + raise ValueError("n_features must be ≥ 1") + if max_retries < 1: + raise ValueError("max_retries must be ≥ 1") + if timeout_s <= 0: + raise ValueError("timeout_s must be > 0") + self.inference_engine = inference_engine + self.n_features = n_features + self.max_retries = max_retries + self.timeout_s = timeout_s + self.cache = cache + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def discover( + self, + df: IntoFrameT, + target_col: str, + ) -> IntoFrameT: + """Discover and apply feature crosses to the DataFrame. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column name of the target variable. + + Returns: + DataFrame augmented with new feature columns, or the + original DataFrame unchanged if generation fails. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + if target_col not in df_nw.columns: + raise ValueError(f"Column '{target_col}' not found in DataFrame") + + state = self._extract_state(df_nw, target_col) + + cache_key = compute_code_key( + columns=state["columns"], + dtypes=list(state["dtypes"].values()), + target_col=target_col, + module_prefix="feature_discovery", + ) + + cached_source = self.cache.get_code(cache_key) if self.cache else None + if cached_source is not None: + fn = self._compile_function(cached_source) + sample_rows = state["sample_rows"] + ok, _ = self._verify_function(fn, sample_rows, self.timeout_s) + if ok: + logger.info("[green]✓[/green] Cache hit — reusing compiled features") + result = self._apply_to_dataframe(df_nw, fn, self.timeout_s) + return result.to_native() # type: ignore[no-any-return,return-value] + + source = self._propose_features(state) + fn = self._compile_function(source) + + sample_rows = state["sample_rows"] + ok, error = self._verify_function(fn, sample_rows, self.timeout_s) + retries = 0 + while not ok and retries < self.max_retries: + source = self._repair_function(source, error, state) + fn = self._compile_function(source) + ok, error = self._verify_function(fn, sample_rows, self.timeout_s) + retries += 1 + + if not ok: + logger.warning( + f"[yellow]⚠[/yellow] Feature generation failed after " + f"{self.max_retries} retries: {error} — returning original DataFrame" + ) + return df + + if self.cache: + self.cache.set_code(cache_key, source) + + result = self._apply_to_dataframe(df_nw, fn, self.timeout_s) + + logger.info( + "[green]✓[/green] Discovered and applied " + f"[bold]{self.n_features}[/bold] new feature columns" + ) + + return result.to_native() # type: ignore[no-any-return,return-value] + + # ------------------------------------------------------------------ + # State preparation + # ------------------------------------------------------------------ + + @staticmethod + def _extract_state( + df_nw: nw.DataFrame[Any], + target_col: str, + sample_n: int = 10, + ) -> dict[str, Any]: + """Build structural metadata for prompting. + + Args: + df_nw: Narwhals DataFrame. + target_col: Target variable column. + sample_n: Number of sample rows to include. + + Returns: + Dict with ``columns``, ``dtypes``, ``target_col``, + and ``sample_rows``. + """ + columns = df_nw.columns + dtypes = {col: str(df_nw[col].dtype) for col in columns} + all_rows: list[dict[str, Any]] = df_nw.rows(named=True) # type: ignore[assignment] + + if len(all_rows) <= sample_n: + sample_rows = all_rows + else: + step = len(all_rows) / sample_n + sample_rows = [all_rows[int(i * step)] for i in range(sample_n)] + + return { + "columns": columns, + "dtypes": dtypes, + "target_col": target_col, + "sample_rows": sample_rows, + } + + # ------------------------------------------------------------------ + # Generative proposal + # ------------------------------------------------------------------ + + def _propose_features(self, state: dict[str, Any]) -> str: + """Prompt the engine to write a generate_features function. + + Args: + state: Structural metadata from _extract_state. + + Returns: + Python source code string. + """ + col_info = json.dumps(state["dtypes"], indent=2) + samples = json.dumps(state["sample_rows"][:5], ensure_ascii=False, default=str) + + prompt = ( + "You are an expert feature engineer.\n\n" + "Given a dataset with these columns and types:\n" + f"{col_info}\n\n" + f"Target variable: {state['target_col']}\n\n" + "Sample rows:\n" + f"{samples}\n\n" + f"Propose exactly {self.n_features} mathematical " + "transformations between existing columns that would " + "maximise mutual information I(X_new; Y) with the target.\n\n" + "Write a pure Python function with this exact signature:\n\n" + "def generate_features(row: dict) -> dict:\n\n" + "The function must:\n" + "- Accept a dict of column_name: value pairs\n" + f"- Return a dict with exactly {self.n_features} new " + "key-value pairs (the new feature names and values)\n" + "- Use ONLY standard library modules (math, etc.)\n" + "- Wrap each calculation in try/except, defaulting to " + "None on failure\n" + "- Use descriptive feature names like 'ratio_a_b' or " + "'log_amount'\n\n" + "Return ONLY the Python function code. " + "No markdown fences, no explanations." + ) + + raw = self.inference_engine.generate(prompt) + source = str(raw).strip() + if source.startswith("```"): + lines = source.split("\n") + lines = [line for line in lines if not line.strip().startswith("```")] + source = "\n".join(lines) + return source + + # ------------------------------------------------------------------ + # Compilation + # ------------------------------------------------------------------ + + @staticmethod + def _compile_function( + source: str, + ) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Compile source code in a restricted sandbox. + + Args: + source: Python source containing ``generate_features``. + + Returns: + The compiled function. + + Raises: + ValueError: If compilation fails or function not found. + """ + from loclean.utils.sandbox import compile_sandboxed + + return compile_sandboxed(source, "generate_features", ["math"]) + + # ------------------------------------------------------------------ + # Verification + # ------------------------------------------------------------------ + + @staticmethod + def _verify_function( + fn: Callable[[dict[str, Any]], dict[str, Any]], + sample_rows: list[dict[str, Any]], + timeout_s: float = 2.0, + ) -> tuple[bool, str]: + """Test the compiled function against sample rows with timeout. + + Args: + fn: Compiled feature generation function. + sample_rows: Rows to test. + timeout_s: Maximum seconds per row execution. + + Returns: + Tuple of (success, error_message). + """ + from loclean.utils.sandbox import run_with_timeout + + test_rows = sample_rows[:5] + + for row in test_rows: + result, error = run_with_timeout(fn, (row,), timeout_s) + + if error: + return False, (f"Function failed for row {str(row)[:100]}: {error}") + + if not isinstance(result, dict): + return False, (f"Expected dict return, got {type(result).__name__}") + + if not result: + return False, "Function returned empty dict" + + return True, "" + + # ------------------------------------------------------------------ + # Repair + # ------------------------------------------------------------------ + + def _repair_function( + self, + source: str, + error: str, + state: dict[str, Any], + ) -> str: + """Ask the engine to fix a broken feature function. + + Args: + source: Current source code. + error: Error message from verification. + state: Structural metadata for context. + + Returns: + Repaired Python source code string. + """ + samples = json.dumps(state["sample_rows"][:3], ensure_ascii=False, default=str) + prompt = ( + "The following Python function has a bug.\n\n" + f"Source:\n{source}\n\n" + f"Error:\n{error}\n\n" + f"Sample input rows:\n{samples}\n\n" + "Fix the function. Return ONLY the corrected Python code, " + "no markdown fences." + ) + + raw = self.inference_engine.generate(prompt) + repaired = str(raw).strip() + if repaired.startswith("```"): + lines = repaired.split("\n") + lines = [line for line in lines if not line.strip().startswith("```")] + repaired = "\n".join(lines) + return repaired + + # ------------------------------------------------------------------ + # Application + # ------------------------------------------------------------------ + + @staticmethod + def _apply_to_dataframe( + df_nw: nw.DataFrame[Any], + fn: Callable[[dict[str, Any]], dict[str, Any]], + timeout_s: float = 2.0, + ) -> nw.DataFrame[Any]: + """Map the feature function across all rows with timeout. + + Args: + df_nw: Narwhals DataFrame. + fn: Compiled feature generation function. + timeout_s: Maximum seconds per row execution. + + Returns: + Augmented Narwhals DataFrame with new feature columns. + """ + from loclean.utils.sandbox import run_with_timeout + + rows: list[dict[str, Any]] = df_nw.rows(named=True) # type: ignore[assignment] + + new_col_data: dict[str, list[Any]] = {} + + for row in rows: + result, error = run_with_timeout(fn, (row,), timeout_s) + features: dict[str, Any] = result if isinstance(result, dict) else {} + + if error: + logger.debug(f"Row execution failed: {error}") + + if not new_col_data: + for key in features: + new_col_data[key] = [] + + for key in new_col_data: + new_col_data[key].append(features.get(key)) + + if not new_col_data: + return df_nw + + native_ns = nw.get_native_namespace(df_nw) + new_df = nw.from_dict(new_col_data, backend=native_ns) + + original_rows: list[dict[str, Any]] = df_nw.rows(named=True) # type: ignore[assignment] + combined_data: dict[str, list[Any]] = { + col: [r[col] for r in original_rows] for col in df_nw.columns + } + for col in new_col_data: + combined_data[col] = new_df[col].to_list() + + return nw.from_dict(combined_data, backend=native_ns) diff --git a/src/loclean/extraction/oversampler.py b/src/loclean/extraction/oversampler.py new file mode 100644 index 0000000..1b6088d --- /dev/null +++ b/src/loclean/extraction/oversampler.py @@ -0,0 +1,333 @@ +"""Semantic synthetic oversampling via LLM-driven record generation. + +Replaces geometric interpolation (SMOTE) with generative modelling. +The :class:`SemanticOversampler` produces structurally valid minority-class +records that satisfy a user-provided Pydantic schema and maintain the +logical correlations present in the original data sample. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any, Type + +import narwhals as nw +from narwhals.typing import IntoFrameT +from pydantic import BaseModel, Field, ValidationError + +from loclean.extraction.json_repair import repair_json +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class _GeneratedBatch(BaseModel): + """Wrapper schema for batch-generated records.""" + + records: list[dict[str, Any]] = Field( + ..., + description="List of generated records matching the target schema", + ) + + +class SemanticOversampler: + """Generate synthetic minority-class records using an LLM. + + Instead of computing $x_{new} = x_i + \\lambda(x_{neighbor} - x_i)$, + uses semantic generation to produce records that satisfy structural + dependencies found in a valid data sample. + + Args: + inference_engine: Engine used for generative requests. + batch_size: Maximum records to request per LLM call. + max_retries: Maximum generation rounds before giving up. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + batch_size: int = 10, + max_retries: int = 5, + ) -> None: + if batch_size < 1: + raise ValueError("batch_size must be ≥ 1") + if max_retries < 1: + raise ValueError("max_retries must be ≥ 1") + self.inference_engine = inference_engine + self.batch_size = batch_size + self.max_retries = max_retries + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def oversample( + self, + df: IntoFrameT, + target_col: str, + target_value: Any, + n: int, + schema: Type[BaseModel], + ) -> IntoFrameT: + """Generate *n* synthetic minority-class records and append them. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column identifying the class label. + target_value: Value of the minority class to oversample. + n: Number of synthetic records to generate. + schema: Pydantic model defining the record structure. + + Returns: + DataFrame of the same native type with synthetic rows appended. + + Raises: + ValueError: If *target_col* is missing or no minority rows exist. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + if target_col not in df_nw.columns: + raise ValueError(f"Column '{target_col}' not found in DataFrame") + + sample_rows = self._sample_rows(df_nw, target_col, target_value) + if not sample_rows: + raise ValueError(f"No rows found where '{target_col}' == {target_value!r}") + + existing_keys = self._build_key_set(df_nw) + generated: list[dict[str, Any]] = [] + retries = 0 + + while len(generated) < n and retries < self.max_retries: + remaining = n - len(generated) + request_count = min(remaining, self.batch_size) + + batch = self._generate_batch(sample_rows, schema, request_count) + + validated = self._validate_and_filter( + batch, schema, existing_keys, generated + ) + generated.extend(validated) + retries += 1 + + if len(generated) < n: + logger.warning( + f"[yellow]⚠[/yellow] Generated {len(generated)}/{n} " + f"records after {self.max_retries} retries" + ) + + if not generated: + return df_nw.to_native() # type: ignore[return-value] + + native_ns = nw.get_native_namespace(df_nw) + col_data: dict[str, list[Any]] = {col: [] for col in df_nw.columns} + for rec in generated: + for col in df_nw.columns: + col_data[col].append(rec.get(col)) + + synthetic_nw = nw.from_dict(col_data, backend=native_ns) + result = nw.concat([df_nw, synthetic_nw]) + + logger.info( + f"[green]✓[/green] Oversampled " + f"[bold]{len(generated)}[/bold] synthetic records " + f"for '{target_col}' == {target_value!r}" + ) + + return result.to_native() # type: ignore[no-any-return,return-value] + + # ------------------------------------------------------------------ + # Minority sampling + # ------------------------------------------------------------------ + + @staticmethod + def _sample_rows( + df_nw: nw.DataFrame[Any], + target_col: str, + target_value: Any, + n: int = 20, + ) -> list[dict[str, Any]]: + """Extract a representative sample of minority-class rows. + + Args: + df_nw: Narwhals DataFrame. + target_col: Class-label column. + target_value: Minority class value. + n: Maximum sample size. + + Returns: + List of row dicts from the minority class. + """ + minority = df_nw.filter(nw.col(target_col) == target_value) + all_rows: list[dict[str, Any]] = minority.rows(named=True) # type: ignore[assignment] + + if len(all_rows) <= n: + return all_rows + + step = len(all_rows) / n + return [all_rows[int(i * step)] for i in range(n)] + + # ------------------------------------------------------------------ + # Batch generation + # ------------------------------------------------------------------ + + def _generate_batch( + self, + sample_rows: list[dict[str, Any]], + schema: Type[BaseModel], + count: int, + ) -> list[dict[str, Any]]: + """Prompt the engine to generate a batch of synthetic records. + + Args: + sample_rows: Representative minority-class rows. + schema: Pydantic model defining record structure. + count: Number of records to request. + + Returns: + List of raw record dicts (unvalidated). + """ + schema_fields = { + name: { + "type": str(info.annotation), + "description": (info.description if info.description else ""), + } + for name, info in schema.model_fields.items() + } + + prompt = ( + "You are a synthetic data generator.\n\n" + "Generate exactly {count} NEW records that match this schema:\n" + "{schema}\n\n" + "Here are real example rows for reference:\n" + "{samples}\n\n" + "Rules:\n" + "- Each record MUST conform to the schema structure\n" + "- Maintain logical correlations seen in the examples\n" + "- Do NOT copy the example rows exactly\n" + "- Generated values must be physically plausible\n\n" + 'Return a JSON object with key "records" containing ' + "a list of {count} record dictionaries." + ).format( + count=count, + schema=json.dumps(schema_fields, indent=2), + samples=json.dumps(sample_rows[:5], ensure_ascii=False, default=str), + ) + + raw = self.inference_engine.generate(prompt, schema=_GeneratedBatch) + return self._parse_batch_response(raw) + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_batch_response(raw: Any) -> list[dict[str, Any]]: + """Best-effort parse of the LLM response into a list of records. + + Args: + raw: Raw output from the inference engine. + + Returns: + List of record dicts (may be empty on total failure). + """ + if isinstance(raw, dict): + records = raw.get("records", []) + if isinstance(records, list): + return records + return [] + + text = str(raw) if not isinstance(raw, str) else raw + + try: + parsed = json.loads(text) + if isinstance(parsed, dict): + records = parsed.get("records", []) + return records if isinstance(records, list) else [] + if isinstance(parsed, list): + return parsed + except (json.JSONDecodeError, TypeError): + pass + + try: + repaired = repair_json(text) + if isinstance(repaired, dict): + records = repaired.get("records", []) + return records if isinstance(records, list) else [] + parsed_r = json.loads(repaired) # type: ignore[arg-type] + if isinstance(parsed_r, dict): + records = parsed_r.get("records", []) + return records if isinstance(records, list) else [] + if isinstance(parsed_r, list): + return parsed_r + except Exception: + pass + + logger.warning( + "[yellow]⚠[/yellow] Could not parse batch response. Returning empty batch." + ) + return [] + + # ------------------------------------------------------------------ + # Validation and deduplication + # ------------------------------------------------------------------ + + @staticmethod + def _validate_and_filter( + candidates: list[dict[str, Any]], + schema: Type[BaseModel], + existing_keys: set[frozenset[tuple[str, Any]]], + already_generated: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """Validate candidates against the schema and deduplicate. + + Args: + candidates: Raw record dicts from the LLM. + schema: Pydantic model for validation. + existing_keys: Fingerprints of original DataFrame rows. + already_generated: Records already accepted in prior rounds. + + Returns: + List of validated, deduplicated record dicts. + """ + gen_keys = {_row_key(r) for r in already_generated} + combined_keys = existing_keys | gen_keys + + valid: list[dict[str, Any]] = [] + for rec in candidates: + try: + instance = schema(**rec) + clean_rec = instance.model_dump() + except (ValidationError, TypeError, Exception): + continue + + key = _row_key(clean_rec) + if key in combined_keys: + continue + + combined_keys.add(key) + valid.append(clean_rec) + + return valid + + @staticmethod + def _build_key_set( + df_nw: nw.DataFrame[Any], + ) -> set[frozenset[tuple[str, Any]]]: + """Build a set of row fingerprints for deduplication. + + Args: + df_nw: Narwhals DataFrame. + + Returns: + Set of frozensets, each representing one row. + """ + rows: list[dict[str, Any]] = df_nw.rows(named=True) # type: ignore[assignment] + return {_row_key(r) for r in rows} + + +def _row_key(row: dict[str, Any]) -> frozenset[tuple[str, Any]]: + """Create a hashable fingerprint for a row dict.""" + return frozenset((k, str(v)) for k, v in sorted(row.items())) diff --git a/src/loclean/extraction/resolver.py b/src/loclean/extraction/resolver.py new file mode 100644 index 0000000..a36b5ff --- /dev/null +++ b/src/loclean/extraction/resolver.py @@ -0,0 +1,233 @@ +"""Semantic entity resolution via LLM-driven canonicalization. + +This module provides :class:`EntityResolver`, which canonicalizes messy string +columns by: + +1. Extracting unique values from a Narwhals column. +2. Prompting the local Ollama engine to group semantically similar strings + under a single canonical label (respecting a distance threshold ε). +3. Mapping the canonical dictionary back across the original DataFrame column. +""" + +import json +import logging +from typing import TYPE_CHECKING, Any + +import narwhals as nw +from narwhals.typing import IntoFrameT +from pydantic import BaseModel, Field + +from loclean.extraction.json_repair import repair_json +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class _CanonicalMapping(BaseModel): + """Internal schema used to constrain the LLM's structured output.""" + + mapping: dict[str, str] = Field( + ..., + description=( + "Dictionary mapping each input string variation " + "to its canonical (authoritative) form" + ), + ) + + +class EntityResolver: + """Canonicalize messy string values via semantic entity resolution. + + Groups similar string variations into a single authoritative label + using a generative model. The *threshold* parameter (ε) controls + how aggressively strings are merged: a higher value allows more + distant strings to be grouped. + """ + + def __init__( + self, + inference_engine: "InferenceEngine", + threshold: float = 0.8, + max_retries: int = 3, + ) -> None: + """Initialize the resolver. + + Args: + inference_engine: Engine used for semantic evaluation. + threshold: Semantic-distance threshold ε in ``(0, 1]``. + Pairs with ``d(x, y) < ε`` are merged. + max_retries: Currently reserved for future retry logic. + """ + if not 0 < threshold <= 1: + raise ValueError("threshold must be in (0, 1]") + self.inference_engine = inference_engine + self.threshold = threshold + self.max_retries = max_retries + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def resolve( + self, + df: IntoFrameT, + target_col: str, + ) -> IntoFrameT: + """Canonicalize a string column and return the augmented DataFrame. + + A new column ``{target_col}_canonical`` is appended containing the + resolved canonical labels. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column with messy string values. + + Returns: + DataFrame of the same native type with an added canonical column. + + Raises: + ValueError: If *target_col* is not found in the DataFrame. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + if target_col not in df_nw.columns: + raise ValueError(f"Column '{target_col}' not found in DataFrame") + + unique_values = self._extract_unique_values(df_nw, target_col) + if not unique_values: + logger.warning( + "No valid values found in column. Returning original DataFrame." + ) + return df_nw.with_columns( + nw.col(target_col).alias(f"{target_col}_canonical"), + ).to_native() # type: ignore[return-value] + + canonical_map = self._build_canonical_mapping(unique_values) + + logger.info( + f"[green]✓[/green] Resolved [bold]{len(unique_values)}[/bold] " + f"unique values into [bold]{len(set(canonical_map.values()))}[/bold] " + "canonical entities" + ) + + canonical_col = f"{target_col}_canonical" + result = df_nw.with_columns( + nw.col(target_col) + .cast(nw.String) + .replace_strict( + old=list(canonical_map.keys()), + new=list(canonical_map.values()), + default=nw.col(target_col).cast(nw.String), + ) + .alias(canonical_col), + ) + + return result.to_native() # type: ignore[return-value] + + # ------------------------------------------------------------------ + # Unique-value extraction + # ------------------------------------------------------------------ + + @staticmethod + def _extract_unique_values( + df_nw: nw.DataFrame[Any], + target_col: str, + ) -> list[str]: + """Return sorted unique non-empty string values from *target_col*. + + Args: + df_nw: Narwhals DataFrame. + target_col: Column to extract from. + + Returns: + Deduplicated list of non-empty strings. + """ + raw = df_nw.unique(subset=[target_col])[target_col].to_list() + valid: list[str] = [str(v) for v in raw if v is not None and str(v).strip()] + valid.sort() + return valid + + # ------------------------------------------------------------------ + # Canonical mapping via LLM + # ------------------------------------------------------------------ + + def _build_canonical_mapping( + self, + values: list[str], + ) -> dict[str, str]: + """Prompt the engine to produce a variation → canonical mapping. + + Args: + values: List of unique string values to canonicalize. + + Returns: + Dictionary mapping every input value to its canonical form. + """ + prompt = ( + "You are a data-quality expert performing entity resolution.\n\n" + "Given this list of string values extracted from a dataset column:\n" + f"{json.dumps(values, ensure_ascii=False)}\n\n" + "Group strings that refer to the same real-world entity. " + "Two strings x and y should be grouped only when their semantic " + f"distance d(x, y) falls below the threshold ε = {self.threshold}.\n\n" + "For each group, choose the cleanest, most complete string as " + "the canonical label. Strings that do not closely match any " + "other string should map to themselves.\n\n" + "Return a JSON object with key 'mapping' whose value is a " + "dictionary mapping EVERY input string to its canonical form." + ) + + raw = self.inference_engine.generate(prompt, schema=_CanonicalMapping) + data = self._parse_mapping_response(raw) + mapping: dict[str, str] = data.get("mapping", {}) + + validated: dict[str, str] = {} + for val in values: + canonical = mapping.get(val) + if isinstance(canonical, str) and canonical.strip(): + validated[val] = canonical.strip() + else: + validated[val] = val + + return validated + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_mapping_response(raw: Any) -> dict[str, Any]: + """Best-effort parse of the LLM response into a dict. + + Handles raw dicts, valid JSON strings, and malformed JSON (via + ``repair_json``). + + Args: + raw: Raw output from the inference engine. + + Returns: + Parsed dict (may be empty on total failure). + """ + if isinstance(raw, dict): + return raw + + try: + parsed: dict[str, Any] = json.loads(raw) # type: ignore[arg-type] + return parsed + except (json.JSONDecodeError, TypeError): + pass + + try: + repaired = repair_json(raw) + if isinstance(repaired, dict): + return repaired + parsed_repaired: dict[str, Any] = json.loads(repaired) # type: ignore[arg-type] + return parsed_repaired + except (json.JSONDecodeError, TypeError, Exception): + logger.warning( + "[yellow]⚠[/yellow] Could not parse canonical mapping " + "response. Falling back to identity mapping." + ) + return {} diff --git a/src/loclean/extraction/shredder.py b/src/loclean/extraction/shredder.py new file mode 100644 index 0000000..eabdcd1 --- /dev/null +++ b/src/loclean/extraction/shredder.py @@ -0,0 +1,491 @@ +"""Automated relational shredding for unstructured log columns. + +Parses a single column of deeply nested unstructured text into multiple +relational DataFrames by: + +1. Sampling representative log entries. +2. Prompting the Ollama engine to infer a relational schema + (functional dependencies :math:`X \\rightarrow Y`). +3. Generating and compiling a pure-Python ``extract_relations`` function. +4. Applying the function across the column and separating results into + per-table Narwhals DataFrames. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any, Callable + +import narwhals as nw +from pydantic import BaseModel, Field + +from loclean.extraction.json_repair import repair_json +from loclean.utils.cache_keys import compute_code_key +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.cache import LocleanCache + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +# ------------------------------------------------------------------ +# Pydantic schemas for LLM-inferred relational structure +# ------------------------------------------------------------------ + + +class _TableDef(BaseModel): + """Definition of a single relational table.""" + + name: str = Field(..., description="Table name") + columns: list[str] = Field(..., description="Column names for this table") + primary_key: str = Field(..., description="Primary key column name") + foreign_key: str | None = Field( + default=None, + description="Foreign key column referencing another table", + ) + + +class _RelationalSchema(BaseModel): + """Multi-table relational schema inferred from log data.""" + + tables: list[_TableDef] = Field( + ..., + min_length=2, + description="At least two related tables", + ) + + +class RelationalShredder: + """Shred unstructured log columns into relational DataFrames. + + Uses a two-phase LLM approach: + + 1. **Schema inference** — propose tables with PKs/FKs. + 2. **Code generation** — compile a pure-Python extraction function, + verify against samples, and repair on failure. + + Args: + inference_engine: Engine for generative requests. + sample_size: Number of log entries to sample. + max_retries: Repair budget for the compilation loop. + """ + + def __init__( + self, + inference_engine: InferenceEngine, + sample_size: int = 30, + max_retries: int = 3, + timeout_s: float = 2.0, + cache: LocleanCache | None = None, + ) -> None: + if sample_size < 1: + raise ValueError("sample_size must be ≥ 1") + if max_retries < 1: + raise ValueError("max_retries must be ≥ 1") + if timeout_s <= 0: + raise ValueError("timeout_s must be > 0") + self.inference_engine = inference_engine + self.sample_size = sample_size + self.max_retries = max_retries + self.timeout_s = timeout_s + self.cache = cache + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def shred( + self, + df: IntoFrameT, + target_col: str, + ) -> dict[str, Any]: + """Parse a log column into multiple relational DataFrames. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + target_col: Column containing unstructured log text. + + Returns: + Dictionary mapping table names to native DataFrames, + or an empty dict if generation fails. + """ + df_nw = nw.from_native(df) # type: ignore[type-var] + if target_col not in df_nw.columns: + raise ValueError(f"Column '{target_col}' not found in DataFrame") + + samples = self._sample_entries(df_nw, target_col) + if not samples: + raise ValueError(f"No valid entries in column '{target_col}'") + + schema = self._infer_schema(samples) + + all_columns = sorted(col for tbl in schema.tables for col in tbl.columns) + table_names = sorted(tbl.name for tbl in schema.tables) + cache_key = compute_code_key( + columns=all_columns, + dtypes=table_names, + target_col=target_col, + module_prefix="shredder", + ) + + cached_source = self.cache.get_code(cache_key) if self.cache else None + if cached_source is not None: + extract_fn = self._compile_function(cached_source) + ok, _ = self._verify_function(extract_fn, samples, schema, self.timeout_s) + if ok: + logger.info("[green]✓[/green] Cache hit — reusing compiled extractor") + results = self._apply_function( + df_nw, target_col, extract_fn, self.timeout_s + ) + native_ns = nw.get_native_namespace(df_nw) + return self._separate_tables(results, schema, native_ns) + + source = self._generate_extractor(schema, samples) + extract_fn = self._compile_function(source) + + ok, error = self._verify_function(extract_fn, samples, schema, self.timeout_s) + retries = 0 + while not ok and retries < self.max_retries: + source = self._repair_function(source, error, samples) + extract_fn = self._compile_function(source) + ok, error = self._verify_function( + extract_fn, samples, schema, self.timeout_s + ) + retries += 1 + + if not ok: + logger.warning( + f"[yellow]⚠[/yellow] Code generation failed after " + f"{self.max_retries} retries: {error} — returning empty result" + ) + return {} + + if self.cache: + self.cache.set_code(cache_key, source) + + results = self._apply_function(df_nw, target_col, extract_fn, self.timeout_s) + native_ns = nw.get_native_namespace(df_nw) + tables = self._separate_tables(results, schema, native_ns) + + table_summary = ", ".join(f"{name}({len(tbl)})" for name, tbl in tables.items()) + logger.info( + f"[green]✓[/green] Shredded into " + f"[bold]{len(tables)}[/bold] tables: {table_summary}" + ) + + return tables + + # ------------------------------------------------------------------ + # Sampling + # ------------------------------------------------------------------ + + def _sample_entries( + self, + df_nw: nw.DataFrame[Any], + target_col: str, + ) -> list[str]: + """Length-stratified sampling of non-empty log strings. + + Args: + df_nw: Narwhals DataFrame. + target_col: Column to sample from. + + Returns: + List of up to *sample_size* unique strings. + """ + raw = df_nw.unique(subset=[target_col])[target_col].to_list() + valid: list[str] = [str(v) for v in raw if v is not None and str(v).strip()] + if len(valid) <= self.sample_size: + valid.sort(key=len) + return valid + + valid.sort(key=len) + step = len(valid) / self.sample_size + return [valid[int(i * step)] for i in range(self.sample_size)] + + # ------------------------------------------------------------------ + # Phase 1: Schema inference + # ------------------------------------------------------------------ + + def _infer_schema(self, samples: list[str]) -> _RelationalSchema: + """Prompt the engine to propose a relational schema. + + Args: + samples: Representative log entries. + + Returns: + Parsed relational schema with at least two tables. + """ + prompt = ( + "You are a database architect analyzing log data.\n\n" + "Here are sample log entries:\n" + f"{json.dumps(samples[:10], ensure_ascii=False)}\n\n" + "Analyze the structure and propose a relational schema " + "in Third Normal Form (3NF). Identify functional " + "dependencies (X → Y) to separate concerns.\n\n" + "Return a JSON object with key 'tables' containing a list " + "of at least 2 table definitions. Each table must have:\n" + '- "name": table name\n' + '- "columns": list of column names\n' + '- "primary_key": primary key column\n' + '- "foreign_key": foreign key column (null for root table)' + ) + + raw = self.inference_engine.generate(prompt, schema=_RelationalSchema) + return self._parse_schema_response(raw) + + @staticmethod + def _parse_schema_response(raw: Any) -> _RelationalSchema: + """Best-effort parse of the schema inference response. + + Args: + raw: Raw LLM output. + + Returns: + Validated relational schema. + + Raises: + ValueError: If parsing fails completely. + """ + if isinstance(raw, _RelationalSchema): + return raw + + data: dict[str, Any] | None = None + + if isinstance(raw, dict): + data = raw + else: + text = str(raw) if not isinstance(raw, str) else raw + try: + data = json.loads(text) + except (json.JSONDecodeError, TypeError): + try: + repaired = repair_json(text) + if isinstance(repaired, dict): + data = repaired + else: + data = json.loads(repaired) # type: ignore[arg-type] + except Exception: + pass + + if data is None: + raise ValueError("Failed to parse relational schema") + + return _RelationalSchema(**data) + + # ------------------------------------------------------------------ + # Phase 2: Code generation + compilation + # ------------------------------------------------------------------ + + def _generate_extractor( + self, + schema: _RelationalSchema, + samples: list[str], + ) -> str: + """Prompt the engine to write an extract_relations function. + + Args: + schema: Inferred relational schema. + samples: Representative log entries. + + Returns: + Python source code string. + """ + table_specs = json.dumps( + [t.model_dump() for t in schema.tables], + indent=2, + ) + prompt = ( + "You are an expert Python programmer.\n\n" + "Write a pure Python function with this exact signature:\n\n" + "def extract_relations(log: str) -> dict[str, dict]:\n\n" + "The function must parse a single log string and return a " + "dictionary where each key is a table name and each value " + "is a dictionary of column_name: extracted_value pairs.\n\n" + f"Target tables:\n{table_specs}\n\n" + "Sample log entries:\n" + f"{json.dumps(samples[:5], ensure_ascii=False)}\n\n" + "Rules:\n" + "- Use ONLY standard library modules (re, string, etc.)\n" + "- Wrap parsing logic in try/except blocks\n" + "- Return empty strings for fields that cannot be parsed\n" + "- Do NOT import any third-party libraries\n\n" + "Return ONLY the Python function code, no markdown fences." + ) + + raw = self.inference_engine.generate(prompt) + source = str(raw).strip() + if source.startswith("```"): + lines = source.split("\n") + lines = [line for line in lines if not line.strip().startswith("```")] + source = "\n".join(lines) + return source + + @staticmethod + def _compile_function( + source: str, + ) -> Callable[[str], dict[str, dict[str, Any]]]: + """Compile source code in a restricted sandbox. + + Args: + source: Python source containing ``extract_relations``. + + Returns: + The compiled function. + + Raises: + ValueError: If compilation fails or function not found. + """ + from loclean.utils.sandbox import compile_sandboxed + + return compile_sandboxed( + source, + "extract_relations", + ["re", "json", "datetime", "collections"], + ) + + def _verify_function( + self, + fn: Callable[[str], dict[str, dict[str, Any]]], + samples: list[str], + schema: _RelationalSchema, + timeout_s: float = 2.0, + ) -> tuple[bool, str]: + """Test the compiled function against sample entries with timeout. + + Args: + fn: Compiled extraction function. + samples: Log entries to test. + schema: Expected table structure. + timeout_s: Maximum seconds per sample execution. + + Returns: + Tuple of (success, error_message). + """ + from loclean.utils.sandbox import run_with_timeout + + table_names = {t.name for t in schema.tables} + test_samples = samples[:5] + + for sample in test_samples: + result, error = run_with_timeout(fn, (sample,), timeout_s) + + if error: + return False, (f"Function failed for input {sample[:100]!r}: {error}") + + if not isinstance(result, dict): + return False, (f"Expected dict return, got {type(result).__name__}") + + missing = table_names - set(result.keys()) + if missing: + return False, f"Missing tables in output: {missing}" + + return True, "" + + def _repair_function( + self, + source: str, + error: str, + samples: list[str], + ) -> str: + """Ask the engine to fix a broken extraction function. + + Args: + source: Current source code. + error: Error message from verification. + samples: Log entries for context. + + Returns: + Repaired Python source code string. + """ + prompt = ( + "The following Python function has a bug.\n\n" + f"Source:\n{source}\n\n" + f"Error:\n{error}\n\n" + "Sample inputs:\n" + f"{json.dumps(samples[:3], ensure_ascii=False)}\n\n" + "Fix the function. Return ONLY the corrected Python code, " + "no markdown fences." + ) + + raw = self.inference_engine.generate(prompt) + repaired = str(raw).strip() + if repaired.startswith("```"): + lines = repaired.split("\n") + lines = [line for line in lines if not line.strip().startswith("```")] + repaired = "\n".join(lines) + return repaired + + # ------------------------------------------------------------------ + # Phase 3: Full execution + separation + # ------------------------------------------------------------------ + + @staticmethod + def _apply_function( + df_nw: nw.DataFrame[Any], + target_col: str, + fn: Callable[[str], dict[str, dict[str, Any]]], + timeout_s: float = 2.0, + ) -> list[dict[str, dict[str, Any]]]: + """Apply the extraction function to every row with timeout. + + Args: + df_nw: Narwhals DataFrame. + target_col: Column with log text. + fn: Compiled extraction function. + timeout_s: Maximum seconds per row execution. + + Returns: + List of extraction results (one per row). + """ + from loclean.utils.sandbox import run_with_timeout + + values = df_nw[target_col].to_list() + results: list[dict[str, dict[str, Any]]] = [] + + for val in values: + text = str(val) if val is not None else "" + result, error = run_with_timeout(fn, (text,), timeout_s) + if error: + logger.debug(f"Row execution failed: {error}") + results.append({}) + else: + results.append(result if isinstance(result, dict) else {}) + + return results + + @staticmethod + def _separate_tables( + results: list[dict[str, dict[str, Any]]], + schema: _RelationalSchema, + native_ns: Any, + ) -> dict[str, Any]: + """Split extraction results into per-table DataFrames. + + Args: + results: List of extraction dicts from _apply_function. + schema: Relational schema defining table structure. + native_ns: Native namespace (e.g. polars module). + + Returns: + Dict mapping table names to native DataFrames. + """ + tables: dict[str, Any] = {} + + for table_def in schema.tables: + col_data: dict[str, list[Any]] = {col: [] for col in table_def.columns} + + for row_result in results: + table_row = row_result.get(table_def.name, {}) + for col in table_def.columns: + col_data[col].append(table_row.get(col)) + + table_nw = nw.from_dict(col_data, backend=native_ns) + tables[table_def.name] = table_nw.to_native() + + return tables diff --git a/src/loclean/orchestration/__init__.py b/src/loclean/orchestration/__init__.py new file mode 100644 index 0000000..6ac4ce6 --- /dev/null +++ b/src/loclean/orchestration/__init__.py @@ -0,0 +1,25 @@ +"""Orchestration integration for workflow automation tools. + +Provides a lightweight JSON-in / JSON-out runner that can be invoked +as a subprocess by orchestration tools (n8n, Apache Airflow, etc.). +""" + +from loclean.orchestration.runner import ( + EXIT_INVALID_PAYLOAD, + EXIT_OK, + EXIT_PIPELINE_ERROR, + EXIT_SERIALIZATION_ERROR, + PipelinePayload, + main, + run_pipeline, +) + +__all__ = [ + "EXIT_INVALID_PAYLOAD", + "EXIT_OK", + "EXIT_PIPELINE_ERROR", + "EXIT_SERIALIZATION_ERROR", + "PipelinePayload", + "main", + "run_pipeline", +] diff --git a/src/loclean/orchestration/runner.py b/src/loclean/orchestration/runner.py new file mode 100644 index 0000000..4e4a45d --- /dev/null +++ b/src/loclean/orchestration/runner.py @@ -0,0 +1,159 @@ +"""Stdin/stdout orchestration runner for workflow automation. + +This module provides a lightweight execution script that: + +1. Reads a JSON payload from **stdin**. +2. Validates it against :class:`PipelinePayload`. +3. Instantiates the inference engine via :func:`load_config`. +4. Loads the data into a Narwhals-compatible DataFrame and executes + :func:`loclean.clean`. +5. Writes the augmented result as JSON to **stdout**. + +Exit codes are deterministic so that DAG managers (n8n, Airflow) can +branch downstream tasks accordingly. + +Usage:: + + echo '{"data": [...], "target_col": "price"}' \\ + | python -m loclean.orchestration.runner +""" + +from __future__ import annotations + +import json +import sys +from typing import Any + +import polars as pl +from pydantic import BaseModel, Field, ValidationError + +from loclean import clean +from loclean.inference.config import load_config + +EXIT_OK: int = 0 +EXIT_INVALID_PAYLOAD: int = 1 +EXIT_PIPELINE_ERROR: int = 2 +EXIT_SERIALIZATION_ERROR: int = 3 + + +class PipelinePayload(BaseModel): + """Schema for the incoming JSON payload. + + Attributes: + data: Row-oriented records (list of dicts). + target_col: Column to pass to :func:`loclean.clean`. + instruction: Cleaning instruction forwarded to the LLM. + engine_config: Key-value overrides passed to :func:`load_config`. + batch_size: Number of unique values per processing batch. + """ + + data: list[dict[str, Any]] + target_col: str + instruction: str = Field( + default="Extract the numeric value and unit as-is.", + ) + engine_config: dict[str, Any] = Field(default_factory=dict) + batch_size: int = Field(default=50, ge=1) + + +def run_pipeline(payload: PipelinePayload) -> dict[str, Any]: + """Execute the cleaning pipeline for a validated payload. + + Args: + payload: Validated pipeline payload. + + Returns: + Dictionary with ``status``, ``data`` (list of row dicts), and + ``row_count``. + + Raises: + ValueError: If *target_col* is missing from the data. + RuntimeError: On unexpected engine or cleaning failures. + """ + config = load_config(**payload.engine_config) + + df = pl.DataFrame(payload.data) + + result_df = clean( + df, + payload.target_col, + payload.instruction, + model=config.model, + host=config.host, + verbose=config.verbose, + batch_size=payload.batch_size, + ) + + rows: list[dict[str, Any]] = result_df.to_dicts() + + return { + "status": "ok", + "data": rows, + "row_count": len(rows), + } + + +def _error_response(code: int, message: str) -> dict[str, Any]: + """Build a structured error response dict. + + Args: + code: Exit code constant. + message: Human-readable error description. + + Returns: + Error dict with ``status``, ``code``, and ``message``. + """ + return {"status": "error", "code": code, "message": message} + + +def main() -> None: + """Entry point: read stdin → clean → write stdout.""" + try: + raw = sys.stdin.read() + except Exception as exc: + result = _error_response(EXIT_INVALID_PAYLOAD, f"stdin read error: {exc}") + sys.stdout.write(json.dumps(result)) + sys.exit(EXIT_INVALID_PAYLOAD) + return + + try: + data = json.loads(raw) + except json.JSONDecodeError as exc: + result = _error_response(EXIT_INVALID_PAYLOAD, f"Invalid JSON: {exc}") + sys.stdout.write(json.dumps(result)) + sys.exit(EXIT_INVALID_PAYLOAD) + return + + try: + payload = PipelinePayload(**data) + except ValidationError as exc: + result = _error_response( + EXIT_INVALID_PAYLOAD, f"Payload validation failed: {exc}" + ) + sys.stdout.write(json.dumps(result)) + sys.exit(EXIT_INVALID_PAYLOAD) + return + + try: + output = run_pipeline(payload) + except Exception as exc: + result = _error_response(EXIT_PIPELINE_ERROR, f"Pipeline error: {exc}") + sys.stdout.write(json.dumps(result)) + sys.exit(EXIT_PIPELINE_ERROR) + return + + try: + sys.stdout.write(json.dumps(output)) + except (TypeError, ValueError) as exc: + result = _error_response( + EXIT_SERIALIZATION_ERROR, f"Serialization error: {exc}" + ) + sys.stdout.write(json.dumps(result)) + sys.exit(EXIT_SERIALIZATION_ERROR) + return + + sys.exit(EXIT_OK) + + +if __name__ == "__main__": + main() diff --git a/src/loclean/utils/cache_keys.py b/src/loclean/utils/cache_keys.py new file mode 100644 index 0000000..8297d54 --- /dev/null +++ b/src/loclean/utils/cache_keys.py @@ -0,0 +1,37 @@ +"""Deterministic cache-key generation for compiled extraction functions.""" + +from __future__ import annotations + +import hashlib + + +def compute_code_key( + *, + columns: list[str], + dtypes: list[str], + target_col: str, + module_prefix: str, +) -> str: + """Build a SHA256 key from structural metadata. + + Implements ``Key = H(module_prefix + target_col + + sorted(columns) + sorted(dtypes))``. + + Args: + columns: DataFrame column names. + dtypes: Corresponding dtype strings. + target_col: Name of the target / log column. + module_prefix: Module-level discriminator (e.g. ``"feature_discovery"`` + or ``"shredder"``). + + Returns: + Hex-encoded SHA256 digest. + """ + parts = [ + module_prefix, + target_col, + ",".join(sorted(columns)), + ",".join(sorted(dtypes)), + ] + payload = "::".join(parts) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() diff --git a/src/loclean/utils/sandbox.py b/src/loclean/utils/sandbox.py new file mode 100644 index 0000000..632020c --- /dev/null +++ b/src/loclean/utils/sandbox.py @@ -0,0 +1,165 @@ +"""Sandboxed execution utilities for LLM-generated code. + +Provides restricted ``exec`` compilation and wall-clock timeout +enforcement via ``concurrent.futures.ThreadPoolExecutor``. +""" + +from __future__ import annotations + +import importlib +import logging +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError +from typing import Any, Callable, TypeVar + +from loclean.utils.logging import configure_module_logger + +logger = configure_module_logger(__name__, level=logging.INFO) + +T = TypeVar("T") + +_SAFE_BUILTINS: dict[str, Any] = { + "abs": abs, + "all": all, + "any": any, + "bool": bool, + "bytes": bytes, + "chr": chr, + "complex": complex, + "dict": dict, + "divmod": divmod, + "enumerate": enumerate, + "filter": filter, + "float": float, + "format": format, + "frozenset": frozenset, + "hasattr": hasattr, + "hash": hash, + "hex": hex, + "int": int, + "isinstance": isinstance, + "issubclass": issubclass, + "iter": iter, + "len": len, + "list": list, + "map": map, + "max": max, + "min": min, + "next": next, + "oct": oct, + "ord": ord, + "pow": pow, + "print": print, + "range": range, + "repr": repr, + "reversed": reversed, + "round": round, + "set": set, + "slice": slice, + "sorted": sorted, + "str": str, + "sum": sum, + "tuple": tuple, + "type": type, + "zip": zip, + "None": None, + "True": True, + "False": False, + "Exception": Exception, + "ValueError": ValueError, + "TypeError": TypeError, + "KeyError": KeyError, + "IndexError": IndexError, + "ZeroDivisionError": ZeroDivisionError, + "AttributeError": AttributeError, + "RuntimeError": RuntimeError, + "StopIteration": StopIteration, + "ArithmeticError": ArithmeticError, + "OverflowError": OverflowError, +} + + +class SandboxTimeoutError(RuntimeError): + """Raised when sandboxed execution exceeds the time limit.""" + + +def compile_sandboxed( + source: str, + fn_name: str, + allowed_modules: list[str] | None = None, +) -> Callable[..., Any]: + """Compile *source* in a restricted namespace and return *fn_name*. + + The execution environment has: + + * ``__builtins__`` replaced by a curated safe subset (no ``open``, + ``exec``, ``eval``, ``__import__``, ``compile``, ``exit``, + ``quit``, ``input``, ``breakpoint``, ``globals``, ``locals``, + ``vars``, ``dir``). + * Only explicitly listed standard-library modules injected. + + Args: + source: Python source code string. + fn_name: Name of the function to extract from the namespace. + allowed_modules: Standard-library module names to inject + (e.g. ``["math", "re"]``). + + Returns: + The compiled callable. + + Raises: + ValueError: If compilation fails or *fn_name* is not defined. + """ + safe_globals: dict[str, Any] = {"__builtins__": _SAFE_BUILTINS.copy()} + + for mod_name in allowed_modules or []: + try: + safe_globals[mod_name] = importlib.import_module(mod_name) + except ImportError: + logger.warning( + f"[yellow]⚠[/yellow] Module '{mod_name}' not available, skipping" + ) + + try: + exec(source, safe_globals) # noqa: S102 + except Exception as exc: + raise ValueError(f"Compilation failed: {exc}") from exc + + fn = safe_globals.get(fn_name) + if fn is None or not callable(fn): + raise ValueError(f"Source does not define '{fn_name}'") + return fn # type: ignore[no-any-return] + + +def run_with_timeout( + fn: Callable[..., T], + args: tuple[Any, ...], + timeout_s: float = 2.0, +) -> tuple[T | None, str]: + """Execute *fn* with a wall-clock time limit. + + Uses a daemon-threaded pool so the interpreter can exit even + when a timed-out function is still running. + + Args: + fn: Callable to execute. + args: Positional arguments forwarded to *fn*. + timeout_s: Maximum seconds to wait. + + Returns: + ``(result, "")`` on success, or ``(None, error_message)`` on + timeout or exception. + """ + pool = ThreadPoolExecutor(max_workers=1) + future = pool.submit(fn, *args) + try: + result = future.result(timeout=timeout_s) + return result, "" + except FuturesTimeoutError: + msg = f"Execution timed out after {timeout_s}s" + logger.warning(f"[yellow]⚠[/yellow] {msg}") + return None, msg + except Exception as exc: + return None, str(exc) + finally: + pool.shutdown(wait=False, cancel_futures=True) diff --git a/src/loclean/validation/__init__.py b/src/loclean/validation/__init__.py new file mode 100644 index 0000000..dcc788a --- /dev/null +++ b/src/loclean/validation/__init__.py @@ -0,0 +1,9 @@ +"""Explainable data quality gates. + +Evaluate structured data against natural-language constraints and produce +a compliance report with per-row reasoning. +""" + +from loclean.validation.quality_gate import QualityGate, QualityReport + +__all__ = ["QualityGate", "QualityReport"] diff --git a/src/loclean/validation/quality_gate.py b/src/loclean/validation/quality_gate.py new file mode 100644 index 0000000..38d7b2a --- /dev/null +++ b/src/loclean/validation/quality_gate.py @@ -0,0 +1,252 @@ +"""Quality gate evaluation via LLM-driven rule compliance checking. + +Checks each row of a DataFrame against a set of natural-language rules +using the local Ollama engine. Produces a :class:`QualityReport` with +compliance rate, per-failure reasoning, and programmatic dict output. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +import narwhals as nw +from pydantic import BaseModel, Field + +from loclean.extraction.json_repair import repair_json +from loclean.utils.logging import configure_module_logger + +if TYPE_CHECKING: + from narwhals.typing import IntoFrameT + + from loclean.inference.base import InferenceEngine + +logger = configure_module_logger(__name__, level=logging.INFO) + + +class _RowCompliance(BaseModel): + """LLM-constrained output for a single row evaluation.""" + + compliant: bool = Field( + ..., + description="Whether the row satisfies all specified rules", + ) + reasoning: str = Field( + ..., + description="Logical explanation for the compliance decision", + ) + + +class QualityReport(BaseModel): + """Aggregated compliance report across evaluated rows. + + Attributes: + total_rows: Number of rows evaluated. + passed_rows: Number of rows that passed all rules. + compliance_rate: Fraction of rows that passed (0.0–1.0). + failures: List of dicts, each containing ``row`` data and + ``reasoning`` for non-compliant rows. + """ + + total_rows: int + passed_rows: int + compliance_rate: float + failures: list[dict[str, Any]] = Field(default_factory=list) + + +class QualityGate: + """Evaluate data quality against natural-language rules. + + Prompts the Ollama engine to check each sampled row against the + provided rules, collecting boolean compliance flags and reasoning. + + Args: + inference_engine: Engine used for rule evaluation. + batch_size: Rows processed per LLM call batch. + sample_size: Maximum rows to evaluate (sampled if exceeded). + """ + + def __init__( + self, + inference_engine: InferenceEngine, + batch_size: int = 20, + sample_size: int = 100, + ) -> None: + if batch_size < 1: + raise ValueError("batch_size must be ≥ 1") + if sample_size < 1: + raise ValueError("sample_size must be ≥ 1") + self.inference_engine = inference_engine + self.batch_size = batch_size + self.sample_size = sample_size + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def evaluate( + self, + df: IntoFrameT, + rules: list[str], + ) -> dict[str, Any]: + """Evaluate a DataFrame against natural-language rules. + + Args: + df: Input DataFrame (pandas, Polars, etc.). + rules: List of natural-language constraint strings. + + Returns: + Dictionary representation of :class:`QualityReport`. + + Raises: + ValueError: If *rules* is empty. + """ + if not rules: + raise ValueError("At least one rule must be provided") + + df_nw = nw.from_native(df) # type: ignore[type-var] + rows = self._sample_rows(df_nw) + + if not rows: + report = QualityReport( + total_rows=0, + passed_rows=0, + compliance_rate=1.0, + ) + return report.model_dump() + + passed = 0 + failures: list[dict[str, Any]] = [] + + for i in range(0, len(rows), self.batch_size): + batch = rows[i : i + self.batch_size] + for row in batch: + compliance = self._check_row(row, rules) + if compliance.compliant: + passed += 1 + else: + failures.append( + { + "row": row, + "reasoning": compliance.reasoning, + } + ) + + total = len(rows) + rate = passed / total if total > 0 else 1.0 + + report = QualityReport( + total_rows=total, + passed_rows=passed, + compliance_rate=round(rate, 4), + failures=failures, + ) + + logger.info( + f"[green]✓[/green] Quality gate: " + f"[bold]{passed}/{total}[/bold] rows compliant " + f"({report.compliance_rate:.1%})" + ) + + return report.model_dump() + + # ------------------------------------------------------------------ + # Row sampling + # ------------------------------------------------------------------ + + def _sample_rows(self, df_nw: nw.DataFrame[Any]) -> list[dict[str, Any]]: + """Sample up to *sample_size* rows from the DataFrame. + + Args: + df_nw: Narwhals DataFrame. + + Returns: + List of row dicts. + """ + all_rows: list[dict[str, Any]] = df_nw.rows(named=True) # type: ignore[assignment] + if len(all_rows) <= self.sample_size: + return all_rows + + step = len(all_rows) / self.sample_size + return [all_rows[int(i * step)] for i in range(self.sample_size)] + + # ------------------------------------------------------------------ + # Per-row compliance check + # ------------------------------------------------------------------ + + def _check_row( + self, + row: dict[str, Any], + rules: list[str], + ) -> _RowCompliance: + """Prompt the engine to evaluate a single row against rules. + + Args: + row: Row data as a dictionary. + rules: Natural-language constraint strings. + + Returns: + Parsed compliance result. + """ + rules_text = "\n".join(f" {i + 1}. {rule}" for i, rule in enumerate(rules)) + prompt = ( + "You are a data quality auditor.\n\n" + "Given this data row:\n" + f"{json.dumps(row, ensure_ascii=False, default=str)}\n\n" + "Evaluate whether it satisfies ALL of the following rules:\n" + f"{rules_text}\n\n" + "Return a JSON object with:\n" + '- "compliant": true/false\n' + '- "reasoning": a brief explanation of your decision' + ) + + raw = self.inference_engine.generate(prompt, schema=_RowCompliance) + return self._parse_compliance(raw) + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_compliance(raw: Any) -> _RowCompliance: + """Best-effort parse of the LLM response into a compliance result. + + Args: + raw: Raw output from the inference engine. + + Returns: + Parsed ``_RowCompliance`` (defaults to non-compliant on + total parse failure). + """ + if isinstance(raw, dict): + try: + return _RowCompliance(**raw) + except Exception: + pass + + text = str(raw) if not isinstance(raw, str) else raw + + try: + parsed = json.loads(text) + return _RowCompliance(**parsed) + except (json.JSONDecodeError, TypeError, Exception): + pass + + try: + repaired = repair_json(text) + if isinstance(repaired, dict): + return _RowCompliance(**repaired) + parsed_r = json.loads(repaired) # type: ignore[arg-type] + return _RowCompliance(**parsed_r) + except Exception: + pass + + logger.warning( + "[yellow]⚠[/yellow] Could not parse compliance response. " + "Marking row as non-compliant." + ) + return _RowCompliance( + compliant=False, + reasoning="Failed to parse LLM response", + ) diff --git a/tests/unit/extraction/test_feature_discovery.py b/tests/unit/extraction/test_feature_discovery.py new file mode 100644 index 0000000..5024f28 --- /dev/null +++ b/tests/unit/extraction/test_feature_discovery.py @@ -0,0 +1,230 @@ +"""Test cases for FeatureDiscovery.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import polars as pl +import pytest + +from loclean.extraction.feature_discovery import FeatureDiscovery + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + +VALID_FEATURE_SOURCE = """ +def generate_features(row: dict) -> dict: + result = {} + try: + result["ratio_a_b"] = row.get("a", 0) / max(row.get("b", 1), 1) + except Exception: + result["ratio_a_b"] = None + try: + result["sum_a_b"] = row.get("a", 0) + row.get("b", 0) + except Exception: + result["sum_a_b"] = None + try: + val = row.get("a", 0) + result["log_a"] = math.log(val) if val and val > 0 else 0.0 + except Exception: + result["log_a"] = None + try: + result["product_a_b"] = row.get("a", 0) * row.get("b", 0) + except Exception: + result["product_a_b"] = None + try: + result["diff_a_b"] = row.get("a", 0) - row.get("b", 0) + except Exception: + result["diff_a_b"] = None + return result +""" + +SAMPLE_DF = pl.DataFrame( + { + "a": [1.0, 2.0, 3.0, 4.0], + "b": [10.0, 20.0, 30.0, 40.0], + "target": [0, 1, 0, 1], + } +) + + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def mock_engine() -> MagicMock: + engine = MagicMock() + engine.verbose = False + return engine + + +@pytest.fixture +def discoverer(mock_engine: MagicMock) -> FeatureDiscovery: + return FeatureDiscovery(inference_engine=mock_engine, n_features=5, max_retries=2) + + +# ------------------------------------------------------------------ +# __init__ validation +# ------------------------------------------------------------------ + + +class TestInit: + def test_rejects_zero_n_features(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="n_features"): + FeatureDiscovery(inference_engine=mock_engine, n_features=0) + + def test_rejects_zero_max_retries(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="max_retries"): + FeatureDiscovery(inference_engine=mock_engine, max_retries=0) + + def test_accepts_valid_params(self, mock_engine: MagicMock) -> None: + fd = FeatureDiscovery(inference_engine=mock_engine, n_features=3, max_retries=2) + assert fd.n_features == 3 + assert fd.max_retries == 2 + + +# ------------------------------------------------------------------ +# _extract_state +# ------------------------------------------------------------------ + + +class TestExtractState: + def test_extracts_columns_and_dtypes(self) -> None: + import narwhals as nw + + df = nw.from_native(SAMPLE_DF) + state = FeatureDiscovery._extract_state(df, "target") + assert "a" in state["columns"] + assert "b" in state["columns"] + assert "target" in state["columns"] + assert state["target_col"] == "target" + assert len(state["dtypes"]) == 3 + + def test_samples_rows(self) -> None: + import narwhals as nw + + df = nw.from_native(SAMPLE_DF) + state = FeatureDiscovery._extract_state(df, "target", sample_n=2) + assert len(state["sample_rows"]) == 2 + + def test_returns_all_rows_if_small(self) -> None: + import narwhals as nw + + df = nw.from_native(SAMPLE_DF) + state = FeatureDiscovery._extract_state(df, "target", sample_n=100) + assert len(state["sample_rows"]) == 4 + + +# ------------------------------------------------------------------ +# _compile_function +# ------------------------------------------------------------------ + + +class TestCompileFunction: + def test_valid_source(self) -> None: + fn = FeatureDiscovery._compile_function(VALID_FEATURE_SOURCE) + result = fn({"a": 2.0, "b": 10.0}) + assert "ratio_a_b" in result + assert "sum_a_b" in result + + def test_invalid_source_raises(self) -> None: + with pytest.raises(ValueError, match="Compilation failed"): + FeatureDiscovery._compile_function("def broken(") + + def test_missing_function_raises(self) -> None: + with pytest.raises(ValueError, match="does not define"): + FeatureDiscovery._compile_function("x = 1") + + +# ------------------------------------------------------------------ +# _verify_function +# ------------------------------------------------------------------ + + +class TestVerifyFunction: + def test_passes_valid_function(self) -> None: + fn = FeatureDiscovery._compile_function(VALID_FEATURE_SOURCE) + sample_rows = [ + {"a": 1.0, "b": 10.0, "target": 0}, + {"a": 2.0, "b": 20.0, "target": 1}, + ] + ok, error = FeatureDiscovery._verify_function(fn, sample_rows) + assert ok is True + assert error == "" + + def test_fails_on_exception(self) -> None: + def bad_fn(row: dict[str, Any]) -> dict[str, Any]: + raise RuntimeError("boom") + + ok, error = FeatureDiscovery._verify_function(bad_fn, [{"a": 1}]) + assert ok is False + assert "boom" in error + + def test_fails_on_non_dict_return(self) -> None: + def bad_fn(row: dict[str, Any]) -> Any: + return [1, 2, 3] + + ok, error = FeatureDiscovery._verify_function(bad_fn, [{"a": 1}]) + assert ok is False + assert "dict" in error + + def test_fails_on_empty_return(self) -> None: + def empty_fn(row: dict[str, Any]) -> dict[str, Any]: + return {} + + ok, error = FeatureDiscovery._verify_function(empty_fn, [{"a": 1}]) + assert ok is False + assert "empty" in error + + +# ------------------------------------------------------------------ +# discover (end-to-end with mocks) +# ------------------------------------------------------------------ + + +class TestDiscover: + def test_happy_path( + self, discoverer: FeatureDiscovery, mock_engine: MagicMock + ) -> None: + mock_engine.generate.return_value = VALID_FEATURE_SOURCE + + result = discoverer.discover(SAMPLE_DF, "target") + assert len(result) == 4 + assert "ratio_a_b" in result.columns + assert "sum_a_b" in result.columns + assert "a" in result.columns + + def test_missing_column_raises(self, discoverer: FeatureDiscovery) -> None: + with pytest.raises(ValueError, match="not found"): + discoverer.discover(SAMPLE_DF, "missing") + + def test_repair_cycle( + self, discoverer: FeatureDiscovery, mock_engine: MagicMock + ) -> None: + bad_source = "def generate_features(row): raise Exception('bad')" + mock_engine.generate.side_effect = [ + bad_source, + VALID_FEATURE_SOURCE, + ] + + result = discoverer.discover(SAMPLE_DF, "target") + assert "ratio_a_b" in result.columns + assert mock_engine.generate.call_count == 2 + + def test_exhausted_retries_returns_original( + self, discoverer: FeatureDiscovery, mock_engine: MagicMock + ) -> None: + bad_source = "def generate_features(row): raise Exception('bad')" + mock_engine.generate.side_effect = [ + bad_source, + bad_source, + bad_source, + ] + + result = discoverer.discover(SAMPLE_DF, "target") + assert list(result.columns) == list(SAMPLE_DF.columns) + assert len(result) == len(SAMPLE_DF) diff --git a/tests/unit/extraction/test_oversampler.py b/tests/unit/extraction/test_oversampler.py new file mode 100644 index 0000000..b02a2da --- /dev/null +++ b/tests/unit/extraction/test_oversampler.py @@ -0,0 +1,250 @@ +"""Test cases for SemanticOversampler.""" + +from __future__ import annotations + +import json +from typing import Any +from unittest.mock import MagicMock + +import polars as pl +import pytest +from pydantic import BaseModel + +from loclean.extraction.oversampler import SemanticOversampler, _row_key + +# ------------------------------------------------------------------ +# Test schema +# ------------------------------------------------------------------ + + +class SampleRecord(BaseModel): + label: str + value: float + category: str + + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def mock_engine() -> MagicMock: + engine = MagicMock() + engine.verbose = False + return engine + + +@pytest.fixture +def sampler(mock_engine: MagicMock) -> SemanticOversampler: + return SemanticOversampler( + inference_engine=mock_engine, batch_size=5, max_retries=3 + ) + + +# ------------------------------------------------------------------ +# __init__ validation +# ------------------------------------------------------------------ + + +class TestInit: + def test_rejects_zero_batch_size(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="batch_size"): + SemanticOversampler(inference_engine=mock_engine, batch_size=0) + + def test_rejects_zero_max_retries(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="max_retries"): + SemanticOversampler(inference_engine=mock_engine, max_retries=0) + + def test_accepts_valid_params(self, mock_engine: MagicMock) -> None: + s = SemanticOversampler( + inference_engine=mock_engine, batch_size=3, max_retries=2 + ) + assert s.batch_size == 3 + assert s.max_retries == 2 + + +# ------------------------------------------------------------------ +# _sample_rows +# ------------------------------------------------------------------ + + +class TestSampleRows: + def test_filters_minority_class(self, sampler: SemanticOversampler) -> None: + import narwhals as nw + + df = nw.from_native( + pl.DataFrame( + { + "label": ["A", "B", "A", "B", "A"], + "value": [1.0, 2.0, 3.0, 4.0, 5.0], + } + ) + ) + rows = sampler._sample_rows(df, "label", "B") + assert len(rows) == 2 + assert all(r["label"] == "B" for r in rows) + + def test_returns_empty_for_missing_class( + self, sampler: SemanticOversampler + ) -> None: + import narwhals as nw + + df = nw.from_native(pl.DataFrame({"label": ["A", "A"]})) + rows = sampler._sample_rows(df, "label", "Z") + assert rows == [] + + def test_samples_when_exceeds_n(self, sampler: SemanticOversampler) -> None: + import narwhals as nw + + df = nw.from_native( + pl.DataFrame({"label": ["X"] * 100, "val": list(range(100))}) + ) + rows = sampler._sample_rows(df, "label", "X", n=5) + assert len(rows) == 5 + + +# ------------------------------------------------------------------ +# _parse_batch_response +# ------------------------------------------------------------------ + + +class TestParseBatchResponse: + def test_parses_dict_with_records(self) -> None: + raw = {"records": [{"label": "A", "value": 1.0, "category": "x"}]} + result = SemanticOversampler._parse_batch_response(raw) + assert len(result) == 1 + + def test_parses_json_string(self) -> None: + raw = json.dumps({"records": [{"a": 1}, {"a": 2}]}) + result = SemanticOversampler._parse_batch_response(raw) + assert len(result) == 2 + + def test_parses_raw_list(self) -> None: + raw = json.dumps([{"a": 1}]) + result = SemanticOversampler._parse_batch_response(raw) + assert len(result) == 1 + + def test_repairs_malformed(self) -> None: + raw = '{"records": [{"a": 1},]}' + result = SemanticOversampler._parse_batch_response(raw) + assert len(result) >= 1 + + def test_returns_empty_on_failure(self) -> None: + result = SemanticOversampler._parse_batch_response("completely broken") + assert result == [] + + +# ------------------------------------------------------------------ +# _row_key / deduplication +# ------------------------------------------------------------------ + + +class TestDeduplicate: + def test_row_key_deterministic(self) -> None: + r1 = {"a": 1, "b": "x"} + r2 = {"b": "x", "a": 1} + assert _row_key(r1) == _row_key(r2) + + def test_validate_and_filter_removes_dupes( + self, sampler: SemanticOversampler + ) -> None: + existing = {_row_key({"label": "A", "value": 1.0, "category": "x"})} + candidates = [ + {"label": "A", "value": 1.0, "category": "x"}, + {"label": "A", "value": 2.0, "category": "y"}, + ] + result = sampler._validate_and_filter(candidates, SampleRecord, existing, []) + assert len(result) == 1 + assert result[0]["value"] == 2.0 + + def test_validate_and_filter_rejects_invalid_schema( + self, sampler: SemanticOversampler + ) -> None: + candidates: list[dict[str, Any]] = [ + {"label": "A"}, + ] + result = sampler._validate_and_filter(candidates, SampleRecord, set(), []) + assert result == [] + + +# ------------------------------------------------------------------ +# _generate_batch +# ------------------------------------------------------------------ + + +class TestGenerateBatch: + def test_prompt_includes_schema_and_samples( + self, sampler: SemanticOversampler, mock_engine: MagicMock + ) -> None: + mock_engine.generate.return_value = json.dumps( + {"records": [{"label": "B", "value": 9.0, "category": "new"}]} + ) + sample = [{"label": "B", "value": 1.0, "category": "old"}] + sampler._generate_batch(sample, SampleRecord, 1) + + prompt = mock_engine.generate.call_args[0][0] + assert "label" in prompt + assert "value" in prompt + assert "old" in prompt + + +# ------------------------------------------------------------------ +# oversample (end-to-end with mocks) +# ------------------------------------------------------------------ + + +class TestOversample: + def test_happy_path( + self, sampler: SemanticOversampler, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame( + { + "label": ["A", "A", "A", "B"], + "value": [1.0, 2.0, 3.0, 10.0], + "category": ["x", "y", "z", "w"], + } + ) + + mock_engine.generate.return_value = json.dumps( + { + "records": [ + {"label": "B", "value": 20.0, "category": "v"}, + {"label": "B", "value": 30.0, "category": "u"}, + ] + } + ) + + result = sampler.oversample(df, "label", "B", n=2, schema=SampleRecord) + assert len(result) == 6 + assert result["label"].to_list().count("B") == 3 + + def test_missing_column_raises(self, sampler: SemanticOversampler) -> None: + df = pl.DataFrame({"a": [1]}) + with pytest.raises(ValueError, match="not found"): + sampler.oversample(df, "missing", "x", n=1, schema=SampleRecord) + + def test_empty_minority_raises(self, sampler: SemanticOversampler) -> None: + df = pl.DataFrame({"label": ["A", "A"]}) + with pytest.raises(ValueError, match="No rows found"): + sampler.oversample(df, "label", "Z", n=1, schema=SampleRecord) + + def test_dedup_triggers_retry( + self, sampler: SemanticOversampler, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame( + { + "label": ["B"], + "value": [1.0], + "category": ["x"], + } + ) + + mock_engine.generate.side_effect = [ + json.dumps({"records": [{"label": "B", "value": 1.0, "category": "x"}]}), + json.dumps({"records": [{"label": "B", "value": 99.0, "category": "new"}]}), + ] + + result = sampler.oversample(df, "label", "B", n=1, schema=SampleRecord) + assert len(result) == 2 + assert mock_engine.generate.call_count == 2 diff --git a/tests/unit/extraction/test_resolver.py b/tests/unit/extraction/test_resolver.py new file mode 100644 index 0000000..bba7581 --- /dev/null +++ b/tests/unit/extraction/test_resolver.py @@ -0,0 +1,230 @@ +"""Test cases for EntityResolver.""" + +import json +from unittest.mock import MagicMock + +import polars as pl +import pytest + +from loclean.extraction.resolver import EntityResolver + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def mock_engine() -> MagicMock: + engine = MagicMock() + engine.verbose = False + return engine + + +@pytest.fixture +def resolver(mock_engine: MagicMock) -> EntityResolver: + return EntityResolver(inference_engine=mock_engine, threshold=0.8) + + +# ------------------------------------------------------------------ +# __init__ validation +# ------------------------------------------------------------------ + + +class TestInit: + def test_rejects_zero_threshold(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="threshold must be in"): + EntityResolver(inference_engine=mock_engine, threshold=0) + + def test_rejects_negative_threshold(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="threshold must be in"): + EntityResolver(inference_engine=mock_engine, threshold=-0.5) + + def test_rejects_threshold_above_one(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="threshold must be in"): + EntityResolver(inference_engine=mock_engine, threshold=1.5) + + def test_accepts_valid_threshold(self, mock_engine: MagicMock) -> None: + r = EntityResolver(inference_engine=mock_engine, threshold=0.5) + assert r.threshold == 0.5 + + def test_accepts_threshold_one(self, mock_engine: MagicMock) -> None: + r = EntityResolver(inference_engine=mock_engine, threshold=1.0) + assert r.threshold == 1.0 + + +# ------------------------------------------------------------------ +# _extract_unique_values +# ------------------------------------------------------------------ + + +class TestExtractUniqueValues: + def test_returns_unique_non_empty(self, resolver: EntityResolver) -> None: + import narwhals as nw + + df = nw.from_native(pl.DataFrame({"col": ["a", "b", "a", "c", "b"]})) + result = resolver._extract_unique_values(df, "col") + assert sorted(result) == ["a", "b", "c"] + + def test_filters_none_and_whitespace(self, resolver: EntityResolver) -> None: + import narwhals as nw + + df = nw.from_native(pl.DataFrame({"col": ["a", None, "", " ", "b"]})) + result = resolver._extract_unique_values(df, "col") + assert sorted(result) == ["a", "b"] + + def test_empty_column(self, resolver: EntityResolver) -> None: + import narwhals as nw + + df = nw.from_native(pl.DataFrame({"col": [None, "", " "]})) + result = resolver._extract_unique_values(df, "col") + assert result == [] + + +# ------------------------------------------------------------------ +# _parse_mapping_response +# ------------------------------------------------------------------ + + +class TestParseMappingResponse: + def test_parses_dict(self) -> None: + raw = {"mapping": {"NYC": "New York City", "NY": "New York City"}} + result = EntityResolver._parse_mapping_response(raw) + assert result["mapping"]["NYC"] == "New York City" + + def test_parses_json_string(self) -> None: + raw = json.dumps({"mapping": {"a": "A", "b": "B"}}) + result = EntityResolver._parse_mapping_response(raw) + assert result["mapping"] == {"a": "A", "b": "B"} + + def test_repairs_malformed_json(self) -> None: + raw = '{"mapping": {"a": "A", "b": "B",}}' + result = EntityResolver._parse_mapping_response(raw) + assert "mapping" in result + + def test_returns_empty_on_total_failure(self) -> None: + result = EntityResolver._parse_mapping_response("completely broken") + assert result == {} + + +# ------------------------------------------------------------------ +# _build_canonical_mapping +# ------------------------------------------------------------------ + + +class TestBuildCanonicalMapping: + def test_happy_path(self, resolver: EntityResolver, mock_engine: MagicMock) -> None: + mock_engine.generate.return_value = json.dumps( + {"mapping": {"NYC": "New York City", "New York": "New York City"}} + ) + result = resolver._build_canonical_mapping(["NYC", "New York"]) + assert result["NYC"] == "New York City" + assert result["New York"] == "New York City" + + def test_unmapped_values_keep_original( + self, resolver: EntityResolver, mock_engine: MagicMock + ) -> None: + mock_engine.generate.return_value = json.dumps( + {"mapping": {"NYC": "New York City"}} + ) + result = resolver._build_canonical_mapping(["NYC", "Tokyo"]) + assert result["NYC"] == "New York City" + assert result["Tokyo"] == "Tokyo" + + def test_fallback_on_unparseable_response( + self, resolver: EntityResolver, mock_engine: MagicMock + ) -> None: + mock_engine.generate.return_value = "not json at all" + result = resolver._build_canonical_mapping(["a", "b"]) + assert result == {"a": "a", "b": "b"} + + def test_filters_empty_canonical_values( + self, resolver: EntityResolver, mock_engine: MagicMock + ) -> None: + mock_engine.generate.return_value = json.dumps( + {"mapping": {"a": "", "b": " ", "c": "Canon"}} + ) + result = resolver._build_canonical_mapping(["a", "b", "c"]) + assert result["a"] == "a" + assert result["b"] == "b" + assert result["c"] == "Canon" + + def test_prompt_includes_threshold( + self, resolver: EntityResolver, mock_engine: MagicMock + ) -> None: + mock_engine.generate.return_value = json.dumps({"mapping": {"x": "x"}}) + resolver._build_canonical_mapping(["x"]) + prompt = mock_engine.generate.call_args[0][0] + assert "0.8" in prompt + + +# ------------------------------------------------------------------ +# resolve (end-to-end with mocks) +# ------------------------------------------------------------------ + + +class TestResolve: + def test_adds_canonical_column( + self, resolver: EntityResolver, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame({"city": ["NYC", "New York", "NYC", "LA"]}) + + mock_engine.generate.return_value = json.dumps( + { + "mapping": { + "LA": "Los Angeles", + "NYC": "New York City", + "New York": "New York City", + } + } + ) + + result = resolver.resolve(df, "city") + assert "city_canonical" in result.columns + canonical = result["city_canonical"].to_list() + assert canonical == [ + "New York City", + "New York City", + "New York City", + "Los Angeles", + ] + + def test_raises_on_missing_column(self, resolver: EntityResolver) -> None: + df = pl.DataFrame({"a": [1]}) + with pytest.raises(ValueError, match="Column 'b' not found"): + resolver.resolve(df, "b") + + def test_all_null_returns_identity(self, resolver: EntityResolver) -> None: + df = pl.DataFrame({"col": [None, None]}) + result = resolver.resolve(df, "col") + assert "col_canonical" in result.columns + + def test_preserves_original_column( + self, resolver: EntityResolver, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame({"name": ["Alice", "Bob"]}) + mock_engine.generate.return_value = json.dumps( + {"mapping": {"Alice": "Alice", "Bob": "Bob"}} + ) + result = resolver.resolve(df, "name") + assert result["name"].to_list() == ["Alice", "Bob"] + assert "name_canonical" in result.columns + + def test_identity_mapping_when_no_groups( + self, resolver: EntityResolver, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame({"item": ["apple", "banana", "cherry"]}) + mock_engine.generate.return_value = json.dumps( + { + "mapping": { + "apple": "apple", + "banana": "banana", + "cherry": "cherry", + } + } + ) + result = resolver.resolve(df, "item") + assert result["item_canonical"].to_list() == [ + "apple", + "banana", + "cherry", + ] diff --git a/tests/unit/extraction/test_shredder.py b/tests/unit/extraction/test_shredder.py new file mode 100644 index 0000000..c5f6130 --- /dev/null +++ b/tests/unit/extraction/test_shredder.py @@ -0,0 +1,299 @@ +"""Test cases for RelationalShredder.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import polars as pl +import pytest + +from loclean.extraction.shredder import RelationalShredder, _RelationalSchema, _TableDef + +# ------------------------------------------------------------------ +# Helpers +# ------------------------------------------------------------------ + +SAMPLE_LOGS = [ + "2024-01-01 10:00:00 INFO server=web01 user=alice action=login ip=1.2.3.4", + "2024-01-01 10:01:00 ERROR server=web01 user=bob action=timeout code=504", + "2024-01-01 10:02:00 INFO server=db01 user=alice action=query rows=42", +] + +SAMPLE_SCHEMA = _RelationalSchema( + tables=[ + _TableDef( + name="events", + columns=["timestamp", "level", "server", "action"], + primary_key="timestamp", + foreign_key=None, + ), + _TableDef( + name="metadata", + columns=["timestamp", "user", "ip", "code", "rows"], + primary_key="timestamp", + foreign_key="timestamp", + ), + ] +) + +VALID_EXTRACT_SOURCE = """ +def extract_relations(log: str) -> dict[str, dict]: + parts = log.split() + result = { + "events": { + "timestamp": parts[0] + " " + parts[1] if len(parts) > 1 else "", + "level": parts[2] if len(parts) > 2 else "", + "server": "", + "action": "", + }, + "metadata": { + "timestamp": parts[0] + " " + parts[1] if len(parts) > 1 else "", + "user": "", + "ip": "", + "code": "", + "rows": "", + }, + } + try: + for part in parts[3:]: + if "=" in part: + k, v = part.split("=", 1) + if k == "server": + result["events"]["server"] = v + elif k == "action": + result["events"]["action"] = v + elif k in ("user", "ip", "code", "rows"): + result["metadata"][k] = v + except Exception: + pass + return result +""" + + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def mock_engine() -> MagicMock: + engine = MagicMock() + engine.verbose = False + return engine + + +@pytest.fixture +def shredder(mock_engine: MagicMock) -> RelationalShredder: + return RelationalShredder( + inference_engine=mock_engine, sample_size=10, max_retries=2 + ) + + +# ------------------------------------------------------------------ +# __init__ validation +# ------------------------------------------------------------------ + + +class TestInit: + def test_rejects_zero_sample_size(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="sample_size"): + RelationalShredder(inference_engine=mock_engine, sample_size=0) + + def test_rejects_zero_max_retries(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="max_retries"): + RelationalShredder(inference_engine=mock_engine, max_retries=0) + + def test_accepts_valid_params(self, mock_engine: MagicMock) -> None: + s = RelationalShredder( + inference_engine=mock_engine, + sample_size=5, + max_retries=2, + ) + assert s.sample_size == 5 + assert s.max_retries == 2 + + +# ------------------------------------------------------------------ +# _sample_entries +# ------------------------------------------------------------------ + + +class TestSampleEntries: + def test_extracts_non_empty(self, shredder: RelationalShredder) -> None: + import narwhals as nw + + df = nw.from_native(pl.DataFrame({"log": SAMPLE_LOGS + ["", None]})) + entries = shredder._sample_entries(df, "log") + assert len(entries) == 3 + assert all(e.strip() for e in entries) + + def test_samples_when_exceeds_limit(self, mock_engine: MagicMock) -> None: + import narwhals as nw + + s = RelationalShredder(inference_engine=mock_engine, sample_size=2) + df = nw.from_native(pl.DataFrame({"log": [f"entry_{i}" for i in range(100)]})) + entries = s._sample_entries(df, "log") + assert len(entries) == 2 + + def test_empty_column(self, shredder: RelationalShredder) -> None: + import narwhals as nw + + df = nw.from_native(pl.DataFrame({"log": ["", None]})) + entries = shredder._sample_entries(df, "log") + assert entries == [] + + +# ------------------------------------------------------------------ +# _parse_schema_response +# ------------------------------------------------------------------ + + +class TestInferSchema: + def test_parses_dict(self) -> None: + raw = SAMPLE_SCHEMA.model_dump() + result = RelationalShredder._parse_schema_response(raw) + assert len(result.tables) == 2 + + def test_parses_json_string(self) -> None: + raw = SAMPLE_SCHEMA.model_dump_json() + result = RelationalShredder._parse_schema_response(raw) + assert result.tables[0].name == "events" + + def test_raises_on_failure(self) -> None: + with pytest.raises(ValueError, match="Failed to parse"): + RelationalShredder._parse_schema_response("broken") + + +# ------------------------------------------------------------------ +# _compile_function +# ------------------------------------------------------------------ + + +class TestCompileFunction: + def test_valid_source(self) -> None: + fn = RelationalShredder._compile_function(VALID_EXTRACT_SOURCE) + result = fn(SAMPLE_LOGS[0]) + assert "events" in result + assert "metadata" in result + + def test_invalid_source_raises(self) -> None: + with pytest.raises(ValueError, match="Compilation failed"): + RelationalShredder._compile_function("def broken(") + + def test_missing_function_raises(self) -> None: + with pytest.raises(ValueError, match="does not define"): + RelationalShredder._compile_function("x = 1") + + +# ------------------------------------------------------------------ +# _verify_function +# ------------------------------------------------------------------ + + +class TestVerifyFunction: + def test_passes_valid_function(self, shredder: RelationalShredder) -> None: + fn = RelationalShredder._compile_function(VALID_EXTRACT_SOURCE) + ok, error = shredder._verify_function(fn, SAMPLE_LOGS, SAMPLE_SCHEMA) + assert ok is True + assert error == "" + + def test_fails_on_exception(self, shredder: RelationalShredder) -> None: + def bad_fn(log: str) -> dict[str, Any]: + raise RuntimeError("boom") + + ok, error = shredder._verify_function(bad_fn, SAMPLE_LOGS, SAMPLE_SCHEMA) + assert ok is False + assert "boom" in error + + def test_fails_on_missing_table(self, shredder: RelationalShredder) -> None: + def partial_fn(log: str) -> dict[str, dict[str, str]]: + return {"events": {"timestamp": "x"}} + + ok, error = shredder._verify_function(partial_fn, SAMPLE_LOGS, SAMPLE_SCHEMA) + assert ok is False + assert "Missing tables" in error + + +# ------------------------------------------------------------------ +# _separate_tables +# ------------------------------------------------------------------ + + +class TestSeparateTables: + def test_builds_multi_table(self) -> None: + fn = RelationalShredder._compile_function(VALID_EXTRACT_SOURCE) + results = [fn(log) for log in SAMPLE_LOGS] + + tables = RelationalShredder._separate_tables(results, SAMPLE_SCHEMA, pl) + assert "events" in tables + assert "metadata" in tables + assert len(tables["events"]) == 3 + assert len(tables["metadata"]) == 3 + + +# ------------------------------------------------------------------ +# shred (end-to-end with mocks) +# ------------------------------------------------------------------ + + +class TestShred: + def test_happy_path( + self, shredder: RelationalShredder, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame({"log": SAMPLE_LOGS}) + + mock_engine.generate.side_effect = [ + SAMPLE_SCHEMA.model_dump(), + VALID_EXTRACT_SOURCE, + ] + + result = shredder.shred(df, "log") + assert isinstance(result, dict) + assert len(result) >= 2 + assert "events" in result + assert "metadata" in result + + def test_missing_column_raises(self, shredder: RelationalShredder) -> None: + df = pl.DataFrame({"a": [1]}) + with pytest.raises(ValueError, match="not found"): + shredder.shred(df, "missing") + + def test_empty_column_raises(self, shredder: RelationalShredder) -> None: + df = pl.DataFrame({"log": ["", None]}) + with pytest.raises(ValueError, match="No valid entries"): + shredder.shred(df, "log") + + def test_repair_cycle( + self, shredder: RelationalShredder, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame({"log": SAMPLE_LOGS}) + + bad_source = "def extract_relations(log): raise Exception('bad')" + + mock_engine.generate.side_effect = [ + SAMPLE_SCHEMA.model_dump(), + bad_source, + VALID_EXTRACT_SOURCE, + ] + + result = shredder.shred(df, "log") + assert "events" in result + assert mock_engine.generate.call_count == 3 + + def test_exhausted_retries_returns_empty( + self, shredder: RelationalShredder, mock_engine: MagicMock + ) -> None: + df = pl.DataFrame({"log": SAMPLE_LOGS}) + + bad_source = "def extract_relations(log): raise Exception('bad')" + + mock_engine.generate.side_effect = [ + SAMPLE_SCHEMA.model_dump(), + bad_source, + bad_source, + bad_source, + ] + + result = shredder.shred(df, "log") + assert result == {} diff --git a/tests/unit/orchestration/__init__.py b/tests/unit/orchestration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/orchestration/test_runner.py b/tests/unit/orchestration/test_runner.py new file mode 100644 index 0000000..aa7fd51 --- /dev/null +++ b/tests/unit/orchestration/test_runner.py @@ -0,0 +1,226 @@ +"""Test cases for the orchestration runner.""" + +from __future__ import annotations + +import json +from io import StringIO +from unittest.mock import MagicMock, patch + +import polars as pl +import pytest +from pydantic import ValidationError + +from loclean.orchestration.runner import ( + EXIT_INVALID_PAYLOAD, + EXIT_OK, + EXIT_PIPELINE_ERROR, + PipelinePayload, + _error_response, + main, + run_pipeline, +) + +# ------------------------------------------------------------------ +# PipelinePayload +# ------------------------------------------------------------------ + + +class TestPipelinePayload: + def test_valid_minimal(self) -> None: + p = PipelinePayload( + data=[{"price": "10 kg"}], + target_col="price", + ) + assert p.target_col == "price" + assert p.instruction == "Extract the numeric value and unit as-is." + assert p.engine_config == {} + assert p.batch_size == 50 + + def test_valid_full(self) -> None: + p = PipelinePayload( + data=[{"x": "1"}], + target_col="x", + instruction="Custom instruction", + engine_config={"model": "llama3"}, + batch_size=10, + ) + assert p.instruction == "Custom instruction" + assert p.engine_config["model"] == "llama3" + assert p.batch_size == 10 + + def test_missing_data_raises(self) -> None: + with pytest.raises(ValidationError): + PipelinePayload(target_col="x") # type: ignore[call-arg] + + def test_missing_target_col_raises(self) -> None: + with pytest.raises(ValidationError): + PipelinePayload(data=[{"a": 1}]) # type: ignore[call-arg] + + def test_batch_size_must_be_positive(self) -> None: + with pytest.raises(ValidationError): + PipelinePayload( + data=[{"a": 1}], + target_col="a", + batch_size=0, + ) + + +# ------------------------------------------------------------------ +# _error_response +# ------------------------------------------------------------------ + + +class TestErrorResponse: + def test_structure(self) -> None: + r = _error_response(1, "bad input") + assert r == {"status": "error", "code": 1, "message": "bad input"} + + +# ------------------------------------------------------------------ +# run_pipeline +# ------------------------------------------------------------------ + + +class TestRunPipeline: + def test_happy_path(self) -> None: + payload = PipelinePayload( + data=[{"price": "10 kg"}, {"price": "20 lbs"}], + target_col="price", + ) + + fake_result = pl.DataFrame( + { + "price": ["10 kg", "20 lbs"], + "clean_value": [10.0, 20.0], + "clean_unit": ["kg", "lbs"], + "clean_reasoning": ["numeric", "numeric"], + } + ) + + with patch("loclean.orchestration.runner.clean", return_value=fake_result): + result = run_pipeline(payload) + + assert result["status"] == "ok" + assert result["row_count"] == 2 + assert len(result["data"]) == 2 + assert result["data"][0]["clean_value"] == 10.0 + assert result["data"][1]["clean_unit"] == "lbs" + + def test_missing_column_raises(self) -> None: + payload = PipelinePayload( + data=[{"a": 1}], + target_col="missing_col", + ) + + with pytest.raises(ValueError, match="Column 'missing_col' not found"): + run_pipeline(payload) + + def test_engine_config_forwarded(self) -> None: + payload = PipelinePayload( + data=[{"x": "1"}], + target_col="x", + engine_config={"model": "llama3", "host": "http://custom:11434"}, + ) + + fake_result = pl.DataFrame( + { + "x": ["1"], + "clean_value": [1.0], + "clean_unit": [""], + "clean_reasoning": [""], + } + ) + + with ( + patch("loclean.orchestration.runner.clean", return_value=fake_result), + patch("loclean.orchestration.runner.load_config") as mock_load, + ): + mock_load.return_value = MagicMock( + model="llama3", host="http://custom:11434", verbose=False + ) + run_pipeline(payload) + + mock_load.assert_called_once_with( + model="llama3", host="http://custom:11434" + ) + + +# ------------------------------------------------------------------ +# main (stdin/stdout integration) +# ------------------------------------------------------------------ + + +def _run_main(stdin_data: str) -> tuple[str, int]: + """Helper: run main() with given stdin, capture stdout + exit code.""" + exit_code = EXIT_OK + stdout_buf = StringIO() + + def fake_exit(code: int) -> None: + nonlocal exit_code + exit_code = code + + with ( + patch("sys.stdin", StringIO(stdin_data)), + patch("sys.stdout", stdout_buf), + patch("sys.exit", side_effect=fake_exit), + ): + try: + main() + except SystemExit: + pass + + return stdout_buf.getvalue(), exit_code + + +class TestMain: + def test_invalid_json_exits_1(self) -> None: + output, code = _run_main("not json {{{") + assert code == EXIT_INVALID_PAYLOAD + parsed = json.loads(output) + assert parsed["status"] == "error" + assert parsed["code"] == EXIT_INVALID_PAYLOAD + + def test_invalid_payload_schema_exits_1(self) -> None: + output, code = _run_main(json.dumps({"wrong_key": 123})) + assert code == EXIT_INVALID_PAYLOAD + parsed = json.loads(output) + assert "validation" in parsed["message"].lower() + + def test_pipeline_error_exits_2(self) -> None: + payload = {"data": [{"x": "1"}], "target_col": "x"} + with patch( + "loclean.orchestration.runner.run_pipeline", + side_effect=RuntimeError("engine down"), + ): + output, code = _run_main(json.dumps(payload)) + + assert code == EXIT_PIPELINE_ERROR + parsed = json.loads(output) + assert parsed["code"] == EXIT_PIPELINE_ERROR + assert "engine down" in parsed["message"] + + def test_successful_round_trip(self) -> None: + payload = {"data": [{"col": "5 kg"}], "target_col": "col"} + expected_output = { + "status": "ok", + "data": [ + { + "col": "5 kg", + "clean_value": 5.0, + "clean_unit": "kg", + "clean_reasoning": "parsed", + } + ], + "row_count": 1, + } + + with patch( + "loclean.orchestration.runner.run_pipeline", + return_value=expected_output, + ): + output, code = _run_main(json.dumps(payload)) + + assert code == EXIT_OK + parsed = json.loads(output) + assert parsed["status"] == "ok" + assert parsed["row_count"] == 1 diff --git a/tests/unit/test_code_cache.py b/tests/unit/test_code_cache.py new file mode 100644 index 0000000..5376492 --- /dev/null +++ b/tests/unit/test_code_cache.py @@ -0,0 +1,249 @@ +"""Tests for code cache integration and hash key generation.""" + +from __future__ import annotations + +import tempfile +from pathlib import Path +from unittest.mock import MagicMock + +import polars as pl + +from loclean.cache import LocleanCache +from loclean.extraction.feature_discovery import FeatureDiscovery +from loclean.extraction.shredder import RelationalShredder, _RelationalSchema, _TableDef +from loclean.utils.cache_keys import compute_code_key + +# ------------------------------------------------------------------ +# compute_code_key +# ------------------------------------------------------------------ + + +class TestComputeCodeKey: + def test_deterministic(self) -> None: + k1 = compute_code_key( + columns=["a", "b"], + dtypes=["int", "float"], + target_col="t", + module_prefix="test", + ) + k2 = compute_code_key( + columns=["a", "b"], + dtypes=["int", "float"], + target_col="t", + module_prefix="test", + ) + assert k1 == k2 + + def test_order_invariant(self) -> None: + k1 = compute_code_key( + columns=["b", "a"], + dtypes=["float", "int"], + target_col="t", + module_prefix="test", + ) + k2 = compute_code_key( + columns=["a", "b"], + dtypes=["int", "float"], + target_col="t", + module_prefix="test", + ) + assert k1 == k2 + + def test_differs_by_module(self) -> None: + k1 = compute_code_key( + columns=["a"], + dtypes=["int"], + target_col="t", + module_prefix="feature_discovery", + ) + k2 = compute_code_key( + columns=["a"], + dtypes=["int"], + target_col="t", + module_prefix="shredder", + ) + assert k1 != k2 + + def test_differs_by_target(self) -> None: + k1 = compute_code_key( + columns=["a"], + dtypes=["int"], + target_col="x", + module_prefix="test", + ) + k2 = compute_code_key( + columns=["a"], + dtypes=["int"], + target_col="y", + module_prefix="test", + ) + assert k1 != k2 + + def test_returns_hex_sha256(self) -> None: + k = compute_code_key( + columns=["a"], + dtypes=["int"], + target_col="t", + module_prefix="test", + ) + assert len(k) == 64 + assert all(c in "0123456789abcdef" for c in k) + + +# ------------------------------------------------------------------ +# LocleanCache.get_code / .set_code +# ------------------------------------------------------------------ + + +class TestCodeCacheRoundtrip: + def test_miss_returns_none(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + cache = LocleanCache(cache_dir=Path(tmp)) + assert cache.get_code("unknown") is None + cache.close() + + def test_set_then_get(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + cache = LocleanCache(cache_dir=Path(tmp)) + cache.set_code("abc123", "def f(): return 1") + assert cache.get_code("abc123") == "def f(): return 1" + cache.close() + + def test_upsert_overwrites(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + cache = LocleanCache(cache_dir=Path(tmp)) + cache.set_code("key", "v1") + cache.set_code("key", "v2") + assert cache.get_code("key") == "v2" + cache.close() + + +# ------------------------------------------------------------------ +# FeatureDiscovery — cache integration & graceful fallback +# ------------------------------------------------------------------ + +VALID_FEATURE_SOURCE = """ +def generate_features(row: dict) -> dict: + result = {} + try: + result["sum_a_b"] = row.get("a", 0) + row.get("b", 0) + except Exception: + result["sum_a_b"] = None + return result +""" + +SAMPLE_DF = pl.DataFrame( + {"a": [1.0, 2.0, 3.0], "b": [10.0, 20.0, 30.0], "target": [0, 1, 0]} +) + + +class TestDiscoverCacheHit: + def test_skips_llm_on_hit(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + cache = LocleanCache(cache_dir=Path(tmp)) + engine = MagicMock() + engine.verbose = False + + key = compute_code_key( + columns=["a", "b", "target"], + dtypes=[str(SAMPLE_DF[c].dtype) for c in SAMPLE_DF.columns], + target_col="target", + module_prefix="feature_discovery", + ) + cache.set_code(key, VALID_FEATURE_SOURCE) + + fd = FeatureDiscovery( + inference_engine=engine, + n_features=1, + cache=cache, + ) + result = fd.discover(SAMPLE_DF, "target") + + engine.generate.assert_not_called() + assert "sum_a_b" in result.columns + cache.close() + + +class TestDiscoverCacheMissStores: + def test_stores_on_success(self) -> None: + with tempfile.TemporaryDirectory() as tmp: + cache = LocleanCache(cache_dir=Path(tmp)) + engine = MagicMock() + engine.verbose = False + engine.generate.return_value = VALID_FEATURE_SOURCE + + fd = FeatureDiscovery( + inference_engine=engine, + n_features=1, + cache=cache, + ) + result = fd.discover(SAMPLE_DF, "target") + + engine.generate.assert_called_once() + assert "sum_a_b" in result.columns + + key = compute_code_key( + columns=["a", "b", "target"], + dtypes=[str(SAMPLE_DF[c].dtype) for c in SAMPLE_DF.columns], + target_col="target", + module_prefix="feature_discovery", + ) + assert cache.get_code(key) is not None + cache.close() + + +class TestDiscoverGracefulFallback: + def test_returns_original_on_exhausted_retries(self) -> None: + engine = MagicMock() + engine.verbose = False + bad_source = "def generate_features(row): raise Exception('bad')" + engine.generate.return_value = bad_source + + fd = FeatureDiscovery( + inference_engine=engine, + n_features=1, + max_retries=2, + ) + result = fd.discover(SAMPLE_DF, "target") + assert list(result.columns) == list(SAMPLE_DF.columns) + assert len(result) == len(SAMPLE_DF) + + +# ------------------------------------------------------------------ +# RelationalShredder — graceful fallback +# ------------------------------------------------------------------ + + +SAMPLE_LOGS_DF = pl.DataFrame({"log": ["2024-01-01 INFO foo", "2024-01-02 WARN bar"]}) + +SAMPLE_SCHEMA = _RelationalSchema( + tables=[ + _TableDef( + name="events", columns=["ts", "level"], primary_key="ts", foreign_key=None + ), + _TableDef( + name="details", columns=["ts", "msg"], primary_key="ts", foreign_key="ts" + ), + ] +) + + +class TestShredGracefulFallback: + def test_returns_empty_on_exhausted_retries(self) -> None: + engine = MagicMock() + engine.verbose = False + bad_source = "def extract_relations(log): raise Exception('bad')" + engine.generate.side_effect = [ + SAMPLE_SCHEMA.model_dump(), + bad_source, + bad_source, + bad_source, + ] + + s = RelationalShredder( + inference_engine=engine, + sample_size=2, + max_retries=2, + ) + result = s.shred(SAMPLE_LOGS_DF, "log") + assert result == {} diff --git a/tests/unit/test_loclean_class.py b/tests/unit/test_loclean_class.py new file mode 100644 index 0000000..315d96e --- /dev/null +++ b/tests/unit/test_loclean_class.py @@ -0,0 +1,113 @@ +"""Tests for _resolve_engine helper and Loclean class wrapper methods.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import polars as pl +import pytest + +import loclean + +# ------------------------------------------------------------------ +# _resolve_engine +# ------------------------------------------------------------------ + + +class TestResolveEngine: + @patch("loclean.get_engine") + def test_defaults_return_singleton(self, mock_get: MagicMock) -> None: + sentinel = MagicMock() + mock_get.return_value = sentinel + result = loclean._resolve_engine() + mock_get.assert_called_once() + assert result is sentinel + + @patch("loclean.OllamaEngine") + def test_custom_args_create_new_client(self, mock_cls: MagicMock) -> None: + sentinel = MagicMock() + mock_cls.return_value = sentinel + result = loclean._resolve_engine(model="llama3", verbose=True) + mock_cls.assert_called_once_with(model="llama3", verbose=True) + assert result is sentinel + + @patch("loclean.OllamaEngine") + def test_extra_kwargs_forwarded(self, mock_cls: MagicMock) -> None: + loclean._resolve_engine(timeout=30) + mock_cls.assert_called_once_with(timeout=30) + + +# ------------------------------------------------------------------ +# Loclean wrapper methods +# ------------------------------------------------------------------ + + +SAMPLE_DF = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + + +class TestLocleanWrappers: + @pytest.fixture + def client(self) -> loclean.Loclean: + with patch("loclean.OllamaEngine") as mock_cls: + engine = MagicMock() + engine.verbose = False + mock_cls.return_value = engine + inst = loclean.Loclean(model="test") + inst.engine = engine + return inst + + @patch("loclean.NarwhalsEngine.process_column") + def test_clean_delegates( + self, mock_proc: MagicMock, client: loclean.Loclean + ) -> None: + mock_proc.return_value = SAMPLE_DF + client.clean(SAMPLE_DF, "a") + mock_proc.assert_called_once() + _, kwargs = mock_proc.call_args + assert kwargs.get("batch_size") == 50 + + @patch("loclean.extraction.resolver.EntityResolver.resolve") + def test_resolve_entities_delegates( + self, mock_resolve: MagicMock, client: loclean.Loclean + ) -> None: + mock_resolve.return_value = SAMPLE_DF + client.resolve_entities(SAMPLE_DF, "a") + mock_resolve.assert_called_once() + + @patch("loclean.extraction.oversampler.SemanticOversampler.oversample") + def test_oversample_delegates( + self, mock_os: MagicMock, client: loclean.Loclean + ) -> None: + mock_os.return_value = SAMPLE_DF + from pydantic import BaseModel + + class DummySchema(BaseModel): + a: int + b: int + + client.oversample(SAMPLE_DF, "a", 1, 5, DummySchema) + mock_os.assert_called_once() + + @patch("loclean.extraction.shredder.RelationalShredder.shred") + def test_shred_delegates( + self, mock_shred: MagicMock, client: loclean.Loclean + ) -> None: + mock_shred.return_value = {"t1": SAMPLE_DF} + client.shred_to_relations(SAMPLE_DF, "a") + mock_shred.assert_called_once() + + @patch("loclean.extraction.feature_discovery.FeatureDiscovery.discover") + def test_discover_features_delegates( + self, mock_disc: MagicMock, client: loclean.Loclean + ) -> None: + mock_disc.return_value = SAMPLE_DF + client.discover_features(SAMPLE_DF, "a") + mock_disc.assert_called_once() + + @patch("loclean.validation.quality_gate.QualityGate.evaluate") + def test_validate_quality_delegates( + self, mock_eval: MagicMock, client: loclean.Loclean + ) -> None: + mock_eval.return_value = {"total_rows": 2, "passed_rows": 2} + client.validate_quality(SAMPLE_DF, ["rule1"]) + mock_eval.assert_called_once() diff --git a/tests/unit/utils/__init__.py b/tests/unit/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/utils/test_sandbox.py b/tests/unit/utils/test_sandbox.py new file mode 100644 index 0000000..bb03e4b --- /dev/null +++ b/tests/unit/utils/test_sandbox.py @@ -0,0 +1,123 @@ +"""Tests for the sandboxed execution utilities.""" + +from __future__ import annotations + +import pytest + +from loclean.utils.sandbox import compile_sandboxed, run_with_timeout + +# ------------------------------------------------------------------ +# compile_sandboxed +# ------------------------------------------------------------------ + + +class TestCompileSandboxed: + def test_safe_builtins_present(self) -> None: + source = "def f():\n return len([1, 2, 3]), int('5'), str(10), dict(a=1)\n" + fn = compile_sandboxed(source, "f") + assert fn() == (3, 5, "10", {"a": 1}) + + def test_range_and_list_work(self) -> None: + source = "def f():\n return list(range(5))\n" + fn = compile_sandboxed(source, "f") + assert fn() == [0, 1, 2, 3, 4] + + def test_open_blocked(self) -> None: + source = "def f():\n return open('/etc/passwd')\n" + fn = compile_sandboxed(source, "f") + with pytest.raises(NameError): + fn() + + def test_exec_blocked(self) -> None: + source = "def f():\n exec('x = 1')\n" + fn = compile_sandboxed(source, "f") + with pytest.raises(NameError): + fn() + + def test_eval_blocked(self) -> None: + source = "def f():\n return eval('1 + 1')\n" + fn = compile_sandboxed(source, "f") + with pytest.raises(NameError): + fn() + + def test_import_blocked(self) -> None: + source = "def f():\n return __import__('os')\n" + fn = compile_sandboxed(source, "f") + with pytest.raises(NameError): + fn() + + def test_import_statement_blocked(self) -> None: + source = "import os\ndef f():\n return os.listdir('.')\n" + with pytest.raises(ValueError, match="Compilation failed"): + compile_sandboxed(source, "f") + + def test_allowed_modules_injected(self) -> None: + source = "def f(x):\n return math.log(x)\n" + fn = compile_sandboxed(source, "f", ["math"]) + assert fn(1.0) == pytest.approx(0.0) + + def test_missing_function_raises(self) -> None: + source = "def wrong_name():\n return 1\n" + with pytest.raises(ValueError, match="does not define 'target'"): + compile_sandboxed(source, "target") + + def test_syntax_error_raises(self) -> None: + source = "def f(\n" + with pytest.raises(ValueError, match="Compilation failed"): + compile_sandboxed(source, "f") + + def test_exception_types_available(self) -> None: + source = ( + "def f(x):\n" + " if x < 0:\n" + " raise ValueError('negative')\n" + " return x\n" + ) + fn = compile_sandboxed(source, "f") + assert fn(5) == 5 + with pytest.raises(ValueError, match="negative"): + fn(-1) + + +# ------------------------------------------------------------------ +# run_with_timeout +# ------------------------------------------------------------------ + + +class TestRunWithTimeout: + def test_success_returns_result(self) -> None: + def add(a: int, b: int) -> int: + return a + b + + result, error = run_with_timeout(add, (2, 3), timeout_s=1.0) + assert result == 5 + assert error == "" + + def test_timeout_returns_error(self) -> None: + import threading + + gate = threading.Event() + + def blocks() -> None: + gate.wait(timeout=10) + + result, error = run_with_timeout(blocks, (), timeout_s=0.1) + assert result is None + assert "timed out" in error + gate.set() + + def test_exception_returns_error(self) -> None: + def bad() -> None: + raise RuntimeError("boom") + + result, error = run_with_timeout(bad, (), timeout_s=1.0) + assert result is None + assert "boom" in error + + def test_result_none_is_valid(self) -> None: + def returns_none() -> None: + return None + + result, error = run_with_timeout(returns_none, (), timeout_s=1.0) + assert result is None + assert error == "" diff --git a/tests/unit/validation/__init__.py b/tests/unit/validation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/validation/test_quality_gate.py b/tests/unit/validation/test_quality_gate.py new file mode 100644 index 0000000..bf14a4d --- /dev/null +++ b/tests/unit/validation/test_quality_gate.py @@ -0,0 +1,201 @@ +"""Test cases for the QualityGate module.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import polars as pl +import pytest + +from loclean.validation.quality_gate import QualityGate, QualityReport, _RowCompliance + +# ------------------------------------------------------------------ +# Fixtures +# ------------------------------------------------------------------ + + +@pytest.fixture +def mock_engine() -> MagicMock: + engine = MagicMock() + engine.verbose = False + return engine + + +@pytest.fixture +def gate(mock_engine: MagicMock) -> QualityGate: + return QualityGate(inference_engine=mock_engine, batch_size=5, sample_size=50) + + +# ------------------------------------------------------------------ +# _RowCompliance schema +# ------------------------------------------------------------------ + + +class TestRowCompliance: + def test_valid(self) -> None: + rc = _RowCompliance(compliant=True, reasoning="All good") + assert rc.compliant is True + assert rc.reasoning == "All good" + + def test_non_compliant(self) -> None: + rc = _RowCompliance(compliant=False, reasoning="Missing value") + assert rc.compliant is False + + +# ------------------------------------------------------------------ +# QualityReport +# ------------------------------------------------------------------ + + +class TestQualityReport: + def test_full_compliance(self) -> None: + r = QualityReport(total_rows=10, passed_rows=10, compliance_rate=1.0) + assert r.compliance_rate == 1.0 + assert r.failures == [] + + def test_partial_compliance(self) -> None: + r = QualityReport( + total_rows=10, + passed_rows=7, + compliance_rate=0.7, + failures=[{"row": {"a": 1}, "reasoning": "bad"}], + ) + assert r.compliance_rate == 0.7 + assert len(r.failures) == 1 + + def test_model_dump(self) -> None: + r = QualityReport(total_rows=2, passed_rows=1, compliance_rate=0.5) + d = r.model_dump() + assert d["total_rows"] == 2 + assert d["compliance_rate"] == 0.5 + + +# ------------------------------------------------------------------ +# _parse_compliance +# ------------------------------------------------------------------ + + +class TestParseCompliance: + def test_parses_dict(self) -> None: + raw = {"compliant": True, "reasoning": "ok"} + result = QualityGate._parse_compliance(raw) + assert result.compliant is True + + def test_parses_json_string(self) -> None: + raw = json.dumps({"compliant": False, "reasoning": "fail"}) + result = QualityGate._parse_compliance(raw) + assert result.compliant is False + assert result.reasoning == "fail" + + def test_repairs_malformed_json(self) -> None: + raw = '{"compliant": true, "reasoning": "ok",}' + result = QualityGate._parse_compliance(raw) + assert result.compliant is True + + def test_defaults_to_non_compliant_on_failure(self) -> None: + result = QualityGate._parse_compliance("totally broken") + assert result.compliant is False + assert "Failed to parse" in result.reasoning + + +# ------------------------------------------------------------------ +# _check_row +# ------------------------------------------------------------------ + + +class TestCheckRow: + def test_compliant_row(self, gate: QualityGate, mock_engine: MagicMock) -> None: + mock_engine.generate.return_value = json.dumps( + {"compliant": True, "reasoning": "All rules satisfied"} + ) + result = gate._check_row( + {"price": 10, "currency": "USD"}, + ["Price must be positive", "Currency must be a valid ISO code"], + ) + assert result.compliant is True + + def test_non_compliant_row(self, gate: QualityGate, mock_engine: MagicMock) -> None: + mock_engine.generate.return_value = json.dumps( + {"compliant": False, "reasoning": "Price is negative"} + ) + result = gate._check_row({"price": -5}, ["Price must be positive"]) + assert result.compliant is False + assert "negative" in result.reasoning + + def test_prompt_contains_rules( + self, gate: QualityGate, mock_engine: MagicMock + ) -> None: + mock_engine.generate.return_value = json.dumps( + {"compliant": True, "reasoning": "ok"} + ) + gate._check_row({"a": 1}, ["Rule A", "Rule B"]) + prompt = mock_engine.generate.call_args[0][0] + assert "Rule A" in prompt + assert "Rule B" in prompt + + +# ------------------------------------------------------------------ +# evaluate (end-to-end with mocks) +# ------------------------------------------------------------------ + + +class TestEvaluate: + def test_all_pass(self, gate: QualityGate, mock_engine: MagicMock) -> None: + df = pl.DataFrame({"val": [1, 2, 3]}) + mock_engine.generate.return_value = json.dumps( + {"compliant": True, "reasoning": "ok"} + ) + report = gate.evaluate(df, ["val must be positive"]) + assert report["total_rows"] == 3 + assert report["passed_rows"] == 3 + assert report["compliance_rate"] == 1.0 + assert report["failures"] == [] + + def test_partial_failure(self, gate: QualityGate, mock_engine: MagicMock) -> None: + df = pl.DataFrame({"val": [1, -1, 2]}) + + responses = [ + json.dumps({"compliant": True, "reasoning": "ok"}), + json.dumps({"compliant": False, "reasoning": "Negative value"}), + json.dumps({"compliant": True, "reasoning": "ok"}), + ] + mock_engine.generate.side_effect = responses + + report = gate.evaluate(df, ["val must be positive"]) + assert report["total_rows"] == 3 + assert report["passed_rows"] == 2 + assert len(report["failures"]) == 1 + assert "Negative" in report["failures"][0]["reasoning"] + + def test_empty_dataframe(self, gate: QualityGate, mock_engine: MagicMock) -> None: + df = pl.DataFrame({"val": []}) + report = gate.evaluate(df, ["some rule"]) + assert report["total_rows"] == 0 + assert report["compliance_rate"] == 1.0 + + def test_empty_rules_raises(self, gate: QualityGate) -> None: + df = pl.DataFrame({"val": [1]}) + with pytest.raises(ValueError, match="At least one rule"): + gate.evaluate(df, []) + + def test_init_rejects_invalid_batch_size(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="batch_size"): + QualityGate(inference_engine=mock_engine, batch_size=0) + + def test_init_rejects_invalid_sample_size(self, mock_engine: MagicMock) -> None: + with pytest.raises(ValueError, match="sample_size"): + QualityGate(inference_engine=mock_engine, sample_size=0) + + def test_sampling_limits_rows(self, mock_engine: MagicMock) -> None: + gate = QualityGate( + inference_engine=mock_engine, + batch_size=5, + sample_size=3, + ) + df = pl.DataFrame({"val": list(range(100))}) + mock_engine.generate.return_value = json.dumps( + {"compliant": True, "reasoning": "ok"} + ) + report = gate.evaluate(df, ["val must exist"]) + assert report["total_rows"] == 3