diff --git a/README.md b/README.md index 2166f23..b8e23af 100644 --- a/README.md +++ b/README.md @@ -197,6 +197,93 @@ results = refract.search(query, docs, router=MyRouter()) --- +## Train a learned router from relevance feedback + +Use your own judged queries to learn when each metric should matter more: + +```python +from refract.routing import LearnedRouter + +queries = [ + "how to sort a list in Python", + "neural network architecture", + "vector similarity embedding", +] +relevance = { + 0: {0, 16}, + 1: {1, 8, 15}, + 2: {3, 11, 19}, +} + +router = LearnedRouter(["cosine", "bm25", "mahalanobis", "euclidean"]) +report = router.fit_from_relevance( + queries=queries, + corpus=docs, + relevance=relevance, + top_k=5, +) + +print(report) +print(report.metric_quality) +``` + +`fit_from_relevance()` automatically: + +- Builds query + space features +- Measures how well each metric ranks the relevant documents +- Converts those per-query metric scores into target routing weights +- Trains a small gating network to predict those weights later + +--- + +## Use a trained router + +```python +from refract.routing import LearnedRouter + +router.save("learned_router.pkl") +trained_router = LearnedRouter.load("learned_router.pkl") + +results = refract.search( + "how do I sort things in Python", + docs, + router=trained_router, +) +``` + +--- + +## Evaluate learning + +You can evaluate the learned router directly, then benchmark it against heuristic routing: + +```python +evaluation = trained_router.evaluate_from_relevance( + queries=queries, + corpus=docs, + relevance=relevance, + top_k=5, +) +print(evaluation.router_ndcg_at_k, evaluation.oracle_ndcg_at_k) +``` + +```python +from refract.benchmark import BenchmarkHarness, CustomDataset + +dataset = CustomDataset( + name="my_eval", + queries=queries, + corpus=docs, + relevance=relevance, +) + +harness = BenchmarkHarness() +heuristic = harness.run(dataset, compare_cosine_baseline=False)[0] +learned = harness.run(dataset, router=trained_router, compare_cosine_baseline=False)[0] +``` + +--- + ## Use as a RAG retrieval step ```python @@ -240,7 +327,7 @@ for r in results: | Mode | When to use | Training required | |---|---|---| | `HeuristicRouter` (default) | Always -- good out of the box | No | -| `LearnedRouter` | When you have relevance feedback data | Yes | +| `LearnedRouter` | When you have relevance feedback data and want adaptive routing | Yes | | `CompositeRouter` | Blend multiple routers | Depends | | `BaseRouter` subclass | Full custom control | You decide | @@ -292,6 +379,8 @@ src/refract/ | [`custom_metric.py`](examples/custom_metric.py) | Plug in your own metric | | [`compare_cosine.py`](examples/compare_cosine.py) | Side-by-side vs vanilla cosine | | [`benchmark_demo.py`](examples/benchmark_demo.py) | Evaluation harness demo | +| [`train_learned_router.py`](examples/train_learned_router.py) | Train a learned router from judged queries | +| [`evaluate_learned_router.py`](examples/evaluate_learned_router.py) | Compare heuristic vs learned routing | | [`vector_db_integration.py`](examples/vector_db_integration.py) | FAISS/Qdrant integration pattern | --- @@ -304,8 +393,8 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup and guidelines. ## Roadmap -- `0.1.0` -- Core API, heuristic router, cosine / euclidean / mahalanobis / BM25 **(you are here)** -- `0.2.0` -- Learned router with PyTorch gating network +- `0.1.0` -- Core API, heuristic router, cosine / euclidean / mahalanobis / BM25 +- `0.2.0` -- Learned router with relevance-driven training and evaluation **(you are here)** - `0.3.0` -- BEIR benchmark harness with published results - `1.0.0` -- Stable API, comprehensive benchmarks, documentation site diff --git a/examples/evaluate_learned_router.py b/examples/evaluate_learned_router.py new file mode 100644 index 0000000..6a2bf99 --- /dev/null +++ b/examples/evaluate_learned_router.py @@ -0,0 +1,75 @@ +"""Evaluate heuristic vs learned routing on a simple train/test split. + +Run with: python examples/evaluate_learned_router.py +""" + +from __future__ import annotations + +import json +from pathlib import Path + +from refract.benchmark import BenchmarkHarness, CustomDataset +from refract.routing import LearnedRouter + + +def load_sample_dataset() -> tuple[list[str], list[str], dict[int, set[int]]]: + sample_path = Path(__file__).parent.parent / "samples" / "mini_corpus.json" + with open(sample_path) as f: + payload = json.load(f) + + corpus = payload["documents"] + queries = [item["query"] for item in payload["sample_queries"]] + relevance = { + idx: set(item["expected_relevant"]) for idx, item in enumerate(payload["sample_queries"]) + } + return corpus, queries, relevance + + +if __name__ == "__main__": + corpus, queries, relevance = load_sample_dataset() + train_queries = queries[:3] + train_relevance = {idx: relevance[idx] for idx in range(3)} + + test_queries = queries[3:] + test_relevance = {idx: relevance[idx + 3] for idx in range(len(test_queries))} + + router = LearnedRouter(["cosine", "bm25", "mahalanobis", "euclidean"]) + training_report = router.fit_from_relevance(train_queries, corpus, train_relevance, top_k=5) + + test_dataset = CustomDataset( + name="mini_corpus_test_split", + queries=test_queries, + corpus=corpus, + relevance=test_relevance, + ) + + harness = BenchmarkHarness() + heuristic_result = harness.run(test_dataset, compare_cosine_baseline=False)[0] + learned_result = harness.run( + test_dataset, + router=router, + compare_cosine_baseline=False, + )[0] + + print("=" * 72) + print("Training Report") + print("=" * 72) + print(training_report) + print() + print("=" * 72) + print("Held-Out Evaluation") + print("=" * 72) + print(f"{'Router':<12s} {'NDCG@10':>8s} {'Recall@10':>10s} {'MRR':>8s}") + print("-" * 72) + print( + f"{'heuristic':<12s} " + f"{heuristic_result.ndcg_at_10:>8.3f} " + f"{heuristic_result.recall_at_10:>10.3f} " + f"{heuristic_result.mrr_score:>8.3f}" + ) + print( + f"{'learned':<12s} " + f"{learned_result.ndcg_at_10:>8.3f} " + f"{learned_result.recall_at_10:>10.3f} " + f"{learned_result.mrr_score:>8.3f}" + ) diff --git a/examples/train_learned_router.py b/examples/train_learned_router.py new file mode 100644 index 0000000..c4cba4d --- /dev/null +++ b/examples/train_learned_router.py @@ -0,0 +1,56 @@ +"""Train a learned router from relevance feedback and use it for search. + +Run with: python examples/train_learned_router.py +""" + +from __future__ import annotations + +import json +from pathlib import Path +from tempfile import TemporaryDirectory + +import refract +from refract.routing import LearnedRouter + + +def load_sample_dataset() -> tuple[list[str], list[str], dict[int, set[int]]]: + sample_path = Path(__file__).parent.parent / "samples" / "mini_corpus.json" + with open(sample_path) as f: + payload = json.load(f) + + corpus = payload["documents"] + queries = [item["query"] for item in payload["sample_queries"]] + relevance = { + idx: set(item["expected_relevant"]) for idx, item in enumerate(payload["sample_queries"]) + } + return corpus, queries, relevance + + +if __name__ == "__main__": + docs, queries, relevance = load_sample_dataset() + + router = LearnedRouter(["cosine", "bm25", "mahalanobis", "euclidean"]) + evaluation = router.fit_from_relevance(queries, docs, relevance, top_k=5) + + print("=" * 72) + print("Learned Router Training") + print("=" * 72) + print(evaluation) + print("Per-metric quality:", evaluation.metric_quality) + print() + + with TemporaryDirectory() as tmp_dir: + router_path = Path(tmp_dir) / "learned_router.pkl" + router.save(str(router_path)) + loaded_router = LearnedRouter.load(str(router_path)) + + results = refract.search( + "how can I sort a list in Python", + docs, + router=loaded_router, + top_k=3, + ) + + print("Top results with the trained router:") + for result in results: + print(f" {result.score:.3f} {result.text}") diff --git a/pyproject.toml b/pyproject.toml index 0337474..c5cdeef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ sentence-transformers = ["sentence-transformers>=2.6"] openai = ["openai>=1.0"] cohere = ["cohere>=5.0"] -learned = ["torch>=2.0"] +learned = [] benchmark = ["datasets>=2.14"] dev = [ "pytest>=7.0", diff --git a/src/refract/__init__.py b/src/refract/__init__.py index 1f330fa..8980667 100644 --- a/src/refract/__init__.py +++ b/src/refract/__init__.py @@ -34,6 +34,7 @@ from refract.routing.base import BaseRouter from refract.routing.composite import CompositeRouter from refract.routing.heuristic import HeuristicRouter +from refract.routing.learned import LearnedRouter, LearnedRouterEvaluation from refract.search import search, search_batch from refract.types import ( MetricScore, @@ -59,6 +60,8 @@ "CosineMetric", "EuclideanMetric", "HeuristicRouter", + "LearnedRouter", + "LearnedRouterEvaluation", "MahalanobisMetric", "MetricRegistry", "MetricScore", diff --git a/src/refract/routing/__init__.py b/src/refract/routing/__init__.py index 93823aa..738f917 100644 --- a/src/refract/routing/__init__.py +++ b/src/refract/routing/__init__.py @@ -3,9 +3,12 @@ from refract.routing.base import BaseRouter from refract.routing.composite import CompositeRouter from refract.routing.heuristic import HeuristicRouter +from refract.routing.learned import LearnedRouter, LearnedRouterEvaluation __all__ = [ "BaseRouter", "CompositeRouter", "HeuristicRouter", + "LearnedRouter", + "LearnedRouterEvaluation", ] diff --git a/src/refract/routing/learned.py b/src/refract/routing/learned.py index 60bd4bb..dc94b0b 100644 --- a/src/refract/routing/learned.py +++ b/src/refract/routing/learned.py @@ -1,58 +1,39 @@ -"""Learned router — small gating network for metric weight prediction. - -Uses a lightweight MLP to predict metric weights from query and space -features. Requires PyTorch (optional dependency). -""" +"""Learned router — train metric routing from relevance feedback.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any +import dataclasses +import pickle +from typing import TYPE_CHECKING, cast import numpy as np - +from sklearn.neural_network import MLPRegressor +from sklearn.preprocessing import StandardScaler + +from refract.analysis.query_analyzer import analyze_query +from refract.analysis.space_analyzer import analyze_space +from refract.benchmark.eval_metrics import ndcg_at_k +from refract.metrics.bm25 import BM25Metric +from refract.metrics.cosine import CosineMetric from refract.routing.base import BaseRouter +from refract.search import _build_tfidf_vectors, _resolve_metrics if TYPE_CHECKING: - from refract.types import QueryProfile, SpaceProfile - + from collections.abc import Sequence -def _try_import_torch() -> Any: - """Lazily import torch with a helpful error message.""" - try: - import torch + from numpy.typing import NDArray - return torch - except ImportError: - raise ImportError( - "LearnedRouter requires PyTorch. Install it with:\n" - " pip install 'refract-search[learned]'\n" - " # or\n" - " pip install torch" - ) from None + from refract.embedders.base import BaseEmbedder + from refract.metrics.base import BaseMetric + from refract.types import QueryProfile, SpaceProfile def _profile_to_features(query_profile: QueryProfile, space_profile: SpaceProfile) -> list[float]: - """Convert profiles to a flat feature vector. - - Features: - - query: token_count, embedding_norm, entropy, type one-hot (4) - - space: n_candidates (log), variance, anisotropy (log), score_spread, density one-hot (3) - - Total: 13 features. - - Args: - query_profile: Query analysis result. - space_profile: Space analysis result. - - Returns: - List of 13 float features. - """ - # Query type one-hot + """Convert profiles into a stable feature vector.""" type_map = {"keyword": 0, "natural_language": 1, "code": 2, "structured": 3} type_idx = type_map.get(query_profile.query_type, 1) type_onehot = [1.0 if i == type_idx else 0.0 for i in range(4)] - # Density one-hot density_map = {"sparse": 0, "medium": 1, "dense": 2} density_idx = density_map.get(space_profile.density, 1) density_onehot = [1.0 if i == density_idx else 0.0 for i in range(3)] @@ -70,19 +51,47 @@ def _profile_to_features(query_profile: QueryProfile, space_profile: SpaceProfil ] -class LearnedRouter(BaseRouter): - """Learned metric weight router using a small MLP. +def _normalize_weights(weights: Sequence[float]) -> NDArray[np.float64]: + """Project a raw vector onto the probability simplex.""" + arr = cast("NDArray[np.float64]", np.asarray(weights, dtype=np.float64)) + arr = cast("NDArray[np.float64]", np.clip(arr, 0.0, None)) + total = float(arr.sum()) + if total < 1e-12: + return cast( + "NDArray[np.float64]", + np.asarray(np.ones_like(arr) / max(len(arr), 1), dtype=np.float64), + ) + return cast("NDArray[np.float64]", np.asarray(arr / total, dtype=np.float64)) + + +@dataclasses.dataclass(frozen=True) +class LearnedRouterEvaluation: + """Summary of learned-router quality on a labeled dataset.""" + + n_queries: int + top_k: int + weight_mae: float + router_ndcg_at_k: float + oracle_ndcg_at_k: float + metric_quality: dict[str, float] + + def __repr__(self) -> str: + return ( + "LearnedRouterEvaluation(" + f"n_queries={self.n_queries}, " + f"weight_mae={self.weight_mae:.3f}, " + f"router_ndcg@{self.top_k}={self.router_ndcg_at_k:.3f}, " + f"oracle_ndcg@{self.top_k}={self.oracle_ndcg_at_k:.3f})" + ) - A gating network ``f_θ(query_features, space_features) → weights`` - trained on relevance feedback data. The output is passed through - softmax to produce a valid weight distribution. - Requires PyTorch. Install with ``pip install 'refract-search[learned]'``. +class LearnedRouter(BaseRouter): + """Trainable router that predicts metric weights from query/space features. - Example: - >>> router = LearnedRouter(["cosine", "bm25", "mahalanobis"]) - >>> router.fit(query_profiles, space_profiles, relevance_labels) - >>> weights = router.route(query_profile, space_profile, metrics) + The router learns against per-query target weight distributions. In the + common case you do not need to construct those targets yourself: + ``fit_from_relevance()`` derives them directly from relevance judgments by + measuring how well each metric ranks the labeled documents. """ name = "learned" @@ -90,28 +99,22 @@ class LearnedRouter(BaseRouter): def __init__( self, metric_names: list[str], - hidden_size: int = 32, + hidden_size: int = 24, + random_state: int = 42, ) -> None: - """Initialize learned router. + if not metric_names: + raise ValueError("LearnedRouter requires at least one metric name.") - Args: - metric_names: List of metric names this router produces weights for. - hidden_size: Hidden layer size in the MLP. - """ - self.torch = _try_import_torch() self.metric_names = list(metric_names) self.hidden_size = hidden_size - - n_features = 13 # from _profile_to_features - n_outputs = len(metric_names) - - self._model = self.torch.nn.Sequential( - self.torch.nn.Linear(n_features, hidden_size), - self.torch.nn.ReLU(), - self.torch.nn.Dropout(0.1), - self.torch.nn.Linear(hidden_size, hidden_size), - self.torch.nn.ReLU(), - self.torch.nn.Linear(hidden_size, n_outputs), + self.random_state = random_state + self._scaler = StandardScaler() + self._model = MLPRegressor( + hidden_layer_sizes=(hidden_size, hidden_size), + activation="relu", + solver="lbfgs", + random_state=random_state, + max_iter=500, ) self._trained = False @@ -119,54 +122,145 @@ def fit( self, query_profiles: list[QueryProfile], space_profiles: list[SpaceProfile], - relevance_labels: list[list[int]], - epochs: int = 50, - learning_rate: float = 1e-3, - ) -> dict[str, list[float]]: - """Train the gating network on relevance feedback data. - - Args: - query_profiles: List of query profiles. - space_profiles: List of space profiles. - relevance_labels: For each query, list of relevant doc indices. - epochs: Number of training epochs. - learning_rate: Adam learning rate. - - Returns: - Training history with per-epoch loss values. - """ - torch = self.torch - - # Build feature matrix - features = [] - for qp, sp in zip(query_profiles, space_profiles): - features.append(_profile_to_features(qp, sp)) - - x_features = torch.tensor(features, dtype=torch.float32) - - # Target: uniform weights as starting point - # In practice, you'd derive optimal weights from relevance data - n = len(self.metric_names) - targets = torch.ones(len(features), n, dtype=torch.float32) / n - - optimizer = torch.optim.Adam(self._model.parameters(), lr=learning_rate) - criterion = torch.nn.KLDivLoss(reduction="batchmean") - - history: dict[str, list[float]] = {"loss": []} - self._model.train() - - for _epoch in range(epochs): - optimizer.zero_grad() - logits = self._model(x_features) - log_probs = torch.nn.functional.log_softmax(logits, dim=-1) - loss = criterion(log_probs, targets) - loss.backward() - optimizer.step() - history["loss"].append(float(loss.item())) + target_weights: list[dict[str, float]], + ) -> LearnedRouter: + """Fit the router from explicit target weight distributions.""" + if not query_profiles or not space_profiles or not target_weights: + raise ValueError( + "fit() requires non-empty query_profiles, space_profiles, and target_weights." + ) + if not (len(query_profiles) == len(space_profiles) == len(target_weights)): + raise ValueError("fit() inputs must have the same length.") + + feature_rows = [ + _profile_to_features(qp, sp) for qp, sp in zip(query_profiles, space_profiles) + ] + targets = np.vstack( + [ + _normalize_weights([target.get(metric, 0.0) for metric in self.metric_names]) + for target in target_weights + ] + ) + scaled = self._scaler.fit_transform(np.asarray(feature_rows, dtype=np.float64)) + self._model.fit(scaled, targets) self._trained = True - self._model.eval() - return history + return self + + def fit_from_relevance( + self, + queries: list[str] | np.ndarray, + corpus: Sequence[str] | np.ndarray, + relevance: dict[int, set[int]], + *, + embedder: BaseEmbedder | None = None, + top_k: int = 10, + target_temperature: float = 0.15, + ) -> LearnedRouterEvaluation: + """Fit the router directly from relevance judgments.""" + query_ids, query_profiles, space_profiles, targets, metric_quality, oracle_ndcg = ( + self._derive_training_targets( + queries=queries, + corpus=corpus, + relevance=relevance, + embedder=embedder, + top_k=top_k, + target_temperature=target_temperature, + ) + ) + self.fit(query_profiles, space_profiles, targets) + return self.evaluate_from_relevance( + queries=queries, + corpus=corpus, + relevance=relevance, + embedder=embedder, + top_k=top_k, + query_ids=query_ids, + expected_targets=targets, + metric_quality=metric_quality, + oracle_ndcg_at_k=oracle_ndcg, + ) + + def evaluate_from_relevance( + self, + queries: list[str] | np.ndarray, + corpus: Sequence[str] | np.ndarray, + relevance: dict[int, set[int]], + *, + embedder: BaseEmbedder | None = None, + top_k: int = 10, + query_ids: list[int] | None = None, + expected_targets: list[dict[str, float]] | None = None, + metric_quality: dict[str, float] | None = None, + oracle_ndcg_at_k: float | None = None, + ) -> LearnedRouterEvaluation: + """Evaluate a trained router on labeled data.""" + if not self._trained: + raise RuntimeError( + "LearnedRouter has not been trained. Call .fit() or .fit_from_relevance() first." + ) + + _, _, all_query_profiles, all_space_profiles, metric_scores = ( + self._prepare_relevance_problem( + queries=queries, + corpus=corpus, + embedder=embedder, + ) + ) + + if expected_targets is None or metric_quality is None or oracle_ndcg_at_k is None: + ( + query_ids, + query_profiles, + space_profiles, + expected_targets, + metric_quality, + oracle_ndcg_at_k, + ) = self._derive_training_targets( + queries=queries, + corpus=corpus, + relevance=relevance, + embedder=embedder, + top_k=top_k, + ) + else: + query_ids = query_ids or sorted(idx for idx, docs in relevance.items() if docs) + query_profiles = [all_query_profiles[idx] for idx in query_ids] + space_profiles = [all_space_profiles[idx] for idx in query_ids] + + errors: list[float] = [] + router_ndcgs: list[float] = [] + + for query_id, query_profile, space_profile, target in zip( + query_ids, query_profiles, space_profiles, expected_targets + ): + relevant = relevance.get(query_id, set()) + if not relevant: + continue + + predicted = self.route(query_profile, space_profile, self.metric_names) + predicted_arr = np.array([predicted.get(metric, 0.0) for metric in self.metric_names]) + target_arr = np.array([target.get(metric, 0.0) for metric in self.metric_names]) + errors.append(float(np.abs(predicted_arr - target_arr).mean())) + + weighted_scores = np.zeros_like(metric_scores[self.metric_names[0]][query_id]) + for metric_name, scores in metric_scores.items(): + weighted_scores = weighted_scores + ( + scores[query_id] * predicted.get(metric_name, 0.0) + ) + + ranking = np.argsort(-weighted_scores)[:top_k].tolist() + router_ndcgs.append(ndcg_at_k(ranking, relevant, top_k)) + + n_queries = len(errors) + return LearnedRouterEvaluation( + n_queries=n_queries, + top_k=top_k, + weight_mae=float(np.mean(errors)) if errors else 0.0, + router_ndcg_at_k=float(np.mean(router_ndcgs)) if router_ndcgs else 0.0, + oracle_ndcg_at_k=oracle_ndcg_at_k, + metric_quality=metric_quality, + ) def route( self, @@ -174,79 +268,257 @@ def route( space_profile: SpaceProfile, available_metrics: list[str], ) -> dict[str, float]: - """Predict metric weights using the trained model. - - Args: - query_profile: Analyzed query characteristics. - space_profile: Analyzed search space geometry. - available_metrics: List of available metric names. - - Returns: - Dictionary mapping metric name to weight (sum = 1.0). - - Raises: - RuntimeError: If the model has not been trained. - """ if not self._trained: raise RuntimeError( "LearnedRouter has not been trained. Call .fit() first, " - "or use HeuristicRouter (the default) which requires no training." + "or use HeuristicRouter when you do not have relevance data." ) - torch = self.torch - - features = _profile_to_features(query_profile, space_profile) - x = torch.tensor([features], dtype=torch.float32) - - with torch.no_grad(): - logits = self._model(x) - probs = torch.nn.functional.softmax(logits, dim=-1) - weights_arr = probs[0].numpy() - - # Map to available metrics - result: dict[str, float] = {} - for i, name in enumerate(self.metric_names): - if name in available_metrics: - result[name] = float(weights_arr[i]) - # Normalize - total = sum(result.values()) - if total > 1e-12: - result = {k: v / total for k, v in result.items()} - - return result + features = np.asarray( + [_profile_to_features(query_profile, space_profile)], dtype=np.float64 + ) + scaled = self._scaler.transform(features) + prediction = self._model.predict(scaled)[0] + normalized = _normalize_weights(prediction) + + weights = { + metric_name: float(normalized[idx]) + for idx, metric_name in enumerate(self.metric_names) + if metric_name in available_metrics + } + total = sum(weights.values()) + if total < 1e-12: + n = len(available_metrics) + return {metric: 1.0 / n for metric in available_metrics} if n > 0 else {} + return {metric: value / total for metric, value in weights.items()} def save(self, path: str) -> None: - """Save the trained model to disk. - - Args: - path: File path to save the model (e.g., "router.pt"). - """ - torch = self.torch - save_data = { - "model_state": self._model.state_dict(), - "metric_names": self.metric_names, - "hidden_size": self.hidden_size, - "trained": self._trained, - } - torch.save(save_data, path) + """Persist the trained router to disk.""" + with open(path, "wb") as f: + pickle.dump( + { + "metric_names": self.metric_names, + "hidden_size": self.hidden_size, + "random_state": self.random_state, + "scaler": self._scaler, + "model": self._model, + "trained": self._trained, + }, + f, + ) @classmethod def load(cls, path: str) -> LearnedRouter: - """Load a trained model from disk. - - Args: - path: File path to the saved model. + """Load a trained router from disk.""" + with open(path, "rb") as f: + data = pickle.load(f) - Returns: - A LearnedRouter with the loaded model. - """ - torch = _try_import_torch() - data = torch.load(path, weights_only=False) router = cls( - metric_names=data["metric_names"], - hidden_size=data["hidden_size"], + metric_names=list(data["metric_names"]), + hidden_size=int(data["hidden_size"]), + random_state=int(data["random_state"]), ) - router._model.load_state_dict(data["model_state"]) - router._trained = data["trained"] - router._model.eval() + router._scaler = data["scaler"] + router._model = data["model"] + router._trained = bool(data["trained"]) return router + + def _derive_training_targets( + self, + *, + queries: list[str] | np.ndarray, + corpus: Sequence[str] | np.ndarray, + relevance: dict[int, set[int]], + embedder: BaseEmbedder | None = None, + top_k: int = 10, + target_temperature: float = 0.15, + ) -> tuple[ + list[int], + list[QueryProfile], + list[SpaceProfile], + list[dict[str, float]], + dict[str, float], + float, + ]: + ( + _query_texts, + _corpus_texts, + query_profiles, + space_profiles, + metric_scores, + ) = self._prepare_relevance_problem( + queries=queries, + corpus=corpus, + embedder=embedder, + ) + + labeled_ids: list[int] = [] + labeled_profiles: list[QueryProfile] = [] + labeled_spaces: list[SpaceProfile] = [] + targets: list[dict[str, float]] = [] + oracle_ndcgs: list[float] = [] + quality_by_metric: dict[str, list[float]] = {name: [] for name in self.metric_names} + + for idx, query_profile in enumerate(query_profiles): + relevant = relevance.get(idx, set()) + if not relevant: + continue + labeled_ids.append(idx) + labeled_profiles.append(query_profile) + labeled_spaces.append(space_profiles[idx]) + + utilities: list[float] = [] + for metric_name in self.metric_names: + scores = metric_scores[metric_name][idx] + ranking = np.argsort(-scores)[:top_k].tolist() + utility = ndcg_at_k(ranking, relevant, top_k) + utilities.append(utility) + quality_by_metric[metric_name].append(utility) + + targets_arr = self._utilities_to_targets(utilities, temperature=target_temperature) + target = { + metric_name: float(targets_arr[m_idx]) + for m_idx, metric_name in enumerate(self.metric_names) + } + targets.append(target) + + oracle_scores = np.zeros_like(metric_scores[self.metric_names[0]][idx]) + for metric_name, weight in target.items(): + oracle_scores = oracle_scores + (metric_scores[metric_name][idx] * weight) + oracle_ranking = np.argsort(-oracle_scores)[:top_k].tolist() + oracle_ndcgs.append(ndcg_at_k(oracle_ranking, relevant, top_k)) + + metric_quality = { + metric_name: float(np.mean(values)) if values else 0.0 + for metric_name, values in quality_by_metric.items() + } + oracle_ndcg = float(np.mean(oracle_ndcgs)) if oracle_ndcgs else 0.0 + return labeled_ids, labeled_profiles, labeled_spaces, targets, metric_quality, oracle_ndcg + + def _prepare_relevance_problem( + self, + *, + queries: list[str] | np.ndarray, + corpus: Sequence[str] | np.ndarray, + embedder: BaseEmbedder | None = None, + ) -> tuple[ + list[str] | None, + list[str] | None, + list[QueryProfile], + list[SpaceProfile], + dict[str, np.ndarray], + ]: + query_texts: list[str] | None = None + query_vecs: np.ndarray + corpus_texts: list[str] | None = None + corpus_vecs: np.ndarray + + if isinstance(corpus, np.ndarray): + corpus_vecs = corpus.astype(np.float64) + else: + corpus_texts = list(corpus) + if embedder is not None: + corpus_vecs = embedder.embed(corpus_texts).astype(np.float64) + else: + if isinstance(queries, np.ndarray): + raise ValueError( + "Vector queries cannot be trained against a text corpus without an embedder." + ) + else: + all_texts = corpus_texts + list(queries) + all_vecs = _build_tfidf_vectors(all_texts) + corpus_vecs = all_vecs[: len(corpus_texts)] + query_vecs = all_vecs[len(corpus_texts) :] + + if isinstance(queries, np.ndarray): + query_vecs = queries.astype(np.float64) + else: + query_texts = list(queries) + if embedder is not None: + query_vecs = embedder.embed(query_texts).astype(np.float64) + elif corpus_texts is None: + raise ValueError( + "Text queries cannot be trained against a vector-only corpus without an embedder." + ) + + metric_instances = _resolve_metrics( + metrics_arg=[*self.metric_names], + corpus_texts=corpus_texts, + corpus_vecs=corpus_vecs, + ) + missing_metrics = [name for name in self.metric_names if name not in metric_instances] + if missing_metrics: + raise ValueError( + f"Unable to prepare metrics for {missing_metrics}. " + "Text metrics like BM25 require a text corpus." + ) + + query_profiles: list[QueryProfile] = [] + space_profiles: list[SpaceProfile] = [] + metric_scores: dict[str, list[np.ndarray]] = {name: [] for name in self.metric_names} + cosine_metric = CosineMetric() + + for idx in range(len(query_vecs)): + query_text = query_texts[idx] if query_texts is not None else None + query_vec = query_vecs[idx] + cosine_scores = cosine_metric.batch_score(query_vec, corpus_vecs) + + query_profile = analyze_query( + query_text=query_text, + query_vector=query_vec, + candidate_vectors=corpus_vecs, + candidate_scores=cosine_scores, + ) + space_profile = analyze_space(corpus_vecs, cosine_scores) + query_profiles.append(query_profile) + space_profiles.append(space_profile) + + for metric_name, metric in metric_instances.items(): + metric_scores[metric_name].append( + self._score_metric( + metric=metric, + query_text=query_text, + query_vector=query_vec, + corpus_vectors=corpus_vecs, + ) + ) + + stacked_scores = { + metric_name: np.stack(score_list, axis=0) + for metric_name, score_list in metric_scores.items() + } + return query_texts, corpus_texts, query_profiles, space_profiles, stacked_scores + + def _score_metric( + self, + *, + metric: BaseMetric, + query_text: str | None, + query_vector: np.ndarray, + corpus_vectors: np.ndarray, + ) -> np.ndarray: + if metric.is_text_metric and isinstance(metric, BM25Metric): + if query_text is None: + return np.zeros(len(corpus_vectors), dtype=np.float64) + return metric.batch_score_text(query_text) + return metric.batch_score(query_vector, corpus_vectors) + + def _utilities_to_targets( + self, + utilities: Sequence[float], + *, + temperature: float, + ) -> NDArray[np.float64]: + raw = cast("NDArray[np.float64]", np.asarray(utilities, dtype=np.float64)) + if np.all(raw <= 1e-12): + return cast( + "NDArray[np.float64]", + np.asarray( + np.ones(len(raw), dtype=np.float64) / max(len(raw), 1), dtype=np.float64 + ), + ) + + scaled = raw / max(temperature, 1e-6) + scaled = scaled - float(np.max(scaled)) + exp_values = np.exp(scaled) + return _normalize_weights(exp_values) diff --git a/tests/unit/test_learned_router.py b/tests/unit/test_learned_router.py new file mode 100644 index 0000000..b955b26 --- /dev/null +++ b/tests/unit/test_learned_router.py @@ -0,0 +1,101 @@ +"""Tests for the learned router.""" + +from __future__ import annotations + +import numpy as np + +import refract +from refract.routing.learned import LearnedRouter +from refract.types import QueryProfile, SpaceProfile + + +def _query_profile(query_type: str, token_count: int, entropy: float) -> QueryProfile: + return QueryProfile( + raw="sample query", + vector=np.ones(8, dtype=np.float64), + query_type=query_type, # type: ignore[arg-type] + token_count=token_count, + embedding_norm=float(np.sqrt(8.0)), + entropy=entropy, + ) + + +def _space_profile(density: str, variance: float, score_spread: float) -> SpaceProfile: + return SpaceProfile( + n_candidates=32, + embedding_dim=8, + variance=variance, + anisotropy=3.5, + density=density, # type: ignore[arg-type] + score_spread=score_spread, + ) + + +class TestLearnedRouter: + def test_fit_route_and_save_load(self, tmp_path) -> None: + router = LearnedRouter(["cosine", "bm25"]) + query_profiles = [ + _query_profile("keyword", 2, 0.6), + _query_profile("natural_language", 6, 2.4), + _query_profile("keyword", 3, 0.8), + _query_profile("natural_language", 8, 2.7), + ] + space_profiles = [ + _space_profile("sparse", 0.25, 0.20), + _space_profile("dense", 0.03, 0.01), + _space_profile("sparse", 0.22, 0.18), + _space_profile("dense", 0.04, 0.02), + ] + targets = [ + {"cosine": 0.1, "bm25": 0.9}, + {"cosine": 0.85, "bm25": 0.15}, + {"cosine": 0.2, "bm25": 0.8}, + {"cosine": 0.9, "bm25": 0.1}, + ] + + router.fit(query_profiles, space_profiles, targets) + keyword_weights = router.route(query_profiles[0], space_profiles[0], ["cosine", "bm25"]) + natural_weights = router.route(query_profiles[1], space_profiles[1], ["cosine", "bm25"]) + + assert abs(sum(keyword_weights.values()) - 1.0) < 1e-6 + assert keyword_weights["bm25"] > keyword_weights["cosine"] + assert natural_weights["cosine"] > natural_weights["bm25"] + + path = tmp_path / "router.pkl" + router.save(str(path)) + restored = LearnedRouter.load(str(path)) + restored_weights = restored.route(query_profiles[0], space_profiles[0], ["cosine", "bm25"]) + assert restored_weights["bm25"] > restored_weights["cosine"] + + def test_fit_from_relevance_and_use_in_search(self) -> None: + docs = [ + "Sort a Python list with the sorted built-in.", + "Python list comprehensions create lists concisely.", + "Neural networks learn with backpropagation.", + "Transformer architectures power modern deep learning.", + "JSON records store structured key value pairs.", + "Tabular data can be queried with structured filters.", + ] + queries = [ + "python sort list", + "neural network training", + '{"city": "Delhi"}', + ] + relevance = { + 0: {0, 1}, + 1: {2, 3}, + 2: {4, 5}, + } + + router = LearnedRouter(["cosine", "bm25", "euclidean"]) + evaluation = router.fit_from_relevance(queries, docs, relevance, top_k=3) + + assert evaluation.n_queries == 3 + assert 0.0 <= evaluation.weight_mae <= 1.0 + assert 0.0 <= evaluation.router_ndcg_at_k <= 1.0 + assert set(evaluation.metric_quality) == {"cosine", "bm25", "euclidean"} + + results = refract.search("sort a list in python", docs, router=router, top_k=2) + assert results[0].text is not None + assert "Python" in results[0].text + assert results[0].provenance.router_name == "learned"