Skip to content

Commit 2a7f089

Browse files
authored
Merge pull request #11 from OpenTechBio/AddingMultiBenchmarkSupport
Adding multi benchmark support
2 parents 15fe8d0 + 85929c9 commit 2a7f089

4 files changed

Lines changed: 108 additions & 31 deletions

File tree

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Don't import AutomMetric
2+
# from AutoMetric import AutoMetric
3+
import scanpy as sc
4+
import celltypist
5+
from celltypist import models
6+
from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection
7+
import scanpy.external as sce
8+
9+
class CellTypingMetric(AutoMetric):
10+
"""
11+
This is a class that computes cell typing using CellTypist
12+
Then, it evaluates using metrics from Benchmarker class from SCIB's Metric Module.
13+
"""
14+
def metric(self, adata):
15+
#scib_metrics Benchmarker
16+
bm = Benchmarker(
17+
adata,
18+
batch_key="batch",
19+
label_key="majority_voting",
20+
bio_conservation_metrics=BioConservation(nmi_ari_cluster_labels_leiden=True),
21+
batch_correction_metrics=None,
22+
embedding_obsm_keys=["X_pca","X_pca_harmony"], #need to check if it has such a label -> if it doesnt perform pca
23+
n_jobs=6,
24+
)
25+
bm.prepare()
26+
bm.benchmark()
27+
bm.plot_results_table(min_max_scale=False)
28+
bm.get_results()
29+
30+
CellTypingMetric().run(adata)

benchmarking/core/io_helpers.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
import base64
1414
from datetime import datetime
1515

16-
17-
1816
def extract_python_code(txt: str) -> Optional[str]:
1917
"""Return the *first* fenced code block, or None if absent.
2018
@@ -107,6 +105,48 @@ def collect_resources(console, sandbox_sources_dir) -> List[Tuple[Path, str]]:
107105
res.append((path, f"{sandbox_sources_dir}/{path.name}"))
108106
return res
109107

108+
def load_bp_json(console) -> Path:
109+
"""
110+
Try to find a blueprint JSON file from common locations.
111+
If multiple are found, prompt user to choose or enter manual path.
112+
"""
113+
search_paths = [
114+
Path.home() / "Olaf" / "benchmarking" / "agents",
115+
Path.cwd() / "benchmarking" / "agents",
116+
Path.cwd() / "agents"
117+
]
118+
119+
# Search for JSON files in known paths
120+
for path in search_paths:
121+
if path.is_dir():
122+
json_files = list(path.rglob("*.json"))
123+
if json_files:
124+
choices = [f.name for f in json_files]
125+
choices.append("manual")
126+
127+
choice = Prompt.ask(
128+
"Select a blueprint JSON file or choose 'manual' to enter path",
129+
choices=choices,
130+
default="system_blueprint.json"
131+
)
132+
if choice == "manual":
133+
break # jump to manual path section
134+
selected = path / choice
135+
if selected.exists():
136+
return selected
137+
138+
# Manual fallback
139+
user_path = Prompt.ask(
140+
"Please provide absolute or relative path to blueprint JSON",
141+
default="~/system_blueprint.json"
142+
)
143+
bp = Path(user_path).expanduser()
144+
145+
if not bp.exists():
146+
console.print(f"[red]Blueprint file not found at: {bp}[/red]")
147+
sys.exit(1)
148+
149+
return bp
110150

111151
def format_execute_response(resp: dict, output_dir) -> str:
112152
lines = ["Code execution result:"]

benchmarking/prompt_testing/MultiAgentAutoTester.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
collect_resources,
6161
get_initial_prompt,
6262
format_execute_response,
63+
load_bp_json
6364
)
6465
from benchmarking.core.sandbox_management import (
6566
init_docker,
@@ -153,7 +154,7 @@ def _save_benchmark_record(*, run_id: str, results: dict, meta: dict, code: str
153154
# ===========================================================================
154155
def load_agent_system() -> Tuple[AgentSystem, Agent, str]:
155156
"""Load the agent system from a JSON blueprint."""
156-
bp = Path(Prompt.ask("Blueprint JSON", default="system_blueprint.json")).expanduser()
157+
bp = load_bp_json(console)
157158
if not bp.exists():
158159
console.print(f"[red]Blueprint {bp} not found.")
159160
sys.exit(1)

benchmarking/prompt_testing/MultiAgentTester.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@
6868
select_dataset,
6969
collect_resources,
7070
get_initial_prompt,
71-
format_execute_response
71+
format_execute_response,
72+
load_bp_json
7273
)
7374
from benchmarking.core.sandbox_management import (
7475
init_docker,
@@ -115,10 +116,7 @@
115116
# ===========================================================================
116117

117118
def load_agent_system() -> Tuple[AgentSystem, Agent, str]:
118-
bp = Path(Prompt.ask("Blueprint JSON", default="system_blueprint.json")).expanduser()
119-
if not bp.exists():
120-
console.print(f"[red]Blueprint {bp} not found.")
121-
sys.exit(1)
119+
bp = load_bp_json(console)
122120
system = AgentSystem.load_from_json(str(bp))
123121
driver_name = Prompt.ask("Driver agent", choices=list(system.agents.keys()), default=list(system.agents)[0])
124122
driver = system.get_agent(driver_name)
@@ -150,7 +148,7 @@ def api_alive(url: str, tries: int = 10) -> bool:
150148
# 3 · Interactive loop
151149
# ===========================================================================
152150

153-
def run(agent_system: AgentSystem, agent: Agent, roster_instr: str, dataset: Path, metadata: dict, resources: List[Tuple[Path, str]], benchmark_module: Optional[Path] = None):
151+
def run(agent_system: AgentSystem, agent: Agent, roster_instr: str, dataset: Path, metadata: dict, resources: List[Tuple[Path, str]], benchmark_modules: Optional[list[Path]] = None):
154152
mgr = _BackendManager()
155153
console.print(f"Launching sandbox ({backend})…")
156154

@@ -245,7 +243,7 @@ def build_system(a: Agent) -> str:
245243
display(console, "user", feedback)
246244

247245
def input_loop():
248-
if benchmark_module:
246+
if benchmark_modules:
249247
console.print("\n[bold]Next message (blank = continue, 'benchmark' to run benchmarks, 'exit' to quit):[/bold]")
250248
else:
251249
console.print("\n[bold]Next message (blank = continue, 'exit' to quit):[/bold]")
@@ -255,8 +253,9 @@ def input_loop():
255253
user_in = "exit"
256254
if user_in.lower() in {"exit", "quit"}:
257255
return "break"
258-
if user_in.lower() == "benchmark" and benchmark_module:
259-
run_benchmark(mgr, benchmark_module)
256+
if user_in.lower() == "benchmark" and benchmark_modules:
257+
for benchmark_module in benchmark_modules:
258+
run_benchmark(mgr, benchmark_module)
260259
input_loop() # Recurse to continue the loop after benchmarks
261260
if user_in:
262261
history.append({"role": "user", "content": user_in})
@@ -273,7 +272,7 @@ def input_loop():
273272
# 4 · Benchmarking
274273
# ===========================================================================
275274

276-
def get_benchmark_module(console: Console, parent_dir: Path) -> Optional[Path]:
275+
def get_benchmark_modules(console: Console, parent_dir: Path) -> Optional[list[Path]]:
277276
"""
278277
Prompts the user to select a benchmark module from the available ones.
279278
Returns the path to the selected module or None if no selection is made.
@@ -283,31 +282,38 @@ def get_benchmark_module(console: Console, parent_dir: Path) -> Optional[Path]:
283282
console.print("[red]No benchmarks directory found.[/red]")
284283
return None
285284

286-
modules = list(benchmark_dir.glob("*.py"))
285+
module_names = list(benchmark_dir.glob("*.py"))
287286
# remove AutoMetric.py from modules (it is the base class)
288-
modules = [m for m in modules if m.name != "AutoMetric.py"]
289-
if not modules:
287+
module_names = [m for m in module_names if m.name != "AutoMetric.py"]
288+
if not module_names:
290289
console.print("[red]No benchmark modules found.[/red]")
291290
return None
292291

293292
console.print("\n[bold]Available benchmark modules:[/bold]")
294-
for i, mod in enumerate(modules, start=1):
293+
for i, mod in enumerate(module_names, start=1):
295294
console.print(f"{i}. {mod.name}")
296-
297-
choice = Prompt.ask("Select a benchmark module by number (or press Enter to skip)", default="")
298-
if not choice:
295+
console.print(f"{len(module_names)+1}. Select All")
296+
choices = Prompt.ask("Select benchmark modules by number (e.g. 1 2 3 or 1,2,3) (or press Enter to skip)", default="")
297+
choices = re.split(r'[,\s]+', choices) #User input must be seperated by commas or spaces
298+
299+
if not choices or choices == ['']:
299300
return None
300301

301-
try:
302-
index = int(choice) - 1
303-
if 0 <= index < len(modules):
304-
return modules[index]
305-
else:
306-
console.print("[red]Invalid selection.[/red]")
302+
modules = []
303+
for choice in choices:
304+
try:
305+
index = int(choice) - 1
306+
if index == len(module_names): #Handles select all case
307+
return module_names
308+
elif 0 <= index < len(module_names):
309+
modules.append(module_names[index])
310+
else:
311+
console.print("[red]Invalid selection.[/red]")
312+
return None
313+
except ValueError:
314+
console.print("[red]Invalid input. Please enter a number.[/red]")
307315
return None
308-
except ValueError:
309-
console.print("[red]Invalid input. Please enter a number.[/red]")
310-
return None
316+
return modules
311317

312318
def run_benchmark(mgr, benchmark_module: str):
313319
"""
@@ -377,9 +383,9 @@ def main():
377383

378384
sys, drv, roster = load_agent_system()
379385
dp, meta = select_dataset(console, DATASETS_DIR)
380-
benchmark_module = get_benchmark_module(console, PARENT_DIR)
386+
benchmark_modules = get_benchmark_modules(console, PARENT_DIR)
381387
res = collect_resources(console, SANDBOX_RESOURCES_DIR)
382-
run(sys, drv, roster, dp, meta, res, benchmark_module)
388+
run(sys, drv, roster, dp, meta, res, benchmark_modules)
383389

384390

385391
if __name__ == "__main__":

0 commit comments

Comments
 (0)