diff --git a/benchmarking/.gitignore b/benchmarking/.gitignore new file mode 100644 index 0000000..f3a16c3 --- /dev/null +++ b/benchmarking/.gitignore @@ -0,0 +1,4 @@ +.env +__pycache__/ +.DS_store +outputs/ \ No newline at end of file diff --git a/benchmarking/Evaluator.py b/benchmarking/Evaluator.py new file mode 100644 index 0000000..7e84b3b --- /dev/null +++ b/benchmarking/Evaluator.py @@ -0,0 +1,307 @@ +import json +import os +import sys +import argparse +from pathlib import Path +from datetime import datetime +import re + +# --- Dependency Imports --- +try: + from dotenv import load_dotenv +except ImportError: + print("Error: python-dotenv library not found. Please install it: pip install python-dotenv", file=sys.stderr) + sys.exit(1) + +try: + from openai import OpenAI, APIError +except ImportError: + print("Error: openai library not found. Please install it: pip install openai", file=sys.stderr) + sys.exit(1) + +# Optional: Use rich for better formatting +try: + from rich.console import Console + from rich.prompt import Prompt, Confirm + from rich.panel import Panel + HAS_RICH = True + console = Console() +except ImportError: + HAS_RICH = False + console = None + # Simple print/input fallback if rich is not installed + class Console: + def print(self, *args, **kwargs): print(*args) + class Prompt: + @staticmethod + def ask(prompt, default=None): + p_text = f"{prompt} " + if default: p_text += f"[{default}] " + return input(p_text).strip() + class Confirm: + @staticmethod + def ask(prompt, default=False): + val = input(f"{prompt} [y/N] " if not default else f"{prompt} [Y/n] ").lower().strip() + if not val: return default + return val == 'y' + class Panel: + def __init__(self, content, title="", border_style=""): self.content=str(content); self.title=title + def __rich_console__(self, console, options): yield self.title; yield self.content + + +# --- Constants --- +SCRIPT_DIR = Path(__file__).parent.resolve() +DEFAULT_INPUT_DIR = SCRIPT_DIR / "outputs" +DEFAULT_OUTPUT_DIR = SCRIPT_DIR / "outputs" # Default to save back into input dir +ENV_FILE = SCRIPT_DIR / ".env" + +# --- Configuration Loading --- +load_dotenv(dotenv_path=ENV_FILE) +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +OPENAI_MODEL = "gpt-4o" # Or your preferred model for evaluation + +if not OPENAI_API_KEY: + if console: console.print(f"[bold red]Error:[/bold red] OPENAI_API_KEY not found in {ENV_FILE}.") + else: print(f"Error: OPENAI_API_KEY not found in {ENV_FILE}.") + sys.exit(1) + +try: + openai_client = OpenAI(api_key=OPENAI_API_KEY) + if console: console.print(f"OpenAI client initialized for model [cyan]{OPENAI_MODEL}[/cyan].") + else: print(f"OpenAI client initialized for model {OPENAI_MODEL}.") +except Exception as e: + if console: console.print(f"[bold red]Error initializing OpenAI client:[/bold red] {e}") + else: print(f"Error initializing OpenAI client: {e}") + sys.exit(1) + +# --- Helper Functions --- + +def format_conversation_for_eval(test_data): + """ Formats the conversation turns into a readable string for the evaluator prompt. """ + if not test_data or "turns" not in test_data: + return "[No conversation turns found]" + + formatted_lines = [] + for turn in test_data.get("turns", []): + role = turn.get("role", "unknown").upper() + content = turn.get("content", "[No content]") + + # Shorten system prompt for brevity in evaluation context if desired + if role == "SYSTEM": + # Extract key parts or just indicate system prompt presence + content = "[System Prompt Provided - see original log for details]" + # Or keep it: content = turn.get("content", "[No content]") + + # Format code execution results more clearly if they are part of user turn + if role == "USER" and content.startswith("Code execution result:"): + # Reformat slightly for clarity + content = content.replace("Code execution result:", "**CODE EXECUTION RESULT:**") + content = content.replace("--- STDOUT ---", "**STDOUT:**") + content = content.replace("--- STDERR ---", "**STDERR:**") + content = content.replace("--------------", "---") # Shorten separator + + formatted_lines.append(f"--- {role} ---") + formatted_lines.append(content) + formatted_lines.append("\n") # Add space between turns + + return "\n".join(formatted_lines) + + +def call_openai_evaluator(conversation_text, context): + """ Sends the formatted conversation to OpenAI for evaluation. """ + evaluator_prompt = f"""You are an expert evaluator assessing the performance of an AI assistant acting as a bioinformatician. +The assistant was given a task related to analyzing a single-cell transcriptomics dataset. +The expected performance level is that of an **entry-level post-graduate bioinformatician**. + +**Dataset Context:** +- Dataset File: {context.get('dataset_file', 'N/A')} +- Key Metadata: {json.dumps(context.get('dataset_metadata', {}), indent=1, default=str)} + +**Task Context:** +- Initial User Prompt: See the first USER prompt below. +- Max Code Attempts Allowed: {context.get('max_code_tries', 'N/A')} + +**Conversation Log:** +{conversation_text} + +**Evaluation Task:** +Based on the conversation log, evaluate the AI assistant's performance. Consider the following: +1. **Correctness:** Was the generated code correct and did it achieve the intended analysis steps? +2. **Efficiency:** Was the approach reasonable? Were there unnecessary steps? +3. **Interpretation:** Did the assistant correctly interpret the results of its code execution? +4. **Planning:** Did the assistant use its allowed code execution attempts effectively towards the goal? +5. **Clarity:** Was the assistant's text explanation clear and accurate? +6. **Overall Skill:** Does the performance align with an entry-level post-graduate bioinformatician? + +**Output Format:** +Please provide your evaluation strictly in the following JSON format ONLY. Do not include any other text before or after the JSON block: +{{ + "grade": , + "comments": "" +}} +""" + + if console: console.print(f"Sending evaluation request for context: {context.get('prompt_id', 'unknown')[:20]}...") + else: print(f"Sending evaluation request for context: {context.get('prompt_id', 'unknown')[:20]}...") + + try: + response = openai_client.chat.completions.create( + model=OPENAI_MODEL, + messages=[ + # Maybe a short system message for the evaluator role itself? + # {"role": "system", "content": "You are an expert evaluator."}, + {"role": "user", "content": evaluator_prompt} + ], + temperature=0.3, # Lower temperature for more deterministic evaluation + response_format={"type": "json_object"}, # Request JSON output + max_tokens=1000 # Adjust as needed for comments length + ) + eval_content = response.choices[0].message.content + if console: console.print("[green]Evaluation received from OpenAI.[/green]") + else: print("Evaluation received from OpenAI.") + + # Attempt to parse the JSON response + try: + eval_json = json.loads(eval_content) + # Validate expected keys + if "grade" in eval_json and "comments" in eval_json: + # Basic type check (can be more robust) + if isinstance(eval_json["grade"], int) and isinstance(eval_json["comments"], str): + return eval_json + else: + raise ValueError("Incorrect data types for 'grade' or 'comments'.") + else: + raise ValueError("Missing 'grade' or 'comments' key in JSON response.") + except (json.JSONDecodeError, ValueError) as e: + if console: console.print(f"[bold red]Error parsing evaluation JSON from OpenAI: {e}[/bold red]") + else: print(f"Error parsing evaluation JSON from OpenAI: {e}") + if console: console.print(f"Raw response content:\n{eval_content}") + else: print(f"Raw response content:\n{eval_content}") + # Return a structured error + return {"grade": -1, "comments": f"Error parsing OpenAI response: {e}\nRaw Content: {eval_content}"} + + except APIError as e: + if console: console.print(f"[bold red]OpenAI API Error during evaluation: {e}[/bold red]") + else: print(f"OpenAI API Error during evaluation: {e}") + return {"grade": -1, "comments": f"OpenAI API Error: {e}"} + except Exception as e: + if console: console.print(f"[bold red]Unexpected error during evaluation call: {e}[/bold red]") + else: print(f"Unexpected error during evaluation call: {e}") + import traceback + traceback.print_exc() + return {"grade": -1, "comments": f"Unexpected Error: {e}"} + + +def process_folder(input_dir_path, output_path): + """Finds JSON files, gets evaluations, and saves them.""" + evaluations = {} + json_files = list(input_dir_path.glob("*.json")) + + if not json_files: + if console: console.print(f"[yellow]No JSON files found in '{input_dir_path}'.[/yellow]") + else: print(f"No JSON files found in '{input_dir_path}'.") + return + + if console: console.print(f"Found {len(json_files)} JSON file(s) to evaluate.") + else: print(f"Found {len(json_files)} JSON file(s) to evaluate.") + + for json_file in json_files: + if console: console.print(f"\n--- Processing: [cyan]{json_file.name}[/cyan] ---") + else: print(f"\n--- Processing: {json_file.name} ---") + try: + with open(json_file, 'r', encoding='utf-8') as f: + results_data = json.load(f) + + # Process each test run within the file (assuming structure {test_id: test_data}) + file_evaluations = {} + for test_id, test_data in results_data.items(): + if not isinstance(test_data, dict) or "context" not in test_data or "turns" not in test_data: + if console: console.print(f"[yellow]Skipping invalid/incomplete data for test ID '{test_id}' in {json_file.name}.[/yellow]") + else: print(f"Skipping invalid/incomplete data for test ID '{test_id}' in {json_file.name}.") + continue + + conversation_text = format_conversation_for_eval(test_data) + context = test_data.get("context", {}) + evaluation = call_openai_evaluator(conversation_text, context) + file_evaluations[test_id] = evaluation # Store evaluation keyed by test_id + + # Store evaluations for this file, keyed by the original filename stem + evaluations[json_file.stem] = file_evaluations + + except json.JSONDecodeError: + if console: console.print(f"[red]Error decoding JSON from {json_file.name}. Skipping.[/red]") + else: print(f"Error decoding JSON from {json_file.name}. Skipping.") + except Exception as e: + if console: console.print(f"[red]Error processing file {json_file.name}: {e}[/red]") + else: print(f"Error processing file {json_file.name}: {e}") + + # --- Save Evaluations --- + if not evaluations: + if console: console.print("[yellow]No evaluations were generated.[/yellow]") + else: print("No evaluations were generated.") + return + + output_path = Path(output_path) # Ensure it's a Path object + + # Check if output is a directory or file + if output_path.suffix == ".json": + # Save all evaluations to a single specified file + output_filename = output_path + if console: console.print(f"\nSaving all evaluations to single file: [cyan]{output_filename}[/cyan]") + else: print(f"\nSaving all evaluations to single file: {output_filename}") + try: + output_path.parent.mkdir(parents=True, exist_ok=True) # Ensure parent dir exists + with open(output_filename, "w", encoding="utf-8") as f: + json.dump(evaluations, f, indent=2) + if console: console.print("[green]Evaluations saved successfully.[/green]") + else: print("Evaluations saved successfully.") + except Exception as e: + if console: console.print(f"[bold red]Error saving aggregated evaluations to {output_filename}:[/bold red] {e}") + else: print(f"Error saving aggregated evaluations to {output_filename}: {e}") + else: + # Save evaluations to individual files in the specified directory + output_dir = output_path + output_dir.mkdir(parents=True, exist_ok=True) # Ensure dir exists + if console: console.print(f"\nSaving evaluations to directory: [cyan]{output_dir}[/cyan]") + else: print(f"\nSaving evaluations to directory: {output_dir}") + for input_stem, file_evals in evaluations.items(): + output_filename = output_dir / f"{input_stem}_eval.json" + try: + with open(output_filename, "w", encoding="utf-8") as f: + json.dump(file_evals, f, indent=2) + if console: console.print(f" Saved: [green]{output_filename.name}[/green]") + else: print(f" Saved: {output_filename.name}") + except Exception as e: + if console: console.print(f" [red]Error saving evaluation for {input_stem}: {e}[/red]") + else: print(f" Error saving evaluation for {input_stem}: {e}") + + +def interactive_loop(): + """Handles the interactive user prompts.""" + if console: console.print("\n--- Agent Benchmark Evaluator ---") + else: print("\n--- Agent Benchmark Evaluator ---") + + # Get input directory + default_input = str(DEFAULT_INPUT_DIR.resolve()) + while True: + if console: input_dir_str = Prompt.ask("Enter path to input folder containing results JSONs", default=default_input) + else: input_dir_str = input(f"Enter path to input folder containing results JSONs [{default_input}]: ").strip() or default_input + + input_dir_path = Path(input_dir_str).resolve() + if input_dir_path.is_dir(): + break + else: + if console: console.print(f"[red]Error: Input path '{input_dir_path}' is not a valid directory.[/red]") + else: print(f"Error: Input path '{input_dir_path}' is not a valid directory.") + + # Get output path (directory or specific file) + default_output = str(input_dir_path) # Default output to input dir + if console: output_path_str = Prompt.ask("Enter output directory or specific .json filename for results", default=default_output) + else: output_path_str = input(f"Enter output directory or specific .json filename for results [{default_output}]: ").strip() or default_output + + process_folder(input_dir_path, output_path_str) + + +# --- Main Execution --- +if __name__ == "__main__": + interactive_loop() diff --git a/benchmarking/OneShotAgentTester.py b/benchmarking/OneShotAgentTester.py new file mode 100644 index 0000000..3084036 --- /dev/null +++ b/benchmarking/OneShotAgentTester.py @@ -0,0 +1,666 @@ +import argparse +import os +import sys +import json +import re +import shlex +import time +from pathlib import Path +import subprocess # Still needed for docker cp (for dataset copy) +import base64 # For decoding image data from API +from datetime import datetime # For timestamp in filename + +# --- Dependency Imports --- +try: + from dotenv import load_dotenv +except ImportError: + print("Error: python-dotenv library not found. Please install it: pip install python-dotenv", file=sys.stderr) + sys.exit(1) + +try: + from openai import OpenAI, APIError +except ImportError: + print("Error: openai library not found. Please install it: pip install openai", file=sys.stderr) + sys.exit(1) + +try: + import requests # For making HTTP requests to the FastAPI service +except ImportError: + print("Error: requests library not found. Please install it: pip install requests", file=sys.stderr) + sys.exit(1) + + +try: + # Assumes benchmarking_sandbox_management.py is in a 'sandbox' subdirectory + # We still need the manager for start/stop and the container name constant + sandbox_dir = os.path.join(os.path.dirname(__file__), 'sandbox') + sys.path.insert(0, sandbox_dir) + from benchmarking_sandbox_management import SandboxManager, CONTAINER_NAME as SANDBOX_CONTAINER_NAME, API_PORT_HOST +except ImportError as e: + print(f"Error: Could not import SandboxManager or constants from {sandbox_dir}.", file=sys.stderr) + print("Ensure benchmarking_sandbox_management.py (FastAPI version) is present in the 'sandbox' directory.", file=sys.stderr) + print(f"Details: {e}", file=sys.stderr) + sys.exit(1) +finally: + if 'benchmarking_sandbox_management' in sys.modules: + sys.path.pop(0) + + +# Optional: Use rich for better formatting +try: + from rich.console import Console + from rich.prompt import Prompt, Confirm + from rich.panel import Panel + from rich.syntax import Syntax + from rich.table import Table + from rich.markdown import Markdown # For potentially displaying markdown output + from rich.text import Text # For handling plain text better + HAS_RICH = True +except ImportError: + HAS_RICH = False + # Simple print/input fallback if rich is not installed + class Console: + def print(self, *args, **kwargs): print(*args) + class Prompt: + @staticmethod + def ask(prompt, choices=None, default=None): + p_text = f"{prompt} " + if choices: choices_str = '/'.join(choices); p_text += f"({choices_str}) " + if default: p_text += f"[{default}] " + return input(p_text).strip() + @staticmethod + def get_input(console, prompt, password=False): + return input(f"{prompt}: ") + class Confirm: + @staticmethod + def ask(prompt, default=False): + val = input(f"{prompt} [y/N] " if not default else f"{prompt} [Y/n] ").lower().strip() + if not val: return default + return val == 'y' + class Panel: + def __init__(self, content, title="", border_style=""): self.content=str(content); self.title=title # Ensure content is string + def __rich_console__(self, console, options): yield self.title; yield self.content + class Syntax: + def __init__(self, code, lexer, theme="", line_numbers=False): self.code = code; self.lexer = lexer + def __rich_console__(self, console, options): yield f"--- Code ({self.lexer}) ---\n{self.code}\n--- End Code ---" + class Table: + def __init__(self, title=""): self._title=title; self._rows=[]; self._columns=[] + def add_column(self, header, style="", justify="left", no_wrap=False): self._columns.append(header) + def add_row(self, *items): + if len(items) != len(self._columns): raise ValueError("Row items != columns") + self._rows.append(items) + def __rich_console__(self, console, options): + yield self._title; + if self._columns: + yield "\t".join(self._columns) + for row in self._rows: yield "\t".join(map(str, row)) + def print_table(self, console): + console.print(self._title) + if self._columns: + col_widths = [len(h) for h in self._columns] + for row in self._rows: + for i, item in enumerate(row): col_widths[i] = max(col_widths[i], len(str(item))) + header = " ".join(f"{h:<{w}}" for h, w in zip(self._columns, col_widths)) + separator = "-" * len(header) + console.print(header); console.print(separator) + for row in self._rows: + console.print(" ".join(f"{str(item):<{w}}" for item, w in zip(row, col_widths))) + # Dummy classes for rich elements not used directly but potentially in display logic + class Markdown: + def __init__(self, content): self.content = content + def __rich_console__(self, console, options): yield f"--- Markdown ---\n{self.content}\n--- End Markdown ---" + class Text: + def __init__(self, content): self.content = content + def __rich_console__(self, console, options): yield self.content + + +# --- Constants --- +SCRIPT_DIR = Path(__file__).parent.resolve() +DATASETS_DIR = SCRIPT_DIR / "datasets" +OUTPUTS_DIR = SCRIPT_DIR / "outputs" # Define output directory +ENV_FILE = SCRIPT_DIR / ".env" +SANDBOX_DATA_PATH = "/home/sandboxuser/data.h5ad" # Where data will be copied inside container +# URL for the FastAPI service running in the container (mapped to host) +API_BASE_URL = f"http://localhost:{API_PORT_HOST}" +EXECUTE_ENDPOINT = f"{API_BASE_URL}/execute" +STATUS_ENDPOINT = f"{API_BASE_URL}/status" + +# --- Configuration Loading --- +console = Console() +load_dotenv(dotenv_path=ENV_FILE) +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") + +if not OPENAI_API_KEY: + console.print(f"[bold red]Error:[/bold red] OPENAI_API_KEY not found in {ENV_FILE}.") + console.print("Please run the 'create_benchmarking_env.sh' script first.") + sys.exit(1) + +try: + openai_client = OpenAI(api_key=OPENAI_API_KEY) +except Exception as e: + console.print(f"[bold red]Error initializing OpenAI client:[/bold red] {e}") + sys.exit(1) + + +# --- Helper Functions --- +def extract_python_code(text): + """Extracts the first Python code block from text.""" + match = re.search(r"```python\s*([\s\S]+?)\s*```", text, re.MULTILINE) + if match: + return match.group(1).strip() + return None + +def display_message(role, content): + """Displays messages with nice formatting.""" + # Simplified display, as results are now processed separately + if role == "system": + console.print(Panel(content, title="SYSTEM PROMPT", border_style="dim blue")) + elif role == "user": + # Check if it's the special code execution result message + if content.startswith("Code execution result:\n"): + # This will now contain formatted output from the API call + console.print(Panel(content, title="CODE EXECUTION RESULT (Sent as User)", border_style="yellow")) + else: + console.print(Panel(content, title="USER (Input Prompt)", border_style="blue")) + elif role == "assistant": + code = extract_python_code(content) + if code: + text_part = re.sub(r"```python\s*([\s\S]+?)\s*```", "", content, count=1).strip() + if text_part: + console.print(Panel(text_part, title="ASSISTANT (Text)", border_style="green")) + if HAS_RICH: + console.print(Panel(Syntax(code, "python", theme="default", line_numbers=True), title="ASSISTANT (Code)", border_style="green")) + else: + console.print(f"--- ASSISTANT (Code) ---\n{code}\n--- End Code ---") + else: + console.print(Panel(content, title="ASSISTANT (Text Only)", border_style="green")) + else: + console.print(f"[bold]{role.upper()}:[/bold]\n{content}") + console.print("-" * 20) # Separator + +def format_api_response_for_llm(response_data): + """Formats the JSON response from the /execute endpoint into a string for the LLM.""" + output_lines = ["Code execution result:"] + final_status = response_data.get("final_status", "unknown") + outputs = response_data.get("outputs", []) + + stdout_lines = [] + stderr_lines = [] + error_info = None + display_items = [] # Store items for potential later display/saving + max_len = 1000 # Max length for stdout/stderr truncation + + for item in outputs: + output_type = item.get("type") + if output_type == "stream": + if item.get("name") == "stdout": + stdout_lines.append(item.get("text", "")) + elif item.get("name") == "stderr": + stderr_lines.append(item.get("text", "")) + elif output_type == "error": + error_info = item # Store the whole error dict + # Add error info to stderr for LLM visibility + stderr_lines.append(f"Error: {item.get('ename', 'UnknownError')}: {item.get('evalue', '')}\n") + stderr_lines.extend(line + '\n' for line in item.get('traceback', [])) + elif output_type == "display_data": + # Indicate that display data was generated + mime_types = list(item.get("data", {}).keys()) + display_items.append(item) # Store for later processing if needed + output_lines.append(f"[Display data generated: {', '.join(mime_types)}]") + # Optionally include plain text representation if available + if 'text/plain' in item.get('data', {}): + stdout_lines.append(item['data']['text/plain'] + '\n') + elif output_type == "execute_result": + # Append plain text representation to stdout + if 'text/plain' in item.get('data', {}): + stdout_lines.append(item['data']['text/plain'] + '\n') + + # Combine stdout + if stdout_lines: + output_lines.append("--- STDOUT ---") + full_stdout = "".join(stdout_lines) + if len(full_stdout) > max_len: + output_lines.append(full_stdout[:max_len] + "\n... (stdout truncated)") + else: + output_lines.append(full_stdout) + output_lines.append("--------------") + else: + output_lines.append("[No standard output]") + + # Combine stderr + if stderr_lines: + output_lines.append("--- STDERR ---") + full_stderr = "".join(stderr_lines) + # --- ADDED STDERR TRUNCATION --- + if len(full_stderr) > max_len: + output_lines.append(full_stderr[:max_len] + "\n... (stderr truncated)") + else: + output_lines.append(full_stderr) + # --- END STDERR TRUNCATION --- + output_lines.append("--------------") + # No need for an else block if stderr is empty, unlike stdout + + output_lines.append(f"Final Status: {final_status}") + + # Optionally save images/plots from display_items here + # Create outputs directory if it doesn't exist + OUTPUTS_DIR.mkdir(parents=True, exist_ok=True) + for i, item in enumerate(display_items): + if item['type'] == 'display_data': + for mime, b64_data in item.get('data', {}).items(): + if mime.startswith('image/'): + try: + image_data = base64.b64decode(b64_data) + ext = mime.split('/')[-1].split('+')[0] # Handle things like image/svg+xml + # Create a more descriptive filename + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = OUTPUTS_DIR / f"output_image_{timestamp}_{i}.{ext}" + with open(filename, "wb") as f: + f.write(image_data) + output_lines.append(f"[Saved image data ({mime}) to {filename}]") + console.print(f"[bold yellow]Saved image data ({mime}) to {filename}[/bold yellow]") + except Exception as e: + output_lines.append(f"[Error processing/saving display data {mime}: {e}]") + console.print(f"[red]Error processing/saving display data {mime}: {e}[/red]") + + return "\n".join(output_lines) + + +# --- Core Logic Functions --- + +def get_agent_prompts(): + """Gets agent prompt(s) based on user input method.""" + # (No changes needed here) + prompts = {} + while True: + console.print("\n[bold cyan]Select Agent Prompt Input Method:[/bold cyan]") + console.print(" 1. Paste prompt directly into the terminal.") + console.print(" 2. Provide path to a single .txt file.") + console.print(" 3. Provide path to a folder containing .txt prompt files.") + choice = Prompt.ask("Enter choice (1/2/3)", choices=["1", "2", "3"], default="1") + if choice == "1": + console.print("Paste your prompt below. Press Ctrl+D (Unix) or Ctrl+Z+Enter (Windows) when done:") + try: + prompt_text = sys.stdin.read().strip() + if prompt_text: prompts["pasted_prompt"] = prompt_text; return prompts + else: console.print("[yellow]No prompt pasted. Please try again.[/yellow]") + except EOFError: console.print("\n[yellow]No prompt pasted. Please try again.[/yellow]") + elif choice == "2": + file_path_str = Prompt.ask("Enter the path to the .txt prompt file") + file_path = Path(file_path_str).resolve() + if file_path.is_file() and file_path.suffix.lower() == ".txt": + try: + prompt_text = file_path.read_text(encoding='utf-8').strip() + if prompt_text: prompts[file_path.stem] = prompt_text; return prompts + else: console.print(f"[yellow]File '{file_path}' is empty.[/yellow]") + except Exception as e: console.print(f"[red]Error reading file '{file_path}': {e}[/red]") + else: console.print(f"[red]Invalid path or not a .txt file: '{file_path_str}'[/red]") + elif choice == "3": + folder_path_str = Prompt.ask("Enter the path to the folder containing .txt prompt files") + folder_path = Path(folder_path_str).resolve() + if folder_path.is_dir(): + txt_files = list(folder_path.glob("*.txt")) + if not txt_files: console.print(f"[yellow]No .txt files found in folder '{folder_path_str}'.[/yellow]"); continue + for file_path in txt_files: + try: + prompt_text = file_path.read_text(encoding='utf-8').strip() + if prompt_text: prompts[file_path.stem] = prompt_text + else: console.print(f"[yellow]Skipping empty file: '{file_path.name}'[/yellow]") + except Exception as e: console.print(f"[red]Error reading file '{file_path.name}': {e}[/red]") + if prompts: console.print(f"Found {len(prompts)} non-empty prompt files."); return prompts + else: console.print("[yellow]No valid, non-empty prompts found in the folder.[/yellow]") + else: console.print(f"[red]Invalid path or not a directory: '{folder_path_str}'[/red]") + + +def select_dataset(): + """Scans datasets directory and prompts user for selection.""" + # (No changes needed here) + if not DATASETS_DIR.is_dir(): + console.print(f"[bold red]Error:[/bold red] Datasets directory not found at '{DATASETS_DIR}'") + console.print("Please ensure datasets are downloaded using 'czi_browser.py download ...'") + return None, None + datasets = [] + for h5ad_path in DATASETS_DIR.glob("*.h5ad"): + json_path = h5ad_path.with_suffix(".json") + if json_path.is_file(): + try: + with open(json_path, 'r', encoding='utf-8') as f: metadata = json.load(f) + datasets.append({ "h5ad_path": h5ad_path, "json_path": json_path, "metadata": metadata, "display_name": metadata.get("dataset_title", h5ad_path.stem)}) + except Exception as e: console.print(f"[yellow]Warning: Could not load metadata for '{h5ad_path.name}': {e}[/yellow]") + else: console.print(f"[yellow]Warning: Missing metadata file for '{h5ad_path.name}'. Skipping.[/yellow]") + if not datasets: + console.print(f"[bold red]Error:[/bold red] No valid datasets found in '{DATASETS_DIR}'") + return None, None + console.print("\n[bold cyan]Available Datasets:[/bold cyan]") + table = Table(title="Select a Dataset") + table.add_column("Index", style="dim", justify="right") + table.add_column("Dataset Title / Filename", style="green") + table.add_column("Cell Count", style="magenta", justify="right") + table.add_column("Organism", style="blue") + for i, ds in enumerate(datasets): + meta = ds["metadata"]; cell_count = meta.get('cell_count', 'N/A'); organism = ", ".join(meta.get('organism', [])) if isinstance(meta.get('organism'), list) else meta.get('organism', 'N/A') + try: cell_count_str = f"{int(cell_count):,}" if cell_count != 'N/A' else 'N/A' + except (ValueError, TypeError): cell_count_str = str(cell_count) + table.add_row(str(i + 1), ds["display_name"], cell_count_str, organism) + if HAS_RICH: console.print(table) + else: table.print_table(console) + while True: + choice_str = Prompt.ask(f"Enter the index of the dataset to use (1-{len(datasets)})") + try: + choice_idx = int(choice_str) - 1 + if 0 <= choice_idx < len(datasets): + selected_ds = datasets[choice_idx] + console.print(f"Selected dataset: [green]{selected_ds['display_name']}[/green]") + return selected_ds["h5ad_path"], selected_ds["metadata"] + else: console.print(f"[red]Invalid index. Please enter a number between 1 and {len(datasets)}.[/red]") + except ValueError: console.print("[red]Invalid input. Please enter a number.[/red]") + +def get_code_tries(): + """Prompts user for the number of code execution attempts.""" + # (No changes needed here) + while True: + tries_str = Prompt.ask("Enter the maximum number of code execution attempts for the agent", default="5") + try: + tries = int(tries_str) + if tries > 0: return tries + else: console.print("[red]Number of tries must be positive.[/red]") + except ValueError: console.print("[red]Invalid input. Please enter an integer.[/red]") + +def check_api_status(max_retries=5, delay=2): + """Checks if the FastAPI service is responsive.""" + console.print(f"Checking API status at {STATUS_ENDPOINT}...") + for attempt in range(max_retries): + try: + response = requests.get(STATUS_ENDPOINT, timeout=5) # Short timeout for status check + response.raise_for_status() # Raise exception for bad status codes (4xx or 5xx) + data = response.json() + if data.get("status") == "ok": + console.print("[green]API service is responsive.[/green]") + return True + else: + console.print(f"[yellow]API status endpoint returned unexpected data: {data}[/yellow]") + except requests.exceptions.ConnectionError: + console.print(f"[yellow]API connection failed (attempt {attempt+1}/{max_retries}). Retrying in {delay}s...[/yellow]") + except requests.exceptions.Timeout: + console.print(f"[yellow]API status check timed out (attempt {attempt+1}/{max_retries}). Retrying in {delay}s...[/yellow]") + except requests.exceptions.RequestException as e: + console.print(f"[red]API status check error (attempt {attempt+1}/{max_retries}): {e}[/red]") + # Don't retry immediately on other request errors + break + time.sleep(delay) + console.print("[bold red]API service did not become responsive.[/bold red]") + return False + + +def run_agent_test(agent_prompt_id, agent_prompt, dataset_h5ad_path, dataset_metadata, max_code_tries): + """Runs a single agent test loop using the FastAPI kernel service.""" + console.print(f"\n[bold cyan]----- Starting Test: '{agent_prompt_id}' ----- [/bold cyan]") + console.print(f"Dataset: [green]{dataset_metadata.get('dataset_title', dataset_h5ad_path.stem)}[/green]") + console.print(f"Max Code Tries: [yellow]{max_code_tries}[/yellow]") + + sandbox_manager = None + conversation_history = [] + code_tries_left = max_code_tries + # Add metadata to the conversation start for saving context + initial_context = { + "prompt_id": agent_prompt_id, + "dataset_file": str(dataset_h5ad_path.name), + "dataset_metadata": dataset_metadata, + "max_code_tries": max_code_tries, + "start_time": datetime.now().isoformat() + } + # Store the raw API responses alongside the conversation turns + full_conversation_data = {"context": initial_context, "turns": []} + + + try: + # 1. Initialize Manager and Start Sandbox Container with API service + console.print("\nInitializing Sandbox Manager...") + sandbox_manager = SandboxManager() # Manager now just handles container lifecycle + console.print("Starting sandbox container with API service...") + if not sandbox_manager.start_container(): # start_container now waits briefly + console.print("[bold red]Failed to start sandbox container. Aborting test.[/bold red]") + return None # Return None if setup fails + + # 1b. Check if API is responsive + if not check_api_status(): + console.print("[bold red]API service failed to start or respond. Aborting test.[/bold red]") + # Attempt cleanup + sandbox_manager.stop_container(remove=True) + return None # Return None if setup fails + + # 2. Copy Dataset to Sandbox (Still needed) + console.print(f"Copying dataset '{dataset_h5ad_path.name}' to sandbox ({SANDBOX_DATA_PATH})...") + # Ensure container name constant is correct + copy_command = ['docker', 'cp', str(dataset_h5ad_path), f"{SANDBOX_CONTAINER_NAME}:{SANDBOX_DATA_PATH}"] + try: + # Use subprocess.run, check for errors + result = subprocess.run(copy_command, check=False, capture_output=True, text=True) + if result.returncode != 0: + console.print(f"[bold red]Error copying dataset to container:[/bold red]") + console.print(f"Command: {' '.join(copy_command)}") + console.print(f"Return Code: {result.returncode}") + console.print(f"Stderr: {result.stderr}") + console.print(f"Stdout: {result.stdout}") + # Decide if this is fatal + raise subprocess.CalledProcessError(result.returncode, copy_command, output=result.stdout, stderr=result.stderr) + else: + console.print("[green]Dataset copied successfully.[/green]") + except subprocess.CalledProcessError as e: + console.print(f"[bold red]Dataset copy failed. Aborting test.[/bold red]") + raise # Re-raise to be caught by outer try/except for cleanup + + # 3. Prepare Initial Agent Message + system_message_content = f"""You are an AI assistant tasked with analyzing a single-cell transcriptomics dataset. +Your goal is to characterize this dataset based on its metadata and by generating Python code to be executed. +The dataset file is located inside the execution environment at: {SANDBOX_DATA_PATH} +Standard libraries like pandas, numpy, scipy, scikit-learn, and anndata should be available. +Variables and imports **persist** between your code executions within this session. + +Dataset Metadata: +{json.dumps(dataset_metadata, indent=2)} + +You have a maximum of {max_code_tries} attempts to generate Python code blocks for execution. +When you want to execute code, enclose it **only** in a single triple-backtick block with the language specified as python, like this: +```python +# Your analysis code here. Imports and variables persist. +# Example: Load data in the first turn: +import anndata as ad +adata = ad.read_h5ad('{SANDBOX_DATA_PATH}') +print(adata.shape) + +# Example: Use adata in a later turn: +print(adata.obs['cell_type'].value_counts()) + +# Example: Generate a plot (it will be captured if possible) +# import matplotlib.pyplot as plt +# plt.figure() +# plt.scatter(adata.obsm['X_umap'][:,0], adata.obsm['X_umap'][:,1]) +# plt.title('UMAP Plot') +# plt.show() # Or savefig +``` +I will execute the code you provide and return the results (stdout, stderr, errors, and potentially image data). Use the results to inform your next step. +Focus on providing meaningful characterizations and insights based on the data and metadata. Plan your {max_code_tries} code executions wisely. Start by loading the data. + +While you can generate plots, please prioritize investigating via text as you do not have the ability to understand images. +""" + user_message_content = agent_prompt + + # Store initial messages for saving later + full_conversation_data["turns"].append({"role": "system", "content": system_message_content}) + full_conversation_data["turns"].append({"role": "user", "content": user_message_content}) + + # Prepare history for the API call (needs to be in the format OpenAI expects) + conversation_history = [ + {"role": "system", "content": system_message_content}, + {"role": "user", "content": user_message_content} + ] + display_message("system", system_message_content) + display_message("user", user_message_content) + + # 4. Agent Interaction Loop + while code_tries_left > 0: + console.print(f"\n[bold]Sending request to OpenAI... (Code tries left: {code_tries_left})[/bold]") + api_call_successful = False + response_data = None # To store API response for saving + try: + response = openai_client.chat.completions.create( + model="gpt-4o", # Or your preferred model + messages=conversation_history, + temperature=0.7, + ) + assistant_message = response.choices[0].message + assistant_content = assistant_message.content + api_call_successful = True # Mark OpenAI call as successful + + # Add assistant message to both histories + conversation_history.append({"role": "assistant", "content": assistant_content}) + full_conversation_data["turns"].append({"role": "assistant", "content": assistant_content}) + display_message("assistant", assistant_content) + + # 5. Check for and Execute Code via API + agent_code = extract_python_code(assistant_content) + if agent_code: + console.print(f"\n[bold cyan]Executing Code via API (Attempt {max_code_tries - code_tries_left + 1}/{max_code_tries})...[/bold cyan]") + code_tries_left -= 1 + user_feedback_content = "[Code execution failed or API unreachable]" # Default feedback + execution_api_response = None # Store raw API response + + try: + payload = {"code": agent_code, "timeout": 120} + headers = {"Content-Type": "application/json"} + api_response = requests.post(EXECUTE_ENDPOINT, json=payload, headers=headers, timeout=130) + api_response.raise_for_status() + execution_api_response = api_response.json() # Store successful response + user_feedback_content = format_api_response_for_llm(execution_api_response) + + except requests.exceptions.RequestException as e: + console.print(f"[bold red]API Request Error during execution: {e}[/bold red]") + error_detail = str(e) + if e.response is not None: + console.print(f"Response Status: {e.response.status_code}") + error_detail = e.response.text + try: # Try to get detail from JSON + detail_json = e.response.json().get("detail", error_detail) + error_detail = f"API Error {e.response.status_code}: {detail_json}" + except json.JSONDecodeError: + error_detail = f"API Error {e.response.status_code}: {e.response.text}" + user_feedback_content = f"Code execution result:\n[{error_detail}]" + # Store error info instead of successful response + execution_api_response = {"error": error_detail, "status_code": e.response.status_code if e.response else None} + # break # Decide if API errors should stop the loop + + # Append execution result back to conversation history for LLM + conversation_history.append({"role": "user", "content": user_feedback_content}) + # Store user feedback and API response in the full data log + full_conversation_data["turns"].append({ + "role": "user", + "content": user_feedback_content, + "api_response": execution_api_response # Add raw API response here + }) + display_message("user", user_feedback_content) # Display formatted results + + if code_tries_left == 0: + console.print("[bold yellow]Maximum code execution attempts reached.[/bold yellow]") + break + + else: # No code found in assistant response + console.print("[yellow]No code block found in assistant's response this turn.[/yellow]") + # Add a placeholder turn to keep track + full_conversation_data["turns"].append({"role": "user", "content": "[No code executed this turn]"}) + + + except APIError as e: + console.print(f"[bold red]OpenAI API Error:[/bold red] {e}") + if hasattr(e, 'body') and e.body: console.print(f"Error Body: {e.body}") + # Store error in results + full_conversation_data["error"] = f"OpenAI API Error: {e}" + break # Stop test on OpenAI error + except Exception as e: + console.print(f"[bold red]Error during agent interaction: {e}[/bold red]") + import traceback + traceback.print_exc() # Print traceback for unexpected errors + full_conversation_data["error"] = f"Agent Interaction Error: {e}\n{traceback.format_exc()}" + break # Stop test on other errors + + console.print(f"\n[bold cyan]----- Test Finished: '{agent_prompt_id}' ----- [/bold cyan]") + # Return the detailed conversation data including context and API responses + return full_conversation_data + + except Exception as e: + console.print(f"[bold red]An error occurred during test setup or execution for '{agent_prompt_id}':[/bold red] {e}") + import traceback + traceback.print_exc() + # Return error information if setup failed + return {"context": initial_context, "error": f"Setup/Execution Error: {e}\n{traceback.format_exc()}"} + finally: + # 6. Stop and Cleanup Sandbox + if sandbox_manager: + console.print("\nStopping sandbox container...") + if not sandbox_manager.stop_container(remove=True): + console.print("[yellow]Warning: Could not cleanly stop/remove sandbox container.[/yellow]") + +def main(): + parser = argparse.ArgumentParser(description="Run AI agent benchmarks against datasets in a sandbox (API Mode).") + parser.add_argument( + "--output-dir", type=str, default="outputs", + help="Directory to save results JSON file (default: outputs)" + ) + args = parser.parse_args() + + # Use Path object for output directory + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) # Create output dir if needed + + console.print("[bold blue]Welcome to the One-Shot Agent Tester (API Mode)![/bold blue]") + + agent_prompts = get_agent_prompts() + if not agent_prompts: console.print("[red]No agent prompts provided. Exiting.[/red]"); sys.exit(1) + + dataset_h5ad_path, dataset_metadata = select_dataset() + if not dataset_h5ad_path or not dataset_metadata: console.print("[red]No dataset selected or available. Exiting.[/red]"); sys.exit(1) + + max_code_tries = get_code_tries() + + # Dictionary to hold results for all prompts run in this session + all_results = {} + + for prompt_id, prompt_text in agent_prompts.items(): + # Run the test and get the detailed conversation data + test_result_data = run_agent_test( + prompt_id, + prompt_text, + dataset_h5ad_path, + dataset_metadata, + max_code_tries + ) + # Store the result under the prompt ID + all_results[prompt_id] = test_result_data + + if len(agent_prompts) > 1: + if not Confirm.ask(f"\nTest for '{prompt_id}' finished. Continue with the next agent prompt?", default=True): + console.print("[yellow]Aborting remaining tests.[/yellow]"); break + console.print("\n" + "="*40 + "\n"); time.sleep(1) # Separator and pause + + console.print("\n[bold blue]All specified agent tests have concluded.[/bold blue]") + + # --- Save all results to a single JSON file --- + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + # Include dataset name stem in filename for clarity + dataset_stem = dataset_h5ad_path.stem if dataset_h5ad_path else "unknown_dataset" + output_filename = output_dir / f"benchmark_results_{dataset_stem}_{timestamp}.json" + + console.print(f"Saving all results to [cyan]{output_filename}[/cyan]...") + try: + with open(output_filename, "w", encoding="utf-8") as f: + # Use default=str to handle potential non-serializable objects like Path + json.dump(all_results, f, indent=2, default=str) + console.print("[green]Results saved successfully.[/green]") + except TypeError as e: + console.print(f"[bold red]Error: Failed to serialize results to JSON:[/bold red] {e}") + console.print("Check if non-serializable objects (like Path) are in the results data.") + except Exception as e: + console.print(f"[bold red]Error saving results to {output_filename}:[/bold red] {e}") + +if __name__ == "__main__": + main() diff --git a/benchmarking/PromptEvolver.py b/benchmarking/PromptEvolver.py new file mode 100644 index 0000000..0328a96 --- /dev/null +++ b/benchmarking/PromptEvolver.py @@ -0,0 +1,691 @@ +import argparse +import os +import sys +import json +import re +import shlex +import time +from pathlib import Path +import subprocess # Still needed for docker cp (for dataset copy) +import base64 # For decoding image data from API +from datetime import datetime +import copy # For deep copying conversation history + +# --- Dependency Imports --- +try: + from dotenv import load_dotenv +except ImportError: + print("Error: python-dotenv library not found. Please install it: pip install python-dotenv", file=sys.stderr) + sys.exit(1) + +try: + from openai import OpenAI, APIError +except ImportError: + print("Error: openai library not found. Please install it: pip install openai", file=sys.stderr) + sys.exit(1) + +try: + import requests # For making HTTP requests to the FastAPI service +except ImportError: + print("Error: requests library not found. Please install it: pip install requests", file=sys.stderr) + sys.exit(1) + +# Assume sandbox manager is in a 'sandbox' subdirectory relative to this script +try: + sandbox_dir = os.path.join(os.path.dirname(__file__), 'sandbox') + sys.path.insert(0, sandbox_dir) + # Import manager and constants needed for running the sandbox + from benchmarking_sandbox_management import SandboxManager, CONTAINER_NAME as SANDBOX_CONTAINER_NAME, API_PORT_HOST +except ImportError as e: + print(f"Error: Could not import SandboxManager or constants from {sandbox_dir}.", file=sys.stderr) + print("Ensure benchmarking_sandbox_management.py (FastAPI version) is present in the 'sandbox' directory.", file=sys.stderr) + print(f"Details: {e}", file=sys.stderr) + sys.exit(1) +finally: + # Clean up sys.path modification + if 'benchmarking_sandbox_management' in sys.modules and sandbox_dir in sys.path: + # Check if the path is still the one we added before removing + if sys.path[0] == sandbox_dir: + sys.path.pop(0) + else: # If paths changed unexpectedly, search and remove + try: + sys.path.remove(sandbox_dir) + except ValueError: + pass # Path wasn't there + +# Optional: Use rich for better formatting +try: + from rich.console import Console + from rich.prompt import Prompt, Confirm + from rich.panel import Panel + from rich.syntax import Syntax + from rich.table import Table + HAS_RICH = True + console = Console() +except ImportError: + HAS_RICH = False + console = None + # Simple print/input fallback if rich is not installed + class Console: + def print(self, *args, **kwargs): print(*args) + class Prompt: + @staticmethod + def ask(prompt, default=None): + p_text = f"{prompt} " + if default: p_text += f"[{default}] " + return input(p_text).strip() + class Confirm: + @staticmethod + def ask(prompt, default=False): + val = input(f"{prompt} [y/N] " if not default else f"{prompt} [Y/n] ").lower().strip() + if not val: return default + return val == 'y' + class Panel: + def __init__(self, content, title="", border_style=""): self.content=str(content); self.title=title + def __rich_console__(self, console, options): yield self.title; yield self.content + class Syntax: + def __init__(self, code, lexer, theme="", line_numbers=False): self.code = code; self.lexer = lexer + def __rich_console__(self, console, options): yield f"--- Code ({self.lexer}) ---\n{self.code}\n--- End Code ---" + class Table: # Basic fallback Table class + def __init__(self, title=""): self._title=title; self._rows=[]; self._columns=[] + def add_column(self, header, style="", justify="left", no_wrap=False): self._columns.append(header) + def add_row(self, *items): + if len(items) != len(self._columns): raise ValueError("Row items != columns") + self._rows.append(items) + def __rich_console__(self, console, options): + yield self._title; + if self._columns: + yield "\t".join(self._columns) + for row in self._rows: yield "\t".join(map(str, row)) + def print_table(self, console): # Custom print method if rich not available + console.print(self._title) + if self._columns: + col_widths = [len(h) for h in self._columns] + for row in self._rows: + for i, item in enumerate(row): col_widths[i] = max(col_widths[i], len(str(item))) + header = " ".join(f"{h:<{w}}" for h, w in zip(self._columns, col_widths)) + separator = "-" * len(header) + console.print(header); console.print(separator) + for row in self._rows: + console.print(" ".join(f"{str(item):<{w}}" for item, w in zip(row, col_widths))) + + +# --- Constants --- +SCRIPT_DIR = Path(__file__).parent.resolve() +DATASETS_DIR = SCRIPT_DIR / "datasets" +OUTPUTS_DIR = SCRIPT_DIR / "outputs" # Default output directory for evolution logs +ENV_FILE = SCRIPT_DIR / ".env" +SANDBOX_DATA_PATH = "/home/sandboxuser/data.h5ad" # Path inside container +API_BASE_URL = f"http://localhost:{API_PORT_HOST}" +EXECUTE_ENDPOINT = f"{API_BASE_URL}/execute" +STATUS_ENDPOINT = f"{API_BASE_URL}/status" + +# --- Configuration Loading --- +load_dotenv(dotenv_path=ENV_FILE) +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") +# Define models for different roles +AGENT_MODEL = "gpt-4o" # Model for the agent being tested +EVALUATOR_MODEL = "gpt-4o" # Model for evaluating the agent's performance +EVOLVER_MODEL = "gpt-4o" # Model for evolving the prompt + +if not OPENAI_API_KEY: + if console: console.print(f"[bold red]Error:[/bold red] OPENAI_API_KEY not found in {ENV_FILE}.") + else: print(f"Error: OPENAI_API_KEY not found in {ENV_FILE}.") + sys.exit(1) + +try: + openai_client = OpenAI(api_key=OPENAI_API_KEY) + if console: console.print(f"OpenAI client initialized.") + else: print(f"OpenAI client initialized.") +except Exception as e: + if console: console.print(f"[bold red]Error initializing OpenAI client:[/bold red] {e}") + else: print(f"Error initializing OpenAI client: {e}") + sys.exit(1) + +# --- Helper Functions (Adapted from Tester/Evaluator) --- + +def extract_python_code(text): + """Extracts the first Python code block from text.""" + if text is None: return None + match = re.search(r"```python\s*([\s\S]+?)\s*```", text, re.MULTILINE) + if match: + return match.group(1).strip() + return None + +def select_dataset(): + """Scans datasets directory and prompts user for selection.""" + # (Copied from OneShotAgentTester.py - requires DATASETS_DIR constant) + if not DATASETS_DIR.is_dir(): + console.print(f"[bold red]Error:[/bold red] Datasets directory not found at '{DATASETS_DIR}'") + return None, None + datasets = [] + for h5ad_path in DATASETS_DIR.glob("*.h5ad"): + json_path = h5ad_path.with_suffix(".json") + if json_path.is_file(): + try: + with open(json_path, 'r', encoding='utf-8') as f: metadata = json.load(f) + datasets.append({ "h5ad_path": h5ad_path, "json_path": json_path, "metadata": metadata, "display_name": metadata.get("dataset_title", h5ad_path.stem)}) + except Exception as e: console.print(f"[yellow]Warning: Could not load metadata for '{h5ad_path.name}': {e}[/yellow]") + else: console.print(f"[yellow]Warning: Missing metadata file for '{h5ad_path.name}'. Skipping.[/yellow]") + if not datasets: + console.print(f"[bold red]Error:[/bold red] No valid datasets found in '{DATASETS_DIR}'") + return None, None + console.print("\n[bold cyan]Available Datasets:[/bold cyan]") + table = Table(title="Select a Dataset") + table.add_column("Index", style="dim", justify="right") + table.add_column("Dataset Title / Filename", style="green") + table.add_column("Cell Count", style="magenta", justify="right") + table.add_column("Organism", style="blue") + for i, ds in enumerate(datasets): + meta = ds["metadata"]; cell_count = meta.get('cell_count', 'N/A'); organism = ", ".join(meta.get('organism', [])) if isinstance(meta.get('organism'), list) else meta.get('organism', 'N/A') + try: cell_count_str = f"{int(cell_count):,}" if cell_count != 'N/A' else 'N/A' + except (ValueError, TypeError): cell_count_str = str(cell_count) + table.add_row(str(i + 1), ds["display_name"], cell_count_str, organism) + if HAS_RICH: console.print(table) + else: table.print_table(console) + while True: + choice_str = Prompt.ask(f"Enter the index of the dataset to use (1-{len(datasets)})") + try: + choice_idx = int(choice_str) - 1 + if 0 <= choice_idx < len(datasets): + selected_ds = datasets[choice_idx] + console.print(f"Selected dataset: [green]{selected_ds['display_name']}[/green]") + return selected_ds["h5ad_path"], selected_ds["metadata"] + else: console.print(f"[red]Invalid index. Please enter a number between 1 and {len(datasets)}.[/red]") + except ValueError: console.print("[red]Invalid input. Please enter a number.[/red]") + +def check_api_status(max_retries=5, delay=2): + """Checks if the FastAPI service is responsive.""" + # (Copied from OneShotAgentTester.py) + console.print(f"Checking API status at {STATUS_ENDPOINT}...") + for attempt in range(max_retries): + try: + response = requests.get(STATUS_ENDPOINT, timeout=5) + response.raise_for_status() + data = response.json() + if data.get("status") == "ok": + console.print("[green]API service is responsive.[/green]") + return True + else: + console.print(f"[yellow]API status endpoint returned unexpected data: {data}[/yellow]") + except requests.exceptions.ConnectionError: + console.print(f"[yellow]API connection failed (attempt {attempt+1}/{max_retries}). Retrying in {delay}s...[/yellow]") + except requests.exceptions.Timeout: + console.print(f"[yellow]API status check timed out (attempt {attempt+1}/{max_retries}). Retrying in {delay}s...[/yellow]") + except requests.exceptions.RequestException as e: + console.print(f"[red]API status check error (attempt {attempt+1}/{max_retries}): {e}[/red]") + break + time.sleep(delay) + console.print("[bold red]API service did not become responsive.[/bold red]") + return False + +def format_api_response_for_llm(response_data): + """Formats the JSON response from the /execute endpoint into a string for the LLM.""" + # (Copied from OneShotAgentTester.py - simplified image handling) + output_lines = ["Code execution result:"] + final_status = response_data.get("final_status", "unknown") + outputs = response_data.get("outputs", []) + stdout_lines = [] + stderr_lines = [] + max_len = 1500 # Slightly shorter truncation for evaluator context + + for item in outputs: + output_type = item.get("type") + if output_type == "stream": + if item.get("name") == "stdout": stdout_lines.append(item.get("text", "")) + elif item.get("name") == "stderr": stderr_lines.append(item.get("text", "")) + elif output_type == "error": + stderr_lines.append(f"Error: {item.get('ename', 'UnknownError')}: {item.get('evalue', '')}\n") + stderr_lines.extend(line + '\n' for line in item.get('traceback', [])) + elif output_type == "display_data": + mime_types = list(item.get("data", {}).keys()) + output_lines.append(f"[Display data generated: {', '.join(mime_types)}]") + if 'text/plain' in item.get('data', {}): stdout_lines.append(item['data']['text/plain'] + '\n') + elif output_type == "execute_result": + if 'text/plain' in item.get('data', {}): stdout_lines.append(item['data']['text/plain'] + '\n') + + if stdout_lines: + output_lines.append("--- STDOUT ---") + full_stdout = "".join(stdout_lines) + output_lines.append(full_stdout[:max_len] + ("\n... (stdout truncated)" if len(full_stdout) > max_len else "")) + output_lines.append("--------------") + else: output_lines.append("[No standard output]") + + if stderr_lines: + output_lines.append("--- STDERR ---") + full_stderr = "".join(stderr_lines) + output_lines.append(full_stderr[:max_len] + ("\n... (stderr truncated)" if len(full_stderr) > max_len else "")) + output_lines.append("--------------") + + output_lines.append(f"Final Status: {final_status}") + return "\n".join(output_lines) + +# --- ADDED FUNCTION DEFINITION --- +def format_conversation_for_eval(test_data): + """ Formats the conversation turns into a readable string for the evaluator prompt. """ + if not test_data or "turns" not in test_data: + return "[No conversation turns found]" + + formatted_lines = [] + for i, turn in enumerate(test_data.get("turns", [])): + role = turn.get("role", "unknown").upper() + content = turn.get("content", "[No content]") + + # Shorten system prompt for brevity in evaluation context + if role == "SYSTEM": + content = "[System Prompt Provided - see original log for details]" + + # Format code execution results more clearly + if role == "USER" and content.startswith("Code execution result:"): + # Check if this is the *actual* result turn by looking at previous turn + # This avoids mislabeling the initial user prompt if it somehow contained the phrase + if i > 0 and test_data["turns"][i-1].get("role") == "assistant": + content = content.replace("Code execution result:", "**CODE EXECUTION RESULT:**") + content = content.replace("--- STDOUT ---", "**STDOUT:**") + content = content.replace("--- STDERR ---", "**STDERR:**") + content = content.replace("--------------", "---") # Shorten separator + else: + # Treat as regular user prompt if it wasn't preceded by assistant turn + role = "USER PROMPT (Initial)" + + + # Add role separator, handling potential consecutive roles if needed + formatted_lines.append(f"--- {role} ---") + formatted_lines.append(content) + formatted_lines.append("\n") # Add space between turns + + return "\n".join(formatted_lines) +# --- END ADDED FUNCTION DEFINITION --- + +def run_single_test_iteration(agent_prompt_id, agent_prompt, dataset_h5ad_path, dataset_metadata, max_code_tries=5): + """ + Runs one iteration using the agent prompt, adapted from OneShotAgentTester. + Returns the detailed conversation data including context and API responses. + """ + console.print(f"\n[magenta]--- Running Test Iteration for Prompt ID: '{agent_prompt_id}' ---[/magenta]") + sandbox_manager = None + full_conversation_data = {} # Initialize + + # Create initial context for this specific run + initial_context = { + "prompt_id": agent_prompt_id, + "dataset_file": str(dataset_h5ad_path.name), + "dataset_metadata": dataset_metadata, + "max_code_tries": max_code_tries, + "start_time": datetime.now().isoformat() + } + full_conversation_data = {"context": initial_context, "turns": []} + + try: + sandbox_manager = SandboxManager() + if not sandbox_manager.start_container(): + raise RuntimeError("Failed to start sandbox container.") + if not check_api_status(): + raise RuntimeError("API service failed to start or respond.") + + # Copy dataset - Ensure SANDBOX_DATA_PATH is defined globally + copy_command = ['docker', 'cp', str(dataset_h5ad_path), f"{SANDBOX_CONTAINER_NAME}:{SANDBOX_DATA_PATH}"] + result = subprocess.run(copy_command, check=False, capture_output=True, text=True) + if result.returncode != 0: + console.print(f"[bold red]Error copying dataset to container (Code: {result.returncode}):[/bold red]\n{result.stderr}") + raise subprocess.CalledProcessError(result.returncode, copy_command, output=result.stdout, stderr=result.stderr) + console.print("[green]Dataset copied successfully.[/green]") + + # Prepare conversation + system_message_content = f"""You are an AI assistant tasked with analyzing a single-cell transcriptomics dataset. +The dataset file is located inside the execution environment at: {SANDBOX_DATA_PATH} +Variables and imports persist between your code executions. + +Dataset Metadata: +{json.dumps(dataset_metadata, indent=2)} + +Max code attempts: {max_code_tries}. Generate Python code in ```python ... ``` blocks. +Prioritize text analysis over plots. Start by loading the data. +""" + user_message_content = agent_prompt # The prompt being tested + + full_conversation_data["turns"].append({"role": "system", "content": system_message_content}) + full_conversation_data["turns"].append({"role": "user", "content": user_message_content}) + conversation_history = [ + {"role": "system", "content": system_message_content}, + {"role": "user", "content": user_message_content} + ] + console.print(Panel(system_message_content, title="SYSTEM PROMPT (Iteration)", border_style="dim blue")) + console.print(Panel(user_message_content, title="USER PROMPT (Iteration)", border_style="blue")) + + # Agent Interaction Loop + code_tries_left = max_code_tries + while code_tries_left > 0: + console.print(f"\nSending request to Agent ({AGENT_MODEL})... (Tries left: {code_tries_left})") + response = openai_client.chat.completions.create( + model=AGENT_MODEL, messages=conversation_history, temperature=0.7, + ) + assistant_content = response.choices[0].message.content + conversation_history.append({"role": "assistant", "content": assistant_content}) + full_conversation_data["turns"].append({"role": "assistant", "content": assistant_content}) + console.print(Panel(assistant_content, title="ASSISTANT RESPONSE", border_style="green")) + + agent_code = extract_python_code(assistant_content) + if agent_code: + console.print(f"Executing Code via API (Attempt {max_code_tries - code_tries_left + 1}/{max_code_tries})...") + code_tries_left -= 1 + user_feedback_content = "[Code execution failed]" + execution_api_response = None + try: + payload = {"code": agent_code, "timeout": 120} + headers = {"Content-Type": "application/json"} + api_response = requests.post(EXECUTE_ENDPOINT, json=payload, headers=headers, timeout=130) + api_response.raise_for_status() + execution_api_response = api_response.json() + user_feedback_content = format_api_response_for_llm(execution_api_response) + except requests.exceptions.RequestException as e: + console.print(f"[bold red]API Request Error during execution: {e}[/bold red]") + error_detail = str(e) + status_code = None + if e.response is not None: + status_code = e.response.status_code + console.print(f"Response Status: {status_code}") + error_detail = e.response.text + try: # Try to get detail from JSON + detail_json = e.response.json().get("detail", error_detail) + error_detail = f"API Error {status_code}: {detail_json}" + except json.JSONDecodeError: + error_detail = f"API Error {status_code}: {e.response.text}" + user_feedback_content = f"Code execution result:\n[{error_detail}]" + execution_api_response = {"error": error_detail, "status_code": status_code} + + conversation_history.append({"role": "user", "content": user_feedback_content}) + full_conversation_data["turns"].append({ + "role": "user", "content": user_feedback_content, "api_response": execution_api_response + }) + console.print(Panel(user_feedback_content, title="CODE EXECUTION RESULT", border_style="yellow")) + + if code_tries_left == 0: + console.print("[yellow]Maximum code execution attempts reached.[/yellow]") + break + else: + console.print("[yellow]No code block found. Ending interaction loop for this iteration.[/yellow]") + full_conversation_data["turns"].append({"role": "user", "content": "[No code executed, agent finished or failed to provide code]"}) + break # Assume conversation ends if agent doesn't provide code + + return full_conversation_data # Return the captured data + + except Exception as e: + console.print(f"[bold red]Error during test iteration for '{agent_prompt_id}': {e}[/bold red]") + import traceback + traceback.print_exc() + full_conversation_data["error"] = f"Test Iteration Error: {e}\n{traceback.format_exc()}" + return full_conversation_data # Return data with error + finally: + if sandbox_manager: + console.print("Stopping sandbox container for iteration...") + sandbox_manager.stop_container(remove=True) + +def call_openai_evaluator(conversation_text, context): + """ Sends the formatted conversation to OpenAI for evaluation. """ + # (Copied from evaluator.py - requires EVALUATOR_MODEL constant) + evaluator_prompt = f"""You are an expert evaluator assessing the performance of an AI assistant acting as a bioinformatician. +The assistant was given a task related to analyzing a single-cell transcriptomics dataset. +The expected performance level is that of an **entry-level post-graduate bioinformatician**. + +**Dataset Context:** +- Dataset File: {context.get('dataset_file', 'N/A')} +- Key Metadata: {json.dumps(context.get('dataset_metadata', {}), indent=1, default=str)} + +**Task Context:** +- Initial User Prompt: See the first USER prompt below. +- Max Code Attempts Allowed: {context.get('max_code_tries', 'N/A')} + +**Conversation Log:** +{conversation_text} + +**Evaluation Task:** +Based on the conversation log, evaluate the AI assistant's performance. Consider the following: +1. Correctness: Was the generated code correct and did it achieve the intended analysis steps? +2. Efficiency: Was the approach reasonable? Were there unnecessary steps? +3. Interpretation: Did the assistant correctly interpret the results of its code execution? +4. Planning: Did the assistant use its allowed code execution attempts effectively towards the goal? +5. Clarity: Was the assistant's text explanation clear and accurate? +6. Overall Skill: Does the performance align with an entry-level post-graduate bioinformatician? + +**Output Format:** +Please provide your evaluation strictly in the following JSON format ONLY. Do not include any other text before or after the JSON block: +{{ + "grade": , + "comments": "" +}} +""" + console.print(f"Sending evaluation request ({EVALUATOR_MODEL})...") + try: + response = openai_client.chat.completions.create( + model=EVALUATOR_MODEL, + messages=[{"role": "user", "content": evaluator_prompt}], + temperature=0.3, + response_format={"type": "json_object"}, + max_tokens=1000 + ) + eval_content = response.choices[0].message.content + console.print("[green]Evaluation received.[/green]") + try: + eval_json = json.loads(eval_content) + if "grade" in eval_json and "comments" in eval_json and \ + isinstance(eval_json["grade"], int) and isinstance(eval_json["comments"], str): + return eval_json + else: raise ValueError("Invalid format in evaluation JSON.") + except (json.JSONDecodeError, ValueError) as e: + console.print(f"[bold red]Error parsing evaluation JSON: {e}[/bold red]") + console.print(f"Raw response:\n{eval_content}") + return {"grade": -1, "comments": f"Error parsing evaluation: {e}\nRaw: {eval_content}"} + except Exception as e: + console.print(f"[bold red]Error calling evaluation API: {e}[/bold red]") + return {"grade": -1, "comments": f"API Error: {e}"} + +def call_openai_evolver(objective, previous_prompt, conversation_text, evaluation): + """ Calls OpenAI to generate an improved prompt. """ + evolver_prompt = f"""You are an AI Prompt Engineer specializing in bioinformatics tasks. +Your goal is to refine a user prompt to improve the performance of another AI assistant on a specific objective, based on past performance. + +**Overall Objective:** +{objective} + +**Previous Prompt Attempt:** +``` +{previous_prompt} +``` + +**Resulting Conversation Log (summary):** +{conversation_text[:3000]}... (log truncated for brevity) + +**Evaluation of Previous Attempt:** +- Grade (0-100): {evaluation.get('grade', 'N/A')} +- Evaluator Comments: {evaluation.get('comments', 'N/A')} + +**Task:** +Based on the objective, the previous prompt, the conversation summary, and the evaluation feedback, generate a **new, improved prompt** for the AI assistant. +The new prompt should: +- Be clearer and more specific about the desired analysis steps and output. +- Address the weaknesses identified in the evaluator comments. +- Guide the assistant towards better correctness, efficiency, interpretation, and planning. +- Aim to help the assistant perform like an entry-level post-graduate bioinformatician. + +**Output Format:** +Please provide ONLY the refined prompt text itself. Do not include any explanations, greetings, or markdown formatting like backticks around the prompt. Just the raw text of the new prompt. +""" + console.print(f"Sending prompt evolution request ({EVOLVER_MODEL})...") + try: + response = openai_client.chat.completions.create( + model=EVOLVER_MODEL, + messages=[{"role": "user", "content": evolver_prompt}], + temperature=0.6, # Allow for some creativity in prompt generation + max_tokens=500 # Adjust based on expected prompt length + ) + new_prompt = response.choices[0].message.content.strip() + # Optional: Basic cleaning if the model adds quotes or markdown + new_prompt = re.sub(r"^```\s*|\s*```$", "", new_prompt).strip() + console.print("[green]Received evolved prompt.[/green]") + return new_prompt + except Exception as e: + console.print(f"[bold red]Error calling prompt evolver API: {e}[/bold red]") + return None # Return None on error, keep using previous prompt + + +# --- Main Evolution Loop --- +def main_evolution_loop(): + if console: console.print("\n--- Prompt Evolver ---") + else: print("\n--- Prompt Evolver ---") + + # 1. Get Inputs + objective = Prompt.ask("Enter the overall objective for the prompt") + while not objective: + objective = Prompt.ask("Objective cannot be empty. Please enter the objective") + + initial_prompt_path_str = Prompt.ask("Enter path to initial prompt .txt file (or paste directly if empty)") + initial_prompt = "" + if initial_prompt_path_str: + initial_prompt_path = Path(initial_prompt_path_str) + if initial_prompt_path.is_file(): + try: + initial_prompt = initial_prompt_path.read_text(encoding='utf-8').strip() + console.print(f"Loaded initial prompt from: [cyan]{initial_prompt_path}[/cyan]") + except Exception as e: + console.print(f"[red]Error reading prompt file '{initial_prompt_path}': {e}. Please paste prompt.[/red]") + initial_prompt = "" + else: + console.print(f"[yellow]Initial prompt file not found. Please paste prompt.[/yellow]") + + if not initial_prompt: + console.print("Paste your initial prompt below. Press Ctrl+D (Unix) or Ctrl+Z+Enter (Windows) when done:") + try: + initial_prompt = sys.stdin.read().strip() + if not initial_prompt: + console.print("[red]Error: Initial prompt cannot be empty.[/red]") + sys.exit(1) + except EOFError: + console.print("[red]\nError: No prompt pasted.[/red]") + sys.exit(1) + + dataset_h5ad_path, dataset_metadata = select_dataset() + if not dataset_h5ad_path: + console.print("[red]Dataset selection failed. Exiting.[/red]") + sys.exit(1) + + while True: + try: + iterations_str = Prompt.ask("Enter number of evolution iterations", default="3") + num_iterations = int(iterations_str) + if num_iterations > 0: break + else: console.print("[red]Number of iterations must be positive.[/red]") + except ValueError: console.print("[red]Invalid input. Please enter an integer.[/red]") + + default_output = str(OUTPUTS_DIR.resolve()) + output_dir_str = Prompt.ask("Enter output directory for evolution logs", default=default_output) + output_dir = Path(output_dir_str) + output_dir.mkdir(parents=True, exist_ok=True) + console.print(f"Evolution logs will be saved in: [cyan]{output_dir.resolve()}[/cyan]") + + # --- Evolution Process --- + current_prompt = initial_prompt + evolution_log = [] # List to store results of each iteration + + for i in range(num_iterations): + iteration_id = f"iteration_{i+1}" + console.print(f"\n[bold blue]===== Starting Evolution Iteration {i+1}/{num_iterations} =====[/bold blue]") + console.print(f"Current Prompt:\n---\n{current_prompt}\n---") + + # 2. Run Test with Current Prompt + test_data = run_single_test_iteration( + agent_prompt_id=iteration_id, + agent_prompt=current_prompt, + dataset_h5ad_path=dataset_h5ad_path, + dataset_metadata=dataset_metadata, + max_code_tries=5 # Or make this configurable + ) + + if test_data is None or test_data.get("error"): + console.print(f"[bold red]Test iteration {i+1} failed. Skipping evaluation and evolution.[/bold red]") + iteration_result = { + "iteration": i + 1, + "prompt": current_prompt, + "test_data": test_data, # Contains error info + "evaluation": None, + "evolved_prompt": None + } + evolution_log.append(iteration_result) + # Decide whether to stop or continue with the same prompt + if not Confirm.ask("Test failed. Continue evolution with the *same* prompt?", default=False): + break + else: + continue # Try the same prompt again next iteration + + # 3. Evaluate the Result + # Use the newly added function + conversation_text = format_conversation_for_eval(test_data) + evaluation = call_openai_evaluator(conversation_text, test_data.get("context", {})) + console.print(f"Evaluation Grade: {evaluation.get('grade', 'N/A')}") + console.print(f"Evaluation Comments:\n---\n{evaluation.get('comments', 'N/A')}\n---") + + # 4. Evolve the Prompt (unless it's the last iteration) + evolved_prompt = None + if i < num_iterations - 1: + evolved_prompt = call_openai_evolver( + objective=objective, + previous_prompt=current_prompt, + conversation_text=conversation_text, # Pass summary + evaluation=evaluation + ) + if evolved_prompt: + console.print(f"Evolved Prompt for Next Iteration:\n---\n{evolved_prompt}\n---") + else: + console.print("[yellow]Failed to generate evolved prompt. Reusing previous prompt for next iteration.[/yellow]") + else: + console.print("Last iteration reached. No further prompt evolution.") + + + # 5. Log Iteration Result + iteration_result = { + "iteration": i + 1, + "prompt": current_prompt, + "test_data": test_data, # Contains full conversation and context + "evaluation": evaluation, + "evolved_prompt": evolved_prompt # Will be None on last iteration or error + } + evolution_log.append(iteration_result) + + # Update prompt for next iteration if evolution was successful + if evolved_prompt: + current_prompt = evolved_prompt + else: + # If evolution failed, decide whether to continue with the same prompt + if i < num_iterations - 1 and not Confirm.ask("Prompt evolution failed. Continue with the *same* prompt?", default=True): + break + + + # --- Save Final Results --- + console.print("\n[bold blue]===== Evolution Complete =====[/bold blue]") + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + dataset_stem = dataset_h5ad_path.stem if dataset_h5ad_path else "unknown_dataset" + final_log_filename = output_dir / f"evolution_log_{dataset_stem}_{timestamp}.json" + final_prompt_filename = output_dir / f"final_prompt_{dataset_stem}_{timestamp}.txt" + + console.print(f"Saving evolution log to [cyan]{final_log_filename}[/cyan]...") + try: + with open(final_log_filename, "w", encoding="utf-8") as f: + json.dump(evolution_log, f, indent=2, default=str) + console.print("[green]Evolution log saved successfully.[/green]") + except Exception as e: + console.print(f"[bold red]Error saving evolution log: {e}[/bold red]") + + console.print(f"Saving final evolved prompt to [cyan]{final_prompt_filename}[/cyan]...") + try: + with open(final_prompt_filename, "w", encoding="utf-8") as f: + f.write(current_prompt) # Save the last used prompt + console.print("[green]Final prompt saved successfully.[/green]") + except Exception as e: + console.print(f"[bold red]Error saving final prompt: {e}[/bold red]") + + +# --- Main Execution --- +if __name__ == "__main__": + # Ensure output directory exists + OUTPUTS_DIR.mkdir(parents=True, exist_ok=True) + main_evolution_loop() diff --git a/benchmarking/README.md b/benchmarking/README.md new file mode 100644 index 0000000..47e38a2 --- /dev/null +++ b/benchmarking/README.md @@ -0,0 +1,166 @@ +# Benchmarking and Evolving Agent Prompts for Single-Cell Data Analysis + +**⚠️ Work in Progress:** This tooling is currently under development. Its primary goal is to facilitate rapid iteration, testing, evaluation, and evolution of LLM agent prompts for analyzing single-cell transcriptomics datasets using a secure code execution sandbox. + +## Overview + +This framework provides the necessary tools to: + +1. **Discover and Download Datasets:** Browse and fetch datasets (specifically from the CZI CELLxGENE Census) along with their metadata. +2. **Secure Code Execution:** Run Python code generated by an AI agent within an isolated Docker container (sandbox). The sandbox now runs a Jupyter kernel managed by a **FastAPI service** , providing a stable HTTP interface for code execution. +3. **Agent Interaction & Testing (`OneShotAgentTester.py`):** Orchestrate interactions between an AI agent (powered by OpenAI's API), a selected dataset, and the code execution sandbox (via the FastAPI service). Allows testing prompts with limited code execution attempts. +4. **Results Conversion (`output_to_notebook.py`):** Convert the detailed JSON logs from test runs into Jupyter Notebooks (`.ipynb`) for easier review and analysis reproduction. +5. **AI-Powered Evaluation (`evaluator.py`):** Use an LLM (like GPT-4o) to automatically evaluate the performance of the agent based on the conversation logs, assigning a grade and providing comments. +6. **Automated Prompt Evolution (`prompt_evolver.py`):** Iteratively refine an initial agent prompt based on an objective, test results, and AI evaluation feedback to automatically discover more effective prompts. + +## Components + +The framework consists of the following main components: + +* **`.env` / `make_benchmarking_env.sh`:** + * `make_benchmarking_env.sh`: An interactive script to securely prompt for and save your OpenAI API key. + * `.env`: The file (created by the script) storing the `OPENAI_KEY`. This file should be added to your `.gitignore`. +* **`tools/czi_browser.py`:** + * A CLI tool for listing CZI CELLxGENE Census versions and datasets. + * Allows downloading specific datasets (`.h5ad`) and metadata (`.json`) to the `datasets/` directory. +* **`sandbox/`:** Contains the code execution environment. + * `Dockerfile`: Defines the Docker image based on a Python base, adding necessary Python/system dependencies, Jupyter components, FastAPI, Uvicorn, and the application code. + * `requirements.txt`: Lists Python packages installed *inside* the sandbox container (e.g., `anndata`, `scanpy`, `matplotlib`). + * `kernel_api.py`: The FastAPI application running inside the container. It receives code execution requests via HTTP, interacts with a local Jupyter kernel using `jupyter_client`, captures results (stdout, stderr, errors, display data), and returns them as JSON. + * `start_kernel.py`: A simple script used internally by `start.sh` to launch the Jupyter kernel process with specific arguments (e.g., listening IP, ports). + * `start.sh`: The main startup script run by the container (managed by `tini`). It launches the Jupyter kernel in the background and then starts the Uvicorn server to run the `kernel_api.py` FastAPI app. + * `benchmarking_sandbox_management.py`: A Python script (with CLI and interactive modes) primarily used for building the sandbox image and manually starting/stopping the container (which runs the API service). Direct kernel interaction commands have been removed. +* **`datasets/`:** (Created by `czi_browser.py`) + * Stores downloaded `.h5ad` data files and `.json` metadata files. +* **`outputs/`:** (Created automatically) + * Default directory for storing JSON logs from `OneShotAgentTester.py` and `PromptEvolver.py`, evaluation results from `evaluator.py`, and potentially generated notebooks/images. +* **`OneShotAgentTester.py`:** + * Orchestrates a single test run for one or more prompts against a dataset. + * Starts the sandbox container (via `SandboxManager`). + * Copies the dataset into the running container. + * Checks if the internal API service is responsive. + * Manages the interaction loop with the OpenAI API (specified agent model). + * When the agent generates code, it sends the code to the sandbox's FastAPI `/execute` endpoint using the `requests` library. + * Formats the JSON response (stdout, stderr, errors, display data) from the API and feeds it back to the agent. + * Saves the full conversation log for the test run(s) to a JSON file in the `outputs/` directory. +* **`output_to_notebook.py`:** + * An interactive script that takes a results JSON file (from `OneShotAgentTester` or `PromptEvolver`) as input. + * Converts the conversation log, including code cells and their outputs (stdout, stderr, errors, display data), into a Jupyter Notebook (`.ipynb`) file. + * Saves the `.ipynb` file in the same directory as the input JSON. +* **`evaluator.py`:** + * An interactive script that processes results JSON files from a specified input directory (defaults to `outputs/`). + * For each test run in the JSON, it formats the conversation and sends it to an OpenAI model (specified evaluator model) with instructions to evaluate the agent's performance (0-100 grade and comments) based on defined criteria (e.g., correctness, efficiency, clarity). + * Saves the evaluations (grade and comments) to JSON files (either aggregated or individually) in a specified output location (defaults to the input directory). +* **`prompt_evolver.py`:** + * An orchestrator script for automatically refining prompts. + * Takes an initial prompt, an objective, a dataset, and the number of iterations. + * In each iteration: + * Runs the current prompt using the testing logic (`run_single_test_iteration`). + * Evaluates the result using the evaluation logic (`call_openai_evaluator`). + * Calls another OpenAI model (specified evolver model) to generate an improved prompt based on the objective, previous prompt, conversation summary, and evaluation feedback. + * Uses the evolved prompt for the next iteration. + * Saves a detailed log of the entire evolution process (prompts, test data, evaluations) and the final evolved prompt. +* **`requirements.txt`:** (Top-level) + * Lists Python packages required for the *host* scripts (`OneShotAgentTester.py`, `evaluator.py`, `prompt_evolver.py`, `czi_browser.py`, etc.). Key dependencies include `openai`, `python-dotenv`, `requests`, `docker`, `rich`, `nbformat`. + +## Setup + +1. **Prerequisites:** + * Python (3.10+ recommended) + * `pip` (Python package installer) + * Docker Desktop or Docker Engine (must be running) + * Git (for cloning the repository) +2. **Install Host Python Dependencies:** + * Create and activate a Python virtual environment (recommended): + ``` + python -m venv venv + source venv/bin/activate # Linux/macOS + # venv\Scripts\activate # Windows CMD + + ``` + * Install required packages for the host scripts: + ``` + pip install -r requirements.txt + + ``` +3. **Set OpenAI API Key:** + * Make the script executable: `chmod +x make_benchmarking_env.sh` + * Run the script and enter your key when prompted: `./make_benchmarking_env.sh` + * This creates the `.env` file. **Ensure `.env` is listed in your `.gitignore` file.** +4. **Prepare Sandbox Requirements:** + * Edit `sandbox/requirements.txt` to include all the additional Python packages needed *inside* the container for agent code execution (e.g., `pandas`, `numpy`, `scipy`, `scikit-learn`, `anndata`, `matplotlib`, `seaborn`). Ensure these are compatible with the base Python version in the `Dockerfile`. + +## Usage + +1. **Download a Dataset:** + * Use the `tools/czi_browser.py` script (run `python tools/czi_browser.py` for interactive mode) to find and download a dataset to the `datasets/` directory. +2. **Test a Prompt (`OneShotAgentTester.py`):** + * Run the script: `python OneShotAgentTester.py` + * Follow prompts to select the prompt source (paste, file, folder), dataset, and max code attempts. + * The script starts the sandbox, runs the test(s) by communicating with the internal API, and saves the results to a JSON file in `outputs/`. +3. **Convert Results to Notebook (`output_to_notebook.py`):** + * Run the script: `python output_to_notebook.py` + * Enter the path to a results JSON file (e.g., `outputs/benchmark_results_....json`). + * An `.ipynb` file will be generated in the same directory. +4. **Evaluate Results (`evaluator.py`):** + * Run the script: `python evaluator.py` + * Enter the path to the folder containing results JSON files (defaults to `outputs/`). + * Enter the desired output location for evaluation files. + * The script calls OpenAI to evaluate each test run and saves the grades/comments. +5. **Evolve a Prompt (`prompt_evolver.py`):** + * Run the script: `python prompt_evolver.py` + * Enter the overall objective for the prompt. + * Provide the initial prompt (paste or file path). + * Select the dataset. + * Enter the number of evolution iterations. + * Specify the output directory for logs. + * The script runs the test-evaluate-evolve loop and saves the full log and the final prompt. +6. **Manage Sandbox Manually (Optional):** + * Use `sandbox/benchmarking_sandbox_management.py` for basic container control: + * Build image: `python sandbox/benchmarking_sandbox_management.py build` + * Start container (API): `python sandbox/benchmarking_sandbox_management.py start` + * Check status: `python sandbox/benchmarking_sandbox_management.py status` + * View logs: `python sandbox/benchmarking_sandbox_management.py logs [N]` + * Stop container: `python sandbox/benchmarking_sandbox_management.py stop` + * Run interactively: `python sandbox/benchmarking_sandbox_management.py` + +## File Structure (Updated) + +``` +benchmarking/ +├── sandbox/ +│ ├── Dockerfile +│ ├── kernel_api.py # FastAPI application +│ ├── start_kernel.py # Script to launch kernel +│ ├── start.sh # Container startup script (kernel + API) +│ ├── requirements.txt # Requirements for INSIDE the container +│ └── benchmarking_sandbox_management.py # Simplified manager +│ +├── datasets/ # Created by czi_browser.py download +│ └── .h5ad +│ └── .json +│ └── ... +│ +├── outputs/ # Default location for results/logs/notebooks +│ └── benchmark_results_*.json +│ └── benchmark_results_*.ipynb +│ └── *_eval.json +│ └── evolution_log_*.json +│ └── final_prompt_*.txt +│ └── output_image_*.png +│ └── ... +│ +├── tools/ +│ └── czi_browser.py +│ +├── make_benchmarking_env.sh # Used to make the .env file +├── OneShotAgentTester.py # Runs agent tests via API +├── output_to_notebook.py # Converts results JSON to Notebook +├── evaluator.py # Evaluates test results using AI +├── prompt_evolver.py # Orchestrates prompt evolution loop +├── requirements.txt # Requirements for HOST scripts (this file) +└── README.md # This file +└── .env # Stores API key (add to .gitignore) +└── .gitignore # Should include .env, venv/, __pycache__, outputs/, datasets/ + +``` diff --git a/benchmarking/create_benchmark_env.sh b/benchmarking/create_benchmark_env.sh new file mode 100755 index 0000000..94bc64d --- /dev/null +++ b/benchmarking/create_benchmark_env.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# Get the directory where the script is located +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +# Define the path for the .env file in the script's directory +ENV_FILE_PATH="${SCRIPT_DIR}/.env" + +echo "This script will create a .env file to store your OpenAI API key." +echo "The file will be saved in the script's directory: ${SCRIPT_DIR}" +echo "" # Add a blank line for spacing + +# Prompt the user for their OpenAI API key +# -p: Display the prompt string +# -s: Silent mode (do not echo input characters) - recommended for keys/passwords +# -r: Raw mode (backslashes are not treated as escape characters) +read -p "Please enter your OpenAI API key: " -s -r OPENAI_API_KEY +echo "" # Add a newline after the hidden input + +# Check if the key was entered +if [ -z "$OPENAI_API_KEY" ]; then + echo "Error: No API key entered. Exiting." + exit 1 +fi + +# Write the key to the .env file in the format OPENAI_KEY:key_value +# Overwrites the file if it already exists +echo "OPENAI_API_KEY=${OPENAI_API_KEY}" > "${ENV_FILE_PATH}" + +# Check if the file was created successfully +if [ $? -eq 0 ]; then + echo "" # Add a blank line + echo "Successfully saved the OpenAI API key to ${ENV_FILE_PATH}" + # Optionally, set permissions to be readable only by the user + chmod 600 "${ENV_FILE_PATH}" + echo "Set permissions for ${ENV_FILE_PATH} to read-only for the current user (600)." +else + echo "Error: Failed to write to ${ENV_FILE_PATH}. Please check permissions." + exit 1 +fi + +exit 0 diff --git a/benchmarking/datasets/spatial_transcriptomics_in_mouse_puck_191109_14.h5ad b/benchmarking/datasets/spatial_transcriptomics_in_mouse_puck_191109_14.h5ad new file mode 100644 index 0000000..fbab1e0 Binary files /dev/null and b/benchmarking/datasets/spatial_transcriptomics_in_mouse_puck_191109_14.h5ad differ diff --git a/benchmarking/datasets/spatial_transcriptomics_in_mouse_puck_191109_14.json b/benchmarking/datasets/spatial_transcriptomics_in_mouse_puck_191109_14.json new file mode 100644 index 0000000..dc2807a --- /dev/null +++ b/benchmarking/datasets/spatial_transcriptomics_in_mouse_puck_191109_14.json @@ -0,0 +1,13 @@ +{ + "soma_joinid": 5, + "citation": "Publication: https://doi.org/10.1016/j.isci.2022.104097 Dataset Version: https://datasets.cellxgene.cziscience.com/d5e8239f-a6a3-475a-b60d-d6991ad9a108.h5ad curated and distributed by CZ CELLxGENE Discover in Collection: https://cellxgene.cziscience.com/collections/8e880741-bf9a-4c8e-9227-934204631d2a", + "collection_id": "8e880741-bf9a-4c8e-9227-934204631d2a", + "collection_name": "High Resolution Slide-seqV2 Spatial Transcriptomics Enables Discovery of Disease-Specific Cell Neighborhoods and Pathways", + "collection_doi": "10.1016/j.isci.2022.104097", + "collection_doi_label": "Marshall et al. (2022) iScience", + "dataset_id": "530c9bff-c552-45b9-be77-0de605e6858b", + "dataset_version_id": "d5e8239f-a6a3-475a-b60d-d6991ad9a108", + "dataset_title": "Spatial transcriptomics in mouse: Puck_191109_14", + "dataset_h5ad_path": "530c9bff-c552-45b9-be77-0de605e6858b.h5ad", + "dataset_total_cell_count": 12351 +} \ No newline at end of file diff --git a/benchmarking/requirements.txt b/benchmarking/requirements.txt new file mode 100644 index 0000000..d908ddd --- /dev/null +++ b/benchmarking/requirements.txt @@ -0,0 +1,9 @@ +cellxgene-census +tiledbsoma +rich +numpy +docker +dotenv +openai +jupyter_client +nbformat \ No newline at end of file diff --git a/benchmarking/sample_prompt_library/Basic_scRNA_Agent.txt b/benchmarking/sample_prompt_library/Basic_scRNA_Agent.txt new file mode 100644 index 0000000..4f22eb9 --- /dev/null +++ b/benchmarking/sample_prompt_library/Basic_scRNA_Agent.txt @@ -0,0 +1,165 @@ +You are a highly skilled bioinformatics agent specializing in single-cell RNA-seq data analysis using Python. Your goal is to provide accurate, efficient, and clear analysis while adapting to different datasets and scenarios. You have access to a python code interpreter, so every code block you generate will be executed, and you'll receive feedback on its execution. The code will be executed on a python jupyter kernel and the kernel will remain active after execution retaining all variables in memory. Use the following framework for structured analysis with detailed code, outputs, and guidance to the user. + +**Primary Analysis Flow**: +For analyzing single-cell RNA-seq data using the `Scanpy` package, follow this structured framework: + +### 1. **Data Loading & Package Setup** + a. Load the provided dataset from the working directory. + b. Recognize common formats (e.g., 10X `.h5` or `mtx` files). If multiple samples are present, load them as a batch. + c. Use the following libraries and settings: + ```python + import scanpy as sc + import os + import pandas as pd + import matplotlib.pyplot as plt + import seaborn as sns + import numpy as np + from scipy.stats import median_abs_deviation as mad + import celltypist + from celltypist import models + import anndata as ad + + # Set verbosity and figure parameters + sc.settings.verbosity = 0 + sc.settings.set_figure_params(dpi=50, facecolor="white", frameon=False) + ``` + +### 2. **Initial Data Inspection** + a. **Summarize the dataset**: Provide the number of cells and genes for each sample. + b. **Plot initial cell and gene counts** for user reference: + ```python + fig, ax = plt.subplots(figsize=(10, 6)) + n_cells = [adata.n_obs for adata in adatas] + n_genes = [adata.n_vars for adata in adatas] + ax.bar(range(len(adatas)), n_cells, label='Cells') + ax.bar(range(len(adatas)), n_genes, label='Genes', align='edge') + ax.set_title('Cell and Gene Counts Before QC') + plt.show() + ``` + +### 3. **Quality Control (QC) Metrics** + a. Calculate mitochondrial content per cell and flag potential low-quality cells. + ```python + def calculate_mito_percentage(adata): + mito_genes = adata.var_names.str.contains('^MT-') + adata.obs['percent_mito'] = np.sum(adata[:, mito_genes].X, axis=1) / np.sum(adata.X, axis=1) + return adata + adatas = [calculate_mito_percentage(x) for x in adatas] + ``` + b. Visualize the key QC metrics: counts, genes, mitochondrial content: + ```python + for adata in adatas: + sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt']) + ``` + +### 4. **Pre-QC Analysis** + a. Perform normalization, feature selection, clustering, and UMAP projection: + ```python + for adata in adatas: + sc.pp.normalize_total(adata) + sc.pp.log1p(adata) + sc.pp.highly_variable_genes(adata, n_top_genes=2000) + sc.tl.pca(adata) + sc.pp.neighbors(adata, n_pcs=20) + sc.tl.umap(adata) + sc.tl.leiden(adata, resolution=0.5) + sc.pl.umap(adata, color=['leiden']) + ``` + b. Plot differential expression for the top 3 genes per cluster: + ```python + sc.tl.rank_genes_groups(adata, 'leiden', method='wilcoxon') + sc.pl.rank_genes_groups_dotplot(adata, n_genes=3) + ``` + +### 5. **Post-QC Filtering** + a. Apply filtering based on cell quality and mitochondrial content: + ```python + def filter_cells(adata): + sc.pp.filter_cells(adata, min_genes=200) + sc.pp.filter_genes(adata, min_cells=3) + return adata + adatas = [filter_cells(adata) for adata in adatas] + ``` + +### 6. **Reanalysis Post-QC** + a. Re-perform PCA, clustering, and UMAP after filtering: + ```python + for adata in adatas: + sc.tl.pca(adata) + sc.pp.neighbors(adata, n_pcs=20) + sc.tl.umap(adata) + sc.pl.umap(adata, color=['leiden']) + ``` + +### 7. **Cell Type Annotation** + a. Download and apply `Celltypist` models for automatic cell-type annotation: + ```python + models.download_models() + predictions = celltypist.annotate(adata, model='Developing_Mouse_Brain.pkl', majority_voting=True) + adata.obs['celltypes'] = predictions.cell_types + sc.pl.umap(adata, color='celltypes') + ``` + +### 8. **Batch Effect Correction** (if applicable) + a. If multiple samples are present, merge datasets and perform batch correction: + ```python + adata = ad.concat(adatas, label='sample', keys=['sample1', 'sample2']) + sc.pp.combat(adata, key='sample') + sc.pp.neighbors(adata) + sc.tl.umap(adata) + sc.pl.umap(adata, color=['sample', 'celltypes']) + ``` + +### 9. **Final Output and Saving** + a. Save the final integrated dataset in `.h5ad` format: + ```python + adata.write('path/to/final_output.h5ad') + ``` + +**Execution Instructions**: +1. Before proceeding with any step, confirm execution and results with the user. +2. Adjust or modify steps based on the user's input. +3. Output visualizations for the user to inspect results at each step (e.g., UMAP plots, differential expression). +4. Ensure appropriate feedback and quality checks (e.g., warnings, large deviations in mitochondrial content). + +**Customization**: +1. If the user provides specific thresholds or metrics for QC, adjust your methods accordingly. +2. Ensure adaptability to multiple formats (e.g., `.h5`, `.mtx`) and large datasets. +3. If batch correction is requested, use advanced methods (e.g., Harmony, scDREAMER) based on the scenario. + +The following dependencies are already installed and available in the Jupyter kernel: + +ansi2html==1.8.0 +scanpy==1.10.2 +scrublet +anndata==0.10.8 +celltypist==1.6.3 +leidenalg==0.10.2 +igraph==0.11.6 +networkx==3.2.1 +pynndescent==0.5.13 +numpy==1.26.4 +scipy==1.13.1 +pandas==2.2.2 +scikit-learn==1.5.1 +umap-learn==0.5.6 +statsmodels==0.14.2 +numba==0.60.0 +matplotlib==3.9.1 +seaborn==0.13.2 +h5py==3.11.0 +openpyxl==3.1.5 +PyPDF2 +tqdm==4.66.4 +psutil==6.0.0 +defusedxml==0.7.1 +requests==2.32.3 + +Whenever you need to run code on the terminal using a package that is not already install, first provide a corresponding Bash code block labeled ```bash``` with the installation commands for all dependencies utilized, if they are not already installed in the environment. Do this for each code snippet you generate, like so: +```bash +pip install +``` + +You can proceed with executing code that utilizes any of these packages without needing to install them. Don't install any additional packages + +Your objective is to guide the user through single-cell RNA-seq analysis, ensuring accuracy, reproducibility, and meaningful insights from the data. \ No newline at end of file diff --git a/benchmarking/sandbox/Dockerfile b/benchmarking/sandbox/Dockerfile new file mode 100644 index 0000000..c10e567 --- /dev/null +++ b/benchmarking/sandbox/Dockerfile @@ -0,0 +1,104 @@ +# Use official Python slim image based on Debian (adjust version if needed) +FROM python:3.11-slim + +# Set DEBIAN_FRONTEND to noninteractive to prevent interactive prompts +ENV DEBIAN_FRONTEND=noninteractive + +# --- Install System Dependencies --- +# Combine apt-get operations into a single layer to leverage caching. +# This layer rarely changes unless system dependencies are added/removed. +# Install tini, tzdata, build tools, C libraries, and utilities. +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + tini \ + tzdata \ + build-essential \ + pkg-config \ + libhdf5-dev \ + libsodium-dev \ + libzmq3-dev \ + gcc \ + g++ \ + sudo \ + curl \ + wget \ + git \ + vim \ + nano \ + unzip \ + zip \ + # Configure timezone + && ln -fs /usr/share/zoneinfo/Etc/UTC /etc/localtime \ + && dpkg-reconfigure --frontend noninteractive tzdata \ + # Clean up apt cache + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# --- Create Non-Root User & Group --- +# These arguments and user setup steps rarely change. +ARG NB_USER="sandboxuser" +ARG NB_UID=1001 +ARG NB_GID=1001 +ENV USER=${NB_USER} +ENV HOME=/home/${NB_USER} +# Add user's local bin to PATH early +ENV PATH=${HOME}/.local/bin:${PATH} + +# Create group, user, add to sudoers (run as root) +RUN groupadd -g ${NB_GID} ${NB_USER} && \ + useradd -m -s /bin/bash -u ${NB_UID} -g ${NB_GID} ${NB_USER} && \ + adduser ${NB_USER} sudo && \ + echo "${NB_USER} ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers + +# --- Install Python Dependencies --- +COPY ./requirements.txt /tmp/requirements.txt + +# Install Python packages. This layer is cached if requirements.txt hasn't changed. +# Run pip installs as the target user to ensure correct permissions and paths. +# Switch user and set working directory *before* pip install --user. +USER ${NB_USER} +WORKDIR ${HOME} + +RUN python -m pip install --no-cache-dir --upgrade pip --user && \ + python -m pip install --no-cache-dir --user \ + # Core Jupyter components (pin versions for stability) + ipython==8.12.0 \ + traitlets==5.9.0 \ + jupyter_client==8.3.0 \ + jupyter_core==5.3.1 \ + pyzmq==25.1.0 \ + tornado==6.3.2 \ + ipykernel==6.25.1 \ + # FastAPI dependencies + fastapi \ + uvicorn[standard] \ + python-multipart \ + # Install user requirements from the temporary location + -r /tmp/requirements.txt + +# --- Application Setup --- +# Copy application code and scripts AFTER dependencies are installed. +# Changes to these files will only invalidate the cache from this point. +COPY --chown=${NB_USER}:${NB_GID} ./kernel_api.py ${HOME}/kernel_api.py +COPY --chown=${NB_USER}:${NB_GID} ./start_kernel.py ${HOME}/start_kernel.py +COPY --chown=${NB_USER}:${NB_GID} ./start.sh ${HOME}/start.sh + +# Create user directories and make scripts executable in a single layer +RUN mkdir -p ${HOME}/.local/share/jupyter \ + ${HOME}/.ipython/profile_default/startup \ + ${HOME}/.ipython/profile_default/static && \ + chmod +x ${HOME}/start_kernel.py ${HOME}/start.sh + +# --- Runtime Configuration --- +# Expose the FastAPI port (informational) +EXPOSE 8000 + +# Set environment variable for kernel port (used by start_kernel.py) +ENV IPY_BASE_PORT=4000 + +# Use tini as the entrypoint; it will execute the CMD +# Ensure tini installed via apt is in the default PATH or use /usr/bin/tini +ENTRYPOINT ["/usr/bin/tini", "--"] + +# Set the default command to run the startup script from user's home +CMD ["/home/sandboxuser/start.sh"] diff --git a/benchmarking/sandbox/__pycache__/benchmarking_sandbox_management.cpython-311.pyc b/benchmarking/sandbox/__pycache__/benchmarking_sandbox_management.cpython-311.pyc new file mode 100644 index 0000000..05786f7 Binary files /dev/null and b/benchmarking/sandbox/__pycache__/benchmarking_sandbox_management.cpython-311.pyc differ diff --git a/benchmarking/sandbox/benchmarking_sandbox_management.py b/benchmarking/sandbox/benchmarking_sandbox_management.py new file mode 100644 index 0000000..1f8e624 --- /dev/null +++ b/benchmarking/sandbox/benchmarking_sandbox_management.py @@ -0,0 +1,523 @@ +# Import logging and sys first for configuration +import logging +import sys + +# --- Standard Library Imports --- +import argparse +import os +import time +import subprocess # Still needed for docker cp (if used elsewhere) +import shlex +import json +import io +import tempfile # May not be needed anymore + +# --- Third-Party Imports --- +try: + import docker +except ImportError: + logging.error("Error: docker library not found.") + logging.error("Please install it in your host environment: pip install docker") + sys.exit(1) + +# Optional: Use rich for better formatting if available +try: + from rich.console import Console + from rich.prompt import Prompt + HAS_RICH = True + console = Console() +except ImportError: + HAS_RICH = False + console = None + class Prompt: + @staticmethod + def ask(prompt, choices=None, default=None): + p_text = f"{prompt} " + if choices: p_text += f"({'/'.join(choices)}) " + if default: p_text += f"[{default}] " + return input(p_text).strip() + +# --- Configuration --- +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +DOCKERFILE_PATH = os.path.join(SCRIPT_DIR, 'Dockerfile') # Assumes Dockerfile is in the same dir +IMAGE_TAG = "benchmarking-sandbox:latest" +CONTAINER_NAME = "benchmarking_sandbox_instance" +# Port mapping for the FastAPI service inside the container +API_PORT_INSIDE = 8000 +API_PORT_HOST = 8000 # Host port to map to + +# Centralized message printing function using logging and optional rich console +def _print_message(message, style=None, is_error=False): + """Helper to print using rich console or standard logging.""" + level = logging.INFO # Default level + if is_error or (style and 'red' in style): + level = logging.ERROR + elif style and 'yellow' in style: + level = logging.WARNING + elif style and 'green' in style: + level = logging.INFO + elif style and ('dim' in style or 'blue' in style or 'cyan' in style): + # Keep these less critical styled messages as DEBUG if overall level is DEBUG + # If overall level is INFO, they won't show unless changed here. + # Let's make them INFO so they appear with INFO level logging. + level = logging.INFO # Changed from DEBUG to INFO + + # Log the message using Python's standard logging + # It will only appear if 'level' >= the level set in basicConfig + logging.log(level, message) + + # Additionally print to console using rich if available and not an error + # This provides the styling even if the log level is higher than the message level + if HAS_RICH and console and not is_error: + console.print(message, style=style if style else None) + elif not HAS_RICH and not is_error and level >= logging.INFO: + # Fallback print for non-error, non-debug messages when rich is unavailable + # Ensures INFO messages are printed if logging level is INFO or lower + print(message) + + +class SandboxManager: + """ + Manages the lifecycle of the benchmarking sandbox Docker container, + which now runs a kernel and a FastAPI service. + Uses logging for internal messages. + """ + def __init__(self): + self.client = None + self.container = None + try: + docker_host = os.environ.get("DOCKER_HOST") + if docker_host: + logging.info(f"Using DOCKER_HOST: {docker_host}") + self.client = docker.from_env() + self.client.ping() + logging.info("Docker client initialized successfully.") + except Exception as e: + logging.error(f"Error initializing Docker client: {e}", exc_info=True) + logging.error("Ensure Docker Desktop/Engine is running and DOCKER_HOST is set if needed.") + sys.exit(1) + + def _get_container_logs(self, tail=50): + """Retrieves recent logs from the managed container.""" + current_container = self._find_container() + if not current_container: + logging.warning("Attempted to get logs, but container '%s' not found.", CONTAINER_NAME) + if self.container: + logging.debug("Clearing stale internal container object.") + self.container = None + return "(Container not found or already removed)" + + target_container = current_container + try: + logs = target_container.logs(tail=tail).decode('utf-8', errors='ignore') + return logs + except Exception as log_e: + logging.error(f"Could not retrieve logs for container '{target_container.id}': {log_e}") + return f"(Could not retrieve logs: {log_e})" + + def _find_container(self): + """Finds the container by name, returns container object or None.""" + try: + container = self.client.containers.get(CONTAINER_NAME) + return container + except docker.errors.NotFound: + return None + except Exception as e: + logging.error(f"Error finding container '{CONTAINER_NAME}': {e}") + return None + + def build_image(self): + """Builds the Docker image from the Dockerfile.""" + _print_message(f"Building Docker image '[cyan]{IMAGE_TAG}[/cyan]' from [blue]{DOCKERFILE_PATH}[/blue]...", style="bold blue") + if not os.path.exists(DOCKERFILE_PATH): + _print_message(f"Error: Dockerfile not found at {DOCKERFILE_PATH}", style="bold red", is_error=True) + return False + try: + build_context = os.path.dirname(DOCKERFILE_PATH) + _print_message(f"Using build context: [blue]{build_context}[/blue]", style="blue") + stream = self.client.api.build( + path=build_context, + dockerfile=os.path.basename(DOCKERFILE_PATH), + tag=IMAGE_TAG, + rm=True, + decode=True + ) + last_status = None + for chunk in stream: + # Process build stream for logging/display + if 'stream' in chunk: + line = chunk['stream'].strip() + # Log/print build output lines only if they contain content + # Use _print_message with no style (defaults to INFO level) + if line: _print_message(line) # <-- CHANGED from logging.debug + elif 'errorDetail' in chunk: + error_msg = chunk['errorDetail']['message'] + _print_message(f"Build Error: {error_msg}", style="bold red", is_error=True) + return False + elif 'status' in chunk: + status = chunk['status'] + # Log/print status changes, but reduce noise from download/extract steps + if status != last_status and "Downloading" not in status and "Extracting" not in status: + # Use _print_message with dim style (logs as INFO) + _print_message(f"Status: {status}", style="dim") # <-- CHANGED from logging.debug + last_status = status + _print_message(f"Image '[cyan]{IMAGE_TAG}[/cyan]' built successfully.", style="green") + return True + except docker.errors.BuildError as e: + _print_message(f"Docker build failed: {e}", style="bold red", is_error=True) + for line in e.build_log: + if 'stream' in line: logging.error(f"Build Log: {line['stream'].strip()}") + return False + except Exception as e: + _print_message(f"An unexpected error occurred during build: {e}", style="bold red", is_error=True) + logging.exception("Build error details:") + return False + + def start_container(self, rebuild=False): + """Starts the Docker container with the FastAPI service.""" + # Handle rebuild request + if rebuild: + _print_message("Rebuild requested.") + existing_container = self._find_container() + if existing_container: + _print_message("Stopping existing container before rebuild...") + if not self.stop_container(remove=True, container_obj=existing_container): + _print_message("Failed to stop/remove existing container during rebuild. Aborting start.", style="red", is_error=True) + return False + if not self.build_image(): + # Error printed by build_image + return False + + # Ensure no old container instance is running + existing_container = self._find_container() + if existing_container: + _print_message(f"Found existing container '{CONTAINER_NAME}'. Stopping and removing it first...") + if not self.stop_container(remove=True, container_obj=existing_container): + _print_message(f"Failed to stop/remove existing container '{CONTAINER_NAME}'. Aborting start.", style="bold red", is_error=True) + return False + + # Start the container + _print_message(f"Starting container '[cyan]{CONTAINER_NAME}[/cyan]' with API service...", style="cyan") + try: + # Check if the image exists locally, build if not + try: + self.client.images.get(IMAGE_TAG) + except docker.errors.ImageNotFound: + _print_message(f"Image '[cyan]{IMAGE_TAG}[/cyan]' not found locally. Building...", style="yellow") + if not self.build_image(): + return False + + # Define port mapping for the FastAPI service + port_map = {f'{API_PORT_INSIDE}/tcp': API_PORT_HOST} + logging.info(f"Mapping container port {API_PORT_INSIDE} to host port {API_PORT_HOST}") + + # Define container run options (using default bridge network) + run_options = { + 'name': CONTAINER_NAME, + 'detach': True, + 'auto_remove': False, # Keep container for inspection on failure + 'ports': port_map, # Map the API port + } + logging.debug(f"Docker run options: {run_options}") + + # Run the container + self.container = self.client.containers.run( + IMAGE_TAG, + **run_options + ) + + _print_message(f"Container '[cyan]{CONTAINER_NAME}[/cyan]' started ([yellow]{self.container.short_id}[/yellow]).", style="cyan") + + # Wait briefly for the service inside to potentially start + wait_time = 5 # Seconds to wait for API/kernel startup + _print_message(f"Waiting {wait_time}s for services inside container to initialize...") + time.sleep(wait_time) + + # Basic check: Is the container still running? + # Use _find_container to check live status + current_container = self._find_container() + if not current_container or current_container.status != 'running': + status = current_container.status if current_container else 'not found' + _print_message(f"Container exited unexpectedly shortly after start (status: {status}).", style="bold red", is_error=True) + # Try to get logs even if stopped/gone + logs = self._get_container_logs() + _print_message("--- Container Logs ---", is_error=True) + _print_message(logs if logs else "(Could not retrieve logs)", is_error=True) + _print_message("----------------------", is_error=True) + self.container = None # Clear internal ref + return False + + _print_message(f"Container running. API should be accessible at http://localhost:{API_PORT_HOST}", style="green") + return True # Container started successfully + + except Exception as e: + _print_message(f"Error during container start: {e}", style="bold red", is_error=True) + logging.exception("Container start error details:") + # Ensure cleanup if self.container was assigned + current_container = self._find_container() + if current_container: + self.stop_container(remove=True, container_obj=current_container) + self.container = None + return False + + def stop_container(self, remove=False, container_obj=None): + """Stops the container and optionally removes it.""" + # Find the container to stop if not provided + container_to_stop = container_obj or self._find_container() + + if not container_to_stop: + _print_message(f"Container '{CONTAINER_NAME}' not found or already stopped/removed.", style="yellow") + if container_obj is None: + self.container = None + return True + + _print_message(f"Stopping container '[cyan]{CONTAINER_NAME}[/cyan]' ([yellow]{container_to_stop.short_id}[/yellow])...", style="cyan") + stopped = False + removed_flag = False + try: + container_to_stop.reload() + current_status = container_to_stop.status + if current_status == 'running': + _print_message("Sending stop signal to container...") + container_to_stop.stop(timeout=10) + time.sleep(1) + container_to_stop.reload() + if container_to_stop.status == 'exited': + _print_message("Container stopped successfully.", style="green") + stopped = True + else: + _print_message(f"Container status is '{container_to_stop.status}' after stop attempt. Trying force stop...", style="yellow") + container_to_stop.kill() + time.sleep(1) + container_to_stop.reload() + if container_to_stop.status == 'exited': + _print_message("Container force stopped.", style="green") + stopped = True + else: + _print_message(f"Container still '{container_to_stop.status}' after force stop.", style="red", is_error=True) + else: + _print_message(f"Container was not running (status: {current_status}).") + stopped = True + + if remove and stopped: + try: + _print_message("Removing container...") + container_to_stop.remove(force=True) + _print_message("Container removed.", style="green") + removed_flag = True + except docker.errors.NotFound: + _print_message("Container was already removed.", style="yellow") + removed_flag = True + except docker.errors.APIError as e: + _print_message(f"API error removing container: {e}", style="yellow", is_error=True) + try: + self.client.containers.get(container_to_stop.id) + removed_flag = False + except docker.errors.NotFound: + _print_message("Container appears removed despite API error.", style="yellow") + removed_flag = True + elif remove and not stopped: + _print_message("Remove requested, but container failed to stop. Attempting force remove...", style="yellow") + try: + container_to_stop.remove(force=True) + _print_message("Container force removed.", style="green") + removed_flag = True + except Exception as fe: + _print_message(f"Failed to force remove container: {fe}", style="red", is_error=True) + removed_flag = False + elif not remove: + removed_flag = True + + if self.container and self.container.id == container_to_stop.id: + self.container = None + + return stopped and removed_flag + + except Exception as e: + _print_message(f"Error stopping/removing container: {e}", style="bold red", is_error=True) + logging.exception("Stop/remove error details:") + if remove and container_to_stop: + try: + _print_message("Attempting force remove after error...", style="yellow") + container_to_stop.remove(force=True) + _print_message("Container force removed after error.", style="green") + except Exception as fe: + _print_message(f"Failed to force remove container after error: {fe}", style="red", is_error=True) + if self.container and container_to_stop and self.container.id == container_to_stop.id: + self.container = None + return False + + def get_status(self): + """Gets the status of the container.""" + container_status = "not found" + container = self._find_container() + if container: + try: + container.reload() + container_status = container.status + except Exception as e: + logging.warning(f"Error getting container status: {e}") + container_status = "unknown (error)" + + return f"Container: {container_status}, API Port (Host): {API_PORT_HOST}" + + +# --- Interactive Mode Functions --- +def print_interactive_help(): + """Prints help message for interactive mode using _print_message.""" + _print_message("\n[bold cyan]Available Commands:[/bold cyan]", style="bold cyan") + _print_message(" [green]build[/green] Build the Docker image.", style="green") + _print_message(" [green]start[/green] [--rebuild] Start container with API service.", style="green") + _print_message(" [green]stop[/green] Stop & remove container.", style="green") + _print_message(" [green]status[/green] Check container status.", style="green") + _print_message(" [green]logs[/green] [N] Show last N container logs (default 50).", style="green") + _print_message(" [green]help[/green] Show this help message.", style="green") + _print_message(" [green]exit[/green] Exit (prompts to stop container if running).", style="green") + _print_message("\nExample: [yellow]start --rebuild[/yellow]", style="yellow") + +def interactive_loop(manager): + """Runs the interactive command loop.""" + _print_message("[bold blue]Welcome to the Stateful Benchmarking Sandbox Manager (API Mode)![/bold blue]", style="bold blue") + print_interactive_help() + while True: + try: + raw_command = Prompt.ask("\nEnter command ('help' or 'exit')") + if not raw_command: continue + + try: + command_parts = shlex.split(raw_command) + except ValueError as e: + _print_message(f"Error parsing command: {e}", style="red", is_error=True) + continue + if not command_parts: continue + + command = command_parts[0].lower() + args = command_parts[1:] + + if command == "exit": + container_running = False + container_obj = manager._find_container() + if container_obj: + try: + container_obj.reload() + container_running = container_obj.status == 'running' + except Exception as e: + logging.warning(f"Could not check container status on exit: {e}") + if container_running: + should_stop_str = Prompt.ask(f"Container '{CONTAINER_NAME}' is running. Stop it?", choices=["y", "n"], default="y") + if should_stop_str.lower() == 'y': + _print_message("Stopping container on exit...") + manager.stop_container(remove=True, container_obj=container_obj) + break + elif command == "help": + print_interactive_help() + elif command == "build": + if len(args) == 0: manager.build_image() + else: _print_message("Usage: build", style="yellow") + elif command == "start": + rebuild = '--rebuild' in args + if all(a == '--rebuild' for a in args if a.startswith('--')) or len(args) == 0: + manager.start_container(rebuild=rebuild) + else: _print_message("Usage: start [--rebuild]", style="yellow") + elif command == "stop": + if len(args) == 0: manager.stop_container(remove=True) + else: _print_message("Usage: stop", style="yellow") + elif command == "status": + if len(args) == 0: _print_message(f"Status: {manager.get_status()}") + else: _print_message("Usage: status", style="yellow") + elif command == "logs": + tail_count = 50 + if len(args) == 1: + try: + tail_count = int(args[0]) + except ValueError: + _print_message("Usage: logs [number_of_lines]", style="yellow") + continue + elif len(args) > 1: + _print_message("Usage: logs [number_of_lines]", style="yellow") + continue + + logs = manager._get_container_logs(tail=tail_count) + _print_message(f"\n--- Last {tail_count} Container Logs ---") + # Use standard print for logs as they can be multiline and formatting is less critical + print(logs if logs else "(No logs retrieved or container not found)") + _print_message("-----------------------------") + else: + _print_message(f"Unknown command: '{command}'. Type 'help'.", style="red") + + except EOFError: + _print_message("\nEOF detected. Exiting.", style="yellow") + container_running = False; container_obj = manager._find_container() + if container_obj: + try: + container_obj.reload(); + container_running = (container_obj.status == 'running') + except Exception: pass + if container_running: + should_stop_str = Prompt.ask(f"Container '{CONTAINER_NAME}' is running. Stop it?", choices=["y", "n"], default="y") + if should_stop_str.lower() == 'y': manager.stop_container(remove=True, container_obj=container_obj) + break + except KeyboardInterrupt: + _print_message("\nInterrupted by user. Type 'exit'.", style="yellow") + except Exception as e: + _print_message(f"Unexpected error in interactive loop: {e}", style="bold red", is_error=True) + logging.exception("Interactive loop error details:") + + _print_message("Exiting sandbox manager.", style="bold blue") + + +# --- Main Entry Point --- +def main(): + try: + manager = SandboxManager() + except SystemExit: + sys.exit(1) + + if len(sys.argv) == 1: + interactive_loop(manager) + sys.exit(0) + + # --- Command-Line Argument Parsing (Simplified) --- + parser = argparse.ArgumentParser( + description="Manage the Stateful Benchmarking Sandbox (API Mode).", + formatter_class=argparse.RawTextHelpFormatter + ) + subparsers = parser.add_subparsers( + dest='command', + help='Action to perform', + required=True + ) + + parser_build = subparsers.add_parser('build', help='Build the Docker image.') + parser_build.set_defaults(func=lambda args, mgr: mgr.build_image()) + + parser_start = subparsers.add_parser('start', help='Start container with API service.') + parser_start.add_argument('--rebuild', action='store_true', help='Rebuild image first.') + parser_start.set_defaults(func=lambda args, mgr: mgr.start_container(rebuild=args.rebuild)) + + parser_stop = subparsers.add_parser('stop', help='Stop & remove container.') + parser_stop.set_defaults(func=lambda args, mgr: mgr.stop_container(remove=True)) + + parser_status = subparsers.add_parser('status', help='Check container status.') + parser_status.set_defaults(func=lambda args, mgr: _print_message(f"Status: {mgr.get_status()}")) + + parser_logs = subparsers.add_parser('logs', help='Show last N container logs.') + parser_logs.add_argument('n', type=int, nargs='?', default=50, help='Number of lines to show (default: 50)') + def show_logs(args, mgr): + logs = mgr._get_container_logs(tail=args.n) + _print_message(f"\n--- Last {args.n} Container Logs ---") + # Use standard print for logs + print(logs if logs else "(No logs retrieved or container not found)") + _print_message("-----------------------------") + return True # Assume success for showing logs + parser_logs.set_defaults(func=show_logs) + + args = parser.parse_args() + # Execute the function and store success status + success = args.func(args, manager) + # Exit with appropriate status code + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() diff --git a/benchmarking/sandbox/kernel_api.py b/benchmarking/sandbox/kernel_api.py new file mode 100644 index 0000000..5b872dd --- /dev/null +++ b/benchmarking/sandbox/kernel_api.py @@ -0,0 +1,315 @@ +import logging +import sys +import os +import json +import asyncio +import base64 +import tempfile +from contextlib import asynccontextmanager +from queue import Empty +import time + +# --- FastAPI Imports --- +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field # Updated import for Field + +# --- Jupyter Client Imports --- +# Using AsyncClient for better compatibility with FastAPI +from jupyter_client.manager import AsyncKernelManager +# Corrected import path for AsyncKernelClient +from jupyter_client.asynchronous.client import AsyncKernelClient + +# --- Logging Setup --- +# Configure logging to see messages from this API and jupyter_client +logging.basicConfig( + level=logging.DEBUG, + stream=sys.stdout, + format='%(asctime)s - %(name)s - %(levelname)s - [FastAPI_Kernel] %(message)s', + force=True +) +log = logging.getLogger(__name__) + +# --- Global Variables --- +# Global kernel manager and client to potentially reuse connection (or manage lifecycle) +# Using lifespan events is generally preferred for managing resources like this. +kernel_manager: AsyncKernelManager | None = None +kernel_client: AsyncKernelClient | None = None +KERNEL_CONNECTION_FILE = "/home/sandboxuser/kernel-connection.json" # Path inside container + +# --- Pydantic Models --- +class CodeExecutionRequest(BaseModel): + """Request model for code execution.""" + code: str = Field(..., description="The Python code string to execute.") + timeout: int = Field(60, description="Execution timeout in seconds.") + +class StreamOutput(BaseModel): + """Model for stdout/stderr stream messages.""" + type: str = "stream" + name: str # 'stdout' or 'stderr' + text: str + +class ErrorOutput(BaseModel): + """Model for execution errors.""" + type: str = "error" + ename: str # Error name (e.g., 'ValueError') + evalue: str # Error value (message) + traceback: list[str] # List of traceback lines + +class DisplayDataOutput(BaseModel): + """Model for display_data messages (like images).""" + type: str = "display_data" + data: dict[str, str] # Mime-type -> Base64 encoded data (e.g., 'image/png': 'base64...') + metadata: dict = Field(default_factory=dict) + +class ExecuteResultOutput(BaseModel): + """Model for the result of the last expression.""" + type: str = "execute_result" + data: dict[str, str] # Mime-type -> String data (e.g., 'text/plain': 'result') + metadata: dict = Field(default_factory=dict) + +class ExecutionStatus(BaseModel): + """Model for kernel status updates.""" + type: str = "status" + execution_state: str # e.g., 'busy', 'idle' + +# Union type for different output possibilities +OutputType = StreamOutput | ErrorOutput | DisplayDataOutput | ExecuteResultOutput | ExecutionStatus + +class CodeExecutionResponse(BaseModel): + """Response model containing execution results.""" + outputs: list[OutputType] = Field(..., description="List of outputs from the kernel.") + final_status: str = Field("unknown", description="Final status from execute_reply ('ok', 'error', 'aborted').") + +# --- Helper Functions --- +async def get_kernel_client() -> AsyncKernelClient: + """ + Connects to the kernel using the connection file. + Raises FileNotFoundError or TimeoutError if connection fails. + """ + log.info(f"Attempting to connect to kernel using {KERNEL_CONNECTION_FILE}") + if not os.path.exists(KERNEL_CONNECTION_FILE): + log.error(f"Kernel connection file not found at {KERNEL_CONNECTION_FILE}") + raise FileNotFoundError("Kernel connection file not found.") + + # Create a client connected to the existing kernel + kc = AsyncKernelClient(connection_file=KERNEL_CONNECTION_FILE) + kc.load_connection_file() + + # Start channels - crucial for communication + # This method is synchronous in the async client, it starts background tasks/threads. + try: + log.debug("Starting kernel client channels (synchronous call)...") + kc.start_channels() # <-- REMOVED await asyncio.wait_for() + log.info("Kernel client channels started.") + except Exception as e: + log.error(f"Error starting kernel client channels: {e}", exc_info=True) + raise e # Re-raise other exceptions + + # Check if kernel is alive with timeout + try: + log.info("Waiting for kernel to be ready...") + # wait_for_ready IS awaitable + await asyncio.wait_for(kc.wait_for_ready(timeout=15.0), timeout=20.0) + log.info("Kernel is ready.") + return kc + except asyncio.TimeoutError: + log.error("Timeout waiting for kernel to become ready.") + # Attempt to stop channels if started + try: + # stop_channels might also be sync, handle potential errors + if kc.channels_running: + kc.stop_channels() + except Exception: + pass # Ignore errors during cleanup + raise TimeoutError("Timeout waiting for kernel readiness.") + except RuntimeError as e: + log.error(f"Kernel readiness check failed: {e}") + try: + if kc.channels_running: + kc.stop_channels() + except Exception: + pass + raise RuntimeError(f"Kernel readiness check failed: {e}") + except Exception as e: + log.error(f"Unexpected error during kernel readiness check: {e}", exc_info=True) + try: + if kc.channels_running: + kc.stop_channels() + except Exception: + pass + raise e + + +async def execute_code_on_kernel(kc: AsyncKernelClient, code: str, timeout: int) -> CodeExecutionResponse: + """ + Executes code using the provided async kernel client and gathers results. + """ + log.info(f"Executing code (timeout={timeout}s):\n---\n{code}\n---") + outputs = [] + final_status = "unknown" + + # Send execute request + msg_id = kc.execute(code=code, store_history=True) + log.debug(f"Execute request sent, msg_id: {msg_id}") + + # Process messages until idle or error + start_time = time.time() + execution_done = False + while time.time() - start_time < timeout: + try: + # Get message from IOPub channel with a short timeout + msg = await asyncio.wait_for(kc.get_iopub_msg(timeout=1.0), timeout=1.5) + msg_type = msg['header']['msg_type'] + content = msg['content'] + log.debug(f"Received IOPub message type: {msg_type}") + + if msg_type == 'status': + outputs.append(ExecutionStatus(execution_state=content['execution_state'])) + if content['execution_state'] == 'idle': + log.debug("Kernel reported idle status.") + execution_done = True + # Don't break immediately, wait for shell reply below + elif msg_type == 'stream': + outputs.append(StreamOutput(name=content['name'], text=content['text'])) + elif msg_type == 'display_data': + # Base64 encode binary data for JSON transfer + encoded_data = {} + for mime, data in content.get('data', {}).items(): + if isinstance(data, bytes): + encoded_data[mime] = base64.b64encode(data).decode('utf-8') + elif isinstance(data, str): # Assume text is already appropriate string + if mime.startswith('image/') or mime == 'text/html': + encoded_data[mime] = data # Keep as string + else: + pass + else: + log.warning(f"Unsupported data type '{type(data)}' in display_data for mime '{mime}'") + + if encoded_data: + outputs.append(DisplayDataOutput(data=encoded_data, metadata=content.get('metadata', {}))) + + elif msg_type == 'execute_result': + outputs.append(ExecuteResultOutput(data=content.get('data', {}), metadata=content.get('metadata', {}))) + elif msg_type == 'error': + outputs.append(ErrorOutput( + ename=content.get('ename', 'UnknownError'), + evalue=content.get('evalue', ''), + traceback=content.get('traceback', []) + )) + log.error(f"Kernel execution error: {content.get('ename')}") + execution_done = True # Error means execution finished + # Don't break immediately, wait for shell reply + + except asyncio.TimeoutError: + if execution_done: + log.debug("IOPub processing finished after kernel idle/error.") + break + else: + pass + except Empty: + log.debug("IOPub queue empty, continuing wait...") + if execution_done: break + except Exception as e: + log.error(f"Error processing IOPub message: {e}", exc_info=True) + outputs.append(ErrorOutput(ename="ClientError", evalue=f"Error reading IOPub: {e}", traceback=[])) + execution_done = True + break + + # After loop (timeout or kernel idle/error), get the shell reply + try: + shell_reply = await asyncio.wait_for(kc.get_shell_msg(timeout=5.0), timeout=6.0) + if shell_reply['parent_header'].get('msg_id') == msg_id: + final_status = shell_reply['content']['status'] + log.info(f"Execution final status from shell reply: {final_status}") + if final_status == 'error' and not any(o.type == 'error' for o in outputs): + outputs.append(ErrorOutput( + ename=shell_reply['content'].get('ename', 'ShellError'), + evalue=shell_reply['content'].get('evalue', 'Error reported by shell'), + traceback=shell_reply['content'].get('traceback', []) + )) + else: + log.warning(f"Received shell message {shell_reply.get('msg_id')} doesn't match request {msg_id}") + final_status = "mismatched_reply" + except asyncio.TimeoutError: + log.warning("Timeout waiting for shell reply.") + final_status = "timeout_shell_reply" + except Empty: + log.warning("Shell reply queue empty.") + final_status = "empty_shell_reply" + except Exception as e: + log.error(f"Error getting shell reply: {e}", exc_info=True) + final_status = "error_shell_reply" + + if not execution_done and time.time() - start_time >= timeout: + log.error("Execution timed out.") + if not any(o.type == 'error' and 'Timeout' in o.evalue for o in outputs): + outputs.append(ErrorOutput(ename="TimeoutError", evalue=f"Execution timed out after {timeout} seconds", traceback=[])) + final_status = "timeout" + + + return CodeExecutionResponse(outputs=outputs, final_status=final_status) + + +# --- FastAPI App --- +@asynccontextmanager +async def lifespan(app: FastAPI): + log.info("FastAPI application startup...") + yield + log.info("FastAPI application shutdown...") + +app = FastAPI(lifespan=lifespan, title="Jupyter Kernel Execution API") + +@app.get("/status", summary="Check API status") +async def get_status(): + """Returns the status of the API.""" + log.info("Status endpoint called.") + kernel_file_exists = os.path.exists(KERNEL_CONNECTION_FILE) + return JSONResponse(content={ + "status": "ok", + "kernel_connection_file_found": kernel_file_exists + }) + +@app.post("/execute", + response_model=CodeExecutionResponse, + summary="Execute Python code in the kernel") +async def execute_code_endpoint(payload: CodeExecutionRequest): + """ + Receives Python code, executes it using the Jupyter kernel, + and returns captured outputs (stdout, stderr, errors, display data). + """ + log.info(f"Received code execution request (timeout={payload.timeout}s).") + kc = None + try: + kc = await get_kernel_client() + response = await execute_code_on_kernel(kc, payload.code, payload.timeout) + log.info(f"Execution finished with final status: {response.final_status}") + return response + + except FileNotFoundError: + log.error("Kernel connection file missing.") + raise HTTPException(status_code=503, detail="Kernel connection file not found. Is the kernel running?") + except TimeoutError as e: + log.error(f"Timeout during kernel connection or execution: {e}") + raise HTTPException(status_code=504, detail=f"Timeout: {e}") + except RuntimeError as e: + log.error(f"Runtime error during kernel connection: {e}") + raise HTTPException(status_code=503, detail=f"Kernel connection runtime error: {e}") + except Exception as e: + log.error(f"Unexpected error during code execution: {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error: {e}") + finally: + # Ensure kernel client channels are stopped after each request + if kc and kc.channels_running: + log.debug("Stopping kernel client channels for this request.") + try: + # stop_channels is likely synchronous, call directly + kc.stop_channels() + except Exception as e: + log.warning(f"Error stopping kernel client channels for request: {e}") + +# --- Uvicorn Entry Point (for direct execution if needed) --- +if __name__ == "__main__": + import uvicorn + log.info("Starting Uvicorn server directly for debugging...") + uvicorn.run("kernel_api:app", host="0.0.0.0", port=8000, log_level="debug", reload=True) # Added reload for dev diff --git a/benchmarking/sandbox/requirements.txt b/benchmarking/sandbox/requirements.txt new file mode 100644 index 0000000..956c953 --- /dev/null +++ b/benchmarking/sandbox/requirements.txt @@ -0,0 +1,27 @@ +ansi2html==1.8.0 +scanpy==1.10.2 +scrublet +anndata==0.10.8 +celltypist==1.6.3 +leidenalg==0.10.2 +igraph==0.11.6 +networkx==3.2.1 +pynndescent==0.5.13 +numpy==1.26.4 +scipy==1.13.1 +pandas==2.2.2 +scikit-learn==1.5.1 +umap-learn==0.5.6 +statsmodels==0.14.2 +numba==0.60.0 +matplotlib==3.9.1 +seaborn==0.13.2 +h5py==3.11.0 +openpyxl==3.1.5 +PyPDF2 +tqdm==4.66.4 +psutil==6.0.0 +defusedxml==0.7.1 +requests==2.32.3 +jupyter +jupyter_client \ No newline at end of file diff --git a/benchmarking/sandbox/start.sh b/benchmarking/sandbox/start.sh new file mode 100644 index 0000000..71b413a --- /dev/null +++ b/benchmarking/sandbox/start.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# Exit immediately if a command exits with a non-zero status. +set -e + +# Use $HOME which is set to /home/sandboxuser in the Dockerfile +KERNEL_SCRIPT_PATH="$HOME/start_kernel.py" +CONNECTION_FILE_PATH="$HOME/kernel-connection.json" +KERNEL_LOG_PATH="/tmp/kernel.log" # Keep logs in /tmp + +echo "[start.sh] Starting Jupyter Kernel ($KERNEL_SCRIPT_PATH) in background..." +# Start the kernel using the script in the user's home directory. +# Redirect kernel output to a log file. +python "$KERNEL_SCRIPT_PATH" > "$KERNEL_LOG_PATH" 2>&1 & + +# Store the PID of the kernel process +KERNEL_PID=$! +echo "[start.sh] Kernel process started with PID: $KERNEL_PID" + +# Wait a few seconds to allow the kernel to initialize and write the connection file +echo "[start.sh] Waiting 5 seconds for kernel to initialize..." +sleep 5 + +# Check if the kernel process is still running +if ! kill -0 $KERNEL_PID > /dev/null 2>&1; then + echo "[start.sh] ERROR: Kernel process died shortly after starting. Check $KERNEL_LOG_PATH" + cat "$KERNEL_LOG_PATH" # Print kernel log on error + exit 1 +fi +echo "[start.sh] Kernel process appears to be running." + +# Check if connection file was created using the dynamic path +if [ ! -f "$CONNECTION_FILE_PATH" ]; then + echo "[start.sh] ERROR: Kernel connection file ($CONNECTION_FILE_PATH) was not created. Check $KERNEL_LOG_PATH" + cat "$KERNEL_LOG_PATH" # Print kernel log on error + exit 1 +fi +echo "[start.sh] Kernel connection file found at $CONNECTION_FILE_PATH." + + +echo "[start.sh] Starting FastAPI Uvicorn server..." +# Start the FastAPI application using Uvicorn. +# Assumes kernel_api.py is in the WORKDIR ($HOME). +# --host 0.0.0.0 makes it accessible from outside the container (host machine). +# --port 8000 is the standard port, ensure it's mapped in docker run/compose. +# Use exec to replace the shell process with uvicorn, allowing tini to manage it correctly. +exec uvicorn kernel_api:app --host 0.0.0.0 --port 8000 --log-level debug \ No newline at end of file diff --git a/benchmarking/sandbox/start_kernel.py b/benchmarking/sandbox/start_kernel.py new file mode 100644 index 0000000..dcfe01e --- /dev/null +++ b/benchmarking/sandbox/start_kernel.py @@ -0,0 +1,15 @@ +import os, os.path, sys + +base = int(os.environ.get("IPY_BASE_PORT", 4000)) +argv = [ + sys.executable, "-Xfrozen_modules=off", "-vv", "-m", "ipykernel_launcher", + "--ip=127.0.0.1", + "--log-level=DEBUG", + f"--shell={base + 0}", + f"--iopub={base + 1}", + f"--stdin={base + 2}", + f"--hb={base + 3}", + f"--control={base + 4}", + "-f", "/home/sandboxuser/kernel-connection.json", +] +os.execvp(argv[0], argv) \ No newline at end of file diff --git a/benchmarking/tools/czi_browser.py b/benchmarking/tools/czi_browser.py new file mode 100644 index 0000000..597ea09 --- /dev/null +++ b/benchmarking/tools/czi_browser.py @@ -0,0 +1,580 @@ +#!/usr/bin/env python +import argparse +import cellxgene_census +import sys +import math +import shlex # For parsing interactive commands safely +import os # For path operations and directory creation +import json # For saving metadata +import re # For sanitizing filenames + +try: + from rich.console import Console + from rich.table import Table + from rich.pretty import pprint + from rich.prompt import Prompt # For interactive prompts + HAS_RICH = True +except ImportError: + HAS_RICH = False + # Simple print/input fallback if rich is not installed + def pprint(obj): print(obj) + class Console: + def print(self, *args, **kwargs): print(*args) + class Table: + # Basic fallback Table class + def __init__(self, title=""): + self._title = title + self._rows = [] + self._columns = [] + self._styles = {} # Dummy style storage + def add_column(self, header, style=""): + self._columns.append(header) + self._styles[header] = style # Store style info even if unused + def add_row(self, *items): + # Ensure row has same number of items as columns + if len(items) != len(self._columns): + raise ValueError("Number of items in row does not match number of columns") + self._rows.append(items) + def __rich_console__(self, console, options): # Dummy method for rich compatibility + # Basic text rendering for fallback + yield self._title + yield "\t".join(self._columns) + for row in self._rows: yield "\t".join(map(str, row)) + def print_table(self, console): # Custom print method if rich not available + console.print(self._title) + if self._columns: # Only print header/rows if columns exist + col_widths = [len(h) for h in self._columns] + for row in self._rows: + for i, item in enumerate(row): + col_widths[i] = max(col_widths[i], len(str(item))) + + header_line = " ".join(f"{h:<{w}}" for h, w in zip(self._columns, col_widths)) + separator = "-" * len(header_line) + console.print(header_line) + console.print(separator) + for row in self._rows: + row_line = " ".join(f"{str(item):<{w}}" for item, w in zip(row, col_widths)) + console.print(row_line) + + class Prompt: + @staticmethod + def ask(prompt, choices=None, default=None): + p_text = f"{prompt} " + if choices: + choices_str = '/'.join(choices) + p_text += f"({choices_str}) " + if default: + p_text += f"[{default}] " + return input(p_text).strip() + +# --- Helper Functions --- + +def sanitize_filename(name): + """Removes invalid characters and replaces spaces for use in filenames.""" + # Remove characters that are not alphanumeric, underscore, or hyphen + name = re.sub(r'[^\w\-]+', '_', name) + # Replace multiple underscores with a single one + name = re.sub(r'_+', '_', name) + # Remove leading/trailing underscores + name = name.strip('_') + # Convert to lowercase + return name.lower() + +def ensure_datasets_dir_exists(base_dir="../datasets"): + """Checks if the target directory exists and creates it if not.""" + # Get the absolute path relative to the script location + script_dir = os.path.dirname(os.path.abspath(__file__)) + target_dir = os.path.abspath(os.path.join(script_dir, base_dir)) + + if not os.path.exists(target_dir): + print(f"Creating target directory: {target_dir}") + try: + os.makedirs(target_dir) + except OSError as e: + raise OSError(f"Failed to create directory {target_dir}: {e}") + elif not os.path.isdir(target_dir): + raise NotADirectoryError(f"Target path {target_dir} exists but is not a directory.") + return target_dir + + +# --- Core Data Fetching Functions --- + +def get_census_versions_data(): + """Fetches available CELLxGENE Census versions data.""" + try: + census_versions = cellxgene_census.get_census_version_directory() + versions_list = [] + # Prioritize 'stable', then 'latest', then sort others reverse chronologically + sorted_versions = sorted( + census_versions.keys(), + key=lambda v: ('0' if v == 'stable' else '1' if v == 'latest' else '2') + v, + reverse=True # Puts stable/latest effectively first, then sorts dates reverse + ) + + for version in sorted_versions: + description = census_versions[version] + release_date = "N/A" + try: + # Avoid fetching description again if already present + release_date = description.get("release_date") + if not release_date: + details = cellxgene_census.get_census_version_description(version) + release_date = details.get("release_date", "N/A") + except Exception: + pass # Ignore if details can't be fetched + versions_list.append({ + "version": version, + "description": description.get('description', description.get('uri', 'N/A')), + "release_date": release_date + }) + return versions_list + except Exception as e: + raise RuntimeError(f"Error listing versions: {e}") + +def fetch_source_datasets_data(census_version): + """Fetches source datasets DataFrame for a specific Census version.""" + console = Console() + console.print(f"Fetching source datasets info for Census version: [cyan]{census_version}[/cyan]...") + try: + # Check if version is valid before opening (optional, but good practice) + available_versions = cellxgene_census.get_census_version_directory() + if census_version not in available_versions: + console.print(f"[bold red]Error:[/bold red] Census version '{census_version}' not found.") + # Attempt to list versions to help user + try: + versions_data = get_census_versions_data() + console.print("Available versions:") + for v in versions_data: + console.print(f" - {v['version']} ({v.get('release_date', 'N/A')})") + except Exception: + console.print("(Could not fetch list of available versions)") + return None + + # Inform user about specific date mapping if using 'stable'/'latest' + try: + version_description = cellxgene_census.get_census_version_description(census_version) + actual_version = version_description.get("release_date", census_version) + if census_version in ["stable", "latest"] and actual_version != census_version: + console.print(f"The \"{census_version}\" release is currently [bold green]{actual_version}[/bold green]. Specify 'census_version=\"{actual_version}\"' in future calls to open_soma() to ensure data consistency.") + except Exception: + console.print(f"[yellow]Warning: Could not verify exact date for '{census_version}'. Proceeding...[/yellow]") + + + with cellxgene_census.open_soma(census_version=census_version) as census: + if "census_info" not in census or "datasets" not in census["census_info"]: + raise RuntimeError("Census object structure unexpected: 'census_info' or 'datasets' missing.") + + datasets_df = census["census_info"]["datasets"].read().concat().to_pandas() + if datasets_df.empty: + console.print(f"No source dataset information found for version {census_version}.") + return datasets_df # Return empty DataFrame + return datasets_df + except Exception as e: + raise RuntimeError(f"Error fetching datasets for version {census_version}: {e}") + + +def get_dataset_metadata_data(census_version, dataset_id): + """Fetches metadata dictionary for a specific source dataset.""" + console = Console() + console.print(f"Fetching metadata for dataset [cyan]{dataset_id}[/cyan] in Census version: [cyan]{census_version}[/cyan]...") + try: + # Reuse fetch_source_datasets_data which includes version check + datasets_df = fetch_source_datasets_data(census_version) + if datasets_df is None: # Check if fetch failed (e.g., invalid version) + raise ValueError(f"Could not retrieve dataset list for version {census_version}.") + if datasets_df.empty: # Check if fetch succeeded but returned empty + raise ValueError(f"No datasets found for version {census_version}, cannot fetch metadata.") + + dataset_metadata = datasets_df[datasets_df['dataset_id'] == dataset_id] + + if dataset_metadata.empty: + raise ValueError(f"Dataset ID '{dataset_id}' not found in Census version '{census_version}'.") + + return dataset_metadata.iloc[0].to_dict() + except Exception as e: + # Catch specific errors if needed, otherwise re-raise or wrap + raise RuntimeError(f"Error fetching metadata for dataset {dataset_id}: {e}") + + +# --- Download Function --- + +def download_dataset(console, census_version, dataset_id): + """Downloads the H5AD file and saves metadata JSON for a dataset.""" + try: + # 1. Ensure target directory exists + target_dir = ensure_datasets_dir_exists() + console.print(f"Target directory: [blue]{target_dir}[/blue]") + + # 2. Fetch metadata first to get the title and verify dataset exists + metadata = get_dataset_metadata_data(census_version, dataset_id) # Handles errors + dataset_title = metadata.get('dataset_title', f'dataset_{dataset_id}') # Fallback title + + # 3. Generate filenames + base_filename = sanitize_filename(dataset_title) + if not base_filename: # Handle cases where title sanitizes to empty string + base_filename = f"dataset_{dataset_id}" + h5ad_filename = f"{base_filename}.h5ad" + json_filename = f"{base_filename}.json" + h5ad_filepath = os.path.join(target_dir, h5ad_filename) + json_filepath = os.path.join(target_dir, json_filename) + + console.print(f"Preparing to download dataset:") + console.print(f" ID: [cyan]{dataset_id}[/cyan]") + console.print(f" Title: [green]{dataset_title}[/green]") + console.print(f" Version: [cyan]{census_version}[/cyan]") + console.print(f" Output H5AD: [blue]{h5ad_filepath}[/blue]") + console.print(f" Output JSON: [blue]{json_filepath}[/blue]") + + # Check if files already exist (optional, add overwrite flag later if needed) + if os.path.exists(h5ad_filepath) or os.path.exists(json_filepath): + console.print("[yellow]Warning: One or both output files already exist. Skipping download.[/yellow]") + console.print("[yellow] (Delete existing files or implement an --overwrite flag to replace.)[/yellow]") + return # Or prompt user, or add an overwrite flag + + # 4. Download H5AD + console.print(f"Downloading H5AD file...") + cellxgene_census.download_source_h5ad( + dataset_id=dataset_id, + to_path=h5ad_filepath, + census_version=census_version + ) + console.print("[bold green]H5AD Download complete.[/bold green]") + + # 5. Save Metadata JSON + console.print("Saving metadata JSON file...") + try: + with open(json_filepath, 'w', encoding='utf-8') as f: + # Convert numpy types to standard Python types if necessary + def convert_types(obj): + if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, + np.int16, np.int32, np.int64, np.uint8, + np.uint16, np.uint32, np.uint64)): + return int(obj) + elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): + return float(obj) + elif isinstance(obj, (np.ndarray,)): # Handle arrays if needed + return obj.tolist() # Or other representation + elif isinstance(obj, (np.bool_)): + return bool(obj) + elif isinstance(obj, (np.void)): # Handle complex types if they appear + return None # Or suitable representation + return obj + + # Import numpy locally for type checking if needed + import numpy as np + json.dump(metadata, f, indent=4, default=convert_types, ensure_ascii=False) + console.print("[bold green]Metadata JSON saved successfully.[/bold green]") + except Exception as json_e: + console.print(f"[bold red]Error saving metadata JSON:[/bold red] {json_e}") + # Decide if we should clean up the downloaded H5AD file + # try: + # os.remove(h5ad_filepath) + # console.print(f"[yellow]Cleaned up partially downloaded H5AD file.[/yellow]") + # except OSError: + # pass + + except (ValueError, RuntimeError, OSError, NotADirectoryError, Exception) as e: + console.print(f"[bold red]Download failed:[/bold red] {e}") + # Potentially add more specific error handling based on exception type + +# --- Display and Interaction Functions --- + +def display_versions_list(console): + """Displays available versions.""" + try: + versions_data = get_census_versions_data() + if not versions_data: + console.print("[yellow]No Census versions found.[/yellow]") + return + + table = Table(title="Available CELLxGENE Census Versions") + table.add_column("Version Tag", style="cyan", justify="right") + table.add_column("Release Date", style="green") + table.add_column("Description / URL", style="magenta") + + + for v_data in versions_data: + table.add_row(v_data["version"], v_data["release_date"], v_data["description"]) + + if HAS_RICH: + console.print(table) + else: + table.print_table(console) # Use fallback print + except Exception as e: + console.print(f"[bold red]Error displaying versions:[/bold red] {e}") + + +def display_paginated_datasets(console, census_version, limit=None, page_size=5): + """Fetches and displays datasets with pagination.""" + try: + datasets_df = fetch_source_datasets_data(census_version) + if datasets_df is None: # Error handled in fetch + return + if datasets_df.empty: # Message handled in fetch + return + + if limit is not None and limit > 0: + datasets_df = datasets_df.head(limit) + total_items_in_view = len(datasets_df) # Number we are actually paging through + if total_items_in_view == 0: + console.print(f"No datasets found matching the criteria within the limit of {limit}.") + return + else: + total_items_in_view = len(datasets_df) + limit = total_items_in_view # Set limit for display consistency + + if total_items_in_view == 0: + console.print(f"No datasets found for version {census_version}.") + return + + total_pages = math.ceil(total_items_in_view / page_size) + current_page = 1 + + while True: + start_index = (current_page - 1) * page_size + end_index = start_index + page_size + page_df = datasets_df.iloc[start_index:end_index] + + if page_df.empty and current_page > 1: # Handle reaching end with partial page + console.print("[yellow]No more datasets to display.[/yellow]") + break + elif page_df.empty: # Only happens if total_items_in_view was 0 initially + console.print("[yellow]No datasets to display.[/yellow]") + break + + range_end = min(end_index, total_items_in_view) + table = Table(title=f"Source Datasets in Census {census_version} (Showing {start_index+1}-{range_end} of {total_items_in_view})") + table.add_column("Dataset ID", style="cyan", no_wrap=True) + table.add_column("Collection Name", style="magenta", overflow="fold") + table.add_column("Dataset Title", style="green", overflow="fold") + table.add_column("Cell Count", style="yellow", justify="right") + + for _, row in page_df.iterrows(): + # Safely format cell_count, handling potential None or non-numeric types + cell_count = row.get('cell_count') + cell_count_str = 'N/A' + if cell_count is not None: + try: + cell_count_str = f"{int(cell_count):,}" + except (ValueError, TypeError): + cell_count_str = str(cell_count) # Fallback to string if not int-convertible + + table.add_row( + row.get('dataset_id', 'N/A'), + row.get('collection_name', 'N/A'), + row.get('dataset_title', 'N/A'), + cell_count_str + ) + + console.print(f"\n--- Page {current_page} of {total_pages} ---") + if HAS_RICH: + console.print(table) + else: + table.print_table(console) + + if total_pages <= 1: + break # No more pages + + choices = [] + prompt_text = "Action" + if current_page > 1: choices.append("P") + if current_page < total_pages: choices.append("N") + choices.append("Q") + + prompt_parts = [] + if "P" in choices: prompt_parts.append("[P]revious") + if "N" in choices: prompt_parts.append("[N]ext") + prompt_parts.append("[Q]uit listing") + prompt_text = ", ".join(prompt_parts) + "?" + + + default_action = "Q" + if current_page < total_pages: default_action = "N" + elif current_page > 1: default_action = "P" + + + action = Prompt.ask( + prompt_text, + choices=choices, + default=default_action + ).upper() + + if action == "N" and current_page < total_pages: + current_page += 1 + elif action == "P" and current_page > 1: + current_page -= 1 + elif action == "Q": + break + else: + console.print("[yellow]Invalid choice.[/yellow]") + + except Exception as e: + console.print(f"[bold red]Error displaying datasets:[/bold red] {e}") + +def display_dataset_metadata(console, census_version, dataset_id): + """Displays metadata for a specific dataset.""" + try: + metadata_dict = get_dataset_metadata_data(census_version, dataset_id) + console.print(f"\nMetadata for Dataset: [bold green]{dataset_id}[/bold green]") + pprint(metadata_dict) # Use rich's pprint or fallback print + except Exception as e: + console.print(f"[bold red]Error displaying metadata:[/bold red] {e}") + + +def print_interactive_help(console): + """Prints help message for interactive mode.""" + console.print("\n[bold cyan]Available Commands:[/bold cyan]") + console.print(" [green]list_versions[/green] List available CELLxGENE Census versions.") + console.print(" [green]list_datasets[/green] [limit] List source datasets (paginated).") + console.print(" : stable, latest, or YYYY-MM-DD") + console.print(" [limit] (optional): Total number of datasets to fetch.") + console.print(" [green]show_metadata[/green] Show metadata for a specific dataset.") + console.print(" [green]download[/green] Download dataset H5AD and metadata JSON.") + console.print(" [green]help[/green] Show this help message.") + console.print(" [green]exit[/green] Exit the interactive browser.") + console.print("\nExample: [yellow]download stable [/yellow]") + + +def interactive_loop(): + """Runs the interactive command loop.""" + console = Console() + console.print("[bold blue]Welcome to the Interactive CZI CELLxGENE Census Browser![/bold blue]") + print_interactive_help(console) + + while True: + try: + if HAS_RICH: + raw_command = Prompt.ask("\nEnter command (\'help\' or \'exit\')") + else: + raw_command = input("\nEnter command ('help' or 'exit'): ").strip() + + if not raw_command: + continue + + try: + command_parts = shlex.split(raw_command) + except ValueError as e: + console.print(f"[red]Error parsing command (check quotes?): {e}[/red]") + continue + + if not command_parts: continue + + command = command_parts[0].lower() + args = command_parts[1:] + + if command == "exit": + break + elif command == "help": + print_interactive_help(console) + elif command == "list_versions": + if len(args) == 0: + display_versions_list(console) + else: + console.print("[yellow]Usage: list_versions[/yellow]") + elif command == "list_datasets": + version = args[0] if len(args) > 0 else None + limit = None + if len(args) > 1: + try: + limit = int(args[1]) + if limit <= 0: + console.print("[red]Limit must be a positive integer.[/red]") + continue + except ValueError: + console.print(f"[red]Invalid limit '{args[1]}'. Must be an integer.[/red]") + continue + if version: + display_paginated_datasets(console, version, limit=limit, page_size=5) + else: + console.print("[yellow]Usage: list_datasets [limit][/yellow]") + elif command == "show_metadata": + version = args[0] if len(args) > 0 else None + dataset_id = args[1] if len(args) > 1 else None + if version and dataset_id: + display_dataset_metadata(console, version, dataset_id) + else: + console.print("[yellow]Usage: show_metadata [/yellow]") + elif command == "download": + version = args[0] if len(args) > 0 else None + dataset_id = args[1] if len(args) > 1 else None + if version and dataset_id: + download_dataset(console, version, dataset_id) + else: + console.print("[yellow]Usage: download [/yellow]") + else: + console.print(f"[red]Unknown command: '{command}'. Type 'help' for options.[/red]") + + except EOFError: + console.print("\n[yellow]EOF detected. Exiting.[/yellow]") + break + except KeyboardInterrupt: + console.print("\n[yellow]Interrupted by user. Type 'exit' to quit.[/yellow]") + except Exception as e: + console.print(f"[bold red]An unexpected error occurred in the interactive loop:[/bold red] {e}") + + + console.print("[bold blue]Exiting browser. Goodbye![/bold blue]") + + +def main(): + # Check if running interactively (no arguments other than script name) + if len(sys.argv) == 1: + interactive_loop() + sys.exit(0) + + # --- Original argparse logic for non-interactive mode --- + parser = argparse.ArgumentParser( + description="CZI CELLxGENE Census Browser CLI. Run without arguments for interactive mode.", + formatter_class=argparse.RawTextHelpFormatter # Keep help text formatting + ) + subparsers = parser.add_subparsers(dest='command', help='Available commands (run without arguments for interactive mode)') + + # Subparser for listing census versions + parser_list_versions = subparsers.add_parser('list-versions', help='List available CELLxGENE Census versions') + parser_list_versions.set_defaults(func=lambda args: display_versions_list(Console())) + + # Subparser for listing datasets within a version + parser_list_datasets = subparsers.add_parser('list-datasets', help='List source datasets within a specific Census version (paginated)') + parser_list_datasets.add_argument('--version', required=True, help='Census version tag (e.g., "stable", "latest", "YYYY-MM-DD")') + parser_list_datasets.add_argument('--limit', type=int, default=None, help='Maximum number of datasets to fetch and paginate through') + parser_list_datasets.add_argument('--page-size', type=int, default=5, help='Number of datasets to show per page (default: 5)') + parser_list_datasets.set_defaults(func=lambda args: display_paginated_datasets(Console(), args.version, args.limit, args.page_size)) + + # Subparser for showing metadata for a specific dataset + parser_show_metadata = subparsers.add_parser('show-metadata', help='Show metadata for a specific source dataset') + parser_show_metadata.add_argument('--version', required=True, help='Census version tag') + parser_show_metadata.add_argument('--dataset-id', required=True, help='The dataset_id') + parser_show_metadata.set_defaults(func=lambda args: display_dataset_metadata(Console(), args.version, args.dataset_id)) + + # Subparser for downloading a dataset + parser_download = subparsers.add_parser('download', help='Download dataset H5AD and metadata JSON') + parser_download.add_argument('--version', required=True, help='Census version tag') + parser_download.add_argument('--dataset-id', required=True, help='The dataset_id to download') + parser_download.set_defaults(func=lambda args: download_dataset(Console(), args.version, args.dataset_id)) + + + # Allow showing help if no subcommand is given when args are present + if len(sys.argv) > 1 and sys.argv[1] not in ['list-versions', 'list-datasets', 'show-metadata', 'download', '-h', '--help']: + args = parser.parse_args(sys.argv[1:2]) # Parse just the first potential command + else: + args = parser.parse_args() + + if hasattr(args, 'func'): + try: + args.func(args) + except Exception as e: + Console().print(f"[bold red]Command failed:[/bold red] {e}") + sys.exit(1) + else: + if len(sys.argv) > 1: + parser.print_help() + + +if __name__ == "__main__": + # Need numpy for JSON conversion of metadata types + try: + import numpy as np + except ImportError: + print("Error: The 'numpy' package is required for saving metadata. Please install it (`pip install numpy`).") + sys.exit(1) + main() diff --git a/benchmarking/tools/output_to_notebook.py b/benchmarking/tools/output_to_notebook.py new file mode 100644 index 0000000..cd4e1a3 --- /dev/null +++ b/benchmarking/tools/output_to_notebook.py @@ -0,0 +1,286 @@ +import json +import nbformat +import re +import base64 +import sys +from pathlib import Path +from datetime import datetime + +# --- Configuration --- +# Default directory to look for input files and save output files +DEFAULT_DIR = Path("./outputs") + +# --- Helper Functions --- +def extract_python_code(text): + """Extracts the first Python code block from text.""" + # Handle potential None input + if text is None: + return None, None + # Regex to find code block and preceding/succeeding text + # It captures text before, the code itself, and text after + match = re.search(r"(.*?)```python\s*([\s\S]+?)\s*```(.*)", text, re.DOTALL) + if match: + text_before = match.group(1).strip() + code = match.group(2).strip() + text_after = match.group(3).strip() + # Combine non-code text parts + non_code_text = (text_before + "\n\n" + text_after).strip() + return non_code_text, code + else: + # No code block found, return all text as non-code + return text.strip(), None + +def create_markdown_cell(source): + """Creates a Markdown cell for the notebook.""" + # Ignore empty source strings + if not source or not source.strip(): + return None + return nbformat.v4.new_markdown_cell(source=source) + +def create_code_cell(code, execution_count=None): + """Creates a Code cell for the notebook.""" + if not code or not code.strip(): + return None + return nbformat.v4.new_code_cell(source=code, execution_count=execution_count) + +def format_outputs_for_notebook(api_outputs): + """Converts the list of outputs from the API response into notebook cell outputs.""" + notebook_outputs = [] + if not api_outputs: + return notebook_outputs + + for item in api_outputs: + output_type = item.get("type") + + if output_type == "stream": + notebook_outputs.append(nbformat.v4.new_output( + output_type="stream", + name=item.get("name", "stdout"), # Default to stdout if name missing + text=item.get("text", "") + )) + elif output_type == "error": + notebook_outputs.append(nbformat.v4.new_output( + output_type="error", + ename=item.get("ename", "Error"), + evalue=item.get("evalue", ""), + traceback=item.get("traceback", []) + )) + elif output_type == "execute_result": + # Prefer text/plain, but include others if available + data = item.get("data", {}) + if data: # Only add if data exists + notebook_outputs.append(nbformat.v4.new_output( + output_type="execute_result", + data=data, # Pass the whole data dict + metadata=item.get("metadata", {}), + execution_count=None # Typically not set on individual outputs + )) + elif output_type == "display_data": + # Handle potential base64 encoded image data + data = item.get("data", {}) + processed_data = {} + metadata = item.get("metadata", {}) # Include metadata for renderers + + for mime, content in data.items(): + # Keep non-image data as is (e.g., text/html, text/plain) + if not mime.startswith("image/"): + processed_data[mime] = content + else: + # Assume image data might be base64 encoded string + # Notebook format expects base64 string directly for images + if isinstance(content, str): + # Basic check if it looks like base64, otherwise skip/warn + try: + # Test decode just to validate format, don't store decoded + base64.b64decode(content) + processed_data[mime] = content # Store the original base64 string + except (TypeError, ValueError): + print(f"Warning: Skipping display_data for mime '{mime}' - content is string but not valid base64.", file=sys.stderr) + else: + print(f"Warning: Skipping display_data for mime '{mime}' - unexpected data type '{type(content)}'.", file=sys.stderr) + + if processed_data: # Only add if data exists and was processed + notebook_outputs.append(nbformat.v4.new_output( + output_type="display_data", + data=processed_data, + metadata=metadata + )) + # Ignore 'status' type messages for cell output + elif output_type == "status": + pass + else: + print(f"Warning: Unknown output type '{output_type}' encountered.", file=sys.stderr) + + return notebook_outputs + +def create_notebook_cells(test_id, test_data): + """ + Generates a list of notebook cells from a single test run's data. + """ + cells = [] + execution_counter = 1 # Track execution count for code cells + + # --- Add Context Cell --- + context = test_data.get("context", {}) + context_md = f"# Test Run: {test_id}\n\n" + context_md += f"**Dataset:** {context.get('dataset_file', 'N/A')}\n" + context_md += f"**Max Code Tries:** {context.get('max_code_tries', 'N/A')}\n" + context_md += f"**Start Time:** {context.get('start_time', 'N/A')}\n\n" + if context.get("dataset_metadata"): + context_md += "## Dataset Metadata\n\n" + context_md += "```json\n" + context_md += json.dumps(context.get("dataset_metadata"), indent=2) + context_md += "\n```\n" + if context.get("error"): # Add setup error if present + context_md += f"\n**SETUP/EXECUTION ERROR:**\n```\n{context.get('error')}\n```\n" + + cells.append(nbformat.v4.new_markdown_cell(context_md)) + + # --- Process Conversation Turns --- + turns = test_data.get("turns", []) + i = 0 + while i < len(turns): + turn = turns[i] + role = turn.get("role") + content = turn.get("content") + + if role == "system": + # Could add system prompt as a collapsed cell, or skip + # cells.append(create_markdown_cell(f"**SYSTEM PROMPT:**\n\n{content}")) + pass # Often skipped in generated notebooks + elif role == "user": + # Check if this is a result message or an initial prompt + if content and content.startswith("Code execution result:\n"): + # This turn contains results, handled by the preceding assistant code cell + pass # Skip placing result directly, it's an output of the code cell + else: + # This is an initial user prompt or follow-up question + md_cell = create_markdown_cell(f"**USER PROMPT:**\n\n{content}") + if md_cell: cells.append(md_cell) + elif role == "assistant": + text_part, code_part = extract_python_code(content) + + # Add markdown cell for the text explanation + md_cell = create_markdown_cell(f"**ASSISTANT:**\n\n{text_part}") + if md_cell: cells.append(md_cell) + + # Add code cell if code exists + if code_part: + code_cell = create_code_cell(code_part, execution_count=execution_counter) + # Look ahead for the corresponding user result turn + if i + 1 < len(turns) and turns[i+1].get("role") == "user" and turns[i+1].get("content", "").startswith("Code execution result:"): + api_response = turns[i+1].get("api_response", {}) + api_outputs = api_response.get("outputs", []) + code_cell.outputs = format_outputs_for_notebook(api_outputs) + # Increment execution counter only if code was executed + execution_counter += 1 + # Skip the next turn since we've processed it as output + i += 1 + else: + # Code was generated but no result followed? Add empty output. + code_cell.outputs = [] + print(f"Warning: Assistant generated code but no result message followed turn {i}.", file=sys.stderr) + + cells.append(code_cell) + + i += 1 # Move to the next turn + + return cells + + +def convert_json_to_ipynb(json_path_str): + """Loads the results JSON and converts it into a Jupyter Notebook.""" + json_path = Path(json_path_str) + + if not json_path.is_file(): + print(f"Error: Input file not found at '{json_path}'", file=sys.stderr) + return + + # Determine output path + output_path = json_path.with_suffix(".ipynb") + print(f"Input JSON: {json_path}") + print(f"Output Notebook: {output_path}") + + try: + with open(json_path, 'r', encoding='utf-8') as f: + all_results = json.load(f) + except json.JSONDecodeError as e: + print(f"Error: Failed to parse JSON file '{json_path}': {e}", file=sys.stderr) + return + except Exception as e: + print(f"Error reading file '{json_path}': {e}", file=sys.stderr) + return + + # Create a new notebook object + notebook = nbformat.v4.new_notebook() + all_cells = [] + + # Iterate through each test run in the results file + for test_id, test_data in all_results.items(): + if not isinstance(test_data, dict): + print(f"Warning: Skipping invalid data structure for test_id '{test_id}'. Expected a dictionary.", file=sys.stderr) + continue + print(f"Processing test run: {test_id}...") + test_cells = create_notebook_cells(test_id, test_data) + all_cells.extend(test_cells) + + notebook['cells'] = all_cells + + # Write the notebook to the output file + try: + with open(output_path, 'w', encoding='utf-8') as f: + nbformat.write(notebook, f) + print(f"Successfully converted results to '{output_path}'") + except Exception as e: + print(f"Error writing notebook file '{output_path}': {e}", file=sys.stderr) + + +def interactive_loop(): + """Handles the interactive user prompts.""" + print("\n--- JSON to Jupyter Notebook Converter ---") + print(f"Searches for JSON files in: {DEFAULT_DIR.resolve()}") + print("Enter the full path to a results JSON file,") + print("or just the filename if it's in the default directory.") + print("Enter 'q' or press Enter to quit.") + + while True: + try: + input_path_str = input("\nEnter JSON file path/name (or q to quit): ").strip() + + if not input_path_str or input_path_str.lower() == 'q': + print("Exiting.") + break + + input_path = Path(input_path_str) + + # If only filename is given, prepend the default directory + if not input_path.is_absolute() and not input_path.exists(): + potential_path = DEFAULT_DIR / input_path + if potential_path.exists() and potential_path.is_file(): + input_path = potential_path + else: + # Try adding .json suffix if missing + potential_path_json = DEFAULT_DIR / input_path.with_suffix(".json") + if potential_path_json.exists() and potential_path_json.is_file(): + input_path = potential_path_json + # Also check original input path with added suffix + elif input_path.with_suffix(".json").exists() and input_path.with_suffix(".json").is_file(): + input_path = input_path.with_suffix(".json") + + + if input_path.is_file() and input_path.suffix.lower() == ".json": + convert_json_to_ipynb(input_path) + else: + print(f"Error: File not found or not a JSON file: '{input_path_str}'") + print(f"(Checked paths: '{input_path}')") + + except Exception as e: + print(f"An unexpected error occurred: {e}", file=sys.stderr) + import traceback + traceback.print_exc() # Print full traceback for debugging + +# --- Main Execution --- +if __name__ == "__main__": + # Create default output directory if it doesn't exist + DEFAULT_DIR.mkdir(parents=True, exist_ok=True) + interactive_loop()