Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 92 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,93 @@ results = refract.search(query, docs, router=MyRouter())

---

## Train a learned router from relevance feedback

Use your own judged queries to learn when each metric should matter more:

```python
from refract.routing import LearnedRouter

queries = [
"how to sort a list in Python",
"neural network architecture",
"vector similarity embedding",
]
relevance = {
0: {0, 16},
1: {1, 8, 15},
2: {3, 11, 19},
}

router = LearnedRouter(["cosine", "bm25", "mahalanobis", "euclidean"])
report = router.fit_from_relevance(
queries=queries,
corpus=docs,
relevance=relevance,
top_k=5,
)

print(report)
print(report.metric_quality)
```

`fit_from_relevance()` automatically:

- Builds query + space features
- Measures how well each metric ranks the relevant documents
- Converts those per-query metric scores into target routing weights
- Trains a small gating network to predict those weights later

---

## Use a trained router

```python
from refract.routing import LearnedRouter

router.save("learned_router.pkl")
trained_router = LearnedRouter.load("learned_router.pkl")

results = refract.search(
"how do I sort things in Python",
docs,
router=trained_router,
)
```

---

## Evaluate learning

You can evaluate the learned router directly, then benchmark it against heuristic routing:

```python
evaluation = trained_router.evaluate_from_relevance(
queries=queries,
corpus=docs,
relevance=relevance,
top_k=5,
)
print(evaluation.router_ndcg_at_k, evaluation.oracle_ndcg_at_k)
```

```python
from refract.benchmark import BenchmarkHarness, CustomDataset

dataset = CustomDataset(
name="my_eval",
queries=queries,
corpus=docs,
relevance=relevance,
)

harness = BenchmarkHarness()
heuristic = harness.run(dataset, compare_cosine_baseline=False)[0]
learned = harness.run(dataset, router=trained_router, compare_cosine_baseline=False)[0]
```

---

## Use as a RAG retrieval step

```python
Expand Down Expand Up @@ -240,7 +327,7 @@ for r in results:
| Mode | When to use | Training required |
|---|---|---|
| `HeuristicRouter` (default) | Always -- good out of the box | No |
| `LearnedRouter` | When you have relevance feedback data | Yes |
| `LearnedRouter` | When you have relevance feedback data and want adaptive routing | Yes |
| `CompositeRouter` | Blend multiple routers | Depends |
| `BaseRouter` subclass | Full custom control | You decide |

Expand Down Expand Up @@ -292,6 +379,8 @@ src/refract/
| [`custom_metric.py`](examples/custom_metric.py) | Plug in your own metric |
| [`compare_cosine.py`](examples/compare_cosine.py) | Side-by-side vs vanilla cosine |
| [`benchmark_demo.py`](examples/benchmark_demo.py) | Evaluation harness demo |
| [`train_learned_router.py`](examples/train_learned_router.py) | Train a learned router from judged queries |
| [`evaluate_learned_router.py`](examples/evaluate_learned_router.py) | Compare heuristic vs learned routing |
| [`vector_db_integration.py`](examples/vector_db_integration.py) | FAISS/Qdrant integration pattern |

---
Expand All @@ -304,8 +393,8 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup and guidelines.

## Roadmap

- `0.1.0` -- Core API, heuristic router, cosine / euclidean / mahalanobis / BM25 **(you are here)**
- `0.2.0` -- Learned router with PyTorch gating network
- `0.1.0` -- Core API, heuristic router, cosine / euclidean / mahalanobis / BM25
- `0.2.0` -- Learned router with relevance-driven training and evaluation **(you are here)**
- `0.3.0` -- BEIR benchmark harness with published results
- `1.0.0` -- Stable API, comprehensive benchmarks, documentation site

Expand Down
75 changes: 75 additions & 0 deletions examples/evaluate_learned_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Evaluate heuristic vs learned routing on a simple train/test split.

Run with: python examples/evaluate_learned_router.py
"""

from __future__ import annotations

import json
from pathlib import Path

from refract.benchmark import BenchmarkHarness, CustomDataset
from refract.routing import LearnedRouter


def load_sample_dataset() -> tuple[list[str], list[str], dict[int, set[int]]]:
sample_path = Path(__file__).parent.parent / "samples" / "mini_corpus.json"
with open(sample_path) as f:
payload = json.load(f)

corpus = payload["documents"]
queries = [item["query"] for item in payload["sample_queries"]]
relevance = {
idx: set(item["expected_relevant"]) for idx, item in enumerate(payload["sample_queries"])
}
return corpus, queries, relevance


if __name__ == "__main__":
corpus, queries, relevance = load_sample_dataset()
train_queries = queries[:3]
train_relevance = {idx: relevance[idx] for idx in range(3)}

test_queries = queries[3:]
test_relevance = {idx: relevance[idx + 3] for idx in range(len(test_queries))}

router = LearnedRouter(["cosine", "bm25", "mahalanobis", "euclidean"])
training_report = router.fit_from_relevance(train_queries, corpus, train_relevance, top_k=5)

test_dataset = CustomDataset(
name="mini_corpus_test_split",
queries=test_queries,
corpus=corpus,
relevance=test_relevance,
)

harness = BenchmarkHarness()
heuristic_result = harness.run(test_dataset, compare_cosine_baseline=False)[0]
learned_result = harness.run(
test_dataset,
router=router,
compare_cosine_baseline=False,
)[0]

print("=" * 72)
print("Training Report")
print("=" * 72)
print(training_report)
print()
print("=" * 72)
print("Held-Out Evaluation")
print("=" * 72)
print(f"{'Router':<12s} {'NDCG@10':>8s} {'Recall@10':>10s} {'MRR':>8s}")
print("-" * 72)
print(
f"{'heuristic':<12s} "
f"{heuristic_result.ndcg_at_10:>8.3f} "
f"{heuristic_result.recall_at_10:>10.3f} "
f"{heuristic_result.mrr_score:>8.3f}"
)
print(
f"{'learned':<12s} "
f"{learned_result.ndcg_at_10:>8.3f} "
f"{learned_result.recall_at_10:>10.3f} "
f"{learned_result.mrr_score:>8.3f}"
)
56 changes: 56 additions & 0 deletions examples/train_learned_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Train a learned router from relevance feedback and use it for search.

Run with: python examples/train_learned_router.py
"""

from __future__ import annotations

import json
from pathlib import Path
from tempfile import TemporaryDirectory

import refract
from refract.routing import LearnedRouter


def load_sample_dataset() -> tuple[list[str], list[str], dict[int, set[int]]]:
sample_path = Path(__file__).parent.parent / "samples" / "mini_corpus.json"
with open(sample_path) as f:
payload = json.load(f)

corpus = payload["documents"]
queries = [item["query"] for item in payload["sample_queries"]]
relevance = {
idx: set(item["expected_relevant"]) for idx, item in enumerate(payload["sample_queries"])
}
return corpus, queries, relevance


if __name__ == "__main__":
docs, queries, relevance = load_sample_dataset()

router = LearnedRouter(["cosine", "bm25", "mahalanobis", "euclidean"])
evaluation = router.fit_from_relevance(queries, docs, relevance, top_k=5)

print("=" * 72)
print("Learned Router Training")
print("=" * 72)
print(evaluation)
print("Per-metric quality:", evaluation.metric_quality)
print()

with TemporaryDirectory() as tmp_dir:
router_path = Path(tmp_dir) / "learned_router.pkl"
router.save(str(router_path))
loaded_router = LearnedRouter.load(str(router_path))

results = refract.search(
"how can I sort a list in Python",
docs,
router=loaded_router,
top_k=3,
)

print("Top results with the trained router:")
for result in results:
print(f" {result.score:.3f} {result.text}")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies = [
sentence-transformers = ["sentence-transformers>=2.6"]
openai = ["openai>=1.0"]
cohere = ["cohere>=5.0"]
learned = ["torch>=2.0"]
learned = []
benchmark = ["datasets>=2.14"]
dev = [
"pytest>=7.0",
Expand Down
3 changes: 3 additions & 0 deletions src/refract/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from refract.routing.base import BaseRouter
from refract.routing.composite import CompositeRouter
from refract.routing.heuristic import HeuristicRouter
from refract.routing.learned import LearnedRouter, LearnedRouterEvaluation
from refract.search import search, search_batch
from refract.types import (
MetricScore,
Expand All @@ -59,6 +60,8 @@
"CosineMetric",
"EuclideanMetric",
"HeuristicRouter",
"LearnedRouter",
"LearnedRouterEvaluation",
"MahalanobisMetric",
"MetricRegistry",
"MetricScore",
Expand Down
3 changes: 3 additions & 0 deletions src/refract/routing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from refract.routing.base import BaseRouter
from refract.routing.composite import CompositeRouter
from refract.routing.heuristic import HeuristicRouter
from refract.routing.learned import LearnedRouter, LearnedRouterEvaluation

__all__ = [
"BaseRouter",
"CompositeRouter",
"HeuristicRouter",
"LearnedRouter",
"LearnedRouterEvaluation",
]
Loading
Loading