Skip to content

Commit f950653

Browse files
Adding retrieval augmented generation (#12)
* Add RAG module under benchmarking * Imported RAG - first edit to MultiAgentTester * changes to files rag * RAG model working draft * improved format * HTTPclient change - unfixed * more prelim changes to the rag model * initial working prototype revised * added in functions.json and embeddings.json * extracted from scib lib * a better working draft of the rag * more changes to rag * small change * finally no errors! * more changes made * added new extractor class * trying to get scanpy to work * scanpy works * working proto for scanpy and scib-metrics * just added type annotations * trying to resolve json error * decent working ver before signing off * Deleted embeddings.json and functions.json * remove json * making sure file is safe * fixed issue by changing to jsonl * experimenting around with wikipedia lib * fixeS? * added umap with new prompts * fixed umap * reverted back to the technique with urls * fixed error * added some visualization in umap and heatmap * added fixes to rag file * file finalized for 2day * added variations to chunking * created a series of images to test * varying wiki +description contents * diff sizes of wiki page * more variations * trying with a bigger embedding model * trying wiki api * working implementation -not helpful * trying to aggresively remove stuff from the wiki result * trying to aggresively remove more * fixing bugs * trying to make a new ver work * small syntax error * draft * switched out wiki lib for beautiful soup extraction + switched model * working version yayy * rag file * working version of the rag class system * new file for user purposes! * changes to user and skeleton ver * added a new function to extract wiki content * attempt to fix * fixes and clean up * fixes and clean ups ongoing * file improvisations with request failsafe * resolved critical errors * testing * new fixes * removed embedding and functions file * fixing shennanigans * moved files * changes to file in type annotations * made more fixes to rag class by adding 1 function for the entire pipeline * took back some fixes that introduced errors * changed text to string type in func_Def * regex for improvements to extracting html * introduced more aggressive regex for cleaning func def * added support for dict objects * type annotations changed + dict incorporated * errors fixed * quick fixes to rag file * more changes for error correction * deleted folder from wrong location * moved locations for rag folder * improper folder placement * moved locations * making changes to the runner.py file system * moved rag * rag + changes to runner * one working version of implementation of rag - using agents * dylan's proposed version of the implementation * attempts at rag implementation * Added rag support to agent system * query function from database by function signature search * changed function definition search to function signature search * fixed rag implementation * may have fixed import, need to consult and fix file * potentially error fix * fixed imports * working with dylans location of rag folder * userrag file deemed unnecessary and deleted * change to file names and locations - more clean up, fixed imports * user rag file changed * Update system_blueprint.json * changes to imports and file structures * finally fixed import situation * tested code * trivial errors * trivial errors * Update RetrievalAugmentedGeneration.py * Fixed rag implementation * small UX fixes * added in new embeddings * embeddings and functions file created - however the search results from wikipedia are not accurate * fixed embedding file structure * need to fix wikipedia * restructured embeddings.jsonl to signature and embedding + added in query relevance * Synced New Embeddings and Functions --------- Co-authored-by: djriffle <djriffle1@gmail.com>
1 parent 25c9f29 commit f950653

14 files changed

Lines changed: 712 additions & 20 deletions

cli/extra_tools/RetrievalAugmentedEmbedder.py

Lines changed: 334 additions & 0 deletions
Large diffs are not rendered by default.

cli/extra_tools/embeddings.jsonl

Lines changed: 44 additions & 0 deletions
Large diffs are not rendered by default.

cli/extra_tools/functions.jsonl

Lines changed: 44 additions & 0 deletions
Large diffs are not rendered by default.

cli/olaf/pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ dependencies = [
2222
"jupyter-client", # NOTE: PyPI name has a hyphen
2323
"nbformat",
2424
"typer",
25-
"platformdirs"
25+
"platformdirs",
26+
"sentence_transformers",
27+
"tf_keras"
2628
]
2729

2830
# If you want a command like `olaf …`

cli/olaf/src/olaf/agents/AgentSystem.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# 2. The package-internal directory (for default samples), found relative to this file
1313
PACKAGE_CODE_SAMPLES_DIR = Path(__file__).resolve().parent.parent / "code_samples"
1414

15-
1615
class Command:
1716
"""Represents a command an agent can issue to a neighboring agent."""
1817
def __init__(self, name: str, target_agent: str, description: str):
@@ -23,19 +22,18 @@ def __init__(self, name: str, target_agent: str, description: str):
2322
def __repr__(self) -> str:
2423
return (f"Command(name='{self.name}', target='{self.target_agent}', "
2524
f"desc='{self.description[:30]}...')")
26-
27-
2825
class Agent:
2926
"""Represents a single agent in the system."""
30-
def __init__(self, name: str, prompt: str, commands: Dict[str, Command], code_samples: Dict[str, str]):
27+
def __init__(self, name: str, prompt: str, commands: Dict[str, Command], code_samples: Dict[str, str], is_rag_enabled: bool = False):
3128
self.name = name
3229
self.prompt = prompt
3330
self.commands = commands
3431
self.code_samples = code_samples
32+
self.is_rag_enabled = is_rag_enabled
3533

3634
def __repr__(self) -> str:
3735
sample_keys = list(self.code_samples.keys())
38-
return f"Agent(name='{self.name}', commands={list(self.commands.keys())}, samples={sample_keys})"
36+
return f"Agent(name='{self.name}', commands={list(self.commands.keys())}, samples={sample_keys}, rag_enabled={self.is_rag_enabled})"
3937

4038
def get_full_prompt(self, global_policy=None) -> str:
4139
"""Constructs the full prompt including the global policy and command descriptions."""
@@ -53,8 +51,14 @@ def get_full_prompt(self, global_policy=None) -> str:
5351
full_prompt += f"\n - Target Agent: {command.target_agent}"
5452
full_prompt += "\n\n**YOU MUST USE THESE EXACT COMMANDS TO DELEGATE TASKS. NO OTHER FORMATTING OR COMMANDS ARE ALLOWED.**"
5553

54+
if self.is_rag_enabled:
55+
full_prompt += "\n\nYou can query your specialized knowledge base for more context with the following command:"
56+
full_prompt += f"\n- Command: `query_rag_<function>`"
57+
full_prompt += f"\n - Description: Retrieves relevant information about a specific <function> from your knowledge base. Replace <function> with a concise, descriptive search query (e.g., function names, task you are trying to complete)."
58+
full_prompt += f"\n - Example: `query_rag_<scvi model setup>`"
59+
5660
if self.code_samples:
57-
full_prompt += "\n - Code Samples Available:"
61+
full_prompt += "\n\n - Code Samples Available:"
5862
for sample_name in self.code_samples.keys():
5963
full_prompt += f"\n - `{sample_name}`"
6064

@@ -101,7 +105,6 @@ def load_from_json(cls, file_path: str) -> 'AgentSystem':
101105
user_path = USER_CODE_SAMPLES_DIR / filename
102106
package_path = PACKAGE_CODE_SAMPLES_DIR / filename
103107

104-
# Default to package path, but overwrite if user path exists
105108
path_to_load = None
106109
source_label = ""
107110
if user_path.exists():
@@ -120,11 +123,15 @@ def load_from_json(cls, file_path: str) -> 'AgentSystem':
120123
else:
121124
print(f" ❌ WARNING: Code sample file '{filename}' not found in any location.")
122125

126+
rag_config = agent_data.get("rag", {})
127+
is_rag_enabled = rag_config.get("enabled", False)
128+
123129
agent = Agent(
124130
name=agent_name,
125131
prompt=agent_data['prompt'],
126132
commands=commands,
127-
code_samples=loaded_samples
133+
code_samples=loaded_samples,
134+
is_rag_enabled=is_rag_enabled
128135
)
129136
agents[agent_name] = agent
130137

cli/olaf/src/olaf/agents/create_agent_system.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import os
33
from typing import Dict, Any
44
from pathlib import Path
5-
from platformdirs import PlatformDirs # pip install platformdirs
5+
from platformdirs import PlatformDirs
66
import tempfile
77

88
APP_NAME = "olaf"
9-
APP_AUTHOR = "OpenTechBio" # or your org
9+
APP_AUTHOR = "OpenTechBio"
1010
dirs = PlatformDirs(APP_NAME, APP_AUTHOR)
1111

1212
# Root for user-specific OLAF files. Precedence: env -> platformdirs.
@@ -72,9 +72,24 @@ def define_agents() -> Dict[str, Dict[str, Any]]:
7272
if agent_name in agents:
7373
print(f"{Colors.FAIL}Agent '{agent_name}' already exists. Please use a unique name.{Colors.ENDC}")
7474
continue
75+
7576
prompt = input(f"{Colors.WARNING}Enter the system prompt for '{Colors.OKCYAN}{agent_name}{Colors.WARNING}': {Colors.ENDC}").strip()
76-
agents[agent_name] = {"prompt": prompt, "neighbors": {}, "code_samples": []}
77-
print(f"{Colors.OKGREEN}Agent '{Colors.OKCYAN}{agent_name}{Colors.OKGREEN}' added successfully.{Colors.ENDC}")
77+
78+
# --- New RAG Configuration Section ---
79+
rag_enabled_input = input(f"{Colors.WARNING}Enable Retrieval-Augmented Generation (RAG) for '{Colors.OKCYAN}{agent_name}{Colors.WARNING}'? (y/n): {Colors.ENDC}").strip().lower()
80+
is_rag_enabled = rag_enabled_input == 'y'
81+
82+
# Add the new 'rag' key to the agent's data structure
83+
agents[agent_name] = {
84+
"prompt": prompt,
85+
"neighbors": {},
86+
"code_samples": [],
87+
"rag": {"enabled": is_rag_enabled}
88+
}
89+
90+
rag_status = f"{Colors.OKGREEN}enabled" if is_rag_enabled else f"{Colors.FAIL}disabled"
91+
print(f"{Colors.OKGREEN}Agent '{Colors.OKCYAN}{agent_name}{Colors.OKGREEN}' added successfully (RAG: {rag_status}{Colors.OKGREEN}).{Colors.ENDC}")
92+
7893
print(f"\n{Colors.OKBLUE}--- All Agents Defined ---{Colors.ENDC}")
7994
for name in agents:
8095
print(f"- {Colors.OKCYAN}{name}{Colors.ENDC}")
@@ -162,7 +177,7 @@ def _atomic_write_json(obj: Any, path: Path) -> None:
162177
with tempfile.NamedTemporaryFile("w", delete=False, dir=str(path.parent), prefix=path.stem, suffix=".tmp") as tmp:
163178
json.dump(obj, tmp, indent=2)
164179
tmp_path = Path(tmp.name)
165-
tmp_path.replace(path) # atomic on POSIX; safe on Windows
180+
tmp_path.replace(path)
166181

167182
def save_configuration(global_policy: str, agents_config: Dict[str, Any], output_dir: str) -> None:
168183
if not agents_config:

cli/olaf/src/olaf/agents/integration_system.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
"global_policy": "Always be concise, professional, and helpful. Do not refuse to answer a request unless it is harmful.",
33
"agents": {
44
"master_agent": {
5+
"rag": {
6+
"enabled": true
7+
},
58
"prompt": "You are the master agent. Analyze every user request and delegate the task to the appropriate expert: the general coder for standard single-cell analysis or the integration expert for batch correction and data integration tasks. Respond ONLY with a delegation command.",
69
"neighbors": {
710
"delegate_to_general": {
@@ -16,6 +19,9 @@
1619
},
1720
"general_coder": {
1821
"prompt": "You are the *general scRNA-seq coder*. You handle standard single-cell analysis tasks like data loading, QC, filtering, normalization, and basic plotting using scanpy. You are not an expert in data integration.\n\nExample of a task you would perform:\n```python\nimport scanpy as sc\n\n# Assume 'adata' is a loaded AnnData object\n# Basic QC and filtering\nsc.pp.filter_cells(adata, min_genes=200)\nsc.pp.filter_genes(adata, min_cells=3)\nadata.var['mt'] = adata.var_names.str.startswith('MT-')\nsc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], inplace=True)\n\n# Normalize and find highly variable genes\nsc.pp.normalize_total(adata, target_sum=1e4)\nsc.pp.log1p(adata)\nsc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)\n\n# Run PCA\nsc.tl.pca(adata, svd_solver='arpack')\n\nprint('Standard analysis complete. PCA is in adata.obsm[\"X_pca\"].')\n```",
22+
"rag": {
23+
"enabled": true
24+
},
1925
"neighbors": {
2026
"delegate_to_master": {
2127
"target_agent": "master_agent",
@@ -29,6 +35,9 @@
2935
},
3036
"integration_expert": {
3137
"prompt": "You are the *integration expert*. You specialize in combining multiple single-cell datasets and correcting for batch effects using scvi-tools.\n\nExample of a task you would perform:\n```python\nimport scvi\nimport scanpy as sc\n\n# Assume 'adata' is loaded and preprocessed with a 'batch' column\n# Find highly variable genes across batches for integration\nsc.pp.highly_variable_genes(\n adata,\n n_top_genes=2000,\n subset=True,\n layer='counts',\n flavor='seurat_v3',\n batch_key='batch'\n)\n\n# Set up the AnnData object for the scVI model\nscvi.model.SCVI.setup_anndata(adata, layer='counts', batch_key='batch')\n\n# Create and train the scVI model\nmodel = scvi.model.SCVI(adata, n_layers=2, n_latent=30)\nmodel.train()\n\n# Store the integrated latent representation in the AnnData object\nadata.obsm['X_scVI'] = model.get_latent_representation()\n\nprint('Integration complete. Integrated embedding is in adata.obsm[\"X_scVI\"].')\n``` you remeber to wrap your code in triple backticks and python. Please only include one code block per response. Remeber to keep responses short and to the point.",
38+
"rag": {
39+
"enabled": true
40+
},
3241
"neighbors": {
3342
"delegate_to_master": {
3443
"target_agent": "master_agent",

cli/olaf/src/olaf/agents/system_blueprint.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
"agents": {
44
"master_agent": {
55
"prompt": "You are the master agent. Your primary role is to analyze incoming user requests and delegate them to the appropriate specialist agent. You do not perform tasks yourself.",
6+
"rag": {
7+
"enabled": true
8+
},
69
"neighbors": {
710
"delegate_to_coder": {
811
"target_agent": "coder_agent",
@@ -16,14 +19,20 @@
1619
},
1720
"coder_agent": {
1821
"prompt": "You are a specialist single cell RNA coder agent. Your job is to write high-quality, executable code based on the user's request. You do not delegate tasks. The machine you run on has write disabled. You should never save to disk or modify files. Prioritize small step responses and avoid large code dumps.",
22+
"rag": {
23+
"enabled": true
24+
},
1925
"neighbors": {},
2026
"code_samples": [
2127
"load_adata.py"
2228
]
2329
},
2430
"research_agent": {
2531
"prompt": "You are a specialist research agent. You fulfill user requests by finding and synthesizing information from reliable sources. You do not write code or delegate tasks.",
32+
"rag": {
33+
"enabled": true
34+
},
2635
"neighbors": {}
2736
}
2837
}
29-
}
38+
}

cli/olaf/src/olaf/execution/runner.py

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from olaf.config import OLAF_HOME
1818
from olaf.agents.AgentSystem import Agent, AgentSystem
1919
from olaf.core.io_helpers import display, extract_python_code, format_execute_response
20+
from olaf.rag.RetrievalAugmentedGeneration import RetrievalAugmentedGeneration
2021
except ImportError as e:
2122
print(f"Failed to import a required OLAF module: {e}", file=sys.stderr)
2223
sys.exit(1)
@@ -39,6 +40,9 @@ def exec_code(self, code: str, timeout: int) -> dict:
3940
_OUTPUTS_DIR = OLAF_HOME / "runs"
4041
_SNIPPET_DIR = _OUTPUTS_DIR / "snippets"
4142
_LEDGER_PATH = _OUTPUTS_DIR / f"benchmark_history_{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}.jsonl"
43+
_RAG_RE = re.compile(r"query_rag_<([^>]+)>")
44+
RAG = RetrievalAugmentedGeneration()
45+
4246

4347
def _init_paths():
4448
"""Ensure output directories exist before writing."""
@@ -51,6 +55,11 @@ def detect_delegation(msg: str) -> Optional[str]:
5155
m = _DELEG_RE.search(msg)
5256
return f"delegate_to_{m.group(1)}" if m else None
5357

58+
def detect_rag(msg: str) -> Optional[str]:
59+
"""Return the *partial* RAG command if present."""
60+
m = _RAG_RE.search(msg)
61+
return f"{m.group(1)}" if m else None
62+
5463
def _dump_code_snippet(run_id: str, code: str) -> str:
5564
"""Write <run_id>.py under outputs/snippets/ and return the relative path."""
5665
snippet_path = _SNIPPET_DIR / f"{run_id}.py"
@@ -69,7 +78,8 @@ def _save_benchmark_record(*, run_id: str, results: dict, meta: dict, code: str
6978
record["code_path"] = _dump_code_snippet(run_id, code)
7079
with _LEDGER_PATH.open("a") as fh:
7180
fh.write(json.dumps(record) + "\n")
72-
81+
82+
7383
# --- Core Runner Functions ---
7484
def run_benchmark(
7585
console: Console,
@@ -187,7 +197,21 @@ def run_agent_session(
187197
break
188198

189199
history.append({"role": "assistant", "content": msg})
190-
display(console, f"assistant ({current_agent.name})", msg)
200+
display(console, f"assistant ({current_agent.name})", msg)
201+
202+
# --- RAG handling ---
203+
query_from_re = detect_rag(msg)
204+
if query_from_re and current_agent.is_rag_enabled:
205+
console.print(f"[yellow]🔍 Triggering RAG query: {query_from_re}[/yellow]")
206+
retrieved_docs = RAG.query(query_from_re)
207+
if retrieved_docs:
208+
console.print(f"[green] RAG query successful. [/green]")
209+
feedback = retrieved_docs
210+
console.print(feedback)
211+
history.append({"role": "system", "content": feedback})
212+
else:
213+
console.print(f"[red] RAG query unsuccessful. [/red]")
214+
191215

192216
cmd = detect_delegation(msg)
193217
if cmd and cmd in current_agent.commands:
@@ -211,8 +235,39 @@ def run_agent_session(
211235
console.print("[cyan]Executing code in sandbox…[/cyan]")
212236
exec_result = sandbox_manager.exec_code(code, timeout=300)
213237
feedback = format_execute_response(exec_result, _OUTPUTS_DIR)
214-
history.append({"role": "user", "content": feedback})
215-
display(console, "user", feedback)
238+
history.append({"role": "assistant", "content": feedback})
239+
display(console, "assistant", feedback)
240+
241+
stderr = exec_result.get('stderr', '')
242+
if stderr and current_agent.is_rag_enabled:
243+
func_error_patterns = [
244+
r"missing \d+ required positional argument", # TypeError: missing argument
245+
r"NameError: name '(\w+)' is not defined", # NameError
246+
r"AttributeError: '.*' object has no attribute '(\w+)'", # missing attribute
247+
r"got an unexpected keyword argument" # wrong keyword argument
248+
]
249+
function_name = ""
250+
retrieved_docs = ""
251+
if any(re.search(pat, stderr) for pat in func_error_patterns):
252+
lines = stderr.strip().splitlines()
253+
if len(lines) >= 2:
254+
code_line = lines[-2].strip() # second-to-last line: code that failed
255+
match = re.search(r'(\w+)\s*\(', code_line)
256+
if match:
257+
function_name = match.group(1)
258+
259+
if function_name:
260+
retrieved_docs = RAG.retrieve_function(function_name)
261+
console.print(f"[yellow]🔍 Missing function detected: {function_name}, function database search...[/yellow]")
262+
if retrieved_docs:
263+
console.print(f"[green] Query successful - Function signature found. [/green]")
264+
feedback += f"\n {function_name} produced an error. The correct function signature for {function_name} is:\n{retrieved_docs}"
265+
console.print(feedback)
266+
history.append({"role": "system", "content": feedback})
267+
continue
268+
else:
269+
print(f"RAG Error Query unsuccessful - Function signature does not exist in the current database.")
270+
216271

217272
if is_auto:
218273
if benchmark_modules:
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import json
2+
import sys
3+
from pathlib import Path
4+
from typing import List, Dict, Optional
5+
from contextlib import redirect_stdout, redirect_stderr
6+
7+
import os
8+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
9+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
10+
# ── Dependencies ─────────────────────────────────────────────
11+
try:
12+
import re
13+
from sentence_transformers import SentenceTransformer
14+
from rich.console import Console
15+
import matplotlib.pyplot as plt
16+
import numpy as np
17+
18+
except ImportError as e:
19+
print(f"Missing dependency: {e}", file=sys.stderr)
20+
sys.exit(1)
21+
22+
# ── Paths and Constants ─────────────────────────────────────────────
23+
console = Console()
24+
25+
RAG_DIR = Path(__file__).resolve().parent.parent / "rag"
26+
EMBEDDING_FILE = RAG_DIR / "embeddings.jsonl"
27+
FUNCTIONS_FILE = RAG_DIR / "functions.jsonl"
28+
29+
class RetrievalAugmentedGeneration():
30+
model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
31+
32+
def __init__(self):
33+
self.embeddings = self.load_embeddings()
34+
self.functions = self.load_functions()
35+
self.queries = []
36+
37+
def load_embeddings(self) -> List[np.ndarray]:
38+
try:
39+
with open(EMBEDDING_FILE, "r", encoding="utf-8") as f:
40+
return [np.array(json.loads(line)) for line in f if line.strip()]
41+
except FileNotFoundError:
42+
console.log("[red]Embeddings file not found.")
43+
return []
44+
except json.JSONDecodeError:
45+
console.log("[red]Embeddings file is not valid JSONL.")
46+
return []
47+
48+
def load_functions(self) -> List[Dict[str, str]]:
49+
try:
50+
with open(FUNCTIONS_FILE, "r", encoding="utf-8") as f:
51+
return [json.loads(line) for line in f if line.strip()]
52+
except FileNotFoundError:
53+
console.log("[red]Functions file not found.")
54+
return []
55+
except json.JSONDecodeError:
56+
console.log("[red]Functions file is not valid JSONL.")
57+
return []
58+
59+
@staticmethod
60+
def cosine_similarity(A: np.ndarray, B: List[np.ndarray]) -> List[float]:
61+
sims = [np.dot(A, emb) / (np.linalg.norm(A) * np.linalg.norm(emb)) for emb in B]
62+
return sims
63+
64+
def retrieve_function(self, name:str) -> Optional[str]:
65+
for function in self.functions:
66+
if name in function["signature"]:
67+
return function["signature"]
68+
return None
69+
70+
def query(self, text_query: str) -> Optional[np.ndarray]:
71+
self.queries.append(text_query)
72+
if not self.embeddings:
73+
console.log("[yellow]No embeddings to compare.")
74+
return None
75+
query_embedding = self.model.encode([text_query])[0]
76+
sims = self.cosine_similarity(query_embedding, self.embeddings)
77+
idx = np.argmax(sims)
78+
return self.functions[idx]["signature"]
79+
80+
# ──────Implementation──────────────────────────────────────────────────────────
81+
82+
if __name__ == "__main__":
83+
rag = RetrievalAugmentedGeneration()
84+
print(rag.query("What is pca"))

0 commit comments

Comments
 (0)