From 97114a42bdd4d7dda060f5984a2f3a3ae3164e28 Mon Sep 17 00:00:00 2001 From: cemde Date: Fri, 5 Dec 2025 18:31:45 +0000 Subject: [PATCH 01/25] original plan --- PLAN.md | 1295 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1295 insertions(+) create mode 100644 PLAN.md diff --git a/PLAN.md b/PLAN.md new file mode 100644 index 0000000..eae60fd --- /dev/null +++ b/PLAN.md @@ -0,0 +1,1295 @@ +# Parallel Task Execution, Timeout Handling, and Task Queue Design + +This document proposes a unified design for three interconnected features that fundamentally improve MASEval's task execution architecture: + +1. **Parallel Processing** - Concurrent task execution via asyncio or threading +2. **Timeout Handling** - Per-task timeout with graceful failure recording +3. **Task Queue** - Callback-driven task scheduling for adaptive testing + +All three features directly impact the `Benchmark.run()` task loop and should be designed together. + +--- + +## Table of Contents + +1. [Current Architecture Analysis](#current-architecture-analysis) +2. [Key Architectural Changes](#key-architectural-changes) +3. [Feature 1: Parallel Processing](#feature-1-parallel-processing) +4. [Feature 2: Timeout Handling & TaskProtocol](#feature-2-timeout-handling--taskprotocol) +5. [Feature 3: TaskQueue](#feature-3-taskqueue) +6. [Unified Design Proposal](#unified-design-proposal) +7. [Implementation Phases](#implementation-phases) +8. [Risks and Mitigations](#risks-and-mitigations) + +--- + +## Current Architecture Analysis + +### The Run Loop (`benchmark.py` lines 990-1330) + +The current execution model is strictly sequential: + +```python +def run(self, tasks: ...): + for task_idx, (task, agent_data) in enumerate(zip(self.tasks, agent_data_list)): + for repeat_idx in range(self.n_task_repeats): + # Setup + environment = self.setup_environment(agent_data, task) + # ... more setup + + # Execute + final_answers = self.execution_loop(agents_to_run, task, environment, user) + + # Evaluate + eval_results = self.evaluate(...) + + # Store + self.reports.append(report) +``` + +### Key Observations + +1. **Sequential by Design**: No parallelism, no timeouts, no queue abstraction +2. **Callback System**: Already has lifecycle hooks (`on_task_start`, `on_task_repeat_end`, etc.) but callbacks cannot influence task ordering +3. **Component Registry**: Per-task-repetition component tracking with `register()` / `clear_registry()` +4. **Error Handling**: Comprehensive status enum (`TaskExecutionStatus`) with graceful failure paths +5. **Agent Adapters**: Framework-specific adapters (smolagents, langgraph) that may or may not be async-native +6. **Model Adapters**: API clients that are inherently I/O-bound + +### Critical Dependencies for Concurrency + +| Component | Thread-Safety | Async-Native | Notes | +| --------------------------- | -------------- | ------------ | -------------------------------- | +| `Benchmark.reports` | ❌ List append | N/A | Needs synchronization | +| `Benchmark._trace_registry` | ❌ Dict | N/A | Per-task, but needs isolation | +| `CallbackHandler` | ❌ | N/A | Callbacks may not be thread-safe | +| `SmolAgentAdapter` | ✅ (stateless) | ❌ | Uses sync `agent.run()` | +| `LangGraphAgentAdapter` | ✅ (stateless) | ⚠️ Partial | LangGraph has `ainvoke()` | +| `GoogleGenAIModelAdapter` | ✅ | ⚠️ Partial | Google client has async methods | + +--- + +## Key Architectural Changes + +This section summarizes the major architectural decisions made during planning. + +### 1. Extract `ComponentRegistry` from `Benchmark` + +**Problem**: The component registry logic (~150 lines) is mixed with benchmark orchestration. Adding thread-local handling will make it worse. + +**Solution**: Extract into a dedicated `ComponentRegistry` class in `maseval/core/registry.py`. + +```python +# maseval/core/registry.py + +import threading +from typing import Dict, Any, Optional +from datetime import datetime + +from .tracing import TraceableMixin +from .config import ConfigurableMixin + + +class ComponentRegistry: + """Thread-safe registry for tracking components during task execution. + + Each thread gets its own isolated registry state, enabling parallel + task execution without cross-contamination. The registry tracks both + Traceable and Configurable components for comprehensive data collection. + + Usage: + registry = ComponentRegistry() + + # Register components (thread-local) + registry.register("agents", "orchestrator", agent_adapter) + registry.register("environment", "env", environment) + + # Collect data + traces = registry.collect_traces() + configs = registry.collect_configs() + + # Clear for next task + registry.clear() + """ + + def __init__(self, benchmark_config: Optional[Dict[str, Any]] = None): + """Initialize the registry. + + Args: + benchmark_config: Benchmark-level configuration to include in + collect_configs() output. This is shared (not thread-local). + """ + self._local = threading.local() + self._benchmark_config = benchmark_config or {} + + # --- Thread-local state properties --- + + @property + def _trace_registry(self) -> Dict[str, TraceableMixin]: + if not hasattr(self._local, 'trace_registry'): + self._local.trace_registry = {} + return self._local.trace_registry + + @property + def _component_id_map(self) -> Dict[int, str]: + if not hasattr(self._local, 'component_id_map'): + self._local.component_id_map = {} + return self._local.component_id_map + + @property + def _config_registry(self) -> Dict[str, ConfigurableMixin]: + if not hasattr(self._local, 'config_registry'): + self._local.config_registry = {} + return self._local.config_registry + + @property + def _config_component_id_map(self) -> Dict[int, str]: + if not hasattr(self._local, 'config_component_id_map'): + self._local.config_component_id_map = {} + return self._local.config_component_id_map + + # --- Public API --- + + def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: + """Register a component for trace and config collection. + + Args: + category: Component category (e.g., "agents", "models", "environment") + name: Unique identifier within the category + component: Component instance (must be TraceableMixin) + + Returns: + The component (for chaining) + + Raises: + ValueError: If component already registered under a different key + """ + component_id = id(component) + key = f"{category}:{name}" + + # Check for duplicate registration under different key + if component_id in self._component_id_map: + existing_key = self._component_id_map[component_id] + if existing_key != key: + raise ValueError( + f"Component already registered as '{existing_key}', " + f"cannot re-register as '{key}'." + ) + return component # Idempotent + + # Register for tracing + self._trace_registry[key] = component + self._component_id_map[component_id] = key + + # Also register for config if supported + if isinstance(component, ConfigurableMixin): + self._config_registry[key] = component + self._config_component_id_map[component_id] = key + + return component + + def clear(self) -> None: + """Clear all registrations for the current thread.""" + self._trace_registry.clear() + self._component_id_map.clear() + self._config_registry.clear() + self._config_component_id_map.clear() + + def collect_traces(self) -> Dict[str, Any]: + """Collect execution traces from all registered components.""" + traces: Dict[str, Any] = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "thread_id": threading.current_thread().ident, + "total_components": len(self._trace_registry), + }, + "agents": {}, + "models": {}, + "tools": {}, + "simulators": {}, + "callbacks": {}, + "environment": None, + "user": None, + "other": {}, + } + + for key, component in self._trace_registry.items(): + category, comp_name = key.split(":", 1) + try: + component_traces = component.gather_traces() + if "name" not in component_traces: + component_traces["name"] = comp_name + + if category == "environment": + traces["environment"] = component_traces + elif category == "user": + traces["user"] = component_traces + else: + if category not in traces: + traces[category] = {} + traces[category][comp_name] = component_traces + except Exception as e: + error_info = {"error": str(e), "error_type": type(e).__name__} + if category in ("environment", "user"): + traces[category] = error_info + else: + if category not in traces: + traces[category] = {} + traces[category][comp_name] = error_info + + return traces + + def collect_configs(self) -> Dict[str, Any]: + """Collect configuration from all registered components.""" + configs: Dict[str, Any] = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "thread_id": threading.current_thread().ident, + "total_components": len(self._config_registry), + }, + "agents": {}, + "models": {}, + "tools": {}, + "simulators": {}, + "callbacks": {}, + "environment": None, + "user": None, + "other": {}, + "benchmark": self._benchmark_config, + } + + for key, component in self._config_registry.items(): + category, comp_name = key.split(":", 1) + try: + component_config = component.gather_config() + if "name" not in component_config: + component_config["name"] = comp_name + + if category == "environment": + configs["environment"] = component_config + elif category == "user": + configs["user"] = component_config + else: + if category not in configs: + configs[category] = {} + configs[category][comp_name] = component_config + except Exception as e: + error_info = {"error": str(e), "error_type": type(e).__name__} + if category in ("environment", "user"): + configs[category] = error_info + else: + if category not in configs: + configs[category] = {} + configs[category][comp_name] = error_info + + return configs +``` + +**Benchmark integration** (delegation pattern): + +```python +class Benchmark: + def __init__(self, ...): + # ... + self._registry = ComponentRegistry( + benchmark_config=gather_benchmark_config() + ) + + def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: + """Register a component. Delegates to internal registry.""" + return self._registry.register(category, name, component) + + def clear_registry(self) -> None: + """Clear registry after task repetition.""" + self._registry.clear() + + def collect_all_traces(self) -> Dict[str, Any]: + """Collect traces. Delegates to internal registry.""" + return self._registry.collect_traces() + + def collect_all_configs(self) -> Dict[str, Any]: + """Collect configs. Delegates to internal registry.""" + return self._registry.collect_configs() +``` + +**Benefits**: + +- Single Responsibility: Benchmark orchestrates, Registry tracks components +- Testability: Registry can be unit tested in isolation +- Clarity: Thread-local complexity encapsulated in one place +- Zero API changes: Users still call `benchmark.register(...)` + +### 2. Threading over asyncio + +**Decision**: Use `ThreadPoolExecutor` for parallel task execution. + +**Rationale**: + +- No user code changes required (async would require rewriting `run_agents()`) +- Works with all agent frameworks (smolagents is sync-only) +- Same I/O concurrency benefits for LLM API calls +- Future-proof: Python's GIL removal will make threading even more powerful + +### 3. MASEval-Managed Callback Thread Safety + +**Decision**: MASEval serializes all callback invocations with an internal lock. + +**Rationale**: + +- Users don't need to think about thread safety +- Negligible performance cost (callbacks are fast) +- Prevents subtle race condition bugs + +```python +class Benchmark: + def __init__(self, ...): + self._callback_lock = threading.Lock() + + def _invoke_callbacks(self, method_name: str, *args, **kwargs): + with self._callback_lock: + for cb in self.callbacks: + getattr(cb, method_name)(*args, **kwargs) +``` + +### 4. Cooperative Timeout with Hard Backstop + +**Decision**: Use cooperative checkpoint-based timeout with a hard timeout fallback. + +**Rationale**: + +- Cross-platform (signal-based only works on Unix) +- Works in threads (signals only work in main thread) +- Clean interruption at defined checkpoints +- Hard timeout as last resort for misbehaving code + +**Limitation**: Python threads cannot be forcibly killed. Timeout is "best effort." + +--- + +## Feature 1: Parallel Processing + +### Decision: Threading with `ThreadPoolExecutor` + +We use **threading** (not asyncio) for parallel task execution. + +#### Why Threading Over asyncio + +| Consideration | Threading | asyncio | +| ----------------------- | ----------------------------- | ------------------------------------ | +| User code changes | None | Must rewrite `run_agents()` as async | +| Agent framework support | All (smolagents is sync-only) | Only async-native frameworks | +| API signature | Unchanged | Breaking (`async def run()`) | +| Mental model | Familiar to most developers | Requires async expertise | +| Future GIL removal | Benefits automatically | No additional benefit | + +**asyncio would require**: + +- All user-implemented methods become `async def` +- Wrapper code for sync agent frameworks (smolagents) +- Breaking API changes throughout + +**Threading provides**: + +- Zero user code changes +- Works with all agent frameworks today +- Same I/O concurrency benefits (LLM API calls) +- Future-proof: when Python removes the GIL, threading will gain true parallelism + +#### Implementation: `ThreadPoolExecutor` + +```python +from concurrent.futures import ThreadPoolExecutor, as_completed + +def run(self, tasks, max_workers: int = 1): # max_workers=1 = sequential (default) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {} + for task, agent_data in zip(self.tasks, agent_data_list): + future = executor.submit(self._run_single_task, task, agent_data) + futures[future] = task + + for future in as_completed(futures): + report = future.result() + self._append_report_safe(report) # Thread-safe +``` + +**Key design points**: + +1. **Backwards Compatible**: `max_workers=1` maintains current sequential behavior +2. **Framework Agnostic**: Works with sync agent frameworks (smolagents) +3. **I/O Parallelism**: Multiple LLM API calls can happen concurrently +4. **Opt-in**: Users explicitly enable parallelism + +#### Thread-Local Component Registry + +The component registry is already cleared after each task repetition. For parallel execution, we make it thread-local so concurrent tasks don't share registries: + +```python +import threading + +class Benchmark: + def __init__(self, ...): + self._local = threading.local() + + @property + def _trace_registry(self): + if not hasattr(self._local, 'trace_registry'): + self._local.trace_registry = {} + return self._local.trace_registry + + @property + def _component_id_map(self): + if not hasattr(self._local, 'component_id_map'): + self._local.component_id_map = {} + return self._local.component_id_map + + # Same pattern for _config_registry, _config_component_id_map +``` + +This is the correct design because: + +- Each task repetition runs in one thread +- Registries are ephemeral (cleared after each repetition via `clear_registry()`) +- No cross-task state sharing is intended + +#### Thread-Safe Report Collection + +```python +import threading + +class Benchmark: + def __init__(self, ...): + self._reports_lock = threading.Lock() + + def _append_report_safe(self, report): + with self._reports_lock: + self.reports.append(report) +``` + +#### Thread-Safe Callback Invocation + +MASEval serializes all callback invocations internally, so **users don't need to implement thread-safe callbacks**: + +```python +class Benchmark: + def __init__(self, ...): + self._callback_lock = threading.Lock() + + def _invoke_callbacks(self, method_name: str, *args, **kwargs): + """Invoke a callback method on all registered callbacks (thread-safe).""" + with self._callback_lock: + for cb in self.callbacks: + getattr(cb, method_name)(*args, **kwargs) +``` + +**User impact**: None. Users write callbacks exactly as they do today: + +```python +class MyCallback(BenchmarkCallback): + def __init__(self): + self.count = 0 # No lock needed! + + def on_task_repeat_end(self, benchmark, report): + self.count += 1 # Safe because MASEval serializes calls +``` + +This approach: + +- Eliminates thread-safety burden on users +- Has negligible performance cost (callbacks are fast) +- Prevents an entire class of subtle bugs + +```` + +--- + +## Feature 2: Timeout Handling & TaskProtocol + +### Design Goal + +Enable per-task timeout configuration, capturing partial results on timeout. + +### The `TaskProtocol` Concept + +A `TaskProtocol` dataclass defines task-level execution parameters. It's attached to `Task` but describes how MASEval should run the task, not task content. + +```python +from dataclasses import dataclass, field +from typing import Optional +from enum import Enum + + +class TimeoutAction(Enum): + """What to do when a timeout occurs.""" + SKIP = "skip" # Mark as timed out, continue to next task + RETRY = "retry" # Retry once with same timeout + EXTEND = "extend" # Double timeout and retry + + +@dataclass +class TaskProtocol: + """Configuration for how MASEval executes a task. + + This is a data container for execution parameters, separate from + task content (query, environment_data, etc.). It controls the + interface between the task and MASEval's execution engine. + + Attributes: + timeout_seconds: Maximum execution time for this task. None means no timeout. + timeout_action: Action to take when timeout occurs. + max_retries: Maximum retry attempts for transient failures (not timeouts). + priority: Execution priority (higher = sooner). Used by adaptive task queues. + tags: Arbitrary tags for filtering or grouping tasks. + """ + timeout_seconds: Optional[float] = None + timeout_action: TimeoutAction = TimeoutAction.SKIP + max_retries: int = 0 + priority: int = 0 + tags: dict = field(default_factory=dict) +```` + +### Attaching Protocol to Task + +```python +@dataclass +class Task: + query: str + id: UUID = field(default_factory=uuid4) + environment_data: Dict[str, Any] = field(default_factory=dict) + evaluation_data: Dict[str, Any] = field(default_factory=dict) + metadata: Dict[str, Any] = field(default_factory=dict) + + # New: execution protocol + protocol: TaskProtocol = field(default_factory=TaskProtocol) +``` + +### Timeout Implementation Strategies + +#### Strategy A: `concurrent.futures` with Timeout + +```python +from concurrent.futures import ThreadPoolExecutor, TimeoutError + +def _run_task_with_timeout(self, task, agent_data, timeout: Optional[float]): + """Run a single task with optional timeout.""" + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(self._run_single_task_inner, task, agent_data) + try: + return future.result(timeout=timeout) + except TimeoutError: + # Attempt to cancel (may not stop running code) + future.cancel() + return self._create_timeout_report(task) +``` + +**Problem**: `future.cancel()` doesn't actually stop running Python code. The task continues executing in the background. + +#### Strategy B: `multiprocessing` with Termination + +```python +from multiprocessing import Process, Queue + +def _run_task_with_timeout(self, task, agent_data, timeout): + result_queue = Queue() + process = Process(target=self._run_in_process, args=(task, agent_data, result_queue)) + process.start() + process.join(timeout=timeout) + + if process.is_alive(): + process.terminate() # Actually kills the task + return self._create_timeout_report(task) + + return result_queue.get() +``` + +**Problem**: Process isolation means no shared state. Components can't be registered, traces can't be collected incrementally. + +#### Strategy C: Signal-Based Timeout (Unix only) + +```python +import signal + +class TimeoutException(Exception): + pass + +def _run_task_with_timeout(self, task, agent_data, timeout): + def handler(signum, frame): + raise TimeoutException() + + old_handler = signal.signal(signal.SIGALRM, handler) + signal.alarm(int(timeout)) + try: + return self._run_single_task_inner(task, agent_data) + except TimeoutException: + return self._create_timeout_report(task) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) +``` + +**Problem**: Only works on Unix. Doesn't work in threads (signal only works in main thread). + +#### Strategy D: Cooperative Timeout with Checkpoints (Recommended) + +The cleanest approach that works cross-platform and with threads is **cooperative timeout checking**: + +```python +import time +import threading + +class TaskContext: + """Execution context passed to user code for timeout checking.""" + + def __init__(self, deadline: Optional[float] = None): + self._deadline = deadline + self._start_time = time.monotonic() + + @property + def elapsed(self) -> float: + return time.monotonic() - self._start_time + + @property + def remaining(self) -> Optional[float]: + if self._deadline is None: + return None + return max(0, self._deadline - self.elapsed) + + @property + def is_expired(self) -> bool: + return self._deadline is not None and self.elapsed >= self._deadline + + def check_timeout(self): + """Raise TimeoutError if deadline exceeded. Call at checkpoints.""" + if self.is_expired: + raise TaskTimeoutError(f"Task exceeded {self._deadline}s deadline") +``` + +**Usage in `run_agents()`**: + +```python +def run_agents(self, agents, task, environment, query, context: TaskContext) -> Any: + for step in range(self.max_steps): + context.check_timeout() # Cooperative checkpoint + result = agents[0].run(query) + # ... +``` + +**Hybrid with Hard Timeout**: Combine cooperative checking with a hard timeout fallback: + +```python +def _run_task_with_timeout(self, task, agent_data, timeout): + context = TaskContext(deadline=timeout) + + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(self._run_single_task_inner, task, agent_data, context) + try: + # Hard timeout as backstop + return future.result(timeout=timeout + 5) # Grace period + except TimeoutError: + return self._create_timeout_report(task, partial_traces=context.collected_traces) +``` + +### New Exception Type + +```python +class TaskTimeoutError(MASEvalError): + """Task execution exceeded configured timeout. + + This is classified as TASK_TIMEOUT in benchmark results, separate from + other error types. Timeout is neither agent's fault nor infrastructure's + fault—it's a resource constraint. + """ + + def __init__(self, message: str, elapsed: float, timeout: float, partial_traces: Optional[Dict] = None): + super().__init__(message, component="timeout") + self.elapsed = elapsed + self.timeout = timeout + self.partial_traces = partial_traces or {} +``` + +### New Status + +```python +class TaskExecutionStatus(Enum): + SUCCESS = "success" + AGENT_ERROR = "agent_error" + ENVIRONMENT_ERROR = "environment_error" + USER_ERROR = "user_error" + UNKNOWN_EXECUTION_ERROR = "unknown_execution_error" + EVALUATION_FAILED = "evaluation_failed" + SETUP_FAILED = "setup_failed" + TASK_TIMEOUT = "task_timeout" # NEW +``` + +--- + +## Feature 3: TaskQueue + +### Design Goal + +Replace the static `for task in tasks` loop with a queue abstraction that enables: + +1. Dynamic task ordering +2. Callback-driven scheduling (adaptive testing) +3. Priority-based execution +4. Conditional task skipping + +### The `TaskQueue` Interface + +```python +from abc import ABC, abstractmethod +from typing import Iterator, Optional, Tuple + +class TaskQueue(ABC): + """Abstract base for task scheduling strategies.""" + + @abstractmethod + def __iter__(self) -> Iterator[Tuple[Task, Dict]]: + """Yield (task, agent_data) pairs in execution order.""" + pass + + def on_task_complete(self, task: Task, report: Dict) -> None: + """Called after each task completes. Override for adaptive behavior.""" + pass + + def should_continue(self) -> bool: + """Whether to continue processing. Default: True until queue exhausted.""" + return True + + +class SequentialQueue(TaskQueue): + """Default: execute tasks in order (current behavior).""" + + def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict]): + self._tasks = list(zip(tasks, agent_data_list)) + self._index = 0 + + def __iter__(self): + for task, agent_data in self._tasks: + yield task, agent_data + + +class PriorityQueue(TaskQueue): + """Execute tasks by priority (from TaskProtocol.priority).""" + + def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict]): + paired = list(zip(tasks, agent_data_list)) + # Sort by priority descending + self._tasks = sorted(paired, key=lambda x: x[0].protocol.priority, reverse=True) + + def __iter__(self): + for task, agent_data in self._tasks: + yield task, agent_data + + +class AdaptiveQueue(TaskQueue): + """Adaptive testing: adjust task order based on results. + + Example: Item Response Theory (IRT) based testing that estimates + agent difficulty and selects optimally informative tasks. + """ + + def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict]): + self._pending = list(zip(tasks, agent_data_list)) + self._completed = [] + self._agent_ability_estimate = 0.0 + + def __iter__(self): + while self._pending and self.should_continue(): + # Select next task based on estimated ability + next_task = self._select_next_task() + if next_task: + yield next_task + + def _select_next_task(self) -> Optional[Tuple[Task, Dict]]: + """Select task that maximizes information gain.""" + if not self._pending: + return None + + # IRT-based selection (simplified) + best_idx = 0 + best_info = 0 + + for idx, (task, _) in enumerate(self._pending): + difficulty = task.metadata.get("difficulty", 0.5) + # Fisher information at current ability estimate + info = self._fisher_information(difficulty, self._agent_ability_estimate) + if info > best_info: + best_info = info + best_idx = idx + + return self._pending.pop(best_idx) + + def on_task_complete(self, task: Task, report: Dict) -> None: + """Update ability estimate based on task result.""" + self._completed.append((task, report)) + self._update_ability_estimate() + + def _update_ability_estimate(self): + """Bayesian update of ability estimate.""" + # Implementation depends on IRT model + pass + + def should_continue(self) -> bool: + """Stop when estimate is precise enough.""" + return len(self._completed) < 50 # Example stopping rule +``` + +### Integration with `Benchmark.run()` + +```python +def run( + self, + tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]], + queue: Optional[TaskQueue] = None, + max_workers: int = 1, +) -> List[Dict[str, Any]]: + # Normalize tasks + task_collection = self._normalize_tasks(tasks) + agent_data_list = self._normalize_agent_data(task_collection) + + # Create queue (default: sequential) + if queue is None: + queue = SequentialQueue(task_collection, agent_data_list) + + # Callbacks + for cb in self.callbacks: + cb.on_run_start(self) + + # Execute via queue + if max_workers == 1: + self._run_sequential(queue) + else: + self._run_parallel(queue, max_workers) + + # Callbacks + for cb in self.callbacks: + cb.on_run_end(self, self.reports) + + return self.reports + +def _run_sequential(self, queue: TaskQueue): + for task, agent_data in queue: + for repeat_idx in range(self.n_task_repeats): + report = self._execute_single_repetition(task, agent_data, repeat_idx) + self.reports.append(report) + queue.on_task_complete(task, report) + + if not queue.should_continue(): + return +``` + +### Callback Integration for Adaptive Testing + +The existing `BenchmarkCallback` can be extended: + +```python +class BenchmarkCallback(ABC, TraceableMixin): + # ... existing methods ... + + def on_task_selected(self, benchmark: "Benchmark", task: "Task", queue: "TaskQueue"): + """Called when TaskQueue selects the next task to run.""" + pass + + def on_queue_decision(self, benchmark: "Benchmark", queue: "TaskQueue", should_continue: bool): + """Called when TaskQueue makes a continue/stop decision.""" + pass +``` + +--- + +## Unified Design Proposal + +### Architecture Overview + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Benchmark.run() │ +├─────────────────────────────────────────────────────────────────┤ +│ ┌───────────────────┐ │ +│ │ TaskQueue │ ← Adaptive/Priority/Sequential │ +│ │ (iterator) │ │ +│ └────────┬──────────┘ │ +│ │ yields (Task, agent_data) │ +│ ▼ │ +│ ┌────────────────────────────────────────────────────────────┐│ +│ │ ThreadPoolExecutor (max_workers) ││ +│ │ ┌──────────────────────────────────────────────────────┐ ││ +│ │ │ Worker Thread 1 │ ││ +│ │ │ ┌─────────────────────────────────────────────────┐ │ ││ +│ │ │ │ TaskContext (deadline, checkpoints) │ │ ││ +│ │ │ │ ┌─────────────────────────────────────────────┐ │ │ ││ +│ │ │ │ │ setup → execution_loop → evaluate │ │ │ ││ +│ │ │ │ │ (Task.protocol.timeout_seconds) │ │ │ ││ +│ │ │ │ └─────────────────────────────────────────────┘ │ │ ││ +│ │ │ └─────────────────────────────────────────────────┘ │ ││ +│ │ └──────────────────────────────────────────────────────┘ ││ +│ │ ┌──────────────────────────────────────────────────────┐ ││ +│ │ │ Worker Thread 2 ... │ ││ +│ │ └──────────────────────────────────────────────────────┘ ││ +│ └────────────────────────────────────────────────────────────┘│ +│ │ │ +│ ▼ reports │ +│ ┌───────────────────┐ │ +│ │ Thread-Safe │ │ +│ │ Report Collector │ │ +│ └───────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +### Complete `run()` Implementation + +```python +def run( + self, + tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]], + queue: Optional[TaskQueue] = None, + max_workers: int = 1, +) -> List[Dict[str, Any]]: + """Run benchmark with parallel processing, timeouts, and adaptive scheduling. + + Args: + tasks: Tasks to execute. + queue: Task scheduling strategy. Default: SequentialQueue. + max_workers: Maximum parallel task executions. Default: 1 (sequential). + + Returns: + List of report dictionaries. + """ + # Normalize inputs + self.tasks = self._normalize_tasks(tasks) + agent_data_list = self._normalize_agent_data() + + # Create queue + if queue is None: + queue = SequentialQueue(self.tasks, agent_data_list) + + # Clear reports + self.reports = [] + self._reports_lock = threading.Lock() + + # Run start callbacks + for cb in self.callbacks: + cb.on_run_start(self) + + # Execute + if max_workers == 1: + self._run_sequential(queue) + else: + self._run_parallel(queue, max_workers) + + # Run end callbacks + for cb in self.callbacks: + cb.on_run_end(self, self.reports) + + return self.reports + +def _run_sequential(self, queue: TaskQueue): + """Sequential execution with timeout support.""" + for task, agent_data in queue: + for cb in self.callbacks: + cb.on_task_start(self, task) + + for repeat_idx in range(self.n_task_repeats): + report = self._execute_task_repetition(task, agent_data, repeat_idx) + self._append_report_safe(report) + queue.on_task_complete(task, report) + + for cb in self.callbacks: + cb.on_task_end(self, task, self._last_report_for_task(task)) + + if not queue.should_continue(): + break + +def _run_parallel(self, queue: TaskQueue, max_workers: int): + """Parallel execution with timeout support.""" + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {} + + # Submit initial batch + for task, agent_data in queue: + for repeat_idx in range(self.n_task_repeats): + future = executor.submit( + self._execute_task_repetition, + task, agent_data, repeat_idx + ) + futures[future] = (task, repeat_idx) + + if len(futures) >= max_workers * 2: + break # Don't over-submit + + # Process completions and submit more + while futures: + done, _ = wait(futures, return_when=FIRST_COMPLETED) + + for future in done: + task, repeat_idx = futures.pop(future) + try: + report = future.result() + except Exception as e: + report = self._create_error_report(task, repeat_idx, e) + + self._append_report_safe(report) + queue.on_task_complete(task, report) + + # Callbacks serialized internally (thread-safe for users) + self._invoke_callbacks('on_task_repeat_end', self, report) + + if not queue.should_continue(): + # Cancel remaining futures + for f in futures: + f.cancel() + return + + # Submit more work + try: + task, agent_data = next(iter(queue)) + for repeat_idx in range(self.n_task_repeats): + future = executor.submit( + self._execute_task_repetition, + task, agent_data, repeat_idx + ) + futures[future] = (task, repeat_idx) + except StopIteration: + pass + +def _execute_task_repetition( + self, + task: Task, + agent_data: Dict[str, Any], + repeat_idx: int, +) -> Dict[str, Any]: + """Execute a single task repetition with timeout handling.""" + timeout = task.protocol.timeout_seconds + context = TaskContext(deadline=timeout) + + # Thread-local registry for this execution + local_registry = {} + + try: + # Setup + environment = self.setup_environment(agent_data, task) + user = self.setup_user(agent_data, environment, task) + agents_to_run, agents_dict = self.setup_agents(agent_data, environment, task, user) + evaluators = self.setup_evaluators(environment, task, agents_to_run, user) + + # Register components (thread-local) + local_registry.update(self._register_components(environment, user, agents_dict)) + + # Execute with timeout checking + context.check_timeout() + final_answer = self.execution_loop(agents_to_run, task, environment, user, context) + + # Collect traces + traces = self._collect_traces(local_registry) + configs = self._collect_configs(local_registry) + + # Evaluate + context.check_timeout() + eval_results = self.evaluate(evaluators, agents_dict, final_answer, traces) + + return { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": TaskExecutionStatus.SUCCESS.value, + "traces": traces, + "config": configs, + "eval": eval_results, + } + + except TaskTimeoutError as e: + return { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": TaskExecutionStatus.TASK_TIMEOUT.value, + "traces": e.partial_traces, + "config": {}, + "eval": None, + "error": { + "error_type": "TaskTimeoutError", + "error_message": str(e), + "elapsed": e.elapsed, + "timeout": e.timeout, + }, + } + except AgentError as e: + # ... existing error handling + pass +``` + +--- + +## Implementation Phases + +### Phase 0: Extract ComponentRegistry (Low Risk, Do First) + +**Scope**: Extract registry logic from `Benchmark` into dedicated `ComponentRegistry` class. + +**Files Modified**: + +- `maseval/core/registry.py` (new): `ComponentRegistry` with thread-local storage +- `maseval/core/benchmark.py`: Replace inline registry with delegation to `ComponentRegistry` +- `maseval/core/__init__.py`: Export `ComponentRegistry` + +**Effort**: ~1-2 days + +**Breaking Changes**: None (public API unchanged, internal refactoring only) + +**Why first**: This refactoring is needed for clean parallel execution. Doing it first: + +- Isolates the thread-local complexity +- Makes subsequent phases simpler +- Can be tested and merged independently + +### Phase 1: TaskProtocol & Timeout (Low Risk) + +**Scope**: Add `TaskProtocol` dataclass, integrate cooperative timeout. + +**Files Modified**: + +- `maseval/core/task.py`: Add `TaskProtocol`, attach to `Task` +- `maseval/core/exceptions.py`: Add `TaskTimeoutError` +- `maseval/core/benchmark.py`: Add `TaskContext`, timeout checking in execution + +**Effort**: ~2-3 days + +**Breaking Changes**: None (new optional field with defaults) + +### Phase 2: TaskQueue Abstraction (Medium Risk) + +**Scope**: Extract task iteration into `TaskQueue`, maintain sequential default. + +**Files Modified**: + +- `maseval/core/queue.py` (new): `TaskQueue`, `SequentialQueue`, `PriorityQueue` +- `maseval/core/benchmark.py`: Refactor `run()` to use queue + +**Effort**: ~3-4 days + +**Breaking Changes**: None (signature changes are additive) + +### Phase 3: Parallel Execution (Higher Risk) + +**Scope**: Add `max_workers` parameter, thread-safe report collection, callback locking. + +**Files Modified**: + +- `maseval/core/benchmark.py`: Add `_run_parallel()`, `_invoke_callbacks()`, `_append_report_safe()` + +**Effort**: ~4-5 days + +**Breaking Changes**: None. MASEval handles all thread safety internally. + +**Note**: Requires Phase 0 (ComponentRegistry) to be complete. + +### Phase 4: AdaptiveQueue (Collaborator-Driven) + +**Scope**: Implement `AdaptiveQueue` for IRT-based adaptive testing. + +**Files Modified**: + +- `maseval/core/queue.py`: Add `AdaptiveQueue` base or concrete implementation +- `maseval/core/callback.py`: Add `on_task_selected`, `on_queue_decision` (if needed) + +**Effort**: ~3-4 days (depends on algorithm complexity) + +**Breaking Changes**: None + +**Note**: This phase will be driven by collaborator implementing their adaptive sampling paper. MASEval provides the `TaskQueue` interface; they implement the selection algorithm. + +--- + +## Risks and Mitigations + +### Risk 1: Thread Safety Bugs + +**Mitigation**: + +- Thread-local storage for per-task registries (already ephemeral per-repetition) +- Lock for shared report list +- Lock for callback invocations (users don't need to think about this) +- Default to `max_workers=1` for backwards compatibility +- Comprehensive tests with race condition detection + +### Risk 2: Framework Incompatibility + +**Mitigation**: + +- Test with all supported frameworks (smolagents, langgraph, llamaindex) +- Document that user's `run_agents()` should not rely on shared mutable benchmark state +- All current adapters are stateless per-invocation (already safe) + +### Risk 3: Timeout Incomplete Cleanup + +**Mitigation**: + +- Cooperative timeout with checkpoints (clean interruption points) +- Hard timeout as backstop—logs warning but continues gracefully +- Document that timed-out tasks may leave external resources (API connections) in undefined state +- Timeout is "best effort"—we cannot forcibly kill Python threads + +### Risk 4: Callback Ordering in Parallel Mode + +**Mitigation**: + +- In parallel mode, `on_task_repeat_end` order is non-deterministic +- Document this behavior clearly +- Callbacks are still serialized (never concurrent), just out-of-order + +### Risk 5: Memory Pressure with Many Workers + +**Mitigation**: + +- Default `max_workers=1` +- Document memory implications +- Consider `max_workers="auto"` that uses `os.cpu_count()` + +--- + +## Summary + +### Implementation Order + +| Phase | Feature | Risk | Effort | Dependencies | +| ----- | ---------------------------- | ------ | -------- | ------------ | +| 0 | ComponentRegistry extraction | Low | 1-2 days | None | +| 1 | TaskProtocol & Timeout | Low | 2-3 days | None | +| 2 | TaskQueue abstraction | Medium | 3-4 days | None | +| 3 | Parallel Execution | Higher | 4-5 days | Phase 0 | +| 4 | AdaptiveQueue | Medium | 3-4 days | Phase 2 | + +### Feature Summary + +| Feature | Approach | Breaking Changes | +| ------------------- | --------------------------------------------- | ---------------- | +| ComponentRegistry | Extracted class with thread-local state | None | +| Parallel Processing | `ThreadPoolExecutor` with `max_workers` param | None | +| Timeout Handling | Cooperative checkpoints + hard backstop | None | +| TaskQueue | Iterator abstraction with `on_task_complete` | None | +| Callback Safety | MASEval serializes with internal lock | None | + +### Key Design Decisions + +1. **Extract ComponentRegistry**: Separate concerns. Registry manages thread-local component tracking. Benchmark orchestrates execution. Enables clean parallel implementation. + +2. **Threading over asyncio**: No user code changes required. Works with all agent frameworks (including sync-only smolagents). Future-proof for Python's GIL removal. + +3. **MASEval-managed callback safety**: All callback invocations are serialized with a lock. Users never need to think about thread safety in their callbacks. + +4. **Cooperative timeout**: Cross-platform, works in threads, clean interruption at defined checkpoints. Hard timeout as backstop for misbehaving code (best-effort only—Python threads cannot be killed). + +5. **AdaptiveQueue for collaborator**: The `TaskQueue` interface enables a collaborator to implement their adaptive sampling paper. MASEval provides the hooks; they implement the algorithm. + +### What's NOT Changing + +- **Public API**: All existing methods work unchanged +- **User-implemented methods**: `run_agents()`, `setup_environment()`, etc. stay sync +- **Callback interface**: Users write callbacks exactly as today +- **Default behavior**: `max_workers=1` maintains sequential execution + +The unified design maintains **full backwards compatibility** while enabling: + +- **Faster benchmarks** through parallelism +- **Resource-bounded execution** through timeouts +- **Intelligent task selection** through adaptive queues + +All features share the same execution model refactor, making them natural to implement together. From 4104cb1122ac958eb61e4414a7c60914a19bbeca Mon Sep 17 00:00:00 2001 From: cemde Date: Fri, 5 Dec 2025 19:49:50 +0000 Subject: [PATCH 02/25] initial implementation --- CHANGELOG.md | 20 + maseval/__init__.py | 17 +- maseval/core/benchmark.py | 857 ++++++++++-------- maseval/core/context.py | 122 +++ maseval/core/exceptions.py | 64 ++ maseval/core/queue.py | 221 +++++ maseval/core/registry.py | 243 +++++ maseval/core/task.py | 34 + .../test_automatic_registration.py | 18 +- .../test_benchmark_lifecycle.py | 10 +- 10 files changed, 1214 insertions(+), 392 deletions(-) create mode 100644 maseval/core/context.py create mode 100644 maseval/core/queue.py create mode 100644 maseval/core/registry.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3850a69..13412ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,28 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +**Parallel Execution** + +- Added parallel task execution with `max_workers` parameter in `Benchmark.run()` using `ThreadPoolExecutor` (PR: #PR_NUMBER_PLACEHOLDER) +- Added `ComponentRegistry` class for thread-safe component registration with thread-local storage (PR: #PR_NUMBER_PLACEHOLDER) +- Added `TaskContext` for cooperative timeout checking with `check_timeout()`, `elapsed`, `remaining`, and `is_expired` properties (PR: #PR_NUMBER_PLACEHOLDER) +- Added `TaskProtocol` dataclass with `timeout_seconds`, `timeout_action`, `max_retries`, `priority`, and `tags` fields for task-level execution control (PR: #PR_NUMBER_PLACEHOLDER) +- Added `TimeoutAction` enum (`SKIP`, `RETRY`, `RAISE`) for configurable timeout behavior (PR: #PR_NUMBER_PLACEHOLDER) +- Added `TaskTimeoutError` exception with `elapsed`, `timeout`, and `partial_traces` attributes (PR: #PR_NUMBER_PLACEHOLDER) +- Added `TASK_TIMEOUT` to `TaskExecutionStatus` enum for timeout classification (PR: #PR_NUMBER_PLACEHOLDER) + +**Task Queue Abstraction** + +- Added `TaskQueue` abstract base class with iterator interface for flexible task scheduling (PR: #PR_NUMBER_PLACEHOLDER) +- Added `SequentialQueue` for simple FIFO task ordering (PR: #PR_NUMBER_PLACEHOLDER) +- Added `PriorityQueue` for priority-based task scheduling using `TaskProtocol.priority` (PR: #PR_NUMBER_PLACEHOLDER) +- Added `AdaptiveQueue` placeholder for future feedback-based scheduling (PR: #PR_NUMBER_PLACEHOLDER) + ### Changed +- Refactored `Benchmark` to delegate registry operations to `ComponentRegistry` class (PR: #PR_NUMBER_PLACEHOLDER) +- `Benchmark.run()` now accepts optional `queue` parameter for custom task scheduling (PR: #PR_NUMBER_PLACEHOLDER) + ### Fixed ### Removed diff --git a/maseval/__init__.py b/maseval/__init__.py index 11ea20d..1f4e831 100644 --- a/maseval/__init__.py +++ b/maseval/__init__.py @@ -8,7 +8,7 @@ Benchmarks sit in the `maseval.benchmark` submodule. """ -from .core.task import Task, TaskCollection +from .core.task import Task, TaskCollection, TaskProtocol, TimeoutAction from .core.environment import Environment from .core.agent import AgentAdapter from .core.benchmark import Benchmark, TaskExecutionStatus @@ -27,11 +27,15 @@ from .core.evaluator import Evaluator from .core.history import MessageHistory, ToolInvocationHistory from .core.tracing import TraceableMixin +from .core.registry import ComponentRegistry +from .core.context import TaskContext +from .core.queue import TaskQueue, SequentialQueue, PriorityQueue, AdaptiveQueue from .core.exceptions import ( MASEvalError, AgentError, EnvironmentError, UserError, + TaskTimeoutError, validate_argument_type, validate_required_arguments, validate_no_extra_arguments, @@ -42,6 +46,8 @@ # Tasks "Task", "TaskCollection", + "TaskProtocol", + "TimeoutAction", # Core abstractions "Environment", "AgentAdapter", @@ -69,11 +75,20 @@ "ToolInvocationHistory", "ModelAdapter", "TraceableMixin", + # Registry and execution context + "ComponentRegistry", + "TaskContext", + # Task queues + "TaskQueue", + "SequentialQueue", + "PriorityQueue", + "AdaptiveQueue", # Exceptions and validation "MASEvalError", "AgentError", "EnvironmentError", "UserError", + "TaskTimeoutError", "validate_argument_type", "validate_required_arguments", "validate_no_extra_arguments", diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index 14f6432..3780f99 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -2,8 +2,10 @@ from typing import Any, Dict, List, Iterable, Optional, Sequence, Tuple, Union, cast from datetime import datetime import threading +from concurrent.futures import ThreadPoolExecutor, as_completed from enum import Enum import warnings +import traceback from .evaluator import Evaluator from .task import Task, TaskCollection @@ -15,13 +17,16 @@ from .user import User from .tracing import TraceableMixin from .config import ConfigurableMixin +from .registry import ComponentRegistry +from .queue import TaskQueue, SequentialQueue +from .context import TaskContext from .utils.system_info import gather_benchmark_config from .callbacks.progress_bar import ( ProgressBarCallback, TqdmProgressBarCallback, RichProgressBarCallback, ) -from .exceptions import AgentError, EnvironmentError, UserError +from .exceptions import AgentError, EnvironmentError, UserError, TaskTimeoutError class TaskExecutionStatus(Enum): @@ -39,13 +44,14 @@ class TaskExecutionStatus(Enum): AGENT_ERROR: Agent violated contract at a boundary (agent's fault, counts against score). ENVIRONMENT_ERROR: Environment/tool infrastructure failed (not agent's fault, exclude from scoring). USER_ERROR: User simulator failed (not agent's fault, exclude from scoring). + TASK_TIMEOUT: Task execution exceeded configured timeout (resource constraint). UNKNOWN_EXECUTION_ERROR: Unclassified execution error (e.g., agent framework internal failure). EVALUATION_FAILED: Task executed but evaluation raised an exception. SETUP_FAILED: Setup phase (environment, agents, evaluators) raised an exception. Scoring Guidance: - Include in agent score: SUCCESS, AGENT_ERROR - - Exclude from agent score: ENVIRONMENT_ERROR, USER_ERROR, UNKNOWN_EXECUTION_ERROR + - Exclude from agent score: ENVIRONMENT_ERROR, USER_ERROR, TASK_TIMEOUT, UNKNOWN_EXECUTION_ERROR - Handle separately: EVALUATION_FAILED, SETUP_FAILED """ @@ -53,6 +59,7 @@ class TaskExecutionStatus(Enum): AGENT_ERROR = "agent_error" ENVIRONMENT_ERROR = "environment_error" USER_ERROR = "user_error" + TASK_TIMEOUT = "task_timeout" UNKNOWN_EXECUTION_ERROR = "unknown_execution_error" EVALUATION_FAILED = "evaluation_failed" SETUP_FAILED = "setup_failed" @@ -241,21 +248,20 @@ def __init__( self.fail_on_evaluation_error = fail_on_evaluation_error self.fail_on_setup_error = fail_on_setup_error - # Registry for Traceable components (cleared after each task repetition) - self._trace_registry: Dict[str, TraceableMixin] = {} - self._component_id_map: Dict[int, str] = {} # Maps id(component) -> registry key + # Gather benchmark-level configuration (system, git, packages, etc.) + self.benchmark_level_config = gather_benchmark_config() + + # Thread-safe component registry (replaces inline registry dicts) + self._registry = ComponentRegistry(benchmark_config=self.benchmark_level_config) - # Registry for Configurable components (cleared after each task repetition) - self._config_registry: Dict[str, ConfigurableMixin] = {} - self._config_component_id_map: Dict[int, str] = {} # Maps id(component) -> registry key + # Thread safety locks for parallel execution + self._reports_lock = threading.Lock() + self._callback_lock = threading.Lock() # Persistent benchmark-level reports (stored across all task repetitions) # Each report contains both traces and configs for a single task repetition self.reports: List[Dict[str, Any]] = [] - # Gather benchmark-level configuration (system, git, packages, etc.) - self.benchmark_level_config = gather_benchmark_config() - def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: """Register a component for comprehensive trace and configuration collection. @@ -304,35 +310,7 @@ def setup_agents(self, agent_data, environment, task, user): `collect_all_traces()` and `collect_all_configs()` which are called internally by the `run()` method. """ - # Check if this component is already registered for traces - component_id = id(component) - if component_id in self._component_id_map: - existing_key = self._component_id_map[component_id] - existing_category, existing_name = existing_key.split(":", 1) - new_key = f"{category}:{name}" - - if existing_key == new_key: - # Same component, same name - silently accept (idempotent) - return component - else: - raise ValueError( - f"Component is already registered as '{existing_key}' and cannot be " - f"re-registered as '{new_key}'. Note: Environments, users, and agents " - f"returned from setup methods are automatically registered." - ) - - key = f"{category}:{name}" - - # Register for trace collection - self._trace_registry[key] = component - self._component_id_map[component_id] = key - - # Also register for configuration collection if component supports it - if isinstance(component, ConfigurableMixin): - self._config_registry[key] = component - self._config_component_id_map[component_id] = key - - return component + return self._registry.register(category, name, component) def clear_registry(self) -> None: """Clear the component registry after a task repetition completes. @@ -341,10 +319,7 @@ def clear_registry(self) -> None: to ensure components are not carried over between repetitions. The reports list persists across all repetitions for aggregated analysis. """ - self._trace_registry.clear() - self._component_id_map.clear() - self._config_registry.clear() - self._config_component_id_map.clear() + self._registry.clear() def collect_all_traces(self) -> Dict[str, Any]: """Collect execution traces from all registered components for the current task repetition. @@ -389,61 +364,7 @@ def collect_all_traces(self) -> Dict[str, Any]: The collected traces are passed to the evaluator's `evaluate()` method and stored in `benchmark.reports` for later analysis. """ - traces: Dict[str, Any] = { - "metadata": { - "timestamp": datetime.now().isoformat(), - "thread_id": threading.current_thread().ident, - "total_components": len(self._trace_registry), - }, - "agents": {}, - "models": {}, - "tools": {}, - "simulators": {}, - "callbacks": {}, - "environment": None, - "user": None, - "other": {}, - } - - for key, component in self._trace_registry.items(): - category, comp_name = key.split(":", 1) - - try: - component_traces = component.gather_traces() - - # Inject name from registry if component doesn't have it - # Is this intervention obfuscating the mechnaisms too much? - if "name" not in component_traces: - component_traces["name"] = comp_name - - # Handle environment and user as direct values (not nested in dict) - if category == "environment": - traces["environment"] = component_traces - elif category == "user": - traces["user"] = component_traces - else: - # Ensure category exists in traces - if category not in traces: - traces[category] = {} - traces[category][comp_name] = component_traces - except Exception as e: - # Gracefully handle tracing errors - error_info = { - "error": f"Failed to gather traces: {e}", - "error_type": type(e).__name__, - "component_type": type(component).__name__, - } - - if category == "environment": - traces["environment"] = error_info - elif category == "user": - traces["user"] = error_info - else: - if category not in traces: - traces[category] = {} - traces[category][comp_name] = error_info - - return traces + return self._registry.collect_traces() def collect_all_configs(self) -> Dict[str, Any]: """Collect configuration from all registered components for the current task repetition. @@ -490,62 +411,33 @@ def collect_all_configs(self) -> Dict[str, Any]: The collected configs are available in the results for reproducibility analysis. """ - configs: Dict[str, Any] = { - "metadata": { - "timestamp": datetime.now().isoformat(), - "thread_id": threading.current_thread().ident, - "total_components": len(self._config_registry), - }, - "agents": {}, - "models": {}, - "tools": {}, - "simulators": {}, - "callbacks": {}, - "environment": None, - "user": None, - "other": {}, - "benchmark": self.benchmark_level_config, # Include benchmark-level config - } + return self._registry.collect_configs() - for key, component in self._config_registry.items(): - category, comp_name = key.split(":", 1) + def _invoke_callbacks(self, method_name: str, *args, **kwargs) -> None: + """Invoke a callback method on all registered callbacks (thread-safe). - try: - component_config = component.gather_config() - - # Inject name from registry if component doesn't have it - # Is this intervention obfuscating the mechnaisms too much? - if "name" not in component_config: - component_config["name"] = comp_name - - # Handle environment and user as direct values (not nested in dict) - if category == "environment": - configs["environment"] = component_config - elif category == "user": - configs["user"] = component_config - else: - # Ensure category exists in configs - if category not in configs: - configs[category] = {} - configs[category][comp_name] = component_config - except Exception as e: - # Gracefully handle config gathering errors - error_info = { - "error": f"Failed to gather config: {e}", - "error_type": type(e).__name__, - "component_type": type(component).__name__, - } + This method serializes all callback invocations using an internal lock, + so users don't need to implement thread-safe callbacks. - if category == "environment": - configs["environment"] = error_info - elif category == "user": - configs["user"] = error_info - else: - if category not in configs: - configs[category] = {} - configs[category][comp_name] = error_info + Args: + method_name: Name of the callback method to invoke (e.g., "on_task_start"). + *args: Positional arguments to pass to the callback method. + **kwargs: Keyword arguments to pass to the callback method. + """ + with self._callback_lock: + for cb in self.callbacks: + method = getattr(cb, method_name, None) + if method is not None: + method(*args, **kwargs) - return configs + def _append_report_safe(self, report: Dict[str, Any]) -> None: + """Append a report to the reports list (thread-safe). + + Args: + report: The report dictionary to append. + """ + with self._reports_lock: + self.reports.append(report) def add_callback(self, callback: BenchmarkCallback) -> None: """Register a callback handler to monitor benchmark execution. @@ -992,12 +884,427 @@ def __init__(self, ...): return final_answer - def run(self, tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]]) -> List[Dict[str, Any]]: + def _execute_task_repetition( + self, + task: Task, + agent_data: Dict[str, Any], + repeat_idx: int, + ) -> Dict[str, Any]: + """Execute a single task repetition with timeout handling. + + This method encapsulates the complete execution of one task repetition, + including setup, execution, trace collection, and evaluation. It is + designed to be called from both sequential and parallel execution paths. + + Args: + task: The task to execute. + agent_data: Agent configuration for this task. + repeat_idx: Repetition index (0 to n_task_repeats-1). + + Returns: + Report dictionary containing execution results. + """ + # Initialize status and error tracking + execution_status = TaskExecutionStatus.SUCCESS + error_info: Optional[Dict[str, Any]] = None + final_answers: Any = None + eval_results: Any = None + execution_traces: Dict[str, Any] = {} + execution_configs: Dict[str, Any] = {} + evaluators: Sequence[Evaluator] = [] + agents_dict: Dict[str, AgentAdapter] = {} + + # Create execution context with optional timeout + timeout = task.protocol.timeout_seconds + context = TaskContext(deadline=timeout) + + try: + # 1. Setup + environment = self.setup_environment(agent_data, task) + user = self.setup_user(agent_data, environment, task) + if user is None and self.max_invocations > 1: + # Warn if multi-turn is enabled but no user to drive interaction + warnings.warn( + f"max_invocations={self.max_invocations} > 1 but no user simulator provided. " + f"Falling back to single-turn execution for task {task.id}." + ) + agents_to_run, agents_dict = self.setup_agents(agent_data, environment, task, user) + evaluators = self.setup_evaluators(environment, task, agents_to_run, user) + + # Auto-register components returned from setup methods + # Environment + if environment is not None and isinstance(environment, TraceableMixin): + self.register("environment", "env", environment) + + # User + if user is not None and isinstance(user, TraceableMixin): + self.register("user", "user", user) + + # Agents (use their names from agents_dict) + for agent_name, agent in agents_dict.items(): + if isinstance(agent, TraceableMixin): + self.register("agents", agent_name, agent) + + # Check timeout after setup + context.check_timeout() + + except TaskTimeoutError as e: + # Timeout during setup + execution_status = TaskExecutionStatus.TASK_TIMEOUT + error_info = { + "error_type": "TaskTimeoutError", + "error_message": str(e), + "elapsed": e.elapsed, + "timeout": e.timeout, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + # Create a minimal report for this timeout + report = { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": execution_status.value, + "error": error_info, + "traces": e.partial_traces, + "config": {}, + "eval": None, + } + self.clear_registry() + return report + + except Exception as e: + # Setup failed - record error + execution_status = TaskExecutionStatus.SETUP_FAILED + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + # Create a minimal report for this failed setup + report = { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": execution_status.value, + "error": error_info, + "traces": {}, + "config": {}, + "eval": None, + } + self.clear_registry() + + if self.fail_on_setup_error: + raise + + return report + + # 2. Execute agent system with optional user interaction loop + try: + # Check timeout before execution + context.check_timeout() + final_answers = self.execution_loop(agents_to_run, task, environment, user) + except TaskTimeoutError as e: + # Task timed out during execution + execution_status = TaskExecutionStatus.TASK_TIMEOUT + error_info = { + "error_type": "TaskTimeoutError", + "error_message": str(e), + "elapsed": e.elapsed, + "timeout": e.timeout, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + final_answers = None + except AgentError as e: + # Agent violated contract at boundary (agent's fault) + execution_status = TaskExecutionStatus.AGENT_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "component": e.component, + "details": e.details, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + except EnvironmentError as e: + # Environment/tool infrastructure failed (not agent's fault) + execution_status = TaskExecutionStatus.ENVIRONMENT_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "component": e.component, + "details": e.details, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + except UserError as e: + # User simulator failed (not agent's fault) + execution_status = TaskExecutionStatus.USER_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "component": e.component, + "details": e.details, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + except Exception as e: + # Unclassified error (e.g., agent framework internal failure) + execution_status = TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + + # 3. Collect traces and configs (always attempt this) + try: + execution_configs = self.collect_all_configs() + execution_traces = self.collect_all_traces() + # Store in context for potential timeout errors + context.set_collected_traces(execution_traces) + except Exception as e: + # If trace/config collection fails, record it but continue + execution_configs = { + "error": f"Failed to collect configs: {e}", + "error_type": type(e).__name__, + } + execution_traces = { + "error": f"Failed to collect traces: {e}", + "error_type": type(e).__name__, + } + + # 4. Evaluate (skip if task execution failed) + if execution_status == TaskExecutionStatus.SUCCESS: + try: + # Check timeout before evaluation + context.check_timeout() + eval_results = self.evaluate(evaluators, agents_dict, final_answers, execution_traces) + except TaskTimeoutError as e: + execution_status = TaskExecutionStatus.TASK_TIMEOUT + error_info = { + "error_type": "TaskTimeoutError", + "error_message": str(e), + "elapsed": e.elapsed, + "timeout": e.timeout, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + eval_results = None + except Exception as e: + execution_status = TaskExecutionStatus.EVALUATION_FAILED + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_evaluation_error: + self.clear_registry() + raise + + eval_results = None + else: + # Task execution failed, so skip evaluation + eval_results = None + + # 5. Build report + report: Dict[str, Any] = { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": execution_status.value, + "traces": execution_traces, + "config": execution_configs, + "eval": eval_results, + } + + # Add error info if present + if error_info is not None: + report["error"] = error_info + + # Clear registry after task repetition completes + self.clear_registry() + + return report + + def _run_sequential( + self, + queue: TaskQueue, + ) -> None: + """Execute tasks sequentially with optional timeout support. + + Args: + queue: Task queue providing task ordering. + """ + for task, agent_data in queue: + # Callbacks at the start of each task + self._invoke_callbacks("on_task_start", self, task) + + for repeat_idx in range(self.n_task_repeats): + self._invoke_callbacks("on_task_repeat_start", self, task, repeat_idx) + + report = self._execute_task_repetition(task, agent_data, repeat_idx) + self._append_report_safe(report) + queue.on_task_complete(task, report) + + self._invoke_callbacks("on_task_repeat_end", self, report) + + if not queue.should_continue(): + return + + # Callbacks at the end of each task + task_reports = [r for r in self.reports if r["task_id"] == str(task.id)] + last_report = task_reports[-1] if task_reports else {} + self._invoke_callbacks("on_task_end", self, task, last_report) + + if not queue.should_continue(): + return + + def _run_parallel( + self, + queue: TaskQueue, + max_workers: int, + ) -> None: + """Execute tasks in parallel with thread pool. + + Args: + queue: Task queue providing task ordering. + max_workers: Maximum number of concurrent workers. + """ + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures: Dict[Any, Tuple[Task, int]] = {} + task_repeat_counts: Dict[str, int] = {} # Track submitted repeats per task + + def submit_task_repeats(task: Task, agent_data: Dict[str, Any]) -> None: + """Submit all repeats for a task.""" + task_id = str(task.id) + task_repeat_counts[task_id] = 0 + + self._invoke_callbacks("on_task_start", self, task) + + for repeat_idx in range(self.n_task_repeats): + self._invoke_callbacks("on_task_repeat_start", self, task, repeat_idx) + + future = executor.submit( + self._execute_task_repetition, + task, + agent_data, + repeat_idx, + ) + futures[future] = (task, repeat_idx) + task_repeat_counts[task_id] += 1 + + # Submit initial batch from queue + submitted_tasks: List[Task] = [] + for task, agent_data in queue: + submit_task_repeats(task, agent_data) + submitted_tasks.append(task) + + # Limit initial submission to avoid over-committing + if len(futures) >= max_workers * 2: + break + + # Process completions + completed_task_ids: set = set() + queue_iter = iter(queue) + queue_exhausted = len(submitted_tasks) >= len(list(queue)) # Approximate check + + while futures: + # Wait for at least one completion + done_futures = [] + for future in list(futures.keys()): + if future.done(): + done_futures.append(future) + + if not done_futures: + # No futures done yet, wait a bit + import time + + time.sleep(0.01) + continue + + for future in done_futures: + task, repeat_idx = futures.pop(future) + task_id = str(task.id) + + try: + report = future.result() + except Exception as e: + # Create error report for unexpected failures + report = { + "task_id": task_id, + "repeat_idx": repeat_idx, + "status": TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value, + "error": { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + }, + "traces": {}, + "config": {}, + "eval": None, + } + + self._append_report_safe(report) + queue.on_task_complete(task, report) + + self._invoke_callbacks("on_task_repeat_end", self, report) + + # Check if all repeats for this task are done + task_reports = [r for r in self.reports if r["task_id"] == task_id] + if len(task_reports) >= self.n_task_repeats and task_id not in completed_task_ids: + completed_task_ids.add(task_id) + last_report = task_reports[-1] if task_reports else {} + self._invoke_callbacks("on_task_end", self, task, last_report) + + if not queue.should_continue(): + # Cancel remaining futures + for f in futures: + f.cancel() + return + + # Submit more work if queue not exhausted + if not queue_exhausted and len(futures) < max_workers: + try: + task, agent_data = next(queue_iter) + submit_task_repeats(task, agent_data) + submitted_tasks.append(task) + except StopIteration: + queue_exhausted = True + + def run( + self, + tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]], + queue: Optional[TaskQueue] = None, + max_workers: int = 1, + ) -> List[Dict[str, Any]]: """Initialize and execute the complete benchmark loop across all tasks. Args: tasks: Collection of tasks to execute. Can be a single Task, TaskCollection, list of Task objects, or list of dicts that will be converted to Tasks. + queue: Optional task queue for custom scheduling. If None, uses SequentialQueue. + max_workers: Maximum number of parallel task executions. Default 1 (sequential). + Set higher for I/O-bound workloads (e.g., LLM API calls). Returns: List of report dictionaries, one per task repetition. Each report contains: @@ -1074,6 +1381,14 @@ def run(self, tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]]) - print(f"Task {report['task_id']}, Repeat {report['repeat_idx']}: {report['eval']}") print(f"Config: {report['config']}") print(f"Traces: {report['traces']}") + + # Parallel execution with 4 workers + reports = benchmark.run(tasks=tasks, max_workers=4) + + # Custom queue for priority-based execution + from maseval.core.queue import PriorityQueue + queue = PriorityQueue(tasks, agent_data_list) + reports = benchmark.run(tasks=tasks, queue=queue) ``` """ # Normalize tasks into a TaskCollection @@ -1103,235 +1418,22 @@ def run(self, tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]]) - # Clear reports from previous run() calls to prevent accumulation self.reports = [] - # Callbacks at the start of the run - for cb in self.callbacks: - cb.on_run_start(self) - - for task_idx, (task, agent_data) in enumerate(zip(self.tasks, agent_data_list)): - # Callbacks at the start of each task - for cb in self.callbacks: - cb.on_task_start(self, task) - - for repeat_idx in range(self.n_task_repeats): - for cb in self.callbacks: - cb.on_task_repeat_start(self, task, repeat_idx) - - # Initialize status and error tracking - execution_status = TaskExecutionStatus.SUCCESS - error_info: Optional[Dict[str, Any]] = None - final_answers: Any = None - eval_results: Any = None - execution_traces: Dict[str, Any] = {} - execution_configs: Dict[str, Any] = {} - - try: - # 1. Setup - environment = self.setup_environment(agent_data, task) - user = self.setup_user(agent_data, environment, task) - if user is None and self.max_invocations > 1: - # Warn if multi-turn is enabled but no user to drive interaction - warnings.warn( - f"max_invocations={self.max_invocations} > 1 but no user simulator provided. " - f"Falling back to single-turn execution for task {task.id}." - ) - agents_to_run, agents_dict = self.setup_agents(agent_data, environment, task, user) - evaluators = self.setup_evaluators(environment, task, agents_to_run, user) - - # Auto-register components returned from setup methods - # Environment - if environment is not None and isinstance(environment, TraceableMixin): - self.register("environment", "env", environment) - - # User - if user is not None and isinstance(user, TraceableMixin): - self.register("user", "user", user) - - # Agents (use their names from agents_dict) - for agent_name, agent in agents_dict.items(): - if isinstance(agent, TraceableMixin): - self.register("agents", agent_name, agent) - - except Exception as e: - # Setup failed - record error and optionally re-raise - execution_status = TaskExecutionStatus.SETUP_FAILED - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - # Create a minimal report for this failed setup - report = { - "task_id": str(task.id), - "repeat_idx": repeat_idx, - "status": execution_status.value, - "error": error_info, - "traces": {}, - "config": {}, - "eval": None, - } - self.reports.append(report) - - for cb in self.callbacks: - cb.on_task_repeat_end(self, report) - - # Clear registry before potentially re-raising - self.clear_registry() - - if self.fail_on_setup_error: - raise - - # Continue to next task repetition - continue + # Create queue if not provided + if queue is None: + queue = SequentialQueue(self.tasks, agent_data_list) - # 2. Execute agent system with optional user interaction loop - try: - final_answers = self.execution_loop(agents_to_run, task, environment, user) - except AgentError as e: - # Agent violated contract at boundary (agent's fault) - execution_status = TaskExecutionStatus.AGENT_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "component": e.component, - "details": e.details, - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - except EnvironmentError as e: - # Environment/tool infrastructure failed (not agent's fault) - execution_status = TaskExecutionStatus.ENVIRONMENT_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "component": e.component, - "details": e.details, - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - except UserError as e: - # User simulator failed (not agent's fault) - execution_status = TaskExecutionStatus.USER_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "component": e.component, - "details": e.details, - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - except Exception as e: - # Unclassified error (e.g., agent framework internal failure) - execution_status = TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - - # # Callbacks before evaluation - # for cb in self.callbacks: - # cb.on_before_evaluation(self, task, agent_output) - - # 3. Collect traces and configs (always attempt this) - try: - execution_configs = self.collect_all_configs() - execution_traces = self.collect_all_traces() - except Exception as e: - # If trace/config collection fails, record it but continue - execution_configs = { - "error": f"Failed to collect configs: {e}", - "error_type": type(e).__name__, - } - execution_traces = { - "error": f"Failed to collect traces: {e}", - "error_type": type(e).__name__, - } - - # 4. Evaluate (skip if task execution failed, unless we want partial evaluation) - if execution_status == TaskExecutionStatus.SUCCESS: - try: - eval_results = self.evaluate(evaluators, agents_dict, final_answers, execution_traces) - except Exception as e: - execution_status = TaskExecutionStatus.EVALUATION_FAILED - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_evaluation_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Set eval_results to None on failure - eval_results = None - else: - # Task execution failed, so skip evaluation - eval_results = None - - # 5. Store results with status and error info - report = { - "task_id": str(task.id), - "repeat_idx": repeat_idx, - "status": execution_status.value, - "traces": execution_traces, - "config": execution_configs, - "eval": eval_results, - } - - # Add error info if present - if error_info is not None: - report["error"] = error_info - - self.reports.append(report) - - for cb in self.callbacks: - cb.on_task_repeat_end(self, report) - - # Clear registry after task repetition completes - self.clear_registry() + # Callbacks at the start of the run + self._invoke_callbacks("on_run_start", self) - # Callbacks at the end of each task - # Pass the last report for this task to the callback - task_reports = [r for r in self.reports if r["task_id"] == str(task.id)] - last_report = task_reports[-1] if task_reports else {} - for cb in self.callbacks: - cb.on_task_end(self, task, last_report) + # Execute based on max_workers + if max_workers == 1: + self._run_sequential(queue) + else: + self._run_parallel(queue, max_workers) # Callbacks at the end of the run - for cb in self.callbacks: - cb.on_run_end(self, self.reports) + self._invoke_callbacks("on_run_end", self, self.reports) + return self.reports def get_failed_tasks( @@ -1411,6 +1513,7 @@ def get_failed_tasks( TaskExecutionStatus.AGENT_ERROR.value, TaskExecutionStatus.ENVIRONMENT_ERROR.value, TaskExecutionStatus.USER_ERROR.value, + TaskExecutionStatus.TASK_TIMEOUT.value, TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value, TaskExecutionStatus.EVALUATION_FAILED.value, TaskExecutionStatus.SETUP_FAILED.value, diff --git a/maseval/core/context.py b/maseval/core/context.py new file mode 100644 index 0000000..5265d40 --- /dev/null +++ b/maseval/core/context.py @@ -0,0 +1,122 @@ +"""Execution context for task timeout handling. + +This module provides the TaskContext class for cooperative timeout checking +during task execution. The context tracks elapsed time and enables checkpoints +where tasks can gracefully exit if the deadline has passed. +""" + +import time +from typing import Any, Dict, Optional + +from .exceptions import TaskTimeoutError + + +class TaskContext: + """Execution context for cooperative timeout checking. + + TaskContext provides a mechanism for tasks to voluntarily check for timeout + conditions at defined checkpoints. This enables clean interruption without + forcibly killing threads (which Python doesn't support well). + + The context is created with an optional deadline and tracks elapsed time. + Tasks should call `check_timeout()` at natural checkpoints (e.g., between + agent steps, after LLM calls) to allow graceful termination. + + Attributes: + deadline: Maximum execution time in seconds, or None for no timeout. + collected_traces: Traces collected before timeout (for partial results). + + Usage: + context = TaskContext(deadline=60.0) + + for step in range(max_steps): + context.check_timeout() # Raises TaskTimeoutError if expired + result = agent.run(query) + # ... process result + + # Access timing info + print(f"Elapsed: {context.elapsed}s") + print(f"Remaining: {context.remaining}s") + + Thread Safety: + TaskContext instances are not thread-safe. Each thread/task should + have its own context instance. + """ + + def __init__(self, deadline: Optional[float] = None): + """Initialize execution context. + + Args: + deadline: Maximum execution time in seconds. If None, no timeout + checking is performed and check_timeout() is a no-op. + """ + self._deadline = deadline + self._start_time = time.monotonic() + self.collected_traces: Dict[str, Any] = {} + + @property + def deadline(self) -> Optional[float]: + """Maximum execution time in seconds, or None if no timeout.""" + return self._deadline + + @property + def elapsed(self) -> float: + """Time elapsed since context creation in seconds.""" + return time.monotonic() - self._start_time + + @property + def remaining(self) -> Optional[float]: + """Time remaining before deadline in seconds, or None if no deadline. + + Returns 0 if deadline has passed. + """ + if self._deadline is None: + return None + return max(0.0, self._deadline - self.elapsed) + + @property + def is_expired(self) -> bool: + """Whether the deadline has passed. + + Returns False if no deadline is set. + """ + if self._deadline is None: + return False + return self.elapsed >= self._deadline + + def check_timeout(self) -> None: + """Check if deadline has passed and raise TaskTimeoutError if so. + + This method should be called at natural checkpoints during task + execution (e.g., between agent steps, after LLM calls). If the + deadline has passed, it raises TaskTimeoutError with timing info. + + If no deadline is set, this is a no-op. + + Raises: + TaskTimeoutError: If the deadline has passed. + + Usage: + context = TaskContext(deadline=30.0) + + for step in range(max_steps): + context.check_timeout() # May raise + result = agent.run(query) + """ + if self.is_expired: + raise TaskTimeoutError( + f"Task exceeded {self._deadline}s deadline after {self.elapsed:.2f}s", + component="timeout_check", + elapsed=self.elapsed, + timeout=self._deadline or 0.0, + partial_traces=self.collected_traces, + ) + + def set_collected_traces(self, traces: Dict[str, Any]) -> None: + """Store traces collected during execution for inclusion in timeout errors. + + Args: + traces: Traces to store. These will be included in TaskTimeoutError + if a timeout occurs. + """ + self.collected_traces = traces diff --git a/maseval/core/exceptions.py b/maseval/core/exceptions.py index 77fc8c0..d4f6ba1 100644 --- a/maseval/core/exceptions.py +++ b/maseval/core/exceptions.py @@ -216,6 +216,70 @@ class UserError(MASEvalError): pass +class TaskTimeoutError(MASEvalError): + """Task execution exceeded configured timeout. + + This is classified as TASK_TIMEOUT in benchmark results, separate from + other error types. Timeout is neither agent's fault nor infrastructure's + fault—it's a resource constraint. + + When to raise: + - Task execution time exceeds TaskProtocol.timeout_seconds + - Cooperative timeout check detects deadline has passed + - Hard timeout backstop triggers + + Attributes: + elapsed: Time elapsed before timeout was detected. + timeout: The configured timeout value in seconds. + partial_traces: Any traces collected before timeout occurred. + + Examples: + ```python + # Cooperative timeout at checkpoint + raise TaskTimeoutError( + "Task exceeded 60s deadline", + component="execution_loop", + elapsed=62.5, + timeout=60.0 + ) + + # Hard timeout with partial traces + raise TaskTimeoutError( + "Task exceeded 120s hard deadline", + component="timeout_backstop", + elapsed=125.0, + timeout=120.0, + partial_traces={"agents": {"main": {"messages": [...]}}} + ) + ``` + """ + + def __init__( + self, + message: str, + *, + component: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + elapsed: float = 0.0, + timeout: float = 0.0, + partial_traces: Optional[Dict[str, Any]] = None, + ): + """Initialize TaskTimeoutError. + + Args: + message: Human-readable error description. + component: Name of the component that raised the error. + details: Additional structured information about the error. + elapsed: Time elapsed before timeout was detected. + timeout: The configured timeout value in seconds. + partial_traces: Any traces collected before timeout occurred. + """ + super().__init__(message, component=component, details=details) + self.elapsed = elapsed + self.timeout = timeout + self.partial_traces = partial_traces or {} + + # ============================================================================= # Convenience functions for tool implementers # ============================================================================= diff --git a/maseval/core/queue.py b/maseval/core/queue.py new file mode 100644 index 0000000..2c415e2 --- /dev/null +++ b/maseval/core/queue.py @@ -0,0 +1,221 @@ +"""Task queue abstraction for flexible task scheduling. + +This module provides the TaskQueue abstract base class and concrete implementations +for different task scheduling strategies. The queue abstraction replaces the static +`for task in tasks` loop with a dynamic scheduling system that enables: + +1. Dynamic task ordering +2. Callback-driven scheduling (adaptive testing) +3. Priority-based execution +4. Conditional task skipping +""" + +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from .task import Task, TaskCollection + + +class TaskQueue(ABC): + """Abstract base for task scheduling strategies. + + TaskQueue provides an iterator interface for task execution with hooks + for adaptive behavior based on task results. Concrete implementations + can reorder tasks, skip tasks, or terminate early based on execution + outcomes. + + The queue yields (Task, agent_data) tuples for execution. After each + task completes, `on_task_complete()` is called with the result, allowing + the queue to adapt its scheduling strategy. + + Usage: + queue = SequentialQueue(tasks, agent_data_list) + + for task, agent_data in queue: + report = execute_task(task, agent_data) + queue.on_task_complete(task, report) + + if not queue.should_continue(): + break + """ + + @abstractmethod + def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: + """Yield (task, agent_data) pairs in execution order. + + Returns: + Iterator yielding tuples of (Task, agent_data dict). + """ + pass + + def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: + """Called after each task completes. + + Override this method for adaptive scheduling behavior that responds + to task execution results (e.g., updating ability estimates, adjusting + priorities, or marking related tasks for skipping). + + Args: + task: The task that just completed. + report: The execution report containing status, traces, and eval results. + """ + pass + + def should_continue(self) -> bool: + """Whether to continue processing tasks. + + Default implementation returns True until the queue is exhausted. + Override for early termination conditions (e.g., confidence threshold + reached, maximum tasks processed, or error limit exceeded). + + Returns: + True to continue processing, False to stop. + """ + return True + + +class SequentialQueue(TaskQueue): + """Execute tasks in their original order (default behavior). + + This queue maintains the current sequential execution model, processing + tasks in the order they appear in the task collection. It's the default + queue used when no explicit queue is provided. + + Attributes: + tasks: List of (Task, agent_data) pairs. + """ + + def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict[str, Any]]): + """Initialize sequential queue. + + Args: + tasks: Collection of tasks to execute. + agent_data_list: List of agent configuration dicts, one per task. + """ + self._tasks: List[Tuple[Task, Dict[str, Any]]] = list(zip(tasks, agent_data_list)) + self._index = 0 + + def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: + """Yield tasks in original order.""" + for task, agent_data in self._tasks: + yield task, agent_data + + +class PriorityQueue(TaskQueue): + """Execute tasks by priority (from TaskProtocol.priority). + + Tasks with higher priority values are executed first. Tasks with equal + priority maintain their relative order from the original collection. + + This queue is useful when some tasks are more important or time-sensitive + than others, or when you want to process easier tasks first to get quick + feedback. + + Attributes: + tasks: List of (Task, agent_data) pairs sorted by priority. + """ + + def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict[str, Any]]): + """Initialize priority queue. + + Args: + tasks: Collection of tasks to execute. + agent_data_list: List of agent configuration dicts, one per task. + """ + paired = list(zip(tasks, agent_data_list)) + # Sort by priority descending (higher priority first) + # Use enumerate to maintain stable sort for equal priorities + self._tasks: List[Tuple[Task, Dict[str, Any]]] = sorted(paired, key=lambda x: x[0].protocol.priority, reverse=True) + + def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: + """Yield tasks in priority order.""" + for task, agent_data in self._tasks: + yield task, agent_data + + +class AdaptiveQueue(TaskQueue): + """Base class for adaptive task scheduling. + + Adaptive queues adjust task order based on execution results. This is + useful for techniques like Item Response Theory (IRT) based testing, + where task selection optimizes for information gain about agent ability. + + Subclasses should override `_select_next_task()` to implement their + selection algorithm, and `_update_state()` to update internal state + after each task completion. + + Attributes: + pending: Tasks not yet executed. + completed: Tasks that have been executed with their reports. + """ + + def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict[str, Any]]): + """Initialize adaptive queue. + + Args: + tasks: Collection of tasks to execute. + agent_data_list: List of agent configuration dicts, one per task. + """ + self._pending: List[Tuple[Task, Dict[str, Any]]] = list(zip(tasks, agent_data_list)) + self._completed: List[Tuple[Task, Dict[str, Any]]] = [] + self._stop_flag = False + + def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: + """Yield tasks selected by the adaptive algorithm.""" + while self._pending and not self._stop_flag: + next_item = self._select_next_task() + if next_item is not None: + yield next_item + else: + break + + def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: + """Update state based on task result. + + Args: + task: The task that just completed. + report: The execution report. + """ + # Find and move task from pending to completed + for i, (t, agent_data) in enumerate(self._pending): + if t.id == task.id: + self._completed.append(self._pending.pop(i)) + break + + # Update adaptive state + self._update_state(task, report) + + def should_continue(self) -> bool: + """Check if we should continue based on stopping criteria.""" + return not self._stop_flag and len(self._pending) > 0 + + def stop(self) -> None: + """Signal that no more tasks should be processed.""" + self._stop_flag = True + + def _select_next_task(self) -> Optional[Tuple[Task, Dict[str, Any]]]: + """Select the next task to execute. + + Override this method to implement custom selection algorithms + (e.g., IRT-based selection, uncertainty sampling, etc.). + + Default implementation returns tasks in order (first remaining task). + + Returns: + The next (Task, agent_data) pair, or None if no suitable task. + """ + if not self._pending: + return None + return self._pending[0] + + def _update_state(self, task: Task, report: Dict[str, Any]) -> None: + """Update internal state after task completion. + + Override this method to update ability estimates, difficulty models, + or other state used by `_select_next_task()`. + + Args: + task: The task that just completed. + report: The execution report containing status and eval results. + """ + pass diff --git a/maseval/core/registry.py b/maseval/core/registry.py new file mode 100644 index 0000000..ec4e38a --- /dev/null +++ b/maseval/core/registry.py @@ -0,0 +1,243 @@ +"""Thread-safe component registry for task execution. + +This module provides the ComponentRegistry class that tracks components +(agents, models, tools, etc.) during task execution. It uses thread-local +storage to enable parallel task execution without cross-contamination. +""" + +import threading +from typing import Any, Dict, Optional +from datetime import datetime + +from .tracing import TraceableMixin +from .config import ConfigurableMixin + + +class ComponentRegistry: + """Thread-safe registry for tracking components during task execution. + + Each thread gets its own isolated registry state, enabling parallel + task execution without cross-contamination. The registry tracks both + Traceable and Configurable components for comprehensive data collection. + + Usage: + registry = ComponentRegistry() + + # Register components (thread-local) + registry.register("agents", "orchestrator", agent_adapter) + registry.register("environment", "env", environment) + + # Collect data + traces = registry.collect_traces() + configs = registry.collect_configs() + + # Clear for next task + registry.clear() + """ + + def __init__(self, benchmark_config: Optional[Dict[str, Any]] = None): + """Initialize the registry. + + Args: + benchmark_config: Benchmark-level configuration to include in + collect_configs() output. This is shared (not thread-local). + """ + self._local = threading.local() + self._benchmark_config = benchmark_config or {} + + # --- Thread-local state properties --- + + @property + def _trace_registry(self) -> Dict[str, TraceableMixin]: + if not hasattr(self._local, "trace_registry"): + self._local.trace_registry = {} + return self._local.trace_registry + + @property + def _component_id_map(self) -> Dict[int, str]: + if not hasattr(self._local, "component_id_map"): + self._local.component_id_map = {} + return self._local.component_id_map + + @property + def _config_registry(self) -> Dict[str, ConfigurableMixin]: + if not hasattr(self._local, "config_registry"): + self._local.config_registry = {} + return self._local.config_registry + + @property + def _config_component_id_map(self) -> Dict[int, str]: + if not hasattr(self._local, "config_component_id_map"): + self._local.config_component_id_map = {} + return self._local.config_component_id_map + + # --- Public API --- + + def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: + """Register a component for trace and config collection. + + Args: + category: Component category (e.g., "agents", "models", "environment") + name: Unique identifier within the category + component: Component instance (must be TraceableMixin) + + Returns: + The component (for chaining) + + Raises: + ValueError: If component already registered under a different key + """ + component_id = id(component) + key = f"{category}:{name}" + + # Check for duplicate registration under different key + if component_id in self._component_id_map: + existing_key = self._component_id_map[component_id] + if existing_key != key: + raise ValueError( + f"Component is already registered as '{existing_key}' and cannot be " + f"re-registered as '{key}'. Note: Environments, users, and agents " + f"returned from setup methods are automatically registered." + ) + return component # Idempotent + + # Register for tracing + self._trace_registry[key] = component + self._component_id_map[component_id] = key + + # Also register for config if supported + if isinstance(component, ConfigurableMixin): + self._config_registry[key] = component + self._config_component_id_map[component_id] = key + + return component + + def clear(self) -> None: + """Clear all registrations for the current thread.""" + self._trace_registry.clear() + self._component_id_map.clear() + self._config_registry.clear() + self._config_component_id_map.clear() + + def collect_traces(self) -> Dict[str, Any]: + """Collect execution traces from all registered components.""" + traces: Dict[str, Any] = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "thread_id": threading.current_thread().ident, + "total_components": len(self._trace_registry), + }, + "agents": {}, + "models": {}, + "tools": {}, + "simulators": {}, + "callbacks": {}, + "environment": None, + "user": None, + "other": {}, + } + + for key, component in self._trace_registry.items(): + category, comp_name = key.split(":", 1) + + try: + component_traces = component.gather_traces() + + # Inject name from registry if component doesn't have it + if "name" not in component_traces: + component_traces["name"] = comp_name + + # Handle environment and user as direct values (not nested in dict) + if category == "environment": + traces["environment"] = component_traces + elif category == "user": + traces["user"] = component_traces + else: + # Ensure category exists in traces + if category not in traces: + traces[category] = {} + traces[category][comp_name] = component_traces + except Exception as e: + # Gracefully handle tracing errors + error_info = { + "error": f"Failed to gather traces: {e}", + "error_type": type(e).__name__, + "component_type": type(component).__name__, + } + + if category == "environment": + traces["environment"] = error_info + elif category == "user": + traces["user"] = error_info + else: + if category not in traces: + traces[category] = {} + traces[category][comp_name] = error_info + + return traces + + def collect_configs(self) -> Dict[str, Any]: + """Collect configuration from all registered components.""" + configs: Dict[str, Any] = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "thread_id": threading.current_thread().ident, + "total_components": len(self._config_registry), + }, + "agents": {}, + "models": {}, + "tools": {}, + "simulators": {}, + "callbacks": {}, + "environment": None, + "user": None, + "other": {}, + "benchmark": self._benchmark_config, + } + + for key, component in self._config_registry.items(): + category, comp_name = key.split(":", 1) + + try: + component_config = component.gather_config() + + # Inject name from registry if component doesn't have it + if "name" not in component_config: + component_config["name"] = comp_name + + # Handle environment and user as direct values (not nested in dict) + if category == "environment": + configs["environment"] = component_config + elif category == "user": + configs["user"] = component_config + else: + # Ensure category exists in configs + if category not in configs: + configs[category] = {} + configs[category][comp_name] = component_config + except Exception as e: + # Gracefully handle config gathering errors + error_info = { + "error": f"Failed to gather config: {e}", + "error_type": type(e).__name__, + "component_type": type(component).__name__, + } + + if category == "environment": + configs["environment"] = error_info + elif category == "user": + configs["user"] = error_info + else: + if category not in configs: + configs[category] = {} + configs[category][comp_name] = error_info + + return configs + + def update_benchmark_config(self, benchmark_config: Dict[str, Any]) -> None: + """Update the benchmark-level configuration. + + Args: + benchmark_config: New benchmark configuration dict. + """ + self._benchmark_config = benchmark_config diff --git a/maseval/core/task.py b/maseval/core/task.py index 462bbe5..31584aa 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -5,6 +5,38 @@ from typing import Iterable, List, Union, Iterator, Optional import json from pathlib import Path +from enum import Enum + + +class TimeoutAction(Enum): + """Action to take when a task timeout occurs.""" + + SKIP = "skip" # Mark as timed out, continue to next task + RETRY = "retry" # Retry once with same timeout + EXTEND = "extend" # Double timeout and retry + + +@dataclass +class TaskProtocol: + """Configuration for how MASEval executes a task. + + This is a data container for execution parameters, separate from + task content (query, environment_data, etc.). It controls the + interface between the task and MASEval's execution engine. + + Attributes: + timeout_seconds: Maximum execution time for this task. None means no timeout. + timeout_action: Action to take when timeout occurs. + max_retries: Maximum retry attempts for transient failures (not timeouts). + priority: Execution priority (higher = sooner). Used by adaptive task queues. + tags: Arbitrary tags for filtering or grouping tasks. + """ + + timeout_seconds: Optional[float] = None + timeout_action: TimeoutAction = TimeoutAction.SKIP + max_retries: int = 0 + priority: int = 0 + tags: Dict[str, Any] = field(default_factory=dict) @dataclass @@ -17,6 +49,7 @@ class Task: environment_data: A dictionary of data needed to set up the environment for the task. evaluation_data: A dictionary of data needed to evaluate the agent's performance on the task. metadata: A dictionary for any additional metadata about the task. + protocol: Execution protocol controlling timeout, retries, priority, etc. """ query: str @@ -25,6 +58,7 @@ class Task: user_data: Dict[str, Any] = field(default_factory=dict) evaluation_data: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) + protocol: TaskProtocol = field(default_factory=TaskProtocol) class TaskCollection(Sequence): diff --git a/tests/test_core/test_benchmark/test_automatic_registration.py b/tests/test_core/test_benchmark/test_automatic_registration.py index e1c55aa..3f96a77 100644 --- a/tests/test_core/test_benchmark/test_automatic_registration.py +++ b/tests/test_core/test_benchmark/test_automatic_registration.py @@ -26,7 +26,7 @@ def test_automatic_agent_registration(): benchmark = DummyBenchmark(agent_data=agent_data) # Before run, registry should be empty - assert len(benchmark._trace_registry) == 0 + assert len(benchmark._registry._trace_registry) == 0 # Run one step to trigger setup for task, agent_data in zip(tasks, [agent_data]): @@ -46,8 +46,8 @@ def test_automatic_agent_registration(): break # Only test first task # Check that components were registered - assert "environment:env" in benchmark._trace_registry - assert "agents:test_agent" in benchmark._trace_registry + assert "environment:env" in benchmark._registry._trace_registry + assert "agents:test_agent" in benchmark._registry._trace_registry @pytest.mark.core @@ -117,8 +117,8 @@ def test_manual_registration_for_models(): benchmark.register("models", "my_model", model) # Verify it was registered - assert "models:my_model" in benchmark._trace_registry - assert benchmark._trace_registry["models:my_model"] is model + assert "models:my_model" in benchmark._registry._trace_registry + assert benchmark._registry._trace_registry["models:my_model"] is model @pytest.mark.core @@ -137,8 +137,8 @@ def test_component_id_tracking(): benchmark.register("models", "test_model", model) # Verify ID tracking - assert id(model) in benchmark._component_id_map - assert benchmark._component_id_map[id(model)] == "models:test_model" + assert id(model) in benchmark._registry._component_id_map + assert benchmark._registry._component_id_map[id(model)] == "models:test_model" @pytest.mark.core @@ -162,8 +162,8 @@ def test_registry_cleared_after_repetition(): benchmark.run(tasks) # After run completes, registry should be empty (cleared after last repetition) - assert len(benchmark._trace_registry) == 0 - assert len(benchmark._component_id_map) == 0 + assert len(benchmark._registry._trace_registry) == 0 + assert len(benchmark._registry._component_id_map) == 0 # But reports should contain entries for all task repetitions assert len(benchmark.reports) == 4 # 2 tasks * 2 repeats diff --git a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py index 8a15fca..e35a96d 100644 --- a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py +++ b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py @@ -171,7 +171,7 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): # At start of new repeat, registry should be empty (except for callbacks) if repeat_idx > 0: # After first repeat, registry should have been cleared - registry_sizes.append(len(benchmark._trace_registry)) + registry_sizes.append(len(benchmark._registry._trace_registry)) benchmark = DummyBenchmark( agent_data={"model": "test"}, @@ -191,14 +191,14 @@ def test_benchmark_registry_cleared_after_task(self): benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=1) # Before run, registry should be empty - assert len(benchmark._trace_registry) == 0 - assert len(benchmark._config_registry) == 0 + assert len(benchmark._registry._trace_registry) == 0 + assert len(benchmark._registry._config_registry) == 0 benchmark.run(tasks) # After run completes, registry should be cleared - assert len(benchmark._trace_registry) == 0 - assert len(benchmark._config_registry) == 0 + assert len(benchmark._registry._trace_registry) == 0 + assert len(benchmark._registry._config_registry) == 0 def test_benchmark_reports_structure(self): """Test that benchmark reports have the correct structure.""" From 233d1a675c462ddff78c082638517bdd98484832 Mon Sep 17 00:00:00 2001 From: cemde Date: Fri, 5 Dec 2025 19:55:24 +0000 Subject: [PATCH 03/25] added test plan --- SUMMARY.md | 103 +++++++++++++++ TEST_PLAN.md | 359 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 462 insertions(+) create mode 100644 SUMMARY.md create mode 100644 TEST_PLAN.md diff --git a/SUMMARY.md b/SUMMARY.md new file mode 100644 index 0000000..8799a7a --- /dev/null +++ b/SUMMARY.md @@ -0,0 +1,103 @@ +# Implementation Summary: Parallel Task Execution Engine + +This document summarizes the implementation of parallel task execution, timeout handling, and task queue abstraction for MASEval, as specified in PLAN.md. + +## Phase 0: ComponentRegistry Extraction + +**New File:** `maseval/core/registry.py` + +- Created `ComponentRegistry` class that manages component registration for tracing and configuration collection +- Uses `threading.local()` for thread-local storage, enabling parallel task execution without cross-contamination between threads +- Each thread gets isolated registry state: `_trace_registry`, `_component_id_map`, `_config_registry`, `_config_component_id_map` +- Methods: `register()`, `clear()`, `collect_traces()`, `collect_configs()` +- Refactored `Benchmark` class to delegate all registry operations to `self._registry: ComponentRegistry` + +## Phase 1: TaskProtocol & Timeout Infrastructure + +**Modified:** `maseval/core/task.py` + +- Added `TimeoutAction` enum with values `SKIP`, `RETRY`, `RAISE` for configurable timeout behavior +- Added `TaskProtocol` dataclass with fields: + - `timeout_seconds: Optional[float]` - per-task timeout limit + - `timeout_action: TimeoutAction` - what to do on timeout (default: SKIP) + - `max_retries: int` - retry count for failed tasks (default: 0) + - `priority: int` - scheduling priority (default: 0, higher = more important) + - `tags: Dict[str, Any]` - arbitrary metadata for filtering/grouping +- Added `protocol: TaskProtocol` field to `Task` dataclass + +**New File:** `maseval/core/context.py` + +- Created `TaskContext` class for cooperative timeout checking +- Properties: `elapsed` (time since start), `remaining` (time until deadline), `is_expired` (bool) +- Method: `check_timeout()` raises `TaskTimeoutError` if deadline exceeded +- Designed for checkpoint-based timeout checking in user code + +**Modified:** `maseval/core/exceptions.py` + +- Added `TaskTimeoutError(MASEvalError)` with attributes: + - `elapsed: float` - how long the task ran + - `timeout: float` - the configured timeout limit + - `partial_traces: Optional[Dict]` - any traces collected before timeout + +**Modified:** `maseval/core/benchmark.py` + +- Added `TASK_TIMEOUT` to `TaskExecutionStatus` enum + +## Phase 2: TaskQueue Abstraction + +**New File:** `maseval/core/queue.py` + +- Created `TaskQueue` abstract base class with iterator interface (`__iter__`, `__next__`) +- Supports both `Task` and `TaskCollection` inputs, with automatic expansion +- Handles `n_task_repeats` by yielding `(task, repeat_idx)` tuples + +**Implementations:** + +1. `SequentialQueue` - Simple FIFO ordering, iterates tasks in input order +2. `PriorityQueue` - Uses `TaskProtocol.priority` for scheduling (higher priority first) +3. `AdaptiveQueue` - Placeholder for future feedback-based scheduling (currently falls back to sequential) + +## Phase 3: Parallel Execution + +**Modified:** `maseval/core/benchmark.py` + +- Added `max_workers: int = 1` parameter to `Benchmark.run()` for controlling parallelism +- Added `queue: Optional[TaskQueue] = None` parameter for custom scheduling (defaults to `SequentialQueue`) +- Added thread-safety mechanisms: + - `self._reports_lock: threading.Lock` for safe report collection from multiple threads + - `self._callback_lock: threading.Lock` for serialized callback invocation +- New methods: + - `_invoke_callbacks(method_name, *args, **kwargs)` - thread-safe callback invocation + - `_append_report_safe(report)` - thread-safe report collection + - `_execute_task_repetition(task, repeat_idx, context)` - single task execution with timeout support + - `_run_sequential(queue)` - sequential execution (backward compatible) + - `_run_parallel(queue, max_workers)` - parallel execution using `ThreadPoolExecutor` + +**Backward Compatibility:** + +- `max_workers=1` (default) uses `_run_sequential()`, preserving existing behavior +- `max_workers>1` uses `_run_parallel()` with thread pool + +## Phase 4: AdaptiveQueue (Placeholder) + +- `AdaptiveQueue` class created as placeholder for collaborator implementation +- Intended for feedback-based scheduling that reorders remaining tasks based on execution results +- Currently falls back to sequential iteration + +## Updated Exports + +**Modified:** `maseval/__init__.py` + +New public exports: + +- `TaskProtocol`, `TimeoutAction` - task execution configuration +- `ComponentRegistry` - thread-safe component registration +- `TaskContext` - timeout checking context +- `TaskQueue`, `SequentialQueue`, `PriorityQueue`, `AdaptiveQueue` - scheduling abstractions +- `TaskTimeoutError` - timeout exception + +## Test Updates + +- Updated 2 test files that accessed internal registry attributes (`_trace_registry`, `_component_id_map`, `_config_registry`) +- Changed to access through `benchmark._registry._trace_registry` pattern +- All 666 tests pass diff --git a/TEST_PLAN.md b/TEST_PLAN.md new file mode 100644 index 0000000..f4a1ddd --- /dev/null +++ b/TEST_PLAN.md @@ -0,0 +1,359 @@ +# Test Plan: Parallel Task Execution Engine + +This document outlines the testing strategy for the parallel execution implementation. It covers new tests to add, existing tests to adapt, and tests that can be removed. + +--- + +## 1. New Tests to Add + +### 1.1 ComponentRegistry Tests (`tests/test_core/test_registry.py`) + +**Thread Safety Tests:** + +- `test_registry_thread_isolation` - Verify that registrations in one thread don't appear in another thread +- `test_registry_concurrent_registration` - Multiple threads registering components simultaneously without data races +- `test_registry_concurrent_collect_traces` - Multiple threads calling `collect_traces()` simultaneously +- `test_registry_concurrent_collect_configs` - Multiple threads calling `collect_configs()` simultaneously +- `test_registry_clear_only_affects_current_thread` - Calling `clear()` in one thread doesn't affect other threads + +**Basic Functionality Tests:** + +- `test_registry_register_traceable_component` - Component registered for tracing +- `test_registry_register_configurable_component` - Component also registered in config registry +- `test_registry_duplicate_key_idempotent` - Same component, same key is idempotent +- `test_registry_duplicate_component_different_key_raises` - Same component, different key raises ValueError +- `test_registry_collect_traces_structure` - Verify trace output structure +- `test_registry_collect_configs_structure` - Verify config output structure +- `test_registry_benchmark_config_included` - benchmark_config passed to constructor appears in configs + +### 1.2 TaskContext Tests (`tests/test_core/test_context.py`) + +**Timeout Behavior Tests:** + +- `test_context_no_timeout` - Context without deadline never expires +- `test_context_with_timeout_not_expired` - Context before deadline shows remaining time +- `test_context_with_timeout_expired` - Context after deadline shows is_expired=True +- `test_context_check_timeout_raises_on_expiry` - `check_timeout()` raises TaskTimeoutError when expired +- `test_context_check_timeout_with_partial_traces` - TaskTimeoutError includes partial traces +- `test_context_elapsed_increases` - `elapsed` property increases over time +- `test_context_remaining_decreases` - `remaining` property decreases over time + +### 1.3 TaskQueue Tests (`tests/test_core/test_queue.py`) + +**SequentialQueue Tests:** + +- `test_sequential_queue_order_preserved` - Tasks yielded in original order +- `test_sequential_queue_iteration_complete` - All tasks yielded exactly once +- `test_sequential_queue_empty_collection` - Empty collection yields nothing +- `test_sequential_queue_single_task` - Single task handled correctly + +**PriorityQueue Tests:** + +- `test_priority_queue_high_priority_first` - Higher priority tasks come first +- `test_priority_queue_stable_sort` - Equal priority maintains original order +- `test_priority_queue_default_priority_zero` - Tasks without explicit priority treated as 0 +- `test_priority_queue_negative_priority` - Negative priorities handled correctly + +**AdaptiveQueue Tests:** + +- `test_adaptive_queue_on_task_complete_updates_state` - Completed tasks moved to completed list +- `test_adaptive_queue_stop_terminates_iteration` - Calling `stop()` ends iteration early +- `test_adaptive_queue_should_continue_false_after_stop` - `should_continue()` returns False after stop + +### 1.4 TaskProtocol Tests (`tests/test_core/test_task_protocol.py`) + +- `test_task_protocol_defaults` - Default values: timeout=None, action=SKIP, retries=0, priority=0, tags={} +- `test_task_has_protocol_field` - Task dataclass has protocol field +- `test_task_protocol_custom_values` - Custom protocol values preserved +- `test_timeout_action_enum_values` - TimeoutAction has SKIP, RETRY, RAISE + +### 1.5 Parallel Execution Tests (`tests/test_core/test_benchmark/test_parallel_execution.py`) + +**Basic Parallel Execution:** + +- `test_parallel_execution_basic` - `max_workers>1` runs tasks in parallel +- `test_parallel_execution_same_results_as_sequential` - Parallel produces same reports as sequential +- `test_parallel_execution_max_workers_respected` - No more than max_workers concurrent threads +- `test_parallel_execution_single_worker_uses_sequential` - `max_workers=1` uses `_run_sequential` + +**Thread Safety - Report Collection:** + +- `test_parallel_reports_thread_safe` - Reports from parallel tasks all collected correctly +- `test_parallel_report_count_matches_task_count` - Number of reports equals tasks × repeats +- `test_parallel_report_order_independent` - Report content correct regardless of completion order + +**Thread Safety - Callbacks:** + +- `test_parallel_callbacks_serialized` - Callbacks invoked with lock (no concurrent callback execution) +- `test_parallel_callback_data_integrity` - Callback receives correct task/report data +- `test_parallel_callbacks_all_events_fire` - All lifecycle callbacks fire for each task +- `test_parallel_callback_exception_isolated` - Exception in one callback doesn't affect other tasks + +**Thread Safety - Registry:** + +- `test_parallel_registry_isolation` - Each task gets isolated registry state +- `test_parallel_traces_not_cross_contaminated` - Traces from task A don't appear in task B's report +- `test_parallel_configs_not_cross_contaminated` - Configs from task A don't appear in task B's report + +**Race Condition Tests:** + +- `test_parallel_concurrent_setup` - Multiple tasks calling setup methods simultaneously +- `test_parallel_concurrent_evaluation` - Multiple tasks being evaluated simultaneously +- `test_parallel_slow_fast_task_ordering` - Slow task in worker doesn't block fast task reports +- `test_parallel_error_in_one_task_doesnt_affect_others` - One task failing doesn't corrupt other tasks + +### 1.6 Timeout Handling Tests (`tests/test_core/test_benchmark/test_timeout_handling.py`) + +- `test_timeout_task_marked_as_timeout_status` - Timed out task has `TASK_TIMEOUT` status +- `test_timeout_partial_traces_collected` - Traces collected up to timeout point included in report +- `test_timeout_action_skip_continues_to_next` - SKIP action moves to next task +- `test_timeout_action_retry_retries_task` - RETRY action re-executes (up to max_retries) +- `test_timeout_action_raise_propagates` - RAISE action raises TaskTimeoutError +- `test_timeout_cooperative_checkpoint` - Tasks checking `context.check_timeout()` respect timeout + +### 1.7 Queue Integration Tests (`tests/test_core/test_benchmark/test_queue_integration.py`) + +- `test_run_with_custom_queue` - `benchmark.run(tasks, queue=custom_queue)` uses provided queue +- `test_run_default_queue_is_sequential` - No queue specified uses SequentialQueue +- `test_priority_queue_integration` - PriorityQueue orders execution correctly in real benchmark +- `test_queue_on_task_complete_called` - Queue's `on_task_complete` called after each task +- `test_queue_should_continue_checked` - Queue's `should_continue` checked after each task + +### 1.8 TaskTimeoutError Tests (`tests/test_core/test_exceptions.py` - extend existing) + +- `test_task_timeout_error_attributes` - Has elapsed, timeout, partial_traces attributes +- `test_task_timeout_error_message` - Message includes timeout and elapsed time +- `test_task_timeout_error_is_maseval_error` - Inherits from MASEvalError + +--- + +## 2. Existing Tests to Adapt + +### 2.1 Tests Already Adapted (completed) + +- `tests/test_core/test_benchmark/test_automatic_registration.py` + - Changed `benchmark._trace_registry` → `benchmark._registry._trace_registry` + - Changed `benchmark._component_id_map` → `benchmark._registry._component_id_map` +- `tests/test_core/test_benchmark/test_benchmark_lifecycle.py` + - Changed `benchmark._trace_registry` → `benchmark._registry._trace_registry` + - Changed `benchmark._config_registry` → `benchmark._registry._config_registry` + +### 2.2 Tests That May Need Adaptation + +**Callback Tests (`test_callback_orchestration.py`):** + +- Review `test_callback_errors_dont_break_execution` - Ensure behavior consistent with parallel mode +- Consider adding parallel variant of each callback order test + +**Lifecycle Tests (`test_benchmark_lifecycle.py`):** + +- `test_benchmark_lifecycle_hooks_order` - Verify order still guaranteed in sequential mode +- Add note/variant about callback order in parallel mode (order within task preserved, between tasks not) + +**Exception Tests (`test_exceptions.py`):** + +- Extend classification tests to include `TaskTimeoutError` → `TASK_TIMEOUT` mapping + +**Config Collection Tests (`test_config_collection.py`):** + +- Verify config collection works correctly in parallel mode +- `test_config_different_per_repetition` - May need thread-awareness verification + +--- + +## 3. Tests That Can Be Removed + +### 3.1 No Tests to Remove + +The implementation maintains backward compatibility (`max_workers=1` default), so all existing tests remain valid. No tests are obsoleted by this change. + +### 3.2 Tests That Could Be Consolidated (Optional Cleanup) + +- Some registry-related tests in `test_automatic_registration.py` and `test_benchmark_lifecycle.py` overlap in testing registry clearing. Consider consolidating into a single registry test file. + +--- + +## 4. Test Categories and Markers + +### New Pytest Markers to Consider + +```python +# conftest.py additions +pytest.mark.parallel # Tests specific to parallel execution +pytest.mark.thread_safety # Tests for race conditions and thread safety +pytest.mark.timeout # Tests for timeout handling +pytest.mark.queue # Tests for task queue abstraction +``` + +### Marker Usage + +```python +@pytest.mark.core +@pytest.mark.parallel +def test_parallel_execution_basic(): + ... + +@pytest.mark.core +@pytest.mark.thread_safety +def test_parallel_registry_isolation(): + ... +``` + +--- + +## 5. Test Infrastructure Needs + +### 5.1 New Test Fixtures + +```python +# conftest.py additions + +@pytest.fixture +def slow_benchmark(): + """Benchmark that takes configurable time per task (for parallel testing).""" + class SlowBenchmark(DummyBenchmark): + def __init__(self, delay_seconds=0.1, **kwargs): + super().__init__(**kwargs) + self.delay = delay_seconds + + def run_agents(self, agents, task, environment, query): + import time + time.sleep(self.delay) + return super().run_agents(agents, task, environment, query) + + return SlowBenchmark + +@pytest.fixture +def thread_tracking_callback(): + """Callback that records which thread each event fires on.""" + import threading + + class ThreadTracker(BenchmarkCallback): + def __init__(self): + self.thread_ids = [] + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + self.thread_ids.append(threading.current_thread().ident) + + return ThreadTracker +``` + +### 5.2 Helper Functions + +```python +def run_parallel_and_sequential(benchmark, tasks): + """Run same benchmark both ways and compare reports.""" + import copy + + seq_benchmark = copy.deepcopy(benchmark) + par_benchmark = copy.deepcopy(benchmark) + + seq_reports = seq_benchmark.run(tasks, max_workers=1) + par_reports = par_benchmark.run(tasks, max_workers=4) + + return seq_reports, par_reports + +def verify_no_cross_contamination(reports): + """Check that traces in each report only contain that task's data.""" + for report in reports: + task_id = report['task_id'] + for key, trace in report['traces'].get('agents', {}).items(): + # Verify trace belongs to this task + assert task_id in str(trace) or 'task_id' not in trace +``` + +--- + +## 6. Priority Order for Implementation + +### High Priority (Core Functionality) + +1. `test_parallel_execution.py` - Basic parallel execution verification +2. `test_registry.py` - Thread isolation is critical for correctness +3. `test_timeout_handling.py` - Timeout is a key new feature + +### Medium Priority (Integration) + +4. `test_queue.py` - Queue abstraction tests +5. `test_queue_integration.py` - Queue + Benchmark integration +6. `test_context.py` - TaskContext functionality + +### Lower Priority (Edge Cases) + +7. `test_task_protocol.py` - Simple dataclass tests +8. Extended race condition tests +9. Performance/stress tests + +--- + +## 7. Notes for Test Implementation + +### Thread Safety Testing Patterns + +```python +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +def test_concurrent_operation(): + """Pattern for testing concurrent operations.""" + results = [] + errors = [] + barrier = threading.Barrier(4) # Synchronize thread start + + def worker(worker_id): + try: + barrier.wait() # All threads start together + # Perform operation + result = do_something() + results.append((worker_id, result)) + except Exception as e: + errors.append((worker_id, e)) + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(worker, i) for i in range(4)] + for f in futures: + f.result() # Wait for completion + + assert len(errors) == 0, f"Errors occurred: {errors}" + assert len(results) == 4 +``` + +### Timing Considerations + +- Use `time.sleep()` sparingly in tests +- Consider mocking time for deterministic timeout tests +- Use threading barriers for synchronization points +- Allow tolerance in timing assertions (e.g., ±10ms) + +### Isolation Verification + +```python +def test_registry_isolation(): + """Verify thread-local storage works correctly.""" + registry = ComponentRegistry() + results = {} + + def worker(worker_id): + # Each thread should see empty registry initially + assert len(registry._trace_registry) == 0 + + # Register unique component + registry.register("test", f"comp_{worker_id}", MockComponent()) + + # Only our component should be visible + assert len(registry._trace_registry) == 1 + assert f"test:comp_{worker_id}" in registry._trace_registry + + results[worker_id] = list(registry._trace_registry.keys()) + + # Run in parallel + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(worker, i) for i in range(4)] + for f in futures: + f.result() + + # Verify isolation + for worker_id, keys in results.items(): + assert keys == [f"test:comp_{worker_id}"] +``` From ca890e22e8b5fb2408600718fa86587d0c4acbfc Mon Sep 17 00:00:00 2001 From: cemde Date: Fri, 5 Dec 2025 21:06:17 +0000 Subject: [PATCH 04/25] implemented tests --- IDEAS.md | 1 + .../test_benchmark/test_parallel_execution.py | 411 ++++++++++++++++++ tests/test_core/test_context.py | 134 ++++++ tests/test_core/test_exceptions.py | 55 +++ tests/test_core/test_queue.py | 266 ++++++++++++ tests/test_core/test_registry.py | 258 +++++++++++ tests/test_core/test_task_protocol.py | 110 +++++ 7 files changed, 1235 insertions(+) create mode 100644 IDEAS.md create mode 100644 tests/test_core/test_benchmark/test_parallel_execution.py create mode 100644 tests/test_core/test_context.py create mode 100644 tests/test_core/test_queue.py create mode 100644 tests/test_core/test_registry.py create mode 100644 tests/test_core/test_task_protocol.py diff --git a/IDEAS.md b/IDEAS.md new file mode 100644 index 0000000..57d143f --- /dev/null +++ b/IDEAS.md @@ -0,0 +1 @@ +- Guide explaining that Dataset aprpox equal to Queue + Task Collection diff --git a/tests/test_core/test_benchmark/test_parallel_execution.py b/tests/test_core/test_benchmark/test_parallel_execution.py new file mode 100644 index 0000000..4c6e280 --- /dev/null +++ b/tests/test_core/test_benchmark/test_parallel_execution.py @@ -0,0 +1,411 @@ +"""Tests for parallel task execution in Benchmark. + +These tests verify that parallel execution with max_workers > 1 works correctly, +including thread safety, report collection, and callback serialization. +""" + +import pytest +import threading +import time +from typing import List, Tuple, Optional + +from maseval import ( + BenchmarkCallback, + Task, + TaskCollection, + TaskExecutionStatus, +) +from conftest import DummyBenchmark + + +# ==================== Test Fixtures ==================== + + +class SlowBenchmark(DummyBenchmark): + """Benchmark that introduces configurable delays per task.""" + + def __init__(self, delay_seconds: float = 0.05, **kwargs): + super().__init__(**kwargs) + self.delay = delay_seconds + self.execution_times: List[Tuple[str, float, float]] = [] # (task_id, start, end) + self._timing_lock = threading.Lock() + + def run_agents(self, agents, task, environment, query): + start = time.time() + time.sleep(self.delay) + result = super().run_agents(agents, task, environment, query) + end = time.time() + + with self._timing_lock: + self.execution_times.append((str(task.id), start, end)) + + return result + + +class ThreadTrackingCallback(BenchmarkCallback): + """Callback that records which thread each event fires on.""" + + def __init__(self): + self.thread_ids: List[Tuple[str, Optional[int]]] = [] + self._lock = threading.Lock() + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + with self._lock: + self.thread_ids.append(("repeat_start", threading.current_thread().ident)) + + def on_task_repeat_end(self, benchmark, report): + with self._lock: + self.thread_ids.append(("repeat_end", threading.current_thread().ident)) + + +class OrderTrackingCallback(BenchmarkCallback): + """Callback that records the order of callback invocations.""" + + def __init__(self): + self.invocations: List[str] = [] + self._lock = threading.Lock() + + def on_run_start(self, benchmark): + with self._lock: + self.invocations.append("run_start") + + def on_task_start(self, benchmark, task): + with self._lock: + self.invocations.append(f"task_start:{task.query}") + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + with self._lock: + self.invocations.append(f"repeat_start:{task.query}:{repeat_idx}") + + def on_task_repeat_end(self, benchmark, report): + with self._lock: + self.invocations.append(f"repeat_end:{report['task_id'][:8]}") + + def on_task_end(self, benchmark, task, result): + with self._lock: + self.invocations.append(f"task_end:{task.query}") + + def on_run_end(self, benchmark, results): + with self._lock: + self.invocations.append("run_end") + + +@pytest.fixture +def parallel_tasks(): + """Create tasks for parallel execution testing.""" + return TaskCollection.from_list([{"query": f"Task {i}", "environment_data": {"index": i}} for i in range(5)]) + + +# ==================== Basic Parallel Execution Tests ==================== + + +@pytest.mark.core +class TestParallelExecutionBasics: + """Tests for basic parallel execution functionality.""" + + def test_parallel_execution_completes(self, parallel_tasks): + """Verify parallel execution completes all tasks.""" + benchmark = DummyBenchmark(agent_data={"model": "test"}) + + reports = benchmark.run(parallel_tasks, max_workers=3) + + assert len(reports) == 5 + + def test_parallel_produces_same_report_count(self, parallel_tasks): + """Parallel and sequential should produce same number of reports.""" + benchmark_seq = DummyBenchmark(agent_data={"model": "test"}) + benchmark_par = DummyBenchmark(agent_data={"model": "test"}) + + reports_seq = benchmark_seq.run(parallel_tasks, max_workers=1) + reports_par = benchmark_par.run(parallel_tasks, max_workers=3) + + assert len(reports_seq) == len(reports_par) + + def test_parallel_reports_have_correct_structure(self, parallel_tasks): + """Verify parallel reports have expected fields.""" + benchmark = DummyBenchmark(agent_data={"model": "test"}) + + reports = benchmark.run(parallel_tasks, max_workers=2) + + for report in reports: + assert "task_id" in report + assert "repeat_idx" in report + assert "status" in report + assert "traces" in report + assert "config" in report + assert "eval" in report + + def test_single_worker_uses_sequential(self, parallel_tasks): + """max_workers=1 should behave identically to sequential.""" + callback = OrderTrackingCallback() + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[callback], + ) + + benchmark.run(parallel_tasks, max_workers=1) + + # Verify ordering is strictly sequential (task_start before all repeat_starts) + assert callback.invocations[0] == "run_start" + assert callback.invocations[1] == "task_start:Task 0" + assert callback.invocations[-1] == "run_end" + + def test_parallel_with_repetitions(self): + """Verify parallel execution with n_task_repeats > 1.""" + tasks = TaskCollection.from_list( + [ + {"query": "T1", "environment_data": {}}, + {"query": "T2", "environment_data": {}}, + ] + ) + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + n_task_repeats=3, + ) + + reports = benchmark.run(tasks, max_workers=2) + + assert len(reports) == 6 # 2 tasks × 3 repeats + + # Verify repeat indices + repeat_indices = [r["repeat_idx"] for r in reports] + assert set(repeat_indices) == {0, 1, 2} + + +# ==================== Thread Safety Tests ==================== + + +@pytest.mark.core +class TestParallelThreadSafety: + """Tests for thread safety in parallel execution.""" + + def test_reports_all_collected(self, parallel_tasks): + """All reports should be collected regardless of completion order.""" + benchmark = SlowBenchmark( + agent_data={"model": "test"}, + delay_seconds=0.02, + ) + + reports = benchmark.run(parallel_tasks, max_workers=4) + + assert len(reports) == 5 + task_ids = {r["task_id"] for r in reports} + assert len(task_ids) == 5 + + def test_traces_not_cross_contaminated(self, parallel_tasks): + """Traces from one task should not appear in another's report.""" + benchmark = DummyBenchmark(agent_data={"model": "test"}) + + reports = benchmark.run(parallel_tasks, max_workers=3) + + for report in reports: + # Each report should have its own traces + assert report["traces"] is not None + assert "metadata" in report["traces"] + + def test_callbacks_receive_correct_data(self): + """Callbacks should receive correct task/report data in parallel.""" + tasks = TaskCollection.from_list([{"query": f"Query_{i}", "environment_data": {"idx": i}} for i in range(3)]) + + received_data = [] + lock = threading.Lock() + + class DataCapturingCallback(BenchmarkCallback): + def on_task_repeat_end(self, benchmark, report): + with lock: + received_data.append( + { + "task_id": report["task_id"], + "status": report["status"], + } + ) + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[DataCapturingCallback()], + ) + + benchmark.run(tasks, max_workers=2) + + assert len(received_data) == 3 + statuses = {d["status"] for d in received_data} + assert statuses == {"success"} + + def test_callback_exception_propagates(self, parallel_tasks): + """Callback exceptions propagate (current behavior).""" + call_count = [0] + + class FailingCallback(BenchmarkCallback): + def on_task_repeat_end(self, benchmark, report): + call_count[0] += 1 + if call_count[0] == 2: + raise RuntimeError("Intentional failure") + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[FailingCallback()], + ) + + # Current behavior: callback exceptions propagate + with pytest.raises(RuntimeError, match="Intentional failure"): + benchmark.run(parallel_tasks, max_workers=2) + + +# ==================== Concurrency Verification Tests ==================== + + +@pytest.mark.core +class TestParallelConcurrency: + """Tests verifying actual concurrent execution.""" + + def test_parallel_faster_than_sequential(self): + """Parallel execution should be faster for I/O-bound tasks.""" + tasks = TaskCollection.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) + delay = 0.05 + + # Sequential timing + benchmark_seq = SlowBenchmark(agent_data={"model": "test"}, delay_seconds=delay) + start_seq = time.time() + benchmark_seq.run(tasks, max_workers=1) + time_seq = time.time() - start_seq + + # Parallel timing + benchmark_par = SlowBenchmark(agent_data={"model": "test"}, delay_seconds=delay) + start_par = time.time() + benchmark_par.run(tasks, max_workers=4) + time_par = time.time() - start_par + + # Parallel should be significantly faster (at least 2x) + assert time_par < time_seq * 0.7 + + def test_execution_overlaps(self): + """Task executions should overlap in parallel mode.""" + tasks = TaskCollection.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(3)]) + + benchmark = SlowBenchmark( + agent_data={"model": "test"}, + delay_seconds=0.05, + ) + + benchmark.run(tasks, max_workers=3) + + # Check for overlapping execution times + times = benchmark.execution_times + assert len(times) == 3 + + # At least one pair should overlap + overlaps = 0 + for i in range(len(times)): + for j in range(i + 1, len(times)): + _, start_i, end_i = times[i] + _, start_j, end_j = times[j] + # Check if intervals overlap + if start_i < end_j and start_j < end_i: + overlaps += 1 + + assert overlaps > 0, "Expected overlapping execution in parallel mode" + + +# ==================== Error Handling Tests ==================== + + +@pytest.mark.core +class TestParallelErrorHandling: + """Tests for error handling in parallel execution.""" + + def test_error_in_one_task_doesnt_stop_others(self): + """One task failing should not prevent other tasks from completing.""" + + class FailingBenchmark(DummyBenchmark): + def run_agents(self, agents, task, environment, query): + if "fail" in query.lower(): + raise RuntimeError("Intentional failure") + return super().run_agents(agents, task, environment, query) + + tasks = TaskCollection.from_list( + [ + {"query": "Normal 1", "environment_data": {}}, + {"query": "FAIL task", "environment_data": {}}, + {"query": "Normal 2", "environment_data": {}}, + ] + ) + + benchmark = FailingBenchmark(agent_data={"model": "test"}) + reports = benchmark.run(tasks, max_workers=2) + + assert len(reports) == 3 + + statuses = {r["status"] for r in reports} + assert TaskExecutionStatus.SUCCESS.value in statuses + assert TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value in statuses + + def test_all_tasks_get_reports_even_with_failures(self): + """Every task should produce a report even if some fail.""" + + class HalfFailingBenchmark(DummyBenchmark): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._call_count = 0 + self._lock = threading.Lock() + + def run_agents(self, agents, task, environment, query): + with self._lock: + self._call_count += 1 + should_fail = self._call_count % 2 == 0 + + if should_fail: + raise ValueError("Every other task fails") + return super().run_agents(agents, task, environment, query) + + tasks = TaskCollection.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) + + benchmark = HalfFailingBenchmark(agent_data={"model": "test"}) + reports = benchmark.run(tasks, max_workers=2) + + assert len(reports) == 4 + + +# ==================== Queue Integration Tests ==================== + + +@pytest.mark.core +class TestParallelQueueIntegration: + """Tests for queue integration with parallel execution.""" + + def test_custom_queue_respected(self, parallel_tasks): + """Custom queue ordering should be respected.""" + from maseval.core.queue import PriorityQueue + + # Create tasks with priorities + prioritized_tasks = TaskCollection( + [ + Task( + query=f"P{p}", + environment_data={}, + protocol=__import__("maseval.core.task", fromlist=["TaskProtocol"]).TaskProtocol(priority=p), + ) + for p in [1, 5, 3, 2, 4] + ] + ) + + agent_data_list = [{"model": "test"}] * 5 + queue = PriorityQueue(prioritized_tasks, agent_data_list) + + # Track execution order + execution_order = [] + lock = threading.Lock() + + class OrderTracker(BenchmarkCallback): + def on_task_repeat_start(self, benchmark, task, repeat_idx): + with lock: + execution_order.append(task.query) + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[OrderTracker()], + ) + + # With max_workers=1, order should be strictly by priority + benchmark.run(prioritized_tasks, queue=queue, max_workers=1) + + assert execution_order == ["P5", "P4", "P3", "P2", "P1"] diff --git a/tests/test_core/test_context.py b/tests/test_core/test_context.py new file mode 100644 index 0000000..d50e467 --- /dev/null +++ b/tests/test_core/test_context.py @@ -0,0 +1,134 @@ +"""Tests for TaskContext timeout handling. + +These tests verify that TaskContext correctly tracks time, checks deadlines, +and raises TaskTimeoutError when appropriate. +""" + +import pytest +import time + +from maseval.core.context import TaskContext +from maseval.core.exceptions import TaskTimeoutError + + +@pytest.mark.core +class TestTaskContextBasics: + """Tests for basic TaskContext functionality.""" + + def test_context_no_timeout_never_expires(self): + """Context without deadline should never expire.""" + context = TaskContext(deadline=None) + + assert context.deadline is None + assert context.remaining is None + assert context.is_expired is False + + # check_timeout should be no-op + context.check_timeout() # Should not raise + + def test_context_with_timeout_not_expired(self): + """Context before deadline should show remaining time.""" + context = TaskContext(deadline=10.0) + + assert context.deadline == 10.0 + assert context.is_expired is False + assert context.remaining is not None and context.remaining > 9.0 + assert context.elapsed < 1.0 # Just started + + def test_context_elapsed_increases(self): + """Elapsed property should increase over time.""" + context = TaskContext(deadline=10.0) + + elapsed1 = context.elapsed + time.sleep(0.05) + elapsed2 = context.elapsed + + assert elapsed2 > elapsed1 + + def test_context_remaining_decreases(self): + """Remaining property should decrease over time.""" + context = TaskContext(deadline=10.0) + + remaining1 = context.remaining + time.sleep(0.05) + remaining2 = context.remaining + + assert remaining1 is not None and remaining2 is not None + assert remaining2 < remaining1 + + +@pytest.mark.core +class TestTaskContextTimeout: + """Tests for TaskContext timeout behavior.""" + + def test_context_is_expired_after_deadline(self): + """Context after deadline should show is_expired=True.""" + context = TaskContext(deadline=0.01) # Very short deadline + + time.sleep(0.02) # Wait past deadline + + assert context.is_expired is True + assert context.remaining == 0.0 + + def test_check_timeout_raises_on_expiry(self): + """check_timeout() should raise TaskTimeoutError when expired.""" + context = TaskContext(deadline=0.01) + + time.sleep(0.02) + + with pytest.raises(TaskTimeoutError) as exc_info: + context.check_timeout() + + assert exc_info.value.timeout == 0.01 + assert exc_info.value.elapsed >= 0.01 + + def test_check_timeout_includes_partial_traces(self): + """TaskTimeoutError should include partial traces if set.""" + context = TaskContext(deadline=0.01) + partial_traces = {"agents": {"agent1": {"steps": 3}}} + context.set_collected_traces(partial_traces) + + time.sleep(0.02) + + with pytest.raises(TaskTimeoutError) as exc_info: + context.check_timeout() + + assert exc_info.value.partial_traces == partial_traces + + def test_check_timeout_no_traces_if_not_set(self): + """TaskTimeoutError should have empty traces if not set.""" + context = TaskContext(deadline=0.01) + + time.sleep(0.02) + + with pytest.raises(TaskTimeoutError) as exc_info: + context.check_timeout() + + assert exc_info.value.partial_traces == {} + + def test_check_timeout_does_not_raise_before_deadline(self): + """check_timeout() should not raise before deadline.""" + context = TaskContext(deadline=10.0) + + # Should not raise + for _ in range(10): + context.check_timeout() + + def test_set_collected_traces_updates_context(self): + """set_collected_traces should store traces for later use.""" + context = TaskContext(deadline=10.0) + + traces = {"test": "data"} + context.set_collected_traces(traces) + + assert context.collected_traces == traces + + def test_context_timing_accuracy(self): + """Elapsed time should be reasonably accurate.""" + context = TaskContext(deadline=10.0) + + time.sleep(0.1) + elapsed = context.elapsed + + # Allow 50ms tolerance + assert 0.05 < elapsed < 0.2 diff --git a/tests/test_core/test_exceptions.py b/tests/test_core/test_exceptions.py index 0f66eea..f8b5cf1 100644 --- a/tests/test_core/test_exceptions.py +++ b/tests/test_core/test_exceptions.py @@ -163,6 +163,61 @@ def setup_agents(self, agent_data, environment, task, user): assert error["details"]["actual"] == "str" +@pytest.mark.core +class TestTaskTimeoutError: + """Tests for TaskTimeoutError exception.""" + + def test_timeout_error_attributes(self): + """TaskTimeoutError should have elapsed, timeout, partial_traces attributes.""" + from maseval import TaskTimeoutError + + error = TaskTimeoutError( + "Task exceeded 60s deadline", + component="execution_loop", + elapsed=62.5, + timeout=60.0, + partial_traces={"agents": {"agent1": {"steps": 3}}}, + ) + + assert error.elapsed == 62.5 + assert error.timeout == 60.0 + assert error.partial_traces == {"agents": {"agent1": {"steps": 3}}} + + def test_timeout_error_message(self): + """TaskTimeoutError message should include timing info.""" + from maseval import TaskTimeoutError + + error = TaskTimeoutError( + "Task exceeded 60s deadline after 62.5s", + component="timeout_check", + elapsed=62.5, + timeout=60.0, + ) + + assert "60s" in str(error) + assert "62.5s" in str(error) + + def test_timeout_error_inherits_from_maseval_error(self): + """TaskTimeoutError should inherit from MASEvalError.""" + from maseval import TaskTimeoutError + from maseval.core.exceptions import MASEvalError + + error = TaskTimeoutError("timeout", elapsed=1.0, timeout=0.5) + + assert isinstance(error, MASEvalError) + assert isinstance(error, Exception) + + def test_timeout_error_defaults(self): + """TaskTimeoutError should have sensible defaults.""" + from maseval import TaskTimeoutError + + error = TaskTimeoutError("timeout") + + assert error.elapsed == 0.0 + assert error.timeout == 0.0 + assert error.partial_traces == {} + + class TestAgentErrorSuggestion: """Tests for AgentError suggestion field.""" diff --git a/tests/test_core/test_queue.py b/tests/test_core/test_queue.py new file mode 100644 index 0000000..2ded59f --- /dev/null +++ b/tests/test_core/test_queue.py @@ -0,0 +1,266 @@ +"""Tests for TaskQueue implementations. + +These tests verify that SequentialQueue, PriorityQueue, and AdaptiveQueue +correctly order and iterate over tasks. +""" + +import pytest +from typing import Any, Dict, List + +from maseval import Task, TaskCollection +from maseval.core.task import TaskProtocol +from maseval.core.queue import SequentialQueue, PriorityQueue, AdaptiveQueue + + +# ==================== Fixtures ==================== + + +@pytest.fixture +def task_collection_with_priorities() -> TaskCollection: + """Create tasks with different priorities.""" + tasks = [] + for i, priority in enumerate([0, 5, 2, 8, 1]): + task = Task( + query=f"Query {i}", + environment_data={"index": i}, + protocol=TaskProtocol(priority=priority), + ) + tasks.append(task) + return TaskCollection(tasks) + + +@pytest.fixture +def agent_data_list() -> List[Dict[str, Any]]: + """Agent data list matching 5 tasks.""" + return [{"id": i} for i in range(5)] + + +@pytest.fixture +def simple_task_collection() -> TaskCollection: + """Simple task collection for basic tests.""" + return TaskCollection.from_list( + [ + {"query": "Q1", "environment_data": {}}, + {"query": "Q2", "environment_data": {}}, + {"query": "Q3", "environment_data": {}}, + ] + ) + + +@pytest.fixture +def simple_agent_data() -> List[Dict[str, Any]]: + """Agent data matching simple_task_collection.""" + return [{"model": "test"}] * 3 + + +# ==================== SequentialQueue Tests ==================== + + +@pytest.mark.core +class TestSequentialQueue: + """Tests for SequentialQueue ordering.""" + + def test_order_preserved(self, simple_task_collection, simple_agent_data): + """Tasks should be yielded in original order.""" + queue = SequentialQueue(simple_task_collection, simple_agent_data) + + queries = [task.query for task, _ in queue] + + assert queries == ["Q1", "Q2", "Q3"] + + def test_all_tasks_yielded(self, simple_task_collection, simple_agent_data): + """All tasks should be yielded exactly once.""" + queue = SequentialQueue(simple_task_collection, simple_agent_data) + + count = sum(1 for _ in queue) + + assert count == 3 + + def test_empty_collection(self): + """Empty collection should yield nothing.""" + queue = SequentialQueue(TaskCollection([]), []) + + items = list(queue) + + assert items == [] + + def test_single_task(self): + """Single task should be handled correctly.""" + tasks = TaskCollection.from_list([{"query": "Only one"}]) + queue = SequentialQueue(tasks, [{"model": "test"}]) + + items = list(queue) + + assert len(items) == 1 + assert items[0][0].query == "Only one" + + def test_agent_data_paired_correctly(self, simple_task_collection): + """Agent data should be paired with correct task.""" + agent_data = [{"id": 1}, {"id": 2}, {"id": 3}] + queue = SequentialQueue(simple_task_collection, agent_data) + + pairs = list(queue) + + assert pairs[0][1]["id"] == 1 + assert pairs[1][1]["id"] == 2 + assert pairs[2][1]["id"] == 3 + + +# ==================== PriorityQueue Tests ==================== + + +@pytest.mark.core +class TestPriorityQueue: + """Tests for PriorityQueue priority ordering.""" + + def test_high_priority_first(self, task_collection_with_priorities, agent_data_list): + """Higher priority tasks should come first.""" + queue = PriorityQueue(task_collection_with_priorities, agent_data_list) + + priorities = [task.protocol.priority for task, _ in queue] + + assert priorities == [8, 5, 2, 1, 0] + + def test_stable_sort_for_equal_priorities(self): + """Tasks with equal priority should maintain original order.""" + tasks = TaskCollection( + [ + Task(query="First", environment_data={}, protocol=TaskProtocol(priority=5)), + Task(query="Second", environment_data={}, protocol=TaskProtocol(priority=5)), + Task(query="Third", environment_data={}, protocol=TaskProtocol(priority=5)), + ] + ) + agent_data = [{}, {}, {}] + queue = PriorityQueue(tasks, agent_data) + + queries = [task.query for task, _ in queue] + + # Python's sort is stable, so original order should be preserved + assert queries == ["First", "Second", "Third"] + + def test_default_priority_zero(self, simple_task_collection, simple_agent_data): + """Tasks without explicit priority should have priority 0.""" + queue = PriorityQueue(simple_task_collection, simple_agent_data) + + for task, _ in queue: + assert task.protocol.priority == 0 + + def test_negative_priority(self): + """Negative priorities should be handled correctly.""" + tasks = TaskCollection( + [ + Task(query="Low", environment_data={}, protocol=TaskProtocol(priority=-5)), + Task(query="Normal", environment_data={}, protocol=TaskProtocol(priority=0)), + Task(query="High", environment_data={}, protocol=TaskProtocol(priority=5)), + ] + ) + queue = PriorityQueue(tasks, [{}, {}, {}]) + + queries = [task.query for task, _ in queue] + + assert queries == ["High", "Normal", "Low"] + + def test_agent_data_follows_priority(self, task_collection_with_priorities, agent_data_list): + """Agent data should follow task after priority sort.""" + queue = PriorityQueue(task_collection_with_priorities, agent_data_list) + + pairs = list(queue) + + # Task with priority 8 was at index 3 + assert pairs[0][1]["id"] == 3 + # Task with priority 5 was at index 1 + assert pairs[1][1]["id"] == 1 + + +# ==================== AdaptiveQueue Tests ==================== + + +@pytest.mark.core +class TestAdaptiveQueue: + """Tests for AdaptiveQueue adaptive behavior.""" + + def test_basic_iteration_with_completion(self, simple_task_collection, simple_agent_data): + """AdaptiveQueue should yield all tasks when on_task_complete is called.""" + queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + + count = 0 + for task, agent_data in queue: + count += 1 + # Must call on_task_complete to progress to next task + queue.on_task_complete(task, {"status": "success"}) + + assert count == 3 + + def test_on_task_complete_moves_to_completed(self, simple_task_collection, simple_agent_data): + """on_task_complete should move task to completed list.""" + queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + task, _ = next(iter(queue)) + + assert len(queue._completed) == 0 + + queue.on_task_complete(task, {"status": "success"}) + + assert len(queue._completed) == 1 + assert queue._completed[0][0].id == task.id + + def test_stop_terminates_iteration(self, simple_task_collection, simple_agent_data): + """Calling stop() should end iteration early.""" + queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + + items = [] + for task, agent_data in queue: + items.append(task) + queue.stop() # Stop immediately after first yield + + assert len(items) == 1 + + def test_should_continue_false_after_stop(self, simple_task_collection, simple_agent_data): + """should_continue() should return False after stop().""" + queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + + assert queue.should_continue() is True + + queue.stop() + + assert queue.should_continue() is False + + def test_should_continue_false_when_empty(self): + """should_continue() should return False when no pending tasks.""" + queue = AdaptiveQueue(TaskCollection([]), []) + + assert queue.should_continue() is False + + def test_pending_decreases_after_completion(self, simple_task_collection, simple_agent_data): + """Pending list should shrink as tasks complete.""" + queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + + assert len(queue._pending) == 3 + + task, _ = next(iter(queue)) + queue.on_task_complete(task, {"status": "success"}) + + assert len(queue._pending) == 2 + assert len(queue._completed) == 1 + + +# ==================== Queue Integration Tests ==================== + + +@pytest.mark.core +class TestQueueCallbacks: + """Tests for queue callback mechanisms.""" + + def test_on_task_complete_called(self, simple_task_collection, simple_agent_data): + """on_task_complete should be callable without error.""" + queue = SequentialQueue(simple_task_collection, simple_agent_data) + + for task, _ in queue: + # SequentialQueue's on_task_complete is a no-op, but should not raise + queue.on_task_complete(task, {"status": "success"}) + + def test_should_continue_always_true_for_sequential(self, simple_task_collection, simple_agent_data): + """SequentialQueue should always return True for should_continue.""" + queue = SequentialQueue(simple_task_collection, simple_agent_data) + + for task, _ in queue: + assert queue.should_continue() is True diff --git a/tests/test_core/test_registry.py b/tests/test_core/test_registry.py new file mode 100644 index 0000000..925a5f1 --- /dev/null +++ b/tests/test_core/test_registry.py @@ -0,0 +1,258 @@ +"""Tests for ComponentRegistry thread safety and functionality. + +These tests verify that ComponentRegistry correctly isolates state between +threads and provides proper trace/config collection. +""" + +import pytest +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict + +from maseval.core.registry import ComponentRegistry +from maseval.core.tracing import TraceableMixin +from maseval.core.config import ConfigurableMixin + + +# ==================== Test Components ==================== + + +class MockTraceableComponent(TraceableMixin): + """Component that implements TraceableMixin for testing.""" + + def __init__(self, name: str, trace_data: Dict[str, Any] = None): + super().__init__() + self._name = name + self._trace_data = trace_data or {"component": name} + + def gather_traces(self) -> Dict[str, Any]: + return { + "name": self._name, + **self._trace_data, + } + + +class MockConfigurableComponent(TraceableMixin, ConfigurableMixin): + """Component that implements both TraceableMixin and ConfigurableMixin.""" + + def __init__(self, name: str, config: Dict[str, Any] = None): + TraceableMixin.__init__(self) + ConfigurableMixin.__init__(self) + self._name = name + self._config = config or {"setting": "default"} + + def gather_traces(self) -> Dict[str, Any]: + return {"name": self._name, "traced": True} + + def gather_config(self) -> Dict[str, Any]: + return {"name": self._name, **self._config} + + +# ==================== Basic Functionality Tests ==================== + + +@pytest.mark.core +class TestComponentRegistryBasics: + """Tests for basic ComponentRegistry functionality.""" + + def test_register_traceable_component(self): + """Verify component registered for tracing.""" + registry = ComponentRegistry() + component = MockTraceableComponent("test") + + result = registry.register("agents", "my_agent", component) + + assert result is component + assert "agents:my_agent" in registry._trace_registry + assert registry._trace_registry["agents:my_agent"] is component + + def test_register_configurable_component(self): + """Verify configurable component registered in both registries.""" + registry = ComponentRegistry() + component = MockConfigurableComponent("test") + + registry.register("models", "my_model", component) + + assert "models:my_model" in registry._trace_registry + assert "models:my_model" in registry._config_registry + + def test_duplicate_key_idempotent(self): + """Same component, same key should be idempotent.""" + registry = ComponentRegistry() + component = MockTraceableComponent("test") + + registry.register("agents", "agent1", component) + registry.register("agents", "agent1", component) # Same key, no error + + assert len(registry._trace_registry) == 1 + + def test_duplicate_component_different_key_raises(self): + """Same component with different key should raise ValueError.""" + registry = ComponentRegistry() + component = MockTraceableComponent("test") + + registry.register("agents", "name1", component) + + with pytest.raises(ValueError) as exc_info: + registry.register("agents", "name2", component) + + assert "already registered" in str(exc_info.value) + assert "agents:name1" in str(exc_info.value) + + def test_clear_removes_all_registrations(self): + """Clear should remove all registrations.""" + registry = ComponentRegistry() + registry.register("agents", "a1", MockTraceableComponent("a1")) + registry.register("models", "m1", MockConfigurableComponent("m1")) + + registry.clear() + + assert len(registry._trace_registry) == 0 + assert len(registry._config_registry) == 0 + assert len(registry._component_id_map) == 0 + + def test_collect_traces_structure(self): + """Verify trace output has expected structure.""" + registry = ComponentRegistry() + agent = MockTraceableComponent("agent1", {"steps": 5}) + registry.register("agents", "agent1", agent) + + traces = registry.collect_traces() + + assert "metadata" in traces + assert "agents" in traces + assert "agent1" in traces["agents"] + assert traces["agents"]["agent1"]["steps"] == 5 + + def test_collect_configs_structure(self): + """Verify config output has expected structure.""" + registry = ComponentRegistry(benchmark_config={"name": "test_benchmark"}) + model = MockConfigurableComponent("model1", {"temperature": 0.7}) + registry.register("models", "model1", model) + + configs = registry.collect_configs() + + assert "metadata" in configs + assert "benchmark" in configs + assert configs["benchmark"]["name"] == "test_benchmark" + assert "models" in configs + assert "model1" in configs["models"] + + +# ==================== Thread Safety Tests ==================== + + +@pytest.mark.core +class TestComponentRegistryThreadSafety: + """Tests for ComponentRegistry thread isolation.""" + + def test_registry_thread_isolation(self): + """Verify registrations in one thread don't appear in another.""" + registry = ComponentRegistry() + results = {} + barrier = threading.Barrier(2) + + def worker(worker_id: int): + barrier.wait() # Synchronize start + + # Register unique component + component = MockTraceableComponent(f"comp_{worker_id}") + registry.register("agents", f"agent_{worker_id}", component) + + # Record what this thread sees + results[worker_id] = list(registry._trace_registry.keys()) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should only see its own component + assert results[0] == ["agents:agent_0"] + assert results[1] == ["agents:agent_1"] + + def test_clear_only_affects_current_thread(self): + """Clearing in one thread shouldn't affect another.""" + registry = ComponentRegistry() + thread1_sees_after_clear = [] + thread2_sees_after_clear = [] + barrier = threading.Barrier(2) + sync_point = threading.Barrier(2) + + def thread1_worker(): + registry.register("agents", "t1_agent", MockTraceableComponent("t1")) + barrier.wait() # Both threads have registered + sync_point.wait() # Wait for thread 2 to check + + registry.clear() + thread1_sees_after_clear.extend(list(registry._trace_registry.keys())) + + def thread2_worker(): + registry.register("agents", "t2_agent", MockTraceableComponent("t2")) + barrier.wait() # Both threads have registered + sync_point.wait() # Sync before thread 1 clears + + # Wait a bit for thread 1 to clear + time.sleep(0.05) + thread2_sees_after_clear.extend(list(registry._trace_registry.keys())) + + t1 = threading.Thread(target=thread1_worker) + t2 = threading.Thread(target=thread2_worker) + t1.start() + t2.start() + t1.join() + t2.join() + + # Thread 1 cleared its own registry + assert thread1_sees_after_clear == [] + # Thread 2 still has its component + assert thread2_sees_after_clear == ["agents:t2_agent"] + + def test_concurrent_registration_no_race(self): + """Multiple threads registering simultaneously without races.""" + registry = ComponentRegistry() + errors = [] + barrier = threading.Barrier(4) + + def worker(worker_id: int): + try: + barrier.wait() + for i in range(10): + component = MockTraceableComponent(f"comp_{worker_id}_{i}") + registry.register("agents", f"agent_{worker_id}_{i}", component) + except Exception as e: + errors.append((worker_id, e)) + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(worker, i) for i in range(4)] + for f in futures: + f.result() + + assert len(errors) == 0, f"Errors occurred: {errors}" + + def test_concurrent_collect_traces(self): + """Multiple threads collecting traces simultaneously.""" + registry = ComponentRegistry() + results = {} + barrier = threading.Barrier(4) + + def worker(worker_id: int): + # Each thread registers its own component + registry.register("agents", f"agent_{worker_id}", MockTraceableComponent(f"agent_{worker_id}")) + barrier.wait() + + # All threads collect simultaneously + traces = registry.collect_traces() + results[worker_id] = traces + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(worker, i) for i in range(4)] + for f in futures: + f.result() + + # Each thread should see only its own agent + for worker_id, traces in results.items(): + assert f"agent_{worker_id}" in traces["agents"] + assert len(traces["agents"]) == 1 diff --git a/tests/test_core/test_task_protocol.py b/tests/test_core/test_task_protocol.py new file mode 100644 index 0000000..f1c7a5f --- /dev/null +++ b/tests/test_core/test_task_protocol.py @@ -0,0 +1,110 @@ +"""Tests for TaskProtocol and TimeoutAction. + +These tests verify that TaskProtocol correctly configures task execution +parameters and that TimeoutAction enum values are correct. +""" + +import pytest +from maseval import Task, TaskCollection +from maseval.core.task import TaskProtocol, TimeoutAction + + +@pytest.mark.core +class TestTimeoutAction: + """Tests for TimeoutAction enum.""" + + def test_enum_values(self): + """TimeoutAction should have expected values.""" + assert TimeoutAction.SKIP.value == "skip" + assert TimeoutAction.RETRY.value == "retry" + assert TimeoutAction.EXTEND.value == "extend" + + def test_enum_members(self): + """TimeoutAction should have expected members.""" + members = list(TimeoutAction) + assert len(members) == 3 + assert TimeoutAction.SKIP in members + assert TimeoutAction.RETRY in members + assert TimeoutAction.EXTEND in members + + +@pytest.mark.core +class TestTaskProtocol: + """Tests for TaskProtocol dataclass.""" + + def test_default_values(self): + """TaskProtocol should have sensible defaults.""" + protocol = TaskProtocol() + + assert protocol.timeout_seconds is None + assert protocol.timeout_action == TimeoutAction.SKIP + assert protocol.max_retries == 0 + assert protocol.priority == 0 + assert protocol.tags == {} + + def test_custom_values(self): + """TaskProtocol should accept custom values.""" + protocol = TaskProtocol( + timeout_seconds=60.0, + timeout_action=TimeoutAction.RETRY, + max_retries=3, + priority=10, + tags={"category": "hard", "group": "A"}, + ) + + assert protocol.timeout_seconds == 60.0 + assert protocol.timeout_action == TimeoutAction.RETRY + assert protocol.max_retries == 3 + assert protocol.priority == 10 + assert protocol.tags == {"category": "hard", "group": "A"} + + def test_tags_isolation(self): + """Tags dict should be independent per instance.""" + p1 = TaskProtocol() + p2 = TaskProtocol() + + p1.tags["key"] = "value" + + assert "key" not in p2.tags + + +@pytest.mark.core +class TestTaskWithProtocol: + """Tests for Task with TaskProtocol integration.""" + + def test_task_has_protocol_field(self): + """Task dataclass should have protocol field.""" + task = Task(query="Test", environment_data={}) + + assert hasattr(task, "protocol") + assert isinstance(task.protocol, TaskProtocol) + + def test_task_default_protocol(self): + """Task should have default protocol if not specified.""" + task = Task(query="Test") + + assert task.protocol.timeout_seconds is None + assert task.protocol.priority == 0 + + def test_task_custom_protocol(self): + """Task should accept custom protocol.""" + protocol = TaskProtocol( + timeout_seconds=30.0, + priority=5, + ) + task = Task(query="Test", protocol=protocol) + + assert task.protocol.timeout_seconds == 30.0 + assert task.protocol.priority == 5 + + def test_task_collection_preserves_protocol(self): + """TaskCollection should preserve protocol on tasks.""" + task1 = Task(query="Q1", protocol=TaskProtocol(priority=1)) + task2 = Task(query="Q2", protocol=TaskProtocol(priority=2)) + tasks = TaskCollection([task1, task2]) + + first_task: Task = tasks[0] # type: ignore[assignment] + second_task: Task = tasks[1] # type: ignore[assignment] + + assert first_task.protocol.priority == 1 + assert second_task.protocol.priority == 2 From d99d5a7a2220a0272762063c6ed84d28ee420285 Mon Sep 17 00:00:00 2001 From: cemde Date: Sat, 6 Dec 2025 00:44:00 +0000 Subject: [PATCH 05/25] removed task collection --- docs/getting-started/quickstart.md | 2 +- docs/reference/task.md | 4 +- .../five_a_day_benchmark.ipynb | 8 +- .../five_a_day_benchmark.py | 8 +- examples/introduction/tutorial.ipynb | 16 +- maseval/__init__.py | 21 +- maseval/benchmark/macs/data_loader.py | 16 +- maseval/core/benchmark.py | 163 +++++--- maseval/core/queue.py | 221 ----------- maseval/core/task.py | 348 ++++++++++++++++-- tests/conftest.py | 10 +- tests/test_benchmarks/test_macs/conftest.py | 6 +- .../test_macs/test_macs_integration.py | 10 +- .../test_automatic_registration.py | 6 +- .../test_benchmark_lifecycle.py | 44 +-- .../test_callback_orchestration.py | 14 +- .../test_benchmark/test_config_collection.py | 24 +- .../test_benchmark/test_execution_loop.py | 8 +- .../test_benchmark/test_parallel_execution.py | 41 +-- .../test_progress_bar_integration.py | 12 +- .../test_benchmark/test_trace_collection.py | 20 +- tests/test_core/test_evaluator.py | 14 +- tests/test_core/test_exceptions.py | 14 +- tests/test_core/test_queue.py | 342 +++++++++++------ tests/test_core/test_task_collection.py | 191 ---------- tests/test_core/test_task_protocol.py | 8 +- 26 files changed, 804 insertions(+), 767 deletions(-) delete mode 100644 maseval/core/queue.py delete mode 100644 tests/test_core/test_task_collection.py diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index d5ec921..b9c7669 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -117,7 +117,7 @@ Once implemented, run your benchmark: ```python # Define your tasks -tasks = TaskCollection([Task(query="...", expected="..."), ...]) +tasks = TaskQueue([Task(query="..."), ...]) # Configure your agents (e.g., model parameters, tool settings) agent_config = {"model": "gpt-4", "temperature": 0.7} diff --git a/docs/reference/task.md b/docs/reference/task.md index ee2ccd8..0997396 100644 --- a/docs/reference/task.md +++ b/docs/reference/task.md @@ -1,9 +1,9 @@ # Task -Tasks define individual benchmark scenarios including inputs, expected outputs, and any metadata needed for evaluation. TaskCollections group related tasks together. +Tasks define individual benchmark scenarios including inputs, expected outputs, and any metadata needed for evaluation. TaskQueues group related tasks together. [:material-github: View source](https://github.com/parameterlab/maseval/blob/main/maseval/core/task.py){ .md-source-file } ::: maseval.core.task.Task -::: maseval.core.task.TaskCollection +::: maseval.core.task.TaskQueue diff --git a/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb b/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb index 903aab9..a67e8f2 100644 --- a/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb +++ b/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb @@ -124,7 +124,7 @@ "from smolagents import ToolCallingAgent, LiteLLMModel, FinalAnswerTool\n", "\n", "# MASEval core components\n", - "from maseval import Benchmark, Environment, Task, TaskCollection, AgentAdapter, Evaluator, ModelAdapter\n", + "from maseval import Benchmark, Environment, Task, TaskQueue, AgentAdapter, Evaluator, ModelAdapter\n", "from maseval.interface.agents.smolagents import SmolAgentAdapter\n", "\n", "# Import evaluators module (dynamically loaded later)\n", @@ -139,7 +139,7 @@ " limit: int | None = None,\n", " seed: int | None = None,\n", " task_indices: list[int] | None = None,\n", - ") -> tuple[TaskCollection, list[Dict[str, Any]]]:\n", + ") -> tuple[TaskQueue, list[Dict[str, Any]]]:\n", " \"\"\"Load tasks and agent configurations.\n", "\n", " Args:\n", @@ -152,7 +152,7 @@ " task_indices: Optional list of task indices to load (e.g., [0, 2, 4])\n", "\n", " Returns:\n", - " Tuple of (TaskCollection, list of agent configs)\n", + " Tuple of (TaskQueue, list of agent configs)\n", " \"\"\"\n", " data_dir = Path(\"examples/five_a_day_benchmark/data\")\n", "\n", @@ -199,7 +199,7 @@ "\n", " configs_data.append(config)\n", "\n", - " return TaskCollection(tasks_data), configs_data" + " return TaskQueue(tasks_data), configs_data" ] }, { diff --git a/examples/five_a_day_benchmark/five_a_day_benchmark.py b/examples/five_a_day_benchmark/five_a_day_benchmark.py index 68a154b..bc0a986 100644 --- a/examples/five_a_day_benchmark/five_a_day_benchmark.py +++ b/examples/five_a_day_benchmark/five_a_day_benchmark.py @@ -26,7 +26,7 @@ from utils import derive_seed, sanitize_name # type: ignore[unresolved-import] -from maseval import Benchmark, Environment, Evaluator, Task, TaskCollection, AgentAdapter, ModelAdapter +from maseval import Benchmark, Environment, Evaluator, Task, TaskQueue, AgentAdapter, ModelAdapter from maseval.core.callbacks.result_logger import FileResultLogger # Import tool implementations @@ -825,7 +825,7 @@ def load_benchmark_data( limit: Optional[int] = None, specific_task: Optional[int] = None, seed: Optional[int] = None, -) -> tuple[TaskCollection, List[Dict[str, Any]]]: +) -> tuple[TaskQueue, List[Dict[str, Any]]]: """Load tasks and agent configurations with validation. Args: @@ -838,7 +838,7 @@ def load_benchmark_data( seed: Base random seed for reproducibility (None for non-deterministic) Returns: - Tuple of (TaskCollection, agent_configs_list) + Tuple of (TaskQueue, agent_configs_list) """ if limit is not None and specific_task is not None: raise ValueError("Cannot specify both limit and specific_task") @@ -896,7 +896,7 @@ def load_benchmark_data( print(f"Loaded {len(tasks_data)} tasks and {len(configs_data)} agent configs\n") - return TaskCollection(tasks_data), configs_data + return TaskQueue(tasks_data), configs_data # ============================================================================ diff --git a/examples/introduction/tutorial.ipynb b/examples/introduction/tutorial.ipynb index 367291f..afbc50d 100644 --- a/examples/introduction/tutorial.ipynb +++ b/examples/introduction/tutorial.ipynb @@ -330,7 +330,7 @@ "metadata": {}, "outputs": [], "source": [ - "from maseval import Benchmark, Environment, Evaluator, Task, TaskCollection\n", + "from maseval import Benchmark, Environment, Evaluator, Task, TaskQueue\n", "from maseval.interface.agents.smolagents import SmolAgentAdapter\n", "\n", "print(\"MASEval components imported successfully!\")" @@ -634,13 +634,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Create benchmark instance with agent configuration\n", - "agent_data = {\"model_id\": \"gemini/gemini-2.5-flash\", \"temperature\": 0.7}\n", - "\n", - "benchmark = SimpleBenchmark(agent_data=agent_data, progress_bar=False)\n", - "\n", - "# Create task collection\n", - "tasks = TaskCollection([task])\n", + "\"# Create benchmark instance with agent configuration\\n\",\n", + " \"agent_data = {\\\"model_id\\\": \\\"gemini/gemini-2.5-flash\\\", \\\"temperature\\\": 0.7}\\n\",\n", + " \"\\n\",\n", + " \"benchmark = SimpleBenchmark(agent_data=agent_data, progress_bar=False)\\n\",\n", + " \"\\n\",\n", + " \"# Create task queue\\n\",\n", + " \"tasks = TaskQueue([task])\\n\",\n", "\n", "# Run the benchmark\n", "print(\"Running benchmark...\\n\")\n", diff --git a/maseval/__init__.py b/maseval/__init__.py index 1f4e831..b3cf788 100644 --- a/maseval/__init__.py +++ b/maseval/__init__.py @@ -8,7 +8,17 @@ Benchmarks sit in the `maseval.benchmark` submodule. """ -from .core.task import Task, TaskCollection, TaskProtocol, TimeoutAction +from .core.task import ( + Task, + TaskProtocol, + TimeoutAction, + # Task queue classes + BaseTaskQueue, + TaskQueue, + SequentialTaskQueue, + PriorityTaskQueue, + AdaptiveTaskQueue, +) from .core.environment import Environment from .core.agent import AgentAdapter from .core.benchmark import Benchmark, TaskExecutionStatus @@ -29,7 +39,6 @@ from .core.tracing import TraceableMixin from .core.registry import ComponentRegistry from .core.context import TaskContext -from .core.queue import TaskQueue, SequentialQueue, PriorityQueue, AdaptiveQueue from .core.exceptions import ( MASEvalError, AgentError, @@ -45,7 +54,6 @@ __all__ = [ # Tasks "Task", - "TaskCollection", "TaskProtocol", "TimeoutAction", # Core abstractions @@ -79,10 +87,11 @@ "ComponentRegistry", "TaskContext", # Task queues + "BaseTaskQueue", "TaskQueue", - "SequentialQueue", - "PriorityQueue", - "AdaptiveQueue", + "SequentialTaskQueue", + "PriorityTaskQueue", + "AdaptiveTaskQueue", # Exceptions and validation "MASEvalError", "AgentError", diff --git a/maseval/benchmark/macs/data_loader.py b/maseval/benchmark/macs/data_loader.py index fa1d418..94ec1bb 100644 --- a/maseval/benchmark/macs/data_loader.py +++ b/maseval/benchmark/macs/data_loader.py @@ -15,7 +15,7 @@ from urllib.error import HTTPError, URLError from urllib.request import urlopen -from maseval import Task, TaskCollection +from maseval import Task, TaskQueue # ============================================================================= @@ -422,7 +422,7 @@ def load_tasks( domain: str, data_dir: Optional[Path] = None, limit: Optional[int] = None, -) -> TaskCollection: +) -> TaskQueue: """Load tasks for a MACS domain. Args: @@ -432,7 +432,7 @@ def load_tasks( limit: Maximum number of tasks to load Returns: - TaskCollection containing Task objects + TaskQueue containing Task objects Raises: ValueError: If domain is not valid @@ -465,7 +465,7 @@ def load_tasks( ) ) - return TaskCollection(tasks) + return TaskQueue(tasks) def load_agent_config( @@ -503,12 +503,12 @@ def load_agent_config( def configure_model_ids( - tasks: Union[TaskCollection, List[Task]], + tasks: Union[TaskQueue, List[Task]], *, tool_model_id: Optional[str] = None, user_model_id: Optional[str] = None, evaluator_model_id: Optional[str] = None, -) -> Union[TaskCollection, List[Task]]: +) -> Union[TaskQueue, List[Task]]: """Configure model IDs for benchmark components in task data. This helper merges runtime model configuration into task data structures, @@ -519,13 +519,13 @@ def configure_model_ids( task-specific overrides in the original data to take precedence. Args: - tasks: TaskCollection or list of Tasks to configure + tasks: TaskQueue or list of Tasks to configure tool_model_id: Model ID for tool simulators (stored in environment_data) user_model_id: Model ID for user simulator (stored in user_data) evaluator_model_id: Model ID for evaluators (stored in evaluation_data) Returns: - The same collection (mutated in place for convenience) + The same queue or list (mutated in place for convenience) Example: ```python diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index 3780f99..289a6e2 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -1,14 +1,13 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Iterable, Optional, Sequence, Tuple, Union, cast -from datetime import datetime import threading -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ThreadPoolExecutor from enum import Enum import warnings import traceback from .evaluator import Evaluator -from .task import Task, TaskCollection +from .task import Task, BaseTaskQueue, SequentialTaskQueue from .environment import Environment from .agent import AgentAdapter from .model import ModelAdapter @@ -16,9 +15,7 @@ from .callback import BenchmarkCallback from .user import User from .tracing import TraceableMixin -from .config import ConfigurableMixin from .registry import ComponentRegistry -from .queue import TaskQueue, SequentialQueue from .context import TaskContext from .utils.system_info import gather_benchmark_config from .callbacks.progress_bar import ( @@ -216,8 +213,8 @@ def __init__( # Store agent_data as-is (will be normalized in run()) self.agent_data = agent_data - # Initialize tasks to empty collection (will be set in run()) - self.tasks = TaskCollection([]) + # Initialize tasks to empty queue (will be set in run()) + self.tasks: BaseTaskQueue = SequentialTaskQueue([]) self.callback_handler = CallbackHandler() self.callbacks = callbacks or [] @@ -1148,14 +1145,18 @@ def _execute_task_repetition( def _run_sequential( self, - queue: TaskQueue, + queue: BaseTaskQueue, + agent_data_lookup: Dict[str, Dict[str, Any]], ) -> None: """Execute tasks sequentially with optional timeout support. Args: queue: Task queue providing task ordering. + agent_data_lookup: Mapping from task_id to agent_data configuration. """ - for task, agent_data in queue: + for task in queue: + agent_data = agent_data_lookup[str(task.id)] + # Callbacks at the start of each task self._invoke_callbacks("on_task_start", self, task) @@ -1181,23 +1182,26 @@ def _run_sequential( def _run_parallel( self, - queue: TaskQueue, + queue: BaseTaskQueue, + agent_data_lookup: Dict[str, Dict[str, Any]], max_workers: int, ) -> None: """Execute tasks in parallel with thread pool. Args: queue: Task queue providing task ordering. + agent_data_lookup: Mapping from task_id to agent_data configuration. max_workers: Maximum number of concurrent workers. """ with ThreadPoolExecutor(max_workers=max_workers) as executor: futures: Dict[Any, Tuple[Task, int]] = {} task_repeat_counts: Dict[str, int] = {} # Track submitted repeats per task - def submit_task_repeats(task: Task, agent_data: Dict[str, Any]) -> None: + def submit_task_repeats(task: Task) -> None: """Submit all repeats for a task.""" task_id = str(task.id) task_repeat_counts[task_id] = 0 + agent_data = agent_data_lookup[task_id] self._invoke_callbacks("on_task_start", self, task) @@ -1215,8 +1219,8 @@ def submit_task_repeats(task: Task, agent_data: Dict[str, Any]) -> None: # Submit initial batch from queue submitted_tasks: List[Task] = [] - for task, agent_data in queue: - submit_task_repeats(task, agent_data) + for task in queue: + submit_task_repeats(task) submitted_tasks.append(task) # Limit initial submission to avoid over-committing @@ -1226,7 +1230,7 @@ def submit_task_repeats(task: Task, agent_data: Dict[str, Any]) -> None: # Process completions completed_task_ids: set = set() queue_iter = iter(queue) - queue_exhausted = len(submitted_tasks) >= len(list(queue)) # Approximate check + queue_exhausted = len(submitted_tasks) >= len(queue) while futures: # Wait for at least one completion @@ -1285,24 +1289,28 @@ def submit_task_repeats(task: Task, agent_data: Dict[str, Any]) -> None: # Submit more work if queue not exhausted if not queue_exhausted and len(futures) < max_workers: try: - task, agent_data = next(queue_iter) - submit_task_repeats(task, agent_data) + task = next(queue_iter) + submit_task_repeats(task) submitted_tasks.append(task) except StopIteration: queue_exhausted = True def run( self, - tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]], - queue: Optional[TaskQueue] = None, + tasks: Union[Task, BaseTaskQueue, Iterable[Union[Task, dict]]], max_workers: int = 1, ) -> List[Dict[str, Any]]: """Initialize and execute the complete benchmark loop across all tasks. Args: - tasks: Collection of tasks to execute. Can be a single Task, TaskCollection, - list of Task objects, or list of dicts that will be converted to Tasks. - queue: Optional task queue for custom scheduling. If None, uses SequentialQueue. + tasks: Task source for execution. Can be: + - A single Task object + - A BaseTaskQueue (SequentialTaskQueue, PriorityTaskQueue, or custom AdaptiveTaskQueue) + - An iterable of Task objects or dicts that will be converted to Tasks + + When a BaseTaskQueue is provided, it controls the task ordering. AdaptiveTaskQueue + subclasses are automatically registered as callbacks to receive task completion + notifications. max_workers: Maximum number of parallel task executions. Default 1 (sequential). Set higher for I/O-bound workloads (e.g., LLM API calls). @@ -1385,62 +1393,95 @@ def run( # Parallel execution with 4 workers reports = benchmark.run(tasks=tasks, max_workers=4) - # Custom queue for priority-based execution - from maseval.core.queue import PriorityQueue - queue = PriorityQueue(tasks, agent_data_list) - reports = benchmark.run(tasks=tasks, queue=queue) + # Priority-based execution + from maseval.core.task import PriorityTaskQueue + for task in tasks: + task.protocol.priority = compute_priority(task) + queue = PriorityTaskQueue(tasks) + reports = benchmark.run(tasks=queue) + + # Adaptive queue (auto-registered as callback) + queue = MyAdaptiveTaskQueue(tasks) + reports = benchmark.run(tasks=queue) # queue receives on_task_complete callbacks ``` """ - # Normalize tasks into a TaskCollection + # Normalize tasks into a queue + queue: BaseTaskQueue if isinstance(tasks, Task): - # Single task - self.tasks = TaskCollection([tasks]) - elif isinstance(tasks, TaskCollection): - self.tasks = tasks + # Single task - wrap in SequentialTaskQueue + queue = SequentialTaskQueue([tasks]) + elif isinstance(tasks, BaseTaskQueue): + # Already a queue - use directly + queue = tasks else: - # Iterable of tasks or dicts - self.tasks = TaskCollection.from_list(list(tasks)) + # Iterable of tasks or dicts - wrap in SequentialTaskQueue + queue = SequentialTaskQueue.from_list(list(tasks)) - # Normalize agent_data into a list matching the number of tasks - if isinstance(self.agent_data, dict): - # Single config for all tasks - agent_data_list: List[Dict[str, Any]] = [cast(Dict[str, Any], self.agent_data) for _ in range(len(self.tasks))] - else: - # Task-specific configs - agent_data_list = list(self.agent_data) + # Store tasks reference for get_failed_tasks() compatibility + self.tasks = queue - if len(agent_data_list) != len(self.tasks): - raise ValueError( - f"`agent_data` must either be a single dict or an iterable matching the number of tasks. " - f"Got {len(agent_data_list)} agent configs for {len(self.tasks)} tasks." - ) + # Build agent_data lookup (task_id -> agent_data) + agent_data_lookup = self._build_agent_data_lookup(queue) # Clear reports from previous run() calls to prevent accumulation self.reports = [] - # Create queue if not provided - if queue is None: - queue = SequentialQueue(self.tasks, agent_data_list) + # Auto-register queue as callback if it's a BenchmarkCallback (e.g., AdaptiveTaskQueue) + queue_was_added_as_callback = False + if isinstance(queue, BenchmarkCallback) and queue not in self.callbacks: + self.callbacks.append(queue) + queue_was_added_as_callback = True - # Callbacks at the start of the run - self._invoke_callbacks("on_run_start", self) + try: + # Callbacks at the start of the run + self._invoke_callbacks("on_run_start", self) + + # Execute based on max_workers + if max_workers == 1: + self._run_sequential(queue, agent_data_lookup) + else: + self._run_parallel(queue, agent_data_lookup, max_workers) + + # Callbacks at the end of the run + self._invoke_callbacks("on_run_end", self, self.reports) + finally: + # Remove queue from callbacks if we added it + if queue_was_added_as_callback: + self.callbacks.remove(queue) - # Execute based on max_workers - if max_workers == 1: - self._run_sequential(queue) - else: - self._run_parallel(queue, max_workers) + return self.reports - # Callbacks at the end of the run - self._invoke_callbacks("on_run_end", self, self.reports) + def _build_agent_data_lookup(self, tasks: BaseTaskQueue) -> Dict[str, Dict[str, Any]]: + """Build a mapping from task_id to agent_data configuration. - return self.reports + Args: + tasks: The task queue containing all tasks. + + Returns: + Dict mapping task_id (string) to agent_data configuration. + + Raises: + ValueError: If agent_data is a list but doesn't match the number of tasks. + """ + if isinstance(self.agent_data, dict): + # Single config - replicate for all tasks + return {str(task.id): cast(Dict[str, Any], self.agent_data) for task in tasks} + + # List of configs - pair by position + agent_data_list = list(self.agent_data) + if len(agent_data_list) != len(tasks): + raise ValueError( + f"`agent_data` must either be a single dict or an iterable matching the number of tasks. " + f"Got {len(agent_data_list)} agent configs for {len(tasks)} tasks." + ) + + return {str(task.id): agent_data_list[i] for i, task in enumerate(tasks)} def get_failed_tasks( self, status_filter: Optional[Union[TaskExecutionStatus, List[TaskExecutionStatus]]] = None, reports: Optional[List[Dict[str, Any]]] = None, - ) -> TaskCollection: + ) -> SequentialTaskQueue: """Get tasks that failed during benchmark execution. This method retrieves failed tasks based on their execution status, useful for @@ -1458,7 +1499,7 @@ def get_failed_tasks( run() call. This allows analyzing externally stored or modified reports. Returns: - TaskCollection containing the failed tasks. Empty if no failures match the filter. + SequentialTaskQueue containing the failed tasks. Empty if no failures match the filter. Raises: RuntimeError: If reports is None and run() has not been executed yet. @@ -1531,6 +1572,6 @@ def get_failed_tasks( if report["status"] in filter_values: failed_task_ids.add(report["task_id"]) - # Build TaskCollection from original tasks that failed + # Build queue from original tasks that failed failed_tasks = [task for task in self.tasks if str(task.id) in failed_task_ids] - return TaskCollection(failed_tasks) + return SequentialTaskQueue(failed_tasks) diff --git a/maseval/core/queue.py b/maseval/core/queue.py deleted file mode 100644 index 2c415e2..0000000 --- a/maseval/core/queue.py +++ /dev/null @@ -1,221 +0,0 @@ -"""Task queue abstraction for flexible task scheduling. - -This module provides the TaskQueue abstract base class and concrete implementations -for different task scheduling strategies. The queue abstraction replaces the static -`for task in tasks` loop with a dynamic scheduling system that enables: - -1. Dynamic task ordering -2. Callback-driven scheduling (adaptive testing) -3. Priority-based execution -4. Conditional task skipping -""" - -from abc import ABC, abstractmethod -from typing import Any, Dict, Iterator, List, Optional, Tuple - -from .task import Task, TaskCollection - - -class TaskQueue(ABC): - """Abstract base for task scheduling strategies. - - TaskQueue provides an iterator interface for task execution with hooks - for adaptive behavior based on task results. Concrete implementations - can reorder tasks, skip tasks, or terminate early based on execution - outcomes. - - The queue yields (Task, agent_data) tuples for execution. After each - task completes, `on_task_complete()` is called with the result, allowing - the queue to adapt its scheduling strategy. - - Usage: - queue = SequentialQueue(tasks, agent_data_list) - - for task, agent_data in queue: - report = execute_task(task, agent_data) - queue.on_task_complete(task, report) - - if not queue.should_continue(): - break - """ - - @abstractmethod - def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: - """Yield (task, agent_data) pairs in execution order. - - Returns: - Iterator yielding tuples of (Task, agent_data dict). - """ - pass - - def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: - """Called after each task completes. - - Override this method for adaptive scheduling behavior that responds - to task execution results (e.g., updating ability estimates, adjusting - priorities, or marking related tasks for skipping). - - Args: - task: The task that just completed. - report: The execution report containing status, traces, and eval results. - """ - pass - - def should_continue(self) -> bool: - """Whether to continue processing tasks. - - Default implementation returns True until the queue is exhausted. - Override for early termination conditions (e.g., confidence threshold - reached, maximum tasks processed, or error limit exceeded). - - Returns: - True to continue processing, False to stop. - """ - return True - - -class SequentialQueue(TaskQueue): - """Execute tasks in their original order (default behavior). - - This queue maintains the current sequential execution model, processing - tasks in the order they appear in the task collection. It's the default - queue used when no explicit queue is provided. - - Attributes: - tasks: List of (Task, agent_data) pairs. - """ - - def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict[str, Any]]): - """Initialize sequential queue. - - Args: - tasks: Collection of tasks to execute. - agent_data_list: List of agent configuration dicts, one per task. - """ - self._tasks: List[Tuple[Task, Dict[str, Any]]] = list(zip(tasks, agent_data_list)) - self._index = 0 - - def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: - """Yield tasks in original order.""" - for task, agent_data in self._tasks: - yield task, agent_data - - -class PriorityQueue(TaskQueue): - """Execute tasks by priority (from TaskProtocol.priority). - - Tasks with higher priority values are executed first. Tasks with equal - priority maintain their relative order from the original collection. - - This queue is useful when some tasks are more important or time-sensitive - than others, or when you want to process easier tasks first to get quick - feedback. - - Attributes: - tasks: List of (Task, agent_data) pairs sorted by priority. - """ - - def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict[str, Any]]): - """Initialize priority queue. - - Args: - tasks: Collection of tasks to execute. - agent_data_list: List of agent configuration dicts, one per task. - """ - paired = list(zip(tasks, agent_data_list)) - # Sort by priority descending (higher priority first) - # Use enumerate to maintain stable sort for equal priorities - self._tasks: List[Tuple[Task, Dict[str, Any]]] = sorted(paired, key=lambda x: x[0].protocol.priority, reverse=True) - - def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: - """Yield tasks in priority order.""" - for task, agent_data in self._tasks: - yield task, agent_data - - -class AdaptiveQueue(TaskQueue): - """Base class for adaptive task scheduling. - - Adaptive queues adjust task order based on execution results. This is - useful for techniques like Item Response Theory (IRT) based testing, - where task selection optimizes for information gain about agent ability. - - Subclasses should override `_select_next_task()` to implement their - selection algorithm, and `_update_state()` to update internal state - after each task completion. - - Attributes: - pending: Tasks not yet executed. - completed: Tasks that have been executed with their reports. - """ - - def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict[str, Any]]): - """Initialize adaptive queue. - - Args: - tasks: Collection of tasks to execute. - agent_data_list: List of agent configuration dicts, one per task. - """ - self._pending: List[Tuple[Task, Dict[str, Any]]] = list(zip(tasks, agent_data_list)) - self._completed: List[Tuple[Task, Dict[str, Any]]] = [] - self._stop_flag = False - - def __iter__(self) -> Iterator[Tuple[Task, Dict[str, Any]]]: - """Yield tasks selected by the adaptive algorithm.""" - while self._pending and not self._stop_flag: - next_item = self._select_next_task() - if next_item is not None: - yield next_item - else: - break - - def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: - """Update state based on task result. - - Args: - task: The task that just completed. - report: The execution report. - """ - # Find and move task from pending to completed - for i, (t, agent_data) in enumerate(self._pending): - if t.id == task.id: - self._completed.append(self._pending.pop(i)) - break - - # Update adaptive state - self._update_state(task, report) - - def should_continue(self) -> bool: - """Check if we should continue based on stopping criteria.""" - return not self._stop_flag and len(self._pending) > 0 - - def stop(self) -> None: - """Signal that no more tasks should be processed.""" - self._stop_flag = True - - def _select_next_task(self) -> Optional[Tuple[Task, Dict[str, Any]]]: - """Select the next task to execute. - - Override this method to implement custom selection algorithms - (e.g., IRT-based selection, uncertainty sampling, etc.). - - Default implementation returns tasks in order (first remaining task). - - Returns: - The next (Task, agent_data) pair, or None if no suitable task. - """ - if not self._pending: - return None - return self._pending[0] - - def _update_state(self, task: Task, report: Dict[str, Any]) -> None: - """Update internal state after task completion. - - Override this method to update ability estimates, difficulty models, - or other state used by `_select_next_task()`. - - Args: - task: The task that just completed. - report: The execution report containing status and eval results. - """ - pass diff --git a/maseval/core/task.py b/maseval/core/task.py index 31584aa..8648fa9 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -1,5 +1,6 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Tuple from uuid import UUID, uuid4 from collections.abc import Sequence from typing import Iterable, List, Union, Iterator, Optional @@ -61,61 +62,142 @@ class Task: protocol: TaskProtocol = field(default_factory=TaskProtocol) -class TaskCollection(Sequence): - """A lightweight, sequence-like container for `Task` objects. +# ============================================================================= +# Task Queue Classes +# ============================================================================= - Usage: - - Construct from an iterable of `Task` or dicts: `TaskCollection.from_list(data)` - - Load from an examples-style JSON: `TaskCollection.from_json_file("examples/data.json")` - The collection is immutable from the Sequence API perspective (supports indexing and slicing), - but provides `append`/`extend` helpers for convenience when building programmatically. +class BaseTaskQueue(ABC, Sequence): + """Abstract base class for task scheduling strategies. + + BaseTaskQueue provides a sequence-like interface for task execution with hooks + for adaptive behavior based on task results. Concrete implementations can + reorder tasks, skip tasks, or terminate early based on execution outcomes. + + The queue yields Task objects for execution. After each task completes, + ``on_task_complete()`` is called with the result, allowing the queue to + adapt its scheduling strategy. + + Subclasses must implement ``__iter__`` to define the iteration order. + + Attributes: + _tasks: Internal list of tasks. + + Example: + ```python + queue = SequentialTaskQueue(tasks) + + for task in queue: + report = execute_task(task) + queue.on_task_complete(task, report) + + if not queue.should_continue(): + break + ``` """ - def __init__(self, tasks: Optional[Iterable[Task]] = None) -> None: - """Initialize the TaskCollection. + def __init__(self, tasks: Iterable[Task]) -> None: + """Initialize the task queue. Args: - tasks: An optional iterable of `Task` objects to initialize the collection. + tasks: An iterable of Task objects to schedule. """ - # TODO for any element in the iterable that is not a Task, convert it to a Task - self._tasks: List[Task] = list(tasks) if tasks is not None else [] + self._tasks: List[Task] = list(tasks) def __len__(self) -> int: + """Return the total number of tasks in the queue.""" return len(self._tasks) - def __getitem__(self, idx): - # Return a Task for int index, or a new TaskCollection for slices (pythonic behaviour) + def __getitem__(self, idx: Union[int, slice]) -> Union[Task, "BaseTaskQueue"]: + """Get a task by index or a slice of tasks. + + Args: + idx: Integer index or slice object. + + Returns: + A single Task for integer index, or a new queue instance for slices. + """ if isinstance(idx, slice): - return TaskCollection(self._tasks[idx]) + # Return a new instance of the same type with sliced tasks + return self.__class__(self._tasks[idx]) return self._tasks[idx] + @abstractmethod def __iter__(self) -> Iterator[Task]: - return iter(self._tasks) + """Yield tasks in the scheduled execution order. + + Returns: + Iterator yielding Task objects. + """ + pass + + def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: + """Called after each task completes. + + Override this method for adaptive scheduling behavior that responds + to task execution results (e.g., updating ability estimates, adjusting + priorities, or marking related tasks for skipping). + + Args: + task: The task that just completed. + report: The execution report containing status, traces, and eval results. + """ + pass + + def should_continue(self) -> bool: + """Whether to continue processing tasks. - def __repr__(self) -> str: # pragma: no cover - trivial - return f"TaskCollection({len(self._tasks)} tasks)" + Default implementation returns True. Override for early termination + conditions (e.g., confidence threshold reached, maximum tasks processed, + or error limit exceeded). + + Returns: + True to continue processing, False to stop. + """ + return True - # Convenience mutators def append(self, task: Task) -> None: + """Add a task to the end of the queue. + + Args: + task: The task to append. + """ self._tasks.append(task) def extend(self, tasks: Iterable[Task]) -> None: + """Add multiple tasks to the end of the queue. + + Args: + tasks: An iterable of tasks to append. + """ self._tasks.extend(tasks) def to_list(self) -> List[Task]: + """Return a copy of the internal task list. + + Returns: + List of all tasks in the queue. + """ return list(self._tasks) - # Factories @classmethod - def from_list(cls, data: Iterable[Union[Task, dict]]) -> "TaskCollection": + def from_list(cls, data: Iterable[Union[Task, dict]]) -> "BaseTaskQueue": + """Create a queue from an iterable of Tasks or dicts. + + Args: + data: An iterable of Task objects or dicts that can be converted to Tasks. + + Returns: + A new queue instance containing the tasks. + + Raises: + TypeError: If an item is neither a Task nor a dict. + """ tasks: List[Task] = [] for item in data: if isinstance(item, Task): tasks.append(item) elif isinstance(item, dict): - # Expect a dict that can be turned into a Task - # Accept both full Task kwargs or the lightweight example format if "query" in item: query = item["query"] tasks.append( @@ -127,7 +209,6 @@ def from_list(cls, data: Iterable[Union[Task, dict]]) -> "TaskCollection": ) ) else: - # Attempt to map common example keys query = item.get("question") or item.get("prompt") or item.get("query") or "" environment_data = ( item.get("environment_data") or {"text_content": item.get("text")} @@ -148,29 +229,234 @@ def from_list(cls, data: Iterable[Union[Task, dict]]) -> "TaskCollection": ) ) else: - raise TypeError("TaskCollection.from_list expects Task or dict entries") + raise TypeError(f"{cls.__name__}.from_list expects Task or dict entries") return cls(tasks) @classmethod - def from_json_file(cls, path: Union[str, Path], *, limit: Optional[int] = None) -> "TaskCollection": + def from_json_file(cls, path: Union[str, Path], *, limit: Optional[int] = None) -> "BaseTaskQueue": """Load tasks from a JSON file. - This helper understands the example file format used in `examples/data.json` where the - top-level object has a `data` list and optional `metadata`. + This helper understands the example file format used in ``examples/data.json`` + where the top-level object has a ``data`` list and optional ``metadata``. Args: path: Path to the JSON file. limit: Optional limit to the number of tasks to load. + + Returns: + A new queue instance containing the loaded tasks. """ p = Path(path) with p.open("r", encoding="utf-8") as fh: payload = json.load(fh) - # Support both the wrapped `{ "data": [...] }` format and a plain list items = payload.get("data") if isinstance(payload, dict) and "data" in payload else payload if limit is not None: items = items[:limit] - # Convert each item to a Task using the same heuristics as from_list return cls.from_list(items) + + +class SequentialTaskQueue(BaseTaskQueue): + """Execute tasks in their original order. + + This queue maintains the current sequential execution model, processing + tasks in the order they appear in the input iterable. It's the default + queue used when no explicit queue is provided. + + Example: + ```python + queue = SequentialTaskQueue(tasks) + for task in queue: + result = execute(task) + ``` + """ + + def __iter__(self) -> Iterator[Task]: + """Yield tasks in original order.""" + return iter(self._tasks) + + +class PriorityTaskQueue(BaseTaskQueue): + """Execute tasks ordered by priority. + + Tasks are sorted by ``task.protocol.priority`` at construction time. + Higher priority values are executed first by default. Tasks with equal + priority maintain their relative order from the original input (stable sort). + + This queue uses ``task.protocol.priority`` as the sole source of priority. + Pre-compute priority values and assign them to tasks before creating the queue. + + Args: + tasks: An iterable of Task objects to schedule. + reverse: If True (default), higher priority values execute first. + If False, lower priority values execute first. + + Example: + ```python + # Assign priorities based on your criteria + for task in tasks: + task.protocol.priority = compute_priority(task) + + # Create queue (higher priority first) + queue = PriorityTaskQueue(tasks) + + # Or lower priority first + queue = PriorityTaskQueue(tasks, reverse=False) + ``` + """ + + def __init__(self, tasks: Iterable[Task], reverse: bool = True) -> None: + """Initialize priority queue with sorted tasks. + + Args: + tasks: An iterable of Task objects to schedule. + reverse: If True (default), higher priority values execute first. + """ + task_list = list(tasks) + # Stable sort by priority + sorted_tasks = sorted(task_list, key=lambda t: t.protocol.priority, reverse=reverse) + super().__init__(sorted_tasks) + + def __iter__(self) -> Iterator[Task]: + """Yield tasks in priority order.""" + return iter(self._tasks) + + +class AdaptiveTaskQueue(BaseTaskQueue, ABC): + """Abstract base class for adaptive task scheduling. + + AdaptiveTaskQueue enables dynamic task ordering based on execution results. + It integrates with the benchmark callback system to receive notifications + after each task completes, allowing the queue to update internal state and + adjust the execution order. + + Subclasses must implement: + - ``_select_next_task()``: Choose the next task to execute + - ``_update_state()``: Update internal model after task completion + + The queue maintains: + - ``_remaining``: Tasks not yet executed + - ``_completed``: Completed tasks paired with their reports + - ``_stop_flag``: Flag to signal early termination + + When used with ``Benchmark.run()``, the queue is automatically registered + as a callback if it implements the ``BenchmarkCallback`` interface. + + Example: + ```python + class IRTTaskQueue(AdaptiveTaskQueue): + '''Item Response Theory-based adaptive testing.''' + + def __init__(self, tasks: Iterable[Task]): + super().__init__(tasks) + self._ability_estimate = 0.0 + + def _select_next_task(self) -> Optional[Task]: + if not self._remaining: + return None + # Select task with difficulty closest to current ability estimate + return min( + self._remaining, + key=lambda t: abs(t.protocol.priority - self._ability_estimate) + ) + + def _update_state(self, task: Task, report: Dict[str, Any]) -> None: + # Update ability estimate based on task result + correct = report.get("eval", [{}])[0].get("correct", False) + difficulty = task.protocol.priority + self._ability_estimate += 0.5 if correct else -0.5 + + queue = IRTTaskQueue(tasks) + results = benchmark.run(queue) # Auto-registered as callback + ``` + """ + + def __init__(self, tasks: Iterable[Task]) -> None: + """Initialize adaptive queue. + + Args: + tasks: An iterable of Task objects to schedule. + """ + super().__init__(tasks) + self._remaining: List[Task] = list(self._tasks) + self._completed: List[Tuple[Task, Dict[str, Any]]] = [] + self._stop_flag: bool = False + + def __iter__(self) -> Iterator[Task]: + """Yield tasks selected by the adaptive algorithm. + + Continues until ``_select_next_task()`` returns None, ``_remaining`` + is empty, or ``_stop_flag`` is set. + """ + while self._remaining and not self._stop_flag: + next_task = self._select_next_task() + if next_task is not None: + yield next_task + else: + break + + def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: + """Update state based on task result. + + Moves the task from ``_remaining`` to ``_completed`` and calls + ``_update_state()`` to let the subclass update its internal model. + + Args: + task: The task that just completed. + report: The execution report. + """ + # Find and move task from remaining to completed + for i, t in enumerate(self._remaining): + if t.id == task.id: + self._completed.append((self._remaining.pop(i), report)) + break + + # Let subclass update its state + self._update_state(task, report) + + def should_continue(self) -> bool: + """Check if we should continue based on stopping criteria. + + Returns: + True if stop flag is not set and tasks remain, False otherwise. + """ + return not self._stop_flag and len(self._remaining) > 0 + + def stop(self) -> None: + """Signal that no more tasks should be processed. + + Call this from ``_update_state()`` or ``_select_next_task()`` to + trigger early termination (e.g., when confidence threshold is reached). + """ + self._stop_flag = True + + @abstractmethod + def _select_next_task(self) -> Optional[Task]: + """Select the next task to execute. + + Implement this method to define your adaptive selection algorithm + (e.g., IRT-based selection, uncertainty sampling, bandit algorithms). + + Returns: + The next Task to execute, or None if no suitable task is available. + """ + pass + + @abstractmethod + def _update_state(self, task: Task, report: Dict[str, Any]) -> None: + """Update internal state after task completion. + + Implement this method to update ability estimates, difficulty models, + or other state used by ``_select_next_task()``. + + Args: + task: The task that just completed. + report: The execution report containing status and eval results. + """ + pass + + +# Alias for the default queue type +TaskQueue = SequentialTaskQueue diff --git a/tests/conftest.py b/tests/conftest.py index 75b147c..090a675 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ Environment, User, Task, - TaskCollection, + TaskQueue, Evaluator, MessageHistory, ) @@ -306,9 +306,9 @@ def dummy_task(): @pytest.fixture -def dummy_task_collection(): +def dummy_task_queue(): """Create a collection of dummy tasks.""" - return TaskCollection.from_list( + return TaskQueue.from_list( [ {"query": "Query 1", "environment_data": {"task": 1}}, {"query": "Query 2", "environment_data": {"task": 2}}, @@ -318,14 +318,14 @@ def dummy_task_collection(): @pytest.fixture -def simple_benchmark(dummy_task_collection): +def simple_benchmark(dummy_task_queue): """Create a simple benchmark instance with tasks. Returns: tuple: (benchmark, tasks) - Call as benchmark.run(tasks) """ benchmark = DummyBenchmark(agent_data={"model": "test"}) - return benchmark, dummy_task_collection + return benchmark, dummy_task_queue @pytest.fixture diff --git a/tests/test_benchmarks/test_macs/conftest.py b/tests/test_benchmarks/test_macs/conftest.py index 1b0ff35..8a061ad 100644 --- a/tests/test_benchmarks/test_macs/conftest.py +++ b/tests/test_benchmarks/test_macs/conftest.py @@ -23,7 +23,7 @@ from unittest.mock import MagicMock from conftest import DummyModelAdapter -from maseval import AgentAdapter, Task, User, MessageHistory, TaskCollection +from maseval import AgentAdapter, Task, User, MessageHistory, TaskQueue from maseval.benchmark.macs import MACSBenchmark, MACSEnvironment @@ -415,9 +415,9 @@ def travel_task(): @pytest.fixture -def macs_task_collection(sample_task, travel_task): +def macs_task_queue(sample_task, travel_task): """Collection of MACS tasks for benchmark.run() tests.""" - return TaskCollection.from_list([sample_task, travel_task]) + return TaskQueue.from_list([sample_task, travel_task]) # ============================================================================= diff --git a/tests/test_benchmarks/test_macs/test_macs_integration.py b/tests/test_benchmarks/test_macs/test_macs_integration.py index 961b376..3bb38a2 100644 --- a/tests/test_benchmarks/test_macs/test_macs_integration.py +++ b/tests/test_benchmarks/test_macs/test_macs_integration.py @@ -182,7 +182,7 @@ def test_environment_handles_empty_tool_specs(self, macs_model_factory): @pytest.mark.benchmark class TestEndToEndPipeline: - """End-to-end tests that call benchmark.run() with TaskCollection. + """End-to-end tests that call benchmark.run() with TaskQueue. These tests verify the complete MACS benchmark pipeline by actually calling benchmark.run(). More granular integration tests are in @@ -212,8 +212,8 @@ def test_run_single_task_complete_pipeline(self, sample_agent_data, travel_task) assert "config" in report assert "eval" in report - def test_run_multiple_tasks(self, sample_agent_data, macs_task_collection): - """Run benchmark with multiple tasks via TaskCollection.""" + def test_run_multiple_tasks(self, sample_agent_data, macs_task_queue): + """Run benchmark with multiple tasks via TaskQueue.""" model = DummyModelAdapter( responses=[ '{"text": "User response", "details": {}}', @@ -223,9 +223,9 @@ def test_run_multiple_tasks(self, sample_agent_data, macs_task_collection): ) benchmark = ConcreteMACSBenchmark(sample_agent_data, model) - reports = benchmark.run(macs_task_collection) + reports = benchmark.run(macs_task_queue) - assert len(reports) == len(macs_task_collection) + assert len(reports) == len(macs_task_queue) for report in reports: assert report["status"] == "success" assert "eval" in report diff --git a/tests/test_core/test_benchmark/test_automatic_registration.py b/tests/test_core/test_benchmark/test_automatic_registration.py index 3f96a77..486b973 100644 --- a/tests/test_core/test_benchmark/test_automatic_registration.py +++ b/tests/test_core/test_benchmark/test_automatic_registration.py @@ -8,7 +8,7 @@ """ import pytest -from maseval import TaskCollection, TraceableMixin +from maseval import TaskQueue, TraceableMixin from conftest import DummyBenchmark, DummyModelAdapter, DummyAgentAdapter, DummyAgent @@ -20,7 +20,7 @@ def test_automatic_agent_registration(): returned from setup_agents() for trace and config collection without requiring manual register() calls. """ - tasks = TaskCollection.from_list([{"query": "test", "id": "1", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "test", "id": "1", "environment_data": {}}]) agent_data = {} benchmark = DummyBenchmark(agent_data=agent_data) @@ -148,7 +148,7 @@ def test_registry_cleared_after_repetition(): Verifies that after each task iteration completes, the registry is reset to allow fresh components for the next iteration while preserving reports. """ - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "test1", "id": "1", "environment_data": {}}, {"query": "test2", "id": "2", "environment_data": {}}, diff --git a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py index e35a96d..f7b6536 100644 --- a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py +++ b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py @@ -6,7 +6,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -21,7 +21,7 @@ def test_benchmark_complete_run_single_task(self, simple_benchmark): reports = benchmark.run(tasks) # Verify we got one report - assert len(reports) == 3 # 3 tasks in dummy_task_collection + assert len(reports) == 3 # 3 tasks in dummy_task_queue # Verify report structure report = reports[0] @@ -43,7 +43,7 @@ def test_benchmark_complete_run_multiple_tasks(self): """Test that a benchmark handles multiple tasks correctly.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, @@ -69,7 +69,7 @@ def test_benchmark_task_repetitions(self): """Test that task repetitions work correctly.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=3) reports = benchmark.run(tasks) @@ -116,7 +116,7 @@ def on_task_end(self, benchmark, task, result): def on_run_end(self, benchmark, results): invocations.append("on_run_end") - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, @@ -156,7 +156,7 @@ def test_benchmark_component_cleanup_between_repeats(self): from conftest import DummyBenchmark from maseval import BenchmarkCallback - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) # Track registry size after each repetition registry_sizes = [] @@ -187,7 +187,7 @@ def test_benchmark_registry_cleared_after_task(self): """Test that registry is properly cleared after each task repetition.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=1) # Before run, registry should be empty @@ -204,7 +204,7 @@ def test_benchmark_reports_structure(self): """Test that benchmark reports have the correct structure.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -239,7 +239,7 @@ def test_benchmark_agent_data_per_task(self): """Test that different agent_data can be provided per task.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, @@ -267,7 +267,7 @@ def test_benchmark_invalid_agent_data_length(self): """Test that providing wrong number of agent_data items raises error.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, @@ -321,7 +321,7 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingAgentAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = TaskFailureBenchmark( agent_data={"model": "test"}, fail_on_task_error=False, @@ -359,7 +359,7 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingAgentAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = TaskFailureBenchmark( agent_data={"model": "test"}, fail_on_task_error=True, @@ -384,7 +384,7 @@ class EvaluationFailureBenchmark(DummyBenchmark): def setup_evaluators(self, environment, task, agents, user): return [FailingEvaluator(task, environment, user)] - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = EvaluationFailureBenchmark( agent_data={"model": "test"}, fail_on_evaluation_error=False, @@ -416,7 +416,7 @@ class EvaluationFailureBenchmark(DummyBenchmark): def setup_evaluators(self, environment, task, agents, user): return [FailingEvaluator(task, environment, user)] - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = EvaluationFailureBenchmark( agent_data={"model": "test"}, fail_on_evaluation_error=True, @@ -434,7 +434,7 @@ class SetupFailureBenchmark(DummyBenchmark): def setup_environment(self, agent_data, task): raise RuntimeError("Environment setup failed!") - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = SetupFailureBenchmark( agent_data={"model": "test"}, fail_on_setup_error=False, @@ -458,7 +458,7 @@ class SetupFailureBenchmark(DummyBenchmark): def setup_environment(self, agent_data, task): raise RuntimeError("Environment setup failed!") - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = SetupFailureBenchmark( agent_data={"model": "test"}, fail_on_setup_error=True, @@ -498,7 +498,7 @@ def setup_agents(self, agent_data, environment, task, user): self.task_counter += 1 return [agent_adapter], {agent_adapter.name: agent_adapter} - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, @@ -563,7 +563,7 @@ def test_successful_task_has_success_status(self): from maseval import TaskExecutionStatus from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -597,7 +597,7 @@ def test_multiple_run_calls_no_side_effects(self): benchmark = DummyBenchmark(agent_data={"model": "test"}) # First run with 3 tasks - tasks1 = TaskCollection.from_list( + tasks1 = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, @@ -610,7 +610,7 @@ def test_multiple_run_calls_no_side_effects(self): assert len(benchmark.reports) == 3 # Second run with 2 different tasks - tasks2 = TaskCollection.from_list( + tasks2 = TaskQueue.from_list( [ {"query": "Task A", "environment_data": {}}, {"query": "Task B", "environment_data": {}}, @@ -629,7 +629,7 @@ def test_multiple_run_calls_no_side_effects(self): # Third run - retry pattern (simulating failed tasks) # Use one task from tasks1 - retry_tasks = TaskCollection([list(tasks1)[0]]) + retry_tasks = TaskQueue([list(tasks1)[0]]) reports3 = benchmark.run(tasks=retry_tasks) assert len(reports3) == 1 assert len(benchmark.reports) == 1 @@ -671,7 +671,7 @@ def setup_agents(self, agent_data, environment, task, user): self.task_counter += 1 return [agent_adapter], {agent_adapter.name: agent_adapter} - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, diff --git a/tests/test_core/test_benchmark/test_callback_orchestration.py b/tests/test_core/test_benchmark/test_callback_orchestration.py index 63a3028..8de8b17 100644 --- a/tests/test_core/test_benchmark/test_callback_orchestration.py +++ b/tests/test_core/test_benchmark/test_callback_orchestration.py @@ -5,7 +5,7 @@ """ import pytest -from maseval import BenchmarkCallback, TaskCollection +from maseval import BenchmarkCallback, TaskQueue @pytest.mark.core @@ -37,7 +37,7 @@ def on_task_end(self, benchmark, task, result): def on_run_end(self, benchmark, results): order.append("run_end") - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark( agent_data={"model": "test"}, n_task_repeats=2, @@ -79,7 +79,7 @@ def on_run_start(self, benchmark): def on_run_end(self, benchmark, results): callback2_calls.append("end") - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[Callback1(), Callback2()]) benchmark.run(tasks) @@ -104,7 +104,7 @@ def on_run_start(self, benchmark): def on_run_end(self, benchmark, results): successful_calls.append("end") - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark( agent_data={"model": "test"}, callbacks=[FailingCallback(), SuccessfulCallback()], @@ -136,7 +136,7 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): nonlocal repeat_count repeat_count += 1 - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, @@ -172,7 +172,7 @@ def on_task_repeat_end(self, benchmark, report): def on_run_end(self, benchmark, results): contexts["results_count"] = len(results) - tasks = TaskCollection.from_list([{"query": "TestQuery", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "TestQuery", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[ContextCapturingCallback()]) benchmark.run(tasks) @@ -195,7 +195,7 @@ def on_run_start(self, benchmark): captured_state["n_tasks"] = len(benchmark.tasks) captured_state["n_repeats"] = benchmark.n_task_repeats - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Q1", "environment_data": {}}, {"query": "Q2", "environment_data": {}}, diff --git a/tests/test_core/test_benchmark/test_config_collection.py b/tests/test_core/test_benchmark/test_config_collection.py index e1a4ad3..775c39d 100644 --- a/tests/test_core/test_benchmark/test_config_collection.py +++ b/tests/test_core/test_benchmark/test_config_collection.py @@ -5,7 +5,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -16,7 +16,7 @@ def test_config_collected_from_all_components(self): """Test that configs are collected from all registered components.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -48,7 +48,7 @@ def test_config_includes_benchmark_level_info(self): """Test that benchmark-level configuration is included.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -67,7 +67,7 @@ def test_config_includes_system_info(self): """Test that system information is captured.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -81,7 +81,7 @@ def test_config_includes_git_info(self): """Test that git information is captured when available.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -98,7 +98,7 @@ def test_config_includes_package_versions(self): """Test that installed package versions are captured.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -114,7 +114,7 @@ def test_config_structure_matches_spec(self): """Test that config structure matches expected specification.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -166,7 +166,7 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingConfigAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} # type: ignore[return-value] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) # Should complete without raising, with error info in config @@ -184,7 +184,7 @@ def test_config_json_serializable(self): import json from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -205,7 +205,7 @@ def test_config_contains_timestamps(self): """Test that all config components include timestamps.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -229,7 +229,7 @@ def test_config_includes_component_types(self): """Test that all configs include component type information.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -249,7 +249,7 @@ def test_config_different_per_repetition(self): """Test that each repetition has its own config snapshot.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=3) reports = benchmark.run(tasks) diff --git a/tests/test_core/test_benchmark/test_execution_loop.py b/tests/test_core/test_benchmark/test_execution_loop.py index d6629fc..10b895a 100644 --- a/tests/test_core/test_benchmark/test_execution_loop.py +++ b/tests/test_core/test_benchmark/test_execution_loop.py @@ -11,7 +11,7 @@ from typing import Any, List, Optional, Tuple import warnings -from maseval import Benchmark, Task, TaskCollection, User +from maseval import Benchmark, Task, TaskQueue, User # ============================================================================= @@ -350,7 +350,7 @@ def test_warning_max_invocations_without_user(self): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - benchmark.run(TaskCollection([task])) + benchmark.run(TaskQueue([task])) # Check for warning about max_invocations without user warning_messages = [str(warning.message) for warning in w] @@ -381,7 +381,7 @@ def test_run_with_user_uses_execution_loop(self, dummy_model): benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) - benchmark.run(TaskCollection([task])) + benchmark.run(TaskQueue([task])) # Verify run_agents was called with user's initial prompt assert len(benchmark.run_agents_calls) == 1 @@ -407,7 +407,7 @@ def test_complete_traces_with_user(self, dummy_model): max_invocations=2, ) - reports = benchmark.run(TaskCollection([task])) + reports = benchmark.run(TaskQueue([task])) # Check that user traces are in the report assert len(reports) == 1 diff --git a/tests/test_core/test_benchmark/test_parallel_execution.py b/tests/test_core/test_benchmark/test_parallel_execution.py index 4c6e280..186e908 100644 --- a/tests/test_core/test_benchmark/test_parallel_execution.py +++ b/tests/test_core/test_benchmark/test_parallel_execution.py @@ -12,7 +12,7 @@ from maseval import ( BenchmarkCallback, Task, - TaskCollection, + TaskQueue, TaskExecutionStatus, ) from conftest import DummyBenchmark @@ -93,7 +93,7 @@ def on_run_end(self, benchmark, results): @pytest.fixture def parallel_tasks(): """Create tasks for parallel execution testing.""" - return TaskCollection.from_list([{"query": f"Task {i}", "environment_data": {"index": i}} for i in range(5)]) + return TaskQueue.from_list([{"query": f"Task {i}", "environment_data": {"index": i}} for i in range(5)]) # ==================== Basic Parallel Execution Tests ==================== @@ -152,7 +152,7 @@ def test_single_worker_uses_sequential(self, parallel_tasks): def test_parallel_with_repetitions(self): """Verify parallel execution with n_task_repeats > 1.""" - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "T1", "environment_data": {}}, {"query": "T2", "environment_data": {}}, @@ -205,7 +205,7 @@ def test_traces_not_cross_contaminated(self, parallel_tasks): def test_callbacks_receive_correct_data(self): """Callbacks should receive correct task/report data in parallel.""" - tasks = TaskCollection.from_list([{"query": f"Query_{i}", "environment_data": {"idx": i}} for i in range(3)]) + tasks = TaskQueue.from_list([{"query": f"Query_{i}", "environment_data": {"idx": i}} for i in range(3)]) received_data = [] lock = threading.Lock() @@ -260,7 +260,7 @@ class TestParallelConcurrency: def test_parallel_faster_than_sequential(self): """Parallel execution should be faster for I/O-bound tasks.""" - tasks = TaskCollection.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) + tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) delay = 0.05 # Sequential timing @@ -280,7 +280,7 @@ def test_parallel_faster_than_sequential(self): def test_execution_overlaps(self): """Task executions should overlap in parallel mode.""" - tasks = TaskCollection.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(3)]) + tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(3)]) benchmark = SlowBenchmark( agent_data={"model": "test"}, @@ -322,7 +322,7 @@ def run_agents(self, agents, task, environment, query): raise RuntimeError("Intentional failure") return super().run_agents(agents, task, environment, query) - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Normal 1", "environment_data": {}}, {"query": "FAIL task", "environment_data": {}}, @@ -357,7 +357,7 @@ def run_agents(self, agents, task, environment, query): raise ValueError("Every other task fails") return super().run_agents(agents, task, environment, query) - tasks = TaskCollection.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) + tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) benchmark = HalfFailingBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks, max_workers=2) @@ -374,22 +374,19 @@ class TestParallelQueueIntegration: def test_custom_queue_respected(self, parallel_tasks): """Custom queue ordering should be respected.""" - from maseval.core.queue import PriorityQueue + from maseval.core.task import PriorityTaskQueue, TaskProtocol # Create tasks with priorities - prioritized_tasks = TaskCollection( - [ - Task( - query=f"P{p}", - environment_data={}, - protocol=__import__("maseval.core.task", fromlist=["TaskProtocol"]).TaskProtocol(priority=p), - ) - for p in [1, 5, 3, 2, 4] - ] - ) + prioritized_tasks = [ + Task( + query=f"P{p}", + environment_data={}, + protocol=TaskProtocol(priority=p), + ) + for p in [1, 5, 3, 2, 4] + ] - agent_data_list = [{"model": "test"}] * 5 - queue = PriorityQueue(prioritized_tasks, agent_data_list) + queue = PriorityTaskQueue(prioritized_tasks) # Track execution order execution_order = [] @@ -406,6 +403,6 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): ) # With max_workers=1, order should be strictly by priority - benchmark.run(prioritized_tasks, queue=queue, max_workers=1) + benchmark.run(queue, max_workers=1) assert execution_order == ["P5", "P4", "P3", "P2", "P1"] diff --git a/tests/test_core/test_benchmark/test_progress_bar_integration.py b/tests/test_core/test_benchmark/test_progress_bar_integration.py index 138371b..4814a14 100644 --- a/tests/test_core/test_benchmark/test_progress_bar_integration.py +++ b/tests/test_core/test_benchmark/test_progress_bar_integration.py @@ -15,7 +15,7 @@ from conftest import DummyBenchmark # noqa: E402 -from maseval.core.task import Task, TaskCollection # noqa: E402 +from maseval.core.task import Task, TaskQueue # noqa: E402 from maseval.core.callbacks.progress_bar import ( # noqa: E402 TqdmProgressBarCallback, RichProgressBarCallback, @@ -25,7 +25,7 @@ @pytest.mark.core def test_benchmark_with_default_progress_bar(): """Test that benchmark attaches tqdm progress bar by default.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) # Default should have progress bar benchmark = DummyBenchmark(agent_data={"model": "test"}) @@ -43,7 +43,7 @@ def test_benchmark_with_default_progress_bar(): @pytest.mark.core def test_benchmark_with_disabled_progress_bar(): """Test that progress bar can be disabled.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) benchmark = DummyBenchmark(agent_data={"model": "test"}, progress_bar=False) @@ -58,7 +58,7 @@ def test_benchmark_with_disabled_progress_bar(): @pytest.mark.core def test_benchmark_with_rich_progress_bar(): """Test that rich progress bar can be specified.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) benchmark = DummyBenchmark(agent_data={"model": "test"}, progress_bar="rich") @@ -73,7 +73,7 @@ def test_benchmark_with_rich_progress_bar(): @pytest.mark.core def test_benchmark_with_custom_progress_bar(): """Test that custom progress bar callback prevents default from being added.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) # User provides their own progress bar custom_pbar = TqdmProgressBarCallback(desc="Custom Progress") @@ -96,7 +96,7 @@ def test_benchmark_with_custom_progress_bar(): @pytest.mark.core def test_benchmark_with_multiple_tasks_and_repeats(): """Test progress bar with multiple tasks and repeats.""" - tasks = TaskCollection([Task(query=f"Task {i}") for i in range(3)]) + tasks = TaskQueue([Task(query=f"Task {i}") for i in range(3)]) benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=2, progress_bar=True) diff --git a/tests/test_core/test_benchmark/test_trace_collection.py b/tests/test_core/test_benchmark/test_trace_collection.py index 2b801a7..0632e57 100644 --- a/tests/test_core/test_benchmark/test_trace_collection.py +++ b/tests/test_core/test_benchmark/test_trace_collection.py @@ -5,7 +5,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -16,7 +16,7 @@ def test_traces_collected_from_all_components(self): """Test that traces are collected from all registered components.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -49,7 +49,7 @@ def test_traces_include_message_histories(self): """Test that agent traces include complete message histories.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -89,7 +89,7 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingAgentAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} # type: ignore[return-value] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) # Should complete without raising, with error info in traces @@ -130,7 +130,7 @@ def setup_agents(self, agent_data, environment, task, user): self.register("models", "test_model", model) return [agent_adapter], {"test_agent": agent_adapter} # type: ignore[return-value] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -151,7 +151,7 @@ def test_environment_traces_tool_invocations(self): """Test that Environment traces include tool information.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -217,7 +217,7 @@ def gather_traces(self): } callback = CustomCallback() - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[callback]) # Register callback for tracing @@ -238,7 +238,7 @@ def test_traces_json_serializable(self): import json from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -259,7 +259,7 @@ def test_traces_contain_timestamps(self): """Test that all trace components include timestamps.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) @@ -280,7 +280,7 @@ def test_traces_include_component_types(self): """Test that all traces include component type information.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) diff --git a/tests/test_core/test_evaluator.py b/tests/test_core/test_evaluator.py index 5bf2d6a..31ff817 100644 --- a/tests/test_core/test_evaluator.py +++ b/tests/test_core/test_evaluator.py @@ -4,7 +4,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -35,7 +35,7 @@ class TestBenchmark(DummyBenchmark): def setup_evaluators(self, environment, task, agents, user): return [TracingEvaluator(task, environment, user)] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) benchmark.run(tasks) @@ -57,7 +57,7 @@ def evaluate(self, evaluators, agents, final_answer, traces): assert "test_agent" in agents return super().evaluate(evaluators, agents, final_answer, traces) - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) benchmark.run(tasks) @@ -73,7 +73,7 @@ def evaluate(self, evaluators, agents, final_answer, traces): received_answers.append(final_answer) return super().evaluate(evaluators, agents, final_answer, traces) - tasks = TaskCollection.from_list([{"query": "My test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "My test query", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) benchmark.run(tasks) @@ -92,7 +92,7 @@ def evaluate(self, evaluators, agents, final_answer, traces): received_traces.append(traces) return super().evaluate(evaluators, agents, final_answer, traces) - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) benchmark.run(tasks) @@ -138,7 +138,7 @@ def setup_evaluators(self, environment, task, agents, user): Evaluator2(task, environment, user), ] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = TestBenchmark(agent_data={"model": "test"}) benchmark.run(tasks) @@ -150,7 +150,7 @@ def test_evaluator_results_in_report(self): """Test that evaluator results appear in the final report.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark(agent_data={"model": "test"}) reports = benchmark.run(tasks) diff --git a/tests/test_core/test_exceptions.py b/tests/test_core/test_exceptions.py index f8b5cf1..3354221 100644 --- a/tests/test_core/test_exceptions.py +++ b/tests/test_core/test_exceptions.py @@ -8,7 +8,7 @@ import pytest from maseval import ( - TaskCollection, + TaskQueue, TaskExecutionStatus, AgentError, EnvironmentError, @@ -41,7 +41,7 @@ def setup_agents(self, agent_data, environment, task, user): adapter = AgentErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = AgentErrorBenchmark(agent_data={}) reports = benchmark.run(tasks) @@ -68,7 +68,7 @@ def setup_agents(self, agent_data, environment, task, user): adapter = EnvironmentErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = EnvironmentErrorBenchmark(agent_data={}) reports = benchmark.run(tasks) @@ -95,7 +95,7 @@ def setup_agents(self, agent_data, environment, task, user): adapter = UserErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = UserErrorBenchmark(agent_data={}) reports = benchmark.run(tasks) @@ -122,7 +122,7 @@ def setup_agents(self, agent_data, environment, task, user): adapter = GenericErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = GenericErrorBenchmark(agent_data={}) reports = benchmark.run(tasks) @@ -152,7 +152,7 @@ def setup_agents(self, agent_data, environment, task, user): adapter = DetailedAgentErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DetailedAgentErrorBenchmark(agent_data={}) reports = benchmark.run(tasks) @@ -401,7 +401,7 @@ def run(self, query: str) -> str: adapter = DummyAgentAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, diff --git a/tests/test_core/test_queue.py b/tests/test_core/test_queue.py index 2ded59f..558c891 100644 --- a/tests/test_core/test_queue.py +++ b/tests/test_core/test_queue.py @@ -1,22 +1,28 @@ """Tests for TaskQueue implementations. -These tests verify that SequentialQueue, PriorityQueue, and AdaptiveQueue +These tests verify that SequentialTaskQueue, PriorityTaskQueue, and AdaptiveTaskQueue correctly order and iterate over tasks. """ import pytest -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional -from maseval import Task, TaskCollection -from maseval.core.task import TaskProtocol -from maseval.core.queue import SequentialQueue, PriorityQueue, AdaptiveQueue +from maseval import Task +from maseval.core.task import ( + TaskProtocol, + SequentialTaskQueue, + PriorityTaskQueue, + AdaptiveTaskQueue, + TaskQueue, + BaseTaskQueue, +) # ==================== Fixtures ==================== @pytest.fixture -def task_collection_with_priorities() -> TaskCollection: +def tasks_with_priorities() -> List[Task]: """Create tasks with different priorities.""" tasks = [] for i, priority in enumerate([0, 5, 2, 8, 1]): @@ -26,51 +32,163 @@ def task_collection_with_priorities() -> TaskCollection: protocol=TaskProtocol(priority=priority), ) tasks.append(task) - return TaskCollection(tasks) + return tasks @pytest.fixture -def agent_data_list() -> List[Dict[str, Any]]: - """Agent data list matching 5 tasks.""" - return [{"id": i} for i in range(5)] +def simple_tasks() -> List[Task]: + """Simple task list for basic tests.""" + return [ + Task(query="Q1", environment_data={}), + Task(query="Q2", environment_data={}), + Task(query="Q3", environment_data={}), + ] -@pytest.fixture -def simple_task_collection() -> TaskCollection: - """Simple task collection for basic tests.""" - return TaskCollection.from_list( - [ - {"query": "Q1", "environment_data": {}}, - {"query": "Q2", "environment_data": {}}, - {"query": "Q3", "environment_data": {}}, +# ==================== BaseTaskQueue Tests ==================== + + +@pytest.mark.core +class TestBaseTaskQueue: + """Tests for BaseTaskQueue common functionality.""" + + def test_taskqueue_is_alias_for_sequential(self): + """TaskQueue should be an alias for SequentialTaskQueue.""" + assert TaskQueue is SequentialTaskQueue + + def test_sequence_protocol(self, simple_tasks): + """Queue should implement Sequence protocol.""" + queue = SequentialTaskQueue(simple_tasks) + + # __len__ + assert len(queue) == 3 + + # __getitem__ with int + assert queue[0].query == "Q1" + assert queue[1].query == "Q2" + assert queue[-1].query == "Q3" + + # __getitem__ with slice + sliced = queue[1:] + assert isinstance(sliced, BaseTaskQueue) + assert len(sliced) == 2 + + def test_append_and_extend(self, simple_tasks): + """Queue should support append and extend.""" + queue = SequentialTaskQueue(simple_tasks[:2]) + assert len(queue) == 2 + + queue.append(simple_tasks[2]) + assert len(queue) == 3 + + queue.extend([Task(query="Q4"), Task(query="Q5")]) + assert len(queue) == 5 + + def test_to_list(self, simple_tasks): + """to_list() should return a copy of internal list.""" + queue = SequentialTaskQueue(simple_tasks) + + result = queue.to_list() + + assert result == simple_tasks + assert result is not queue._tasks # Should be a copy + + def test_from_list_with_tasks(self, simple_tasks): + """from_list should accept Task objects.""" + queue = SequentialTaskQueue.from_list(simple_tasks) + + assert len(queue) == 3 + assert queue[0].query == "Q1" + + def test_from_list_with_dicts(self): + """from_list should accept dicts and convert to Tasks.""" + data = [ + {"query": "Dict 1"}, + {"query": "Dict 2", "environment_data": {"key": "value"}}, ] - ) + queue = SequentialTaskQueue.from_list(data) + assert len(queue) == 2 + assert queue[0].query == "Dict 1" + assert queue[1].environment_data == {"key": "value"} -@pytest.fixture -def simple_agent_data() -> List[Dict[str, Any]]: - """Agent data matching simple_task_collection.""" - return [{"model": "test"}] * 3 + def test_from_list_type_error(self): + """from_list should raise TypeError for invalid items.""" + with pytest.raises(TypeError, match="expects Task or dict"): + SequentialTaskQueue.from_list(["not a task"]) + + def test_from_json_file(self, tmp_path): + """from_json_file should load tasks from JSON file.""" + import json + + data = { + "data": [ + {"query": "Task 1", "environment_data": {}}, + {"query": "Task 2", "environment_data": {}}, + ] + } + + file_path = tmp_path / "tasks.json" + with open(file_path, "w") as f: + json.dump(data, f) + queue = SequentialTaskQueue.from_json_file(file_path) -# ==================== SequentialQueue Tests ==================== + assert len(queue) == 2 + assert queue[0].query == "Task 1" + assert queue[1].query == "Task 2" + + def test_from_json_file_with_limit(self, tmp_path): + """from_json_file should respect limit parameter.""" + import json + + data = {"data": [{"query": f"Task {i}"} for i in range(10)]} + + file_path = tmp_path / "tasks.json" + with open(file_path, "w") as f: + json.dump(data, f) + + queue = SequentialTaskQueue.from_json_file(file_path, limit=5) + + assert len(queue) == 5 + assert queue[4].query == "Task 4" + + def test_from_list_field_mapping(self): + """from_list should map alternative field names.""" + # Test question -> query mapping and short_answer -> evaluation_data + queue = SequentialTaskQueue.from_list([{"question": "What is 2+2?", "short_answer": "4"}]) + + task = queue[0] + assert task.query == "What is 2+2?" + assert task.evaluation_data == {"short_answer": "4"} + + def test_repr(self, simple_tasks): + """Queue should have informative repr.""" + queue = SequentialTaskQueue(simple_tasks) + + repr_str = repr(queue) + # Should mention queue type and task count + assert "SequentialTaskQueue" in repr_str or "TaskQueue" in repr_str or "3" in repr_str + + +# ==================== SequentialTaskQueue Tests ==================== @pytest.mark.core -class TestSequentialQueue: - """Tests for SequentialQueue ordering.""" +class TestSequentialTaskQueue: + """Tests for SequentialTaskQueue ordering.""" - def test_order_preserved(self, simple_task_collection, simple_agent_data): + def test_order_preserved(self, simple_tasks): """Tasks should be yielded in original order.""" - queue = SequentialQueue(simple_task_collection, simple_agent_data) + queue = SequentialTaskQueue(simple_tasks) - queries = [task.query for task, _ in queue] + queries = [task.query for task in queue] assert queries == ["Q1", "Q2", "Q3"] - def test_all_tasks_yielded(self, simple_task_collection, simple_agent_data): + def test_all_tasks_yielded(self, simple_tasks): """All tasks should be yielded exactly once.""" - queue = SequentialQueue(simple_task_collection, simple_agent_data) + queue = SequentialTaskQueue(simple_tasks) count = sum(1 for _ in queue) @@ -78,7 +196,7 @@ def test_all_tasks_yielded(self, simple_task_collection, simple_agent_data): def test_empty_collection(self): """Empty collection should yield nothing.""" - queue = SequentialQueue(TaskCollection([]), []) + queue = SequentialTaskQueue([]) items = list(queue) @@ -86,115 +204,113 @@ def test_empty_collection(self): def test_single_task(self): """Single task should be handled correctly.""" - tasks = TaskCollection.from_list([{"query": "Only one"}]) - queue = SequentialQueue(tasks, [{"model": "test"}]) + queue = SequentialTaskQueue([Task(query="Only one")]) items = list(queue) assert len(items) == 1 - assert items[0][0].query == "Only one" + assert items[0].query == "Only one" - def test_agent_data_paired_correctly(self, simple_task_collection): - """Agent data should be paired with correct task.""" - agent_data = [{"id": 1}, {"id": 2}, {"id": 3}] - queue = SequentialQueue(simple_task_collection, agent_data) - pairs = list(queue) +# ==================== PriorityTaskQueue Tests ==================== - assert pairs[0][1]["id"] == 1 - assert pairs[1][1]["id"] == 2 - assert pairs[2][1]["id"] == 3 +@pytest.mark.core +class TestPriorityTaskQueue: + """Tests for PriorityTaskQueue priority ordering.""" -# ==================== PriorityQueue Tests ==================== + def test_high_priority_first(self, tasks_with_priorities): + """Higher priority tasks should come first (default).""" + queue = PriorityTaskQueue(tasks_with_priorities) + priorities = [task.protocol.priority for task in queue] -@pytest.mark.core -class TestPriorityQueue: - """Tests for PriorityQueue priority ordering.""" + assert priorities == [8, 5, 2, 1, 0] - def test_high_priority_first(self, task_collection_with_priorities, agent_data_list): - """Higher priority tasks should come first.""" - queue = PriorityQueue(task_collection_with_priorities, agent_data_list) + def test_low_priority_first_with_reverse_false(self, tasks_with_priorities): + """Lower priority tasks should come first when reverse=False.""" + queue = PriorityTaskQueue(tasks_with_priorities, reverse=False) - priorities = [task.protocol.priority for task, _ in queue] + priorities = [task.protocol.priority for task in queue] - assert priorities == [8, 5, 2, 1, 0] + assert priorities == [0, 1, 2, 5, 8] def test_stable_sort_for_equal_priorities(self): """Tasks with equal priority should maintain original order.""" - tasks = TaskCollection( - [ - Task(query="First", environment_data={}, protocol=TaskProtocol(priority=5)), - Task(query="Second", environment_data={}, protocol=TaskProtocol(priority=5)), - Task(query="Third", environment_data={}, protocol=TaskProtocol(priority=5)), - ] - ) - agent_data = [{}, {}, {}] - queue = PriorityQueue(tasks, agent_data) + tasks = [ + Task(query="First", environment_data={}, protocol=TaskProtocol(priority=5)), + Task(query="Second", environment_data={}, protocol=TaskProtocol(priority=5)), + Task(query="Third", environment_data={}, protocol=TaskProtocol(priority=5)), + ] + queue = PriorityTaskQueue(tasks) - queries = [task.query for task, _ in queue] + queries = [task.query for task in queue] # Python's sort is stable, so original order should be preserved assert queries == ["First", "Second", "Third"] - def test_default_priority_zero(self, simple_task_collection, simple_agent_data): + def test_default_priority_zero(self, simple_tasks): """Tasks without explicit priority should have priority 0.""" - queue = PriorityQueue(simple_task_collection, simple_agent_data) + queue = PriorityTaskQueue(simple_tasks) - for task, _ in queue: + for task in queue: assert task.protocol.priority == 0 def test_negative_priority(self): """Negative priorities should be handled correctly.""" - tasks = TaskCollection( - [ - Task(query="Low", environment_data={}, protocol=TaskProtocol(priority=-5)), - Task(query="Normal", environment_data={}, protocol=TaskProtocol(priority=0)), - Task(query="High", environment_data={}, protocol=TaskProtocol(priority=5)), - ] - ) - queue = PriorityQueue(tasks, [{}, {}, {}]) + tasks = [ + Task(query="Low", environment_data={}, protocol=TaskProtocol(priority=-5)), + Task(query="Normal", environment_data={}, protocol=TaskProtocol(priority=0)), + Task(query="High", environment_data={}, protocol=TaskProtocol(priority=5)), + ] + queue = PriorityTaskQueue(tasks) - queries = [task.query for task, _ in queue] + queries = [task.query for task in queue] assert queries == ["High", "Normal", "Low"] - def test_agent_data_follows_priority(self, task_collection_with_priorities, agent_data_list): - """Agent data should follow task after priority sort.""" - queue = PriorityQueue(task_collection_with_priorities, agent_data_list) - pairs = list(queue) +# ==================== AdaptiveTaskQueue Tests ==================== + + +class ConcreteAdaptiveQueue(AdaptiveTaskQueue): + """Concrete implementation of AdaptiveTaskQueue for testing.""" - # Task with priority 8 was at index 3 - assert pairs[0][1]["id"] == 3 - # Task with priority 5 was at index 1 - assert pairs[1][1]["id"] == 1 + def __init__(self, tasks): + super().__init__(tasks) + self._selection_order: List[int] = [] # Track selection indices + def _select_next_task(self) -> Optional[Task]: + """Select tasks in order (simple FIFO).""" + if not self._remaining: + return None + return self._remaining[0] -# ==================== AdaptiveQueue Tests ==================== + def _update_state(self, task: Task, report: Dict[str, Any]) -> None: + """Track update calls.""" + pass @pytest.mark.core -class TestAdaptiveQueue: - """Tests for AdaptiveQueue adaptive behavior.""" +class TestAdaptiveTaskQueue: + """Tests for AdaptiveTaskQueue adaptive behavior.""" - def test_basic_iteration_with_completion(self, simple_task_collection, simple_agent_data): - """AdaptiveQueue should yield all tasks when on_task_complete is called.""" - queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + def test_basic_iteration_with_completion(self, simple_tasks): + """AdaptiveTaskQueue should yield all tasks when on_task_complete is called.""" + queue = ConcreteAdaptiveQueue(simple_tasks) count = 0 - for task, agent_data in queue: + for task in queue: count += 1 # Must call on_task_complete to progress to next task queue.on_task_complete(task, {"status": "success"}) assert count == 3 - def test_on_task_complete_moves_to_completed(self, simple_task_collection, simple_agent_data): + def test_on_task_complete_moves_to_completed(self, simple_tasks): """on_task_complete should move task to completed list.""" - queue = AdaptiveQueue(simple_task_collection, simple_agent_data) - task, _ = next(iter(queue)) + queue = ConcreteAdaptiveQueue(simple_tasks) + task = next(iter(queue)) assert len(queue._completed) == 0 @@ -203,20 +319,20 @@ def test_on_task_complete_moves_to_completed(self, simple_task_collection, simpl assert len(queue._completed) == 1 assert queue._completed[0][0].id == task.id - def test_stop_terminates_iteration(self, simple_task_collection, simple_agent_data): + def test_stop_terminates_iteration(self, simple_tasks): """Calling stop() should end iteration early.""" - queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + queue = ConcreteAdaptiveQueue(simple_tasks) items = [] - for task, agent_data in queue: + for task in queue: items.append(task) queue.stop() # Stop immediately after first yield assert len(items) == 1 - def test_should_continue_false_after_stop(self, simple_task_collection, simple_agent_data): + def test_should_continue_false_after_stop(self, simple_tasks): """should_continue() should return False after stop().""" - queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + queue = ConcreteAdaptiveQueue(simple_tasks) assert queue.should_continue() is True @@ -226,41 +342,41 @@ def test_should_continue_false_after_stop(self, simple_task_collection, simple_a def test_should_continue_false_when_empty(self): """should_continue() should return False when no pending tasks.""" - queue = AdaptiveQueue(TaskCollection([]), []) + queue = ConcreteAdaptiveQueue([]) assert queue.should_continue() is False - def test_pending_decreases_after_completion(self, simple_task_collection, simple_agent_data): - """Pending list should shrink as tasks complete.""" - queue = AdaptiveQueue(simple_task_collection, simple_agent_data) + def test_remaining_decreases_after_completion(self, simple_tasks): + """Remaining list should shrink as tasks complete.""" + queue = ConcreteAdaptiveQueue(simple_tasks) - assert len(queue._pending) == 3 + assert len(queue._remaining) == 3 - task, _ = next(iter(queue)) + task = next(iter(queue)) queue.on_task_complete(task, {"status": "success"}) - assert len(queue._pending) == 2 + assert len(queue._remaining) == 2 assert len(queue._completed) == 1 -# ==================== Queue Integration Tests ==================== +# ==================== Queue Callback Tests ==================== @pytest.mark.core class TestQueueCallbacks: """Tests for queue callback mechanisms.""" - def test_on_task_complete_called(self, simple_task_collection, simple_agent_data): + def test_on_task_complete_callable_without_error(self, simple_tasks): """on_task_complete should be callable without error.""" - queue = SequentialQueue(simple_task_collection, simple_agent_data) + queue = SequentialTaskQueue(simple_tasks) - for task, _ in queue: - # SequentialQueue's on_task_complete is a no-op, but should not raise + for task in queue: + # SequentialTaskQueue's on_task_complete is a no-op, but should not raise queue.on_task_complete(task, {"status": "success"}) - def test_should_continue_always_true_for_sequential(self, simple_task_collection, simple_agent_data): - """SequentialQueue should always return True for should_continue.""" - queue = SequentialQueue(simple_task_collection, simple_agent_data) + def test_should_continue_always_true_for_sequential(self, simple_tasks): + """SequentialTaskQueue should always return True for should_continue.""" + queue = SequentialTaskQueue(simple_tasks) - for task, _ in queue: + for task in queue: assert queue.should_continue() is True diff --git a/tests/test_core/test_task_collection.py b/tests/test_core/test_task_collection.py deleted file mode 100644 index 3ffc9ce..0000000 --- a/tests/test_core/test_task_collection.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Test TaskCollection functionality. - -These tests verify that TaskCollection behaves like a sequence and correctly -loads tasks from various formats. -""" - -import pytest -from maseval import Task, TaskCollection -from pathlib import Path -import json -import tempfile - - -@pytest.mark.core -class TestTaskCollection: - """Tests for TaskCollection interface and factories.""" - - def test_task_collection_from_list(self): - """Test creating TaskCollection from a list of dicts.""" - data = [ - {"query": "Q1", "environment_data": {"e": 1}}, - {"query": "Q2", "environment_data": {"e": 2}}, - ] - - collection = TaskCollection.from_list(data) - - assert len(collection) == 2 - assert collection[0].query == "Q1" - assert collection[1].query == "Q2" - - def test_task_collection_from_json_file(self): - """Test loading TaskCollection from JSON file.""" - with tempfile.TemporaryDirectory() as tmpdir: - # Create test JSON file - data = { - "data": [ - {"query": "Task 1", "environment_data": {}}, - {"query": "Task 2", "environment_data": {}}, - ] - } - - file_path = Path(tmpdir) / "tasks.json" - with open(file_path, "w") as f: - json.dump(data, f) - - # Load from file - collection = TaskCollection.from_json_file(file_path) - - assert len(collection) == 2 - assert collection[0].query == "Task 1" - assert collection[1].query == "Task 2" - - def test_task_collection_sequence_interface(self): - """Test that TaskCollection implements Sequence interface.""" - collection = TaskCollection.from_list( - [ - {"query": "Q1"}, - {"query": "Q2"}, - {"query": "Q3"}, - ] - ) - - # Test len - assert len(collection) == 3 - - # Test iteration - queries = [task.query for task in collection] - assert queries == ["Q1", "Q2", "Q3"] - - # Test indexing - assert collection[0].query == "Q1" - assert collection[-1].query == "Q3" - - def test_task_collection_slicing(self): - """Test that TaskCollection supports slicing.""" - collection = TaskCollection.from_list([{"query": f"Q{i}"} for i in range(10)]) - - # Test slice - subset = collection[2:5] - assert isinstance(subset, TaskCollection) - assert len(subset) == 3 - assert subset[0].query == "Q2" - assert subset[2].query == "Q4" - - # Test slice from start - start = collection[:3] - assert len(start) == 3 - - # Test slice to end - end = collection[7:] - assert len(end) == 3 - - def test_task_collection_iteration(self): - """Test iterating over TaskCollection.""" - data = [{"query": f"Q{i}"} for i in range(5)] - collection = TaskCollection.from_list(data) - - queries = [] - for task in collection: - queries.append(task.query) - - assert queries == ["Q0", "Q1", "Q2", "Q3", "Q4"] - - def test_task_dict_conversion(self): - """Test that dict items are converted to Task objects.""" - collection = TaskCollection.from_list( - [ - { - "query": "Test", - "environment_data": {"key": "value"}, - "evaluation_data": {"expected": "result"}, - "metadata": {"difficulty": "easy"}, - } - ] - ) - - task = collection[0] - assert isinstance(task, Task) - assert task.query == "Test" - assert task.environment_data == {"key": "value"} - assert task.evaluation_data == {"expected": "result"} - assert task.metadata == {"difficulty": "easy"} - - def test_task_field_mapping(self): - """Test that alternative field names are mapped correctly.""" - # Test question -> query mapping - collection = TaskCollection.from_list([{"question": "What is 2+2?", "short_answer": "4"}]) - - task = collection[0] - assert task.query == "What is 2+2?" - assert task.evaluation_data == {"short_answer": "4"} - - def test_task_collection_append(self): - """Test appending tasks to collection.""" - collection = TaskCollection() - assert len(collection) == 0 - - task = Task(query="Test") - collection.append(task) - - assert len(collection) == 1 - assert collection[0].query == "Test" - - def test_task_collection_extend(self): - """Test extending collection with multiple tasks.""" - collection = TaskCollection() - - new_tasks = [ - Task(query="Q1"), - Task(query="Q2"), - Task(query="Q3"), - ] - - collection.extend(new_tasks) - - assert len(collection) == 3 - assert collection[2].query == "Q3" - - def test_task_collection_to_list(self): - """Test converting TaskCollection to list.""" - data = [{"query": f"Q{i}"} for i in range(3)] - collection = TaskCollection.from_list(data) - - task_list = collection.to_list() - - assert isinstance(task_list, list) - assert len(task_list) == 3 - assert all(isinstance(t, Task) for t in task_list) - - def test_task_collection_from_json_with_limit(self): - """Test loading with a limit on number of tasks.""" - with tempfile.TemporaryDirectory() as tmpdir: - data = {"data": [{"query": f"Task {i}"} for i in range(10)]} - - file_path = Path(tmpdir) / "tasks.json" - with open(file_path, "w") as f: - json.dump(data, f) - - # Load only first 5 - collection = TaskCollection.from_json_file(file_path, limit=5) - - assert len(collection) == 5 - assert collection[4].query == "Task 4" - - def test_task_collection_repr(self): - """Test string representation of TaskCollection.""" - collection = TaskCollection.from_list([{"query": "Q1"}, {"query": "Q2"}]) - - repr_str = repr(collection) - assert "TaskCollection" in repr_str - assert "2" in repr_str # Should mention number of tasks diff --git a/tests/test_core/test_task_protocol.py b/tests/test_core/test_task_protocol.py index f1c7a5f..28a7e24 100644 --- a/tests/test_core/test_task_protocol.py +++ b/tests/test_core/test_task_protocol.py @@ -5,7 +5,7 @@ """ import pytest -from maseval import Task, TaskCollection +from maseval import Task, TaskQueue from maseval.core.task import TaskProtocol, TimeoutAction @@ -97,11 +97,11 @@ def test_task_custom_protocol(self): assert task.protocol.timeout_seconds == 30.0 assert task.protocol.priority == 5 - def test_task_collection_preserves_protocol(self): - """TaskCollection should preserve protocol on tasks.""" + def test_task_queue_preserves_protocol(self): + """TaskQueue should preserve protocol on tasks.""" task1 = Task(query="Q1", protocol=TaskProtocol(priority=1)) task2 = Task(query="Q2", protocol=TaskProtocol(priority=2)) - tasks = TaskCollection([task1, task2]) + tasks = TaskQueue([task1, task2]) first_task: Task = tasks[0] # type: ignore[assignment] second_task: Task = tasks[1] # type: ignore[assignment] From 692d6d2549c2da88fa4ec18d651a7cd33ce81adf Mon Sep 17 00:00:00 2001 From: cemde Date: Sat, 6 Dec 2025 00:55:21 +0000 Subject: [PATCH 06/25] fixed typing issue --- examples/introduction/tutorial.ipynb | 26 +++++++++++++------------- maseval/core/task.py | 8 +++++++- tests/test_core/test_context.py | 11 +++++++---- tests/test_core/test_registry.py | 6 +++--- 4 files changed, 30 insertions(+), 21 deletions(-) diff --git a/examples/introduction/tutorial.ipynb b/examples/introduction/tutorial.ipynb index afbc50d..6d51244 100644 --- a/examples/introduction/tutorial.ipynb +++ b/examples/introduction/tutorial.ipynb @@ -9,13 +9,13 @@ "\n", "[![Open Notebook on GitHub](https://img.shields.io/badge/Open%20Notebook%20on-GitHub-blue?logo=github)](https://github.com/parameterlab/MASEval/blob/main/examples/introduction/tutorial.ipynb)\n", "\n", - "This notebook is available as a Jupyter notebook — clone the repo and run it yourself!\n", + "This notebook is available as a Jupyter notebook \u2014 clone the repo and run it yourself!\n", "\n", "## What You'll Learn\n", "\n", - "- **Build your first agent** — Create tools and agents with smolagents\n", - "- **Run a minimal benchmark** — One task, one agent, end-to-end\n", - "- **Understand the core abstractions** — Tasks, Environments, Evaluators working together\n", + "- **Build your first agent** \u2014 Create tools and agents with smolagents\n", + "- **Run a minimal benchmark** \u2014 One task, one agent, end-to-end\n", + "- **Understand the core abstractions** \u2014 Tasks, Environments, Evaluators working together\n", "\n", "\n", "This tutorial first introduces [`smolagents`](https://huggingface.co/docs/smolagents/en/index) as introduction to agents. Then it provides a super small single task benchmark." @@ -634,13 +634,13 @@ "metadata": {}, "outputs": [], "source": [ - "\"# Create benchmark instance with agent configuration\\n\",\n", - " \"agent_data = {\\\"model_id\\\": \\\"gemini/gemini-2.5-flash\\\", \\\"temperature\\\": 0.7}\\n\",\n", - " \"\\n\",\n", - " \"benchmark = SimpleBenchmark(agent_data=agent_data, progress_bar=False)\\n\",\n", - " \"\\n\",\n", - " \"# Create task queue\\n\",\n", - " \"tasks = TaskQueue([task])\\n\",\n", + "# Create benchmark instance with agent configuration\n", + "agent_data = {\"model_id\": \"gemini/gemini-2.5-flash\", \"temperature\": 0.7}\n", + "\n", + "benchmark = SimpleBenchmark(agent_data=agent_data, progress_bar=False)\n", + "\n", + "# Create task queue\n", + "tasks = TaskQueue([task])\n", "\n", "# Run the benchmark\n", "print(\"Running benchmark...\\n\")\n", @@ -715,7 +715,7 @@ "\n", "## Next Steps\n", "\n", - "1. **Try the Five-A-Day Benchmark notebook** — A production-ready example with multi-agent systems and diverse evaluators\n", + "1. **Try the Five-A-Day Benchmark notebook** \u2014 A production-ready example with multi-agent systems and diverse evaluators\n", "2. Create your own custom evaluators for your specific use case\n", "3. Experiment with different agent frameworks (LangGraph, LlamaIndex)\n", "4. Add callbacks for logging and tracing\n", @@ -745,4 +745,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/maseval/core/task.py b/maseval/core/task.py index 8648fa9..dff1d1b 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, overload from uuid import UUID, uuid4 from collections.abc import Sequence from typing import Iterable, List, Union, Iterator, Optional @@ -108,6 +108,12 @@ def __len__(self) -> int: """Return the total number of tasks in the queue.""" return len(self._tasks) + @overload + def __getitem__(self, idx: int) -> Task: ... + + @overload + def __getitem__(self, idx: slice) -> "BaseTaskQueue": ... + def __getitem__(self, idx: Union[int, slice]) -> Union[Task, "BaseTaskQueue"]: """Get a task by index or a slice of tasks. diff --git a/tests/test_core/test_context.py b/tests/test_core/test_context.py index d50e467..2f39170 100644 --- a/tests/test_core/test_context.py +++ b/tests/test_core/test_context.py @@ -79,8 +79,9 @@ def test_check_timeout_raises_on_expiry(self): with pytest.raises(TaskTimeoutError) as exc_info: context.check_timeout() - assert exc_info.value.timeout == 0.01 - assert exc_info.value.elapsed >= 0.01 + error: TaskTimeoutError = exc_info.value # type: ignore[assignment] + assert error.timeout == 0.01 + assert error.elapsed >= 0.01 def test_check_timeout_includes_partial_traces(self): """TaskTimeoutError should include partial traces if set.""" @@ -93,7 +94,8 @@ def test_check_timeout_includes_partial_traces(self): with pytest.raises(TaskTimeoutError) as exc_info: context.check_timeout() - assert exc_info.value.partial_traces == partial_traces + error: TaskTimeoutError = exc_info.value # type: ignore[assignment] + assert error.partial_traces == partial_traces def test_check_timeout_no_traces_if_not_set(self): """TaskTimeoutError should have empty traces if not set.""" @@ -104,7 +106,8 @@ def test_check_timeout_no_traces_if_not_set(self): with pytest.raises(TaskTimeoutError) as exc_info: context.check_timeout() - assert exc_info.value.partial_traces == {} + error: TaskTimeoutError = exc_info.value # type: ignore[assignment] + assert error.partial_traces == {} def test_check_timeout_does_not_raise_before_deadline(self): """check_timeout() should not raise before deadline.""" diff --git a/tests/test_core/test_registry.py b/tests/test_core/test_registry.py index 925a5f1..5940877 100644 --- a/tests/test_core/test_registry.py +++ b/tests/test_core/test_registry.py @@ -8,7 +8,7 @@ import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict +from typing import Any, Dict, Optional from maseval.core.registry import ComponentRegistry from maseval.core.tracing import TraceableMixin @@ -21,7 +21,7 @@ class MockTraceableComponent(TraceableMixin): """Component that implements TraceableMixin for testing.""" - def __init__(self, name: str, trace_data: Dict[str, Any] = None): + def __init__(self, name: str, trace_data: Optional[Dict[str, Any]] = None): super().__init__() self._name = name self._trace_data = trace_data or {"component": name} @@ -36,7 +36,7 @@ def gather_traces(self) -> Dict[str, Any]: class MockConfigurableComponent(TraceableMixin, ConfigurableMixin): """Component that implements both TraceableMixin and ConfigurableMixin.""" - def __init__(self, name: str, config: Dict[str, Any] = None): + def __init__(self, name: str, config: Optional[Dict[str, Any]] = None): TraceableMixin.__init__(self) ConfigurableMixin.__init__(self) self._name = name From 282019785392d9a1a63b25e9651e256c595e4f95 Mon Sep 17 00:00:00 2001 From: cemde Date: Sat, 6 Dec 2025 09:41:06 +0000 Subject: [PATCH 07/25] simplified queue --- maseval/core/benchmark.py | 14 ---- maseval/core/callback.py | 2 +- maseval/core/task.py | 120 ++++++++++++++++------------------ tests/test_core/test_queue.py | 45 ++++++------- 4 files changed, 76 insertions(+), 105 deletions(-) diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index 289a6e2..02a3210 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -1165,21 +1165,14 @@ def _run_sequential( report = self._execute_task_repetition(task, agent_data, repeat_idx) self._append_report_safe(report) - queue.on_task_complete(task, report) self._invoke_callbacks("on_task_repeat_end", self, report) - if not queue.should_continue(): - return - # Callbacks at the end of each task task_reports = [r for r in self.reports if r["task_id"] == str(task.id)] last_report = task_reports[-1] if task_reports else {} self._invoke_callbacks("on_task_end", self, task, last_report) - if not queue.should_continue(): - return - def _run_parallel( self, queue: BaseTaskQueue, @@ -1269,7 +1262,6 @@ def submit_task_repeats(task: Task) -> None: } self._append_report_safe(report) - queue.on_task_complete(task, report) self._invoke_callbacks("on_task_repeat_end", self, report) @@ -1280,12 +1272,6 @@ def submit_task_repeats(task: Task) -> None: last_report = task_reports[-1] if task_reports else {} self._invoke_callbacks("on_task_end", self, task, last_report) - if not queue.should_continue(): - # Cancel remaining futures - for f in futures: - f.cancel() - return - # Submit more work if queue not exhausted if not queue_exhausted and len(futures) < max_workers: try: diff --git a/maseval/core/callback.py b/maseval/core/callback.py index 80c68dc..f352045 100644 --- a/maseval/core/callback.py +++ b/maseval/core/callback.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING from .tracing import TraceableMixin diff --git a/maseval/core/task.py b/maseval/core/task.py index dff1d1b..3660140 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, Tuple, overload +from typing import Any, Dict, Tuple, overload, TYPE_CHECKING from uuid import UUID, uuid4 from collections.abc import Sequence from typing import Iterable, List, Union, Iterator, Optional @@ -8,6 +8,12 @@ from pathlib import Path from enum import Enum +if TYPE_CHECKING: + from .benchmark import Benchmark + +# Import BenchmarkCallback at runtime to enable inheritance +from .callback import BenchmarkCallback + class TimeoutAction(Enum): """Action to take when a task timeout occurs.""" @@ -70,15 +76,13 @@ class Task: class BaseTaskQueue(ABC, Sequence): """Abstract base class for task scheduling strategies. - BaseTaskQueue provides a sequence-like interface for task execution with hooks - for adaptive behavior based on task results. Concrete implementations can - reorder tasks, skip tasks, or terminate early based on execution outcomes. - - The queue yields Task objects for execution. After each task completes, - ``on_task_complete()`` is called with the result, allowing the queue to - adapt its scheduling strategy. + BaseTaskQueue provides a sequence-like interface for task execution. + Concrete implementations can reorder tasks, skip tasks, or terminate + early based on execution outcomes. Subclasses must implement ``__iter__`` to define the iteration order. + For adaptive behavior based on task results, use ``AdaptiveTaskQueue`` + which integrates with the benchmark callback system. Attributes: _tasks: Internal list of tasks. @@ -89,10 +93,7 @@ class BaseTaskQueue(ABC, Sequence): for task in queue: report = execute_task(task) - queue.on_task_complete(task, report) - - if not queue.should_continue(): - break + # Iterator handles termination automatically ``` """ @@ -137,31 +138,6 @@ def __iter__(self) -> Iterator[Task]: """ pass - def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: - """Called after each task completes. - - Override this method for adaptive scheduling behavior that responds - to task execution results (e.g., updating ability estimates, adjusting - priorities, or marking related tasks for skipping). - - Args: - task: The task that just completed. - report: The execution report containing status, traces, and eval results. - """ - pass - - def should_continue(self) -> bool: - """Whether to continue processing tasks. - - Default implementation returns True. Override for early termination - conditions (e.g., confidence threshold reached, maximum tasks processed, - or error limit exceeded). - - Returns: - True to continue processing, False to stop. - """ - return True - def append(self, task: Task) -> None: """Add a task to the end of the queue. @@ -330,13 +306,13 @@ def __iter__(self) -> Iterator[Task]: return iter(self._tasks) -class AdaptiveTaskQueue(BaseTaskQueue, ABC): +class AdaptiveTaskQueue(BaseTaskQueue, BenchmarkCallback, ABC): """Abstract base class for adaptive task scheduling. AdaptiveTaskQueue enables dynamic task ordering based on execution results. - It integrates with the benchmark callback system to receive notifications - after each task completes, allowing the queue to update internal state and - adjust the execution order. + It inherits from BenchmarkCallback to receive notifications after each task + completes, allowing the queue to update internal state and adjust the execution + order. Subclasses must implement: - ``_select_next_task()``: Choose the next task to execute @@ -348,7 +324,7 @@ class AdaptiveTaskQueue(BaseTaskQueue, ABC): - ``_stop_flag``: Flag to signal early termination When used with ``Benchmark.run()``, the queue is automatically registered - as a callback if it implements the ``BenchmarkCallback`` interface. + as a callback and receives ``on_task_repeat_end()`` notifications. Example: ```python @@ -360,9 +336,8 @@ def __init__(self, tasks: Iterable[Task]): self._ability_estimate = 0.0 def _select_next_task(self) -> Optional[Task]: - if not self._remaining: - return None # Select task with difficulty closest to current ability estimate + # (No need to check if _remaining is empty - guaranteed by base class) return min( self._remaining, key=lambda t: abs(t.protocol.priority - self._ability_estimate) @@ -395,40 +370,51 @@ def __iter__(self) -> Iterator[Task]: Continues until ``_select_next_task()`` returns None, ``_remaining`` is empty, or ``_stop_flag`` is set. + + Note: ``_select_next_task()`` is only called when ``_remaining`` is + non-empty, so implementers don't need to check for empty list. """ - while self._remaining and not self._stop_flag: + while not self._stop_flag and self._remaining: next_task = self._select_next_task() if next_task is not None: yield next_task else: break - def on_task_complete(self, task: Task, report: Dict[str, Any]) -> None: - """Update state based on task result. + def on_task_repeat_end(self, benchmark: "Benchmark", report: Dict[str, Any]) -> None: + """BenchmarkCallback hook called after each task repetition completes. - Moves the task from ``_remaining`` to ``_completed`` and calls - ``_update_state()`` to let the subclass update its internal model. + This method extracts the task from the report, moves it from + ``_remaining`` to ``_completed``, and calls ``_update_state()`` + to let the subclass update its adaptive model. Args: - task: The task that just completed. - report: The execution report. + benchmark: The benchmark instance (unused in this implementation). + report: The execution report containing task_id and results. """ - # Find and move task from remaining to completed + # Extract task from report + task_id_str = report.get("task_id") + if task_id_str is None: + return + + # Find the task in remaining list + task = None for i, t in enumerate(self._remaining): - if t.id == task.id: - self._completed.append((self._remaining.pop(i), report)) + if str(t.id) == task_id_str: + task = self._remaining.pop(i) + self._completed.append((task, report)) break - # Let subclass update its state - self._update_state(task, report) - - def should_continue(self) -> bool: - """Check if we should continue based on stopping criteria. + # If not found in remaining, check completed (for n_task_repeats > 1) + if task is None: + for t, _ in self._completed: + if str(t.id) == task_id_str: + task = t + break - Returns: - True if stop flag is not set and tasks remain, False otherwise. - """ - return not self._stop_flag and len(self._remaining) > 0 + # Update subclass state + if task is not None: + self._update_state(task, report) def stop(self) -> None: """Signal that no more tasks should be processed. @@ -445,8 +431,14 @@ def _select_next_task(self) -> Optional[Task]: Implement this method to define your adaptive selection algorithm (e.g., IRT-based selection, uncertainty sampling, bandit algorithms). + **Guaranteed precondition**: This method is only called when + ``self._remaining`` is non-empty, so you don't need to check for + an empty list. You can safely assume at least one task is available. + Returns: - The next Task to execute, or None if no suitable task is available. + The next Task to execute from ``self._remaining``, or None to + signal early termination (e.g., if no suitable task meets your + selection criteria). """ pass diff --git a/tests/test_core/test_queue.py b/tests/test_core/test_queue.py index 558c891..cf0d9f0 100644 --- a/tests/test_core/test_queue.py +++ b/tests/test_core/test_queue.py @@ -296,25 +296,25 @@ class TestAdaptiveTaskQueue: """Tests for AdaptiveTaskQueue adaptive behavior.""" def test_basic_iteration_with_completion(self, simple_tasks): - """AdaptiveTaskQueue should yield all tasks when on_task_complete is called.""" + """AdaptiveTaskQueue should yield all tasks when on_task_repeat_end is called.""" queue = ConcreteAdaptiveQueue(simple_tasks) count = 0 for task in queue: count += 1 - # Must call on_task_complete to progress to next task - queue.on_task_complete(task, {"status": "success"}) + # Simulate callback from benchmark + queue.on_task_repeat_end(None, {"task_id": str(task.id), "status": "success"}) # type: ignore[arg-type] assert count == 3 - def test_on_task_complete_moves_to_completed(self, simple_tasks): - """on_task_complete should move task to completed list.""" + def test_on_task_repeat_end_moves_to_completed(self, simple_tasks): + """on_task_repeat_end should move task to completed list.""" queue = ConcreteAdaptiveQueue(simple_tasks) task = next(iter(queue)) assert len(queue._completed) == 0 - queue.on_task_complete(task, {"status": "success"}) + queue.on_task_repeat_end(None, {"task_id": str(task.id), "status": "success"}) # type: ignore[arg-type] assert len(queue._completed) == 1 assert queue._completed[0][0].id == task.id @@ -330,21 +330,22 @@ def test_stop_terminates_iteration(self, simple_tasks): assert len(items) == 1 - def test_should_continue_false_after_stop(self, simple_tasks): - """should_continue() should return False after stop().""" + def test_stop_sets_flag(self, simple_tasks): + """stop() should set the internal stop flag.""" queue = ConcreteAdaptiveQueue(simple_tasks) - assert queue.should_continue() is True + assert queue._stop_flag is False queue.stop() - assert queue.should_continue() is False + assert queue._stop_flag is True - def test_should_continue_false_when_empty(self): - """should_continue() should return False when no pending tasks.""" + def test_iterator_stops_when_empty(self): + """Iterator should stop when no pending tasks.""" queue = ConcreteAdaptiveQueue([]) - assert queue.should_continue() is False + tasks_yielded = list(queue) + assert len(tasks_yielded) == 0 def test_remaining_decreases_after_completion(self, simple_tasks): """Remaining list should shrink as tasks complete.""" @@ -353,7 +354,7 @@ def test_remaining_decreases_after_completion(self, simple_tasks): assert len(queue._remaining) == 3 task = next(iter(queue)) - queue.on_task_complete(task, {"status": "success"}) + queue.on_task_repeat_end(None, {"task_id": str(task.id), "status": "success"}) # type: ignore[arg-type] assert len(queue._remaining) == 2 assert len(queue._completed) == 1 @@ -366,17 +367,9 @@ def test_remaining_decreases_after_completion(self, simple_tasks): class TestQueueCallbacks: """Tests for queue callback mechanisms.""" - def test_on_task_complete_callable_without_error(self, simple_tasks): - """on_task_complete should be callable without error.""" + def test_sequential_queue_iterates_all_tasks(self, simple_tasks): + """SequentialTaskQueue should iterate through all tasks.""" queue = SequentialTaskQueue(simple_tasks) - for task in queue: - # SequentialTaskQueue's on_task_complete is a no-op, but should not raise - queue.on_task_complete(task, {"status": "success"}) - - def test_should_continue_always_true_for_sequential(self, simple_tasks): - """SequentialTaskQueue should always return True for should_continue.""" - queue = SequentialTaskQueue(simple_tasks) - - for task in queue: - assert queue.should_continue() is True + tasks_yielded = list(queue) + assert len(tasks_yielded) == len(simple_tasks) From b719c5a3c993893e371f8594556027d6f11040dd Mon Sep 17 00:00:00 2001 From: cemde Date: Sat, 6 Dec 2025 09:46:37 +0000 Subject: [PATCH 08/25] type hinting fixes --- maseval/core/task.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/maseval/core/task.py b/maseval/core/task.py index 3660140..83c6166 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -310,15 +310,21 @@ class AdaptiveTaskQueue(BaseTaskQueue, BenchmarkCallback, ABC): """Abstract base class for adaptive task scheduling. AdaptiveTaskQueue enables dynamic task ordering based on execution results. - It inherits from BenchmarkCallback to receive notifications after each task - completes, allowing the queue to update internal state and adjust the execution - order. + It inherits from BenchmarkCallback to integrate with the benchmark's callback + system, creating a clean bidirectional communication model: + + - **Benchmark → Queue**: Via iterator protocol (``for task in queue``) + - **Queue → Benchmark**: Via callback (``on_task_repeat_end()``) + + The queue automatically moves completed tasks from ``_remaining`` to + ``_completed`` and calls ``_update_state()`` to let subclasses adapt their + scheduling strategy based on task results. Subclasses must implement: - ``_select_next_task()``: Choose the next task to execute - ``_update_state()``: Update internal model after task completion - The queue maintains: + Internal state: - ``_remaining``: Tasks not yet executed - ``_completed``: Completed tasks paired with their reports - ``_stop_flag``: Flag to signal early termination @@ -421,6 +427,10 @@ def stop(self) -> None: Call this from ``_update_state()`` or ``_select_next_task()`` to trigger early termination (e.g., when confidence threshold is reached). + + The ``_stop_flag`` is checked in ``__iter__``, which will stop yielding + tasks and naturally terminate the benchmark's iteration loop via Python's + iterator protocol. """ self._stop_flag = True From ea4b00c52e45bdd8899f8baa7b04f6d0b8cd387d Mon Sep 17 00:00:00 2001 From: cemde Date: Sat, 6 Dec 2025 10:19:43 +0000 Subject: [PATCH 09/25] fixed tests --- maseval/core/benchmark.py | 54 ++- .../test_callback_error_handling.py | 429 ++++++++++++++++++ .../test_benchmark/test_parallel_execution.py | 21 +- 3 files changed, 489 insertions(+), 15 deletions(-) create mode 100644 tests/test_core/test_benchmark/test_callback_error_handling.py diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index 02a3210..23392a1 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -5,6 +5,7 @@ from enum import Enum import warnings import traceback +import logging from .evaluator import Evaluator from .task import Task, BaseTaskQueue, SequentialTaskQueue @@ -410,22 +411,53 @@ def collect_all_configs(self) -> Dict[str, Any]: """ return self._registry.collect_configs() - def _invoke_callbacks(self, method_name: str, *args, **kwargs) -> None: + def _invoke_callbacks( + self, method_name: str, *args, suppress_errors: bool = True, **kwargs + ) -> List[Exception]: """Invoke a callback method on all registered callbacks (thread-safe). This method serializes all callback invocations using an internal lock, so users don't need to implement thread-safe callbacks. + Callback errors are caught and logged by default to prevent one failing + callback from disrupting the entire benchmark run. This is especially + important in parallel execution where callback failures could otherwise + cause difficult-to-debug issues. + Args: method_name: Name of the callback method to invoke (e.g., "on_task_start"). *args: Positional arguments to pass to the callback method. + suppress_errors: If True (default), catch and log callback errors instead + of propagating them. If False, first callback error will be raised. **kwargs: Keyword arguments to pass to the callback method. + + Returns: + List of exceptions that occurred during callback invocation (empty if none). + + Raises: + Exception: First callback exception if suppress_errors=False. """ + errors: List[Exception] = [] + logger = logging.getLogger(__name__) + with self._callback_lock: for cb in self.callbacks: method = getattr(cb, method_name, None) if method is not None: - method(*args, **kwargs) + try: + method(*args, **kwargs) + except Exception as e: + if not suppress_errors: + raise + + # Log error with full context + logger.error( + f"Callback {cb.__class__.__name__}.{method_name}() failed: {e}", + exc_info=True, + ) + errors.append(e) + + return errors def _append_report_safe(self, report: Dict[str, Any]) -> None: """Append a report to the reports list (thread-safe). @@ -1212,18 +1244,20 @@ def submit_task_repeats(task: Task) -> None: # Submit initial batch from queue submitted_tasks: List[Task] = [] - for task in queue: - submit_task_repeats(task) - submitted_tasks.append(task) + queue_iter = iter(queue) # Create iterator once + queue_exhausted = False - # Limit initial submission to avoid over-committing - if len(futures) >= max_workers * 2: - break + # Submit initial batch + try: + while len(futures) < max_workers * 2: + task = next(queue_iter) + submit_task_repeats(task) + submitted_tasks.append(task) + except StopIteration: + queue_exhausted = True # Process completions completed_task_ids: set = set() - queue_iter = iter(queue) - queue_exhausted = len(submitted_tasks) >= len(queue) while futures: # Wait for at least one completion diff --git a/tests/test_core/test_benchmark/test_callback_error_handling.py b/tests/test_core/test_benchmark/test_callback_error_handling.py new file mode 100644 index 0000000..3fb0d68 --- /dev/null +++ b/tests/test_core/test_benchmark/test_callback_error_handling.py @@ -0,0 +1,429 @@ +"""Tests for callback error handling in benchmark execution. + +These tests verify that callback exceptions are properly isolated and logged, +preventing one failing callback from disrupting the entire benchmark run. +""" + +import pytest +import logging +from typing import List + +from maseval import ( + BenchmarkCallback, + Task, + TaskQueue, +) +from conftest import DummyBenchmark + + +# ==================== Test Fixtures ==================== + + +class FailingCallback(BenchmarkCallback): + """Callback that raises an exception on specific methods.""" + + def __init__(self, fail_on: str = "on_task_start"): + self.fail_on = fail_on + self.call_count = 0 + + def on_task_start(self, benchmark, task): + if self.fail_on == "on_task_start": + raise RuntimeError("Intentional failure in on_task_start") + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + if self.fail_on == "on_task_repeat_start": + raise ValueError("Intentional failure in on_task_repeat_start") + + def on_task_repeat_end(self, benchmark, report): + self.call_count += 1 + if self.fail_on == "on_task_repeat_end": + raise TypeError("Intentional failure in on_task_repeat_end") + + def on_task_end(self, benchmark, task, result): + if self.fail_on == "on_task_end": + raise KeyError("Intentional failure in on_task_end") + + +class TrackingCallback(BenchmarkCallback): + """Callback that tracks which methods were called.""" + + def __init__(self): + self.calls: List[str] = [] + + def on_run_start(self, benchmark): + self.calls.append("run_start") + + def on_task_start(self, benchmark, task): + self.calls.append(f"task_start:{task.query}") + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + self.calls.append(f"repeat_start:{task.query}:{repeat_idx}") + + def on_task_repeat_end(self, benchmark, report): + self.calls.append(f"repeat_end:{report['task_id'][:8]}") + + def on_task_end(self, benchmark, task, result): + self.calls.append(f"task_end:{task.query}") + + def on_run_end(self, benchmark, results): + self.calls.append("run_end") + + +@pytest.fixture +def simple_tasks(): + """Create simple tasks for testing.""" + return TaskQueue.from_list([ + {"query": "Task 1", "environment_data": {}}, + {"query": "Task 2", "environment_data": {}}, + ]) + + +# ==================== Error Suppression Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorSuppression: + """Tests for callback error suppression in sequential execution.""" + + def test_failing_callback_does_not_stop_execution(self, simple_tasks, caplog): + """A failing callback should not prevent benchmark from completing.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_start") + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb], + ) + + # Should complete despite callback failure + reports = benchmark.run(simple_tasks) + + assert len(reports) == 2 + assert all(r["status"] == "success" for r in reports) + + # Error should be logged + assert "Callback" in caplog.text + assert "on_task_start" in caplog.text + assert "Intentional failure" in caplog.text + + def test_multiple_callbacks_one_fails_others_continue(self, simple_tasks, caplog): + """Other callbacks should continue even if one fails.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb, tracking_cb], + ) + + reports = benchmark.run(simple_tasks) + + # Execution completes + assert len(reports) == 2 + + # Tracking callback still received all events + assert "run_start" in tracking_cb.calls + assert "task_start:Task 1" in tracking_cb.calls + assert "task_start:Task 2" in tracking_cb.calls + assert "run_end" in tracking_cb.calls + + # Error logged + assert "on_task_repeat_end" in caplog.text + + def test_callback_fails_on_every_task(self, simple_tasks, caplog): + """Execution continues even if callback fails on every task.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks) + + # All tasks complete + assert len(reports) == 2 + assert all(r["status"] == "success" for r in reports) + + # Callback was attempted for each task (even though it failed) + assert failing_cb.call_count == 2 + + def test_callback_error_in_on_run_start(self, simple_tasks, caplog): + """Benchmark continues if callback fails in on_run_start.""" + caplog.set_level(logging.ERROR) + + class RunStartFailer(BenchmarkCallback): + def on_run_start(self, benchmark): + raise RuntimeError("Failed at run start") + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[RunStartFailer()], + ) + + reports = benchmark.run(simple_tasks) + + assert len(reports) == 2 + assert "on_run_start" in caplog.text + + def test_callback_error_in_on_run_end(self, simple_tasks, caplog): + """Benchmark completes and logs error if on_run_end fails.""" + caplog.set_level(logging.ERROR) + + class RunEndFailer(BenchmarkCallback): + def on_run_end(self, benchmark, results): + raise RuntimeError("Failed at run end") + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[RunEndFailer()], + ) + + reports = benchmark.run(simple_tasks) + + # Reports are generated (run_end happens after report collection) + assert len(reports) == 2 + assert "on_run_end" in caplog.text + + +# ==================== Parallel Execution Error Handling Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorHandlingParallel: + """Tests for callback error handling in parallel execution.""" + + def test_callback_error_in_parallel_execution(self, simple_tasks, caplog): + """Callback errors in parallel execution should not crash workers.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb, tracking_cb], + ) + + # Run in parallel + reports = benchmark.run(simple_tasks, max_workers=2) + + # All tasks complete + assert len(reports) == 2 + assert all(r["status"] == "success" for r in reports) + + # Tracking callback still received events + assert len(tracking_cb.calls) > 0 + + # Errors logged + assert "on_task_repeat_end" in caplog.text + + def test_multiple_parallel_tasks_with_failing_callback(self, caplog): + """Callback failures should not interfere across parallel workers.""" + caplog.set_level(logging.ERROR) + + tasks = TaskQueue.from_list([ + {"query": f"Task {i}", "environment_data": {}} + for i in range(5) + ]) + + failing_cb = FailingCallback(fail_on="on_task_repeat_start") + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb], + ) + + reports = benchmark.run(tasks, max_workers=3) + + # All 5 tasks complete despite callback failures + assert len(reports) == 5 + assert all(r["status"] == "success" for r in reports) + + # Multiple errors logged (one per task) + error_count = caplog.text.count("on_task_repeat_start") + assert error_count >= 5 + + +# ==================== Error Return Value Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorReturnValues: + """Tests for _invoke_callbacks error return values.""" + + def test_invoke_callbacks_returns_empty_list_on_success(self, simple_tasks): + """_invoke_callbacks should return empty list when no errors occur.""" + tracking_cb = TrackingCallback() + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[tracking_cb], + ) + + # Manually invoke callbacks to test return value + errors = benchmark._invoke_callbacks("on_run_start", benchmark) + + assert errors == [] + assert "run_start" in tracking_cb.calls + + def test_invoke_callbacks_returns_error_list_on_failure(self, simple_tasks): + """_invoke_callbacks should return list of exceptions when callbacks fail.""" + failing_cb1 = FailingCallback(fail_on="on_task_start") + failing_cb2 = FailingCallback(fail_on="on_task_start") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb1, tracking_cb, failing_cb2], + ) + + task = Task(query="Test", environment_data={}) + errors = benchmark._invoke_callbacks("on_task_start", benchmark, task) + + # Two errors returned (from two failing callbacks) + assert len(errors) == 2 + assert all(isinstance(e, RuntimeError) for e in errors) + + # Tracking callback still ran + assert "task_start:Test" in tracking_cb.calls + + def test_invoke_callbacks_with_suppress_false_raises(self, simple_tasks): + """With suppress_errors=False, first exception should be raised.""" + failing_cb = FailingCallback(fail_on="on_task_start") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb, tracking_cb], + ) + + task = Task(query="Test", environment_data={}) + + with pytest.raises(RuntimeError, match="Intentional failure"): + benchmark._invoke_callbacks( + "on_task_start", + benchmark, + task, + suppress_errors=False, + ) + + # Tracking callback was not called (execution stopped at first error) + assert len(tracking_cb.calls) == 0 + + +# ==================== Different Exception Types Tests ==================== + + +@pytest.mark.core +class TestCallbackExceptionTypes: + """Tests for handling different exception types from callbacks.""" + + def test_value_error_in_callback(self, simple_tasks, caplog): + """ValueError from callback should be caught and logged.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_start") + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks) + + assert len(reports) == 2 + assert "ValueError" in caplog.text + + def test_type_error_in_callback(self, simple_tasks, caplog): + """TypeError from callback should be caught and logged.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks) + + assert len(reports) == 2 + assert "TypeError" in caplog.text + + def test_key_error_in_callback(self, simple_tasks, caplog): + """KeyError from callback should be caught and logged.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_end") + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks) + + assert len(reports) == 2 + assert "KeyError" in caplog.text + + +# ==================== Integration Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorHandlingIntegration: + """Integration tests for callback error handling with real scenarios.""" + + def test_failing_callback_with_repeats(self, caplog): + """Callback errors should be handled correctly with n_task_repeats > 1.""" + caplog.set_level(logging.ERROR) + + tasks = TaskQueue.from_list([{"query": "Task", "environment_data": {}}]) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + n_task_repeats=3, + callbacks=[failing_cb], + ) + + reports = benchmark.run(tasks) + + # 3 reports (one per repeat) + assert len(reports) == 3 + + # Callback attempted 3 times + assert failing_cb.call_count == 3 + + # 3 errors logged + error_count = caplog.text.count("on_task_repeat_end") + assert error_count >= 3 + + def test_mixed_callbacks_some_fail_some_succeed(self, simple_tasks, caplog): + """Mixed scenario: some callbacks fail, others succeed.""" + caplog.set_level(logging.ERROR) + + failing_cb1 = FailingCallback(fail_on="on_task_start") + tracking_cb1 = TrackingCallback() + failing_cb2 = FailingCallback(fail_on="on_task_repeat_end") + tracking_cb2 = TrackingCallback() + + benchmark = DummyBenchmark( + agent_data={"model": "test"}, + callbacks=[failing_cb1, tracking_cb1, failing_cb2, tracking_cb2], + ) + + reports = benchmark.run(simple_tasks) + + # Execution completes + assert len(reports) == 2 + + # Both tracking callbacks received all events + assert len(tracking_cb1.calls) > 0 + assert len(tracking_cb2.calls) > 0 + assert tracking_cb1.calls == tracking_cb2.calls + + # Both types of errors logged + assert "on_task_start" in caplog.text + assert "on_task_repeat_end" in caplog.text diff --git a/tests/test_core/test_benchmark/test_parallel_execution.py b/tests/test_core/test_benchmark/test_parallel_execution.py index 186e908..0e72da3 100644 --- a/tests/test_core/test_benchmark/test_parallel_execution.py +++ b/tests/test_core/test_benchmark/test_parallel_execution.py @@ -231,8 +231,14 @@ def on_task_repeat_end(self, benchmark, report): statuses = {d["status"] for d in received_data} assert statuses == {"success"} - def test_callback_exception_propagates(self, parallel_tasks): - """Callback exceptions propagate (current behavior).""" + def test_callback_exceptions_suppressed_by_default(self): + """Callback exceptions are suppressed by default to prevent disruption.""" + # Create fresh tasks for this test + tasks = TaskQueue.from_list([ + {"query": f"Task {i}", "environment_data": {}} + for i in range(5) + ]) + call_count = [0] class FailingCallback(BenchmarkCallback): @@ -246,9 +252,14 @@ def on_task_repeat_end(self, benchmark, report): callbacks=[FailingCallback()], ) - # Current behavior: callback exceptions propagate - with pytest.raises(RuntimeError, match="Intentional failure"): - benchmark.run(parallel_tasks, max_workers=2) + # New behavior: callback exceptions are suppressed by default + # This prevents one failing callback from disrupting parallel execution + reports = benchmark.run(tasks, max_workers=2) + + # Execution completes despite callback failure + assert len(reports) == 5 + # Callback was called multiple times (not stopped at failure) + assert call_count[0] >= 2 # ==================== Concurrency Verification Tests ==================== From 0d933e5cc7448a8b3cf43dfc76f587c90eb5eadb Mon Sep 17 00:00:00 2001 From: cemde Date: Sat, 6 Dec 2025 12:17:52 +0000 Subject: [PATCH 10/25] updated tests --- maseval/benchmark/macs/macs.py | 8 +- maseval/core/benchmark.py | 123 +++++++++--------- tests/conftest.py | 9 +- tests/test_benchmarks/test_macs/conftest.py | 7 +- .../test_macs/test_macs_benchmark.py | 44 +++---- .../test_macs/test_macs_integration.py | 20 +-- .../test_automatic_registration.py | 14 +- .../test_benchmark_lifecycle.py | 82 ++++++------ .../test_callback_error_handling.py | 54 +++----- .../test_callback_orchestration.py | 20 ++- .../test_benchmark/test_config_collection.py | 44 +++---- .../test_benchmark/test_execution_loop.py | 30 ++--- .../test_benchmark/test_parallel_execution.py | 63 ++++----- .../test_progress_bar_integration.py | 21 ++- .../test_benchmark/test_trace_collection.py | 36 ++--- tests/test_core/test_evaluator.py | 24 ++-- tests/test_core/test_exceptions.py | 24 ++-- 17 files changed, 289 insertions(+), 334 deletions(-) diff --git a/maseval/benchmark/macs/macs.py b/maseval/benchmark/macs/macs.py index abda818..5ab834d 100644 --- a/maseval/benchmark/macs/macs.py +++ b/maseval/benchmark/macs/macs.py @@ -36,8 +36,8 @@ def get_model_adapter(self, model_id, **kwargs): return adapter # Run - benchmark = MyMACSBenchmark(agent_data=agent_config) - results = benchmark.run(tasks) + benchmark = MyMACSBenchmark() + results = benchmark.run(tasks, agent_data=agent_config) """ import json @@ -691,7 +691,6 @@ class MACSBenchmark(Benchmark): def __init__( self, - agent_data: Dict[str, Any], callbacks: Optional[List[Any]] = None, n_task_repeats: int = 1, max_invocations: int = 5, @@ -700,12 +699,11 @@ def __init__( """Initialize benchmark. Args: - agent_data: Agent configuration from load_agent_config(). callbacks: Benchmark callbacks n_task_repeats: Repetitions per task max_invocations: Maximum agent-user interaction rounds (default: 5 per MACS paper) """ - super().__init__(agent_data, callbacks, n_task_repeats, max_invocations, **kwargs) + super().__init__(callbacks=callbacks, n_task_repeats=n_task_repeats, max_invocations=max_invocations, **kwargs) def _get_tool_model_id(self, task: Task) -> str: """Get tool simulator model ID from task.environment_data. diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index 23392a1..ea21be8 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -99,21 +99,26 @@ def run_agents(self, agents, task, environment, query): # ... implement other abstract methods # Run the benchmark - benchmark = MyBenchmark(agent_data=config) - reports = benchmark.run(tasks=my_tasks) + config = {"model": "gpt-4", "temperature": 0.7} + benchmark = MyBenchmark() + reports = benchmark.run(tasks=my_tasks, agent_data=config) # Retry failed tasks elegantly (graceful task failure handling by default) failed_tasks = benchmark.get_failed_tasks() if len(failed_tasks) > 0: - retry_reports = benchmark.run(tasks=failed_tasks) + retry_reports = benchmark.run(tasks=failed_tasks, agent_data=config) + + # Parallel execution for I/O-bound workloads + benchmark = MyBenchmark(max_workers=4) + reports = benchmark.run(tasks=my_tasks, agent_data=config) # Or use strict mode for debugging (fail fast) benchmark = MyBenchmark( - agent_data=config, fail_on_task_error=True, fail_on_evaluation_error=True, fail_on_setup_error=True ) + reports = benchmark.run(tasks=my_tasks, agent_data=config) ``` The framework handles task iteration, repetitions for statistical robustness, callback @@ -123,21 +128,18 @@ def run_agents(self, agents, task, environment, query): def __init__( self, - agent_data: Dict[str, Any] | Iterable[Dict[str, Any]], callbacks: Optional[List[BenchmarkCallback]] = None, n_task_repeats: int = 1, max_invocations: int = 1, + max_workers: int = 1, fail_on_setup_error: bool = False, fail_on_task_error: bool = False, fail_on_evaluation_error: bool = False, progress_bar: bool | str = True, ): - """Initialize a benchmark with agent configurations. + """Initialize a benchmark with execution configuration. Args: - agent_data: Configuration for agents. Either a single dict applied to all tasks, or - an iterable of dicts with one configuration per task. Agent data typically includes - model parameters, agent architecture details, and tool specifications. callbacks: Optional list of callback handlers for monitoring execution, tracing messages, or collecting custom metrics during the benchmark run. n_task_repeats: Number of times to repeat each task. Useful for measuring variance in @@ -146,6 +148,9 @@ def __init__( For simple benchmarks, the default (1) means agents run once per task. For interactive benchmarks with user feedback loops, set higher (e.g., 5 for MACS) to allow multiple agent-user interaction rounds. + max_workers: Maximum number of parallel task executions. Default 1 (sequential). + Set higher for I/O-bound workloads (e.g., LLM API calls). This controls the + ThreadPoolExecutor worker count for concurrent task processing. fail_on_setup_error: If True, raise exceptions when setup fails (environment, agents, evaluators). If False (default), catch exceptions during setup and record them in the report with status SETUP_FAILED. This allows the benchmark to continue running remaining tasks even if setup fails. @@ -168,52 +173,32 @@ def __init__( ValueError: If n_task_repeats is less than 1. How to use: - Provide either a single agent configuration for all tasks, or task-specific configurations: + Configure execution settings at initialization: ```python - # Single config for all tasks - benchmark = MyBenchmark(agent_data={"model": "gpt-4", "temperature": 0.7}) + # Sequential execution (default) + benchmark = MyBenchmark() - # Task-specific configs (will be validated in run() based on task count) - benchmark = MyBenchmark( - agent_data=[ - {"model": "gpt-4", "config": "easy"}, - {"model": "gpt-4", "config": "hard"} - ] - ) - - # Enable failure-safe execution (default behavior) - benchmark = MyBenchmark( - agent_data=config, - fail_on_task_error=False, # Continue on task failures - fail_on_evaluation_error=False # Continue on evaluation failures - ) + # Parallel execution for faster I/O-bound workloads + benchmark = MyBenchmark(max_workers=4) # Strict mode - fail fast on any error (useful for debugging) benchmark = MyBenchmark( - agent_data=config, fail_on_task_error=True, fail_on_evaluation_error=True, fail_on_setup_error=True ) - # Progress bar configuration (automatically adds a callback) - benchmark = MyBenchmark(agent_data=config) # Default: adds TqdmProgressBarCallback - benchmark = MyBenchmark(agent_data=config, progress_bar=True) # Explicit: TqdmProgressBarCallback - benchmark = MyBenchmark(agent_data=config, progress_bar="tqdm") # Same as True - benchmark = MyBenchmark(agent_data=config, progress_bar="rich") # Uses RichProgressBarCallback - benchmark = MyBenchmark(agent_data=config, progress_bar=False) # No automatic callback + # Progress bar configuration + benchmark = MyBenchmark() # Default: adds TqdmProgressBarCallback + benchmark = MyBenchmark(progress_bar=True) # Explicit: TqdmProgressBarCallback + benchmark = MyBenchmark(progress_bar="rich") # Uses RichProgressBarCallback + benchmark = MyBenchmark(progress_bar=False) # No automatic callback - # Progress bar configuration (manually add a callback) - benchmark = MyBenchmark( - agent_data=config, - callbacks=[MyCustomProgressBarCallback()] # User-defined progress bar - ) + # Custom callbacks + benchmark = MyBenchmark(callbacks=[MyCustomProgressBarCallback()]) ``` """ - # Store agent_data as-is (will be normalized in run()) - self.agent_data = agent_data - # Initialize tasks to empty queue (will be set in run()) self.tasks: BaseTaskQueue = SequentialTaskQueue([]) @@ -238,8 +223,9 @@ def __init__( if self.n_task_repeats < 1: raise ValueError("n_task_repeats must be at least 1") - # Execution loop configuration + # Execution configuration self.max_invocations = max_invocations + self.max_workers = max_workers # Failure handling configuration self.fail_on_task_error = fail_on_task_error @@ -411,9 +397,7 @@ def collect_all_configs(self) -> Dict[str, Any]: """ return self._registry.collect_configs() - def _invoke_callbacks( - self, method_name: str, *args, suppress_errors: bool = True, **kwargs - ) -> List[Exception]: + def _invoke_callbacks(self, method_name: str, *args, suppress_errors: bool = True, **kwargs) -> List[Exception]: """Invoke a callback method on all registered callbacks (thread-safe). This method serializes all callback invocations using an internal lock, @@ -1318,7 +1302,7 @@ def submit_task_repeats(task: Task) -> None: def run( self, tasks: Union[Task, BaseTaskQueue, Iterable[Union[Task, dict]]], - max_workers: int = 1, + agent_data: Dict[str, Any] | Iterable[Dict[str, Any]], ) -> List[Dict[str, Any]]: """Initialize and execute the complete benchmark loop across all tasks. @@ -1331,8 +1315,9 @@ def run( When a BaseTaskQueue is provided, it controls the task ordering. AdaptiveTaskQueue subclasses are automatically registered as callbacks to receive task completion notifications. - max_workers: Maximum number of parallel task executions. Default 1 (sequential). - Set higher for I/O-bound workloads (e.g., LLM API calls). + agent_data: Configuration for agents. Either a single dict applied to all tasks, or + an iterable of dicts with one configuration per task. Agent data typically includes + model parameters, agent architecture details, and tool specifications. Returns: List of report dictionaries, one per task repetition. Each report contains: @@ -1401,8 +1386,8 @@ def run( ```python # Typical usage - benchmark = MyBenchmark(agent_data=config) - reports = benchmark.run(tasks=tasks) + benchmark = MyBenchmark() + reports = benchmark.run(tasks=tasks, agent_data=config) # Analyze results for report in reports: @@ -1411,14 +1396,27 @@ def run( print(f"Traces: {report['traces']}") # Parallel execution with 4 workers - reports = benchmark.run(tasks=tasks, max_workers=4) + benchmark = MyBenchmark(max_workers=4) + reports = benchmark.run(tasks=tasks, agent_data=config) + + # Single agent config for all tasks + reports = benchmark.run(tasks=tasks, agent_data={"model": "gpt-4"}) + + # Task-specific agent configs (must match task count) + reports = benchmark.run( + tasks=tasks, + agent_data=[ + {"model": "gpt-4", "difficulty": "easy"}, + {"model": "gpt-4", "difficulty": "hard"}, + ] + ) # Priority-based execution from maseval.core.task import PriorityTaskQueue for task in tasks: task.protocol.priority = compute_priority(task) queue = PriorityTaskQueue(tasks) - reports = benchmark.run(tasks=queue) + reports = benchmark.run(tasks=queue, agent_data=config) # Adaptive queue (auto-registered as callback) queue = MyAdaptiveTaskQueue(tasks) @@ -1441,7 +1439,7 @@ def run( self.tasks = queue # Build agent_data lookup (task_id -> agent_data) - agent_data_lookup = self._build_agent_data_lookup(queue) + agent_data_lookup = self._build_agent_data_lookup(queue, agent_data) # Clear reports from previous run() calls to prevent accumulation self.reports = [] @@ -1457,10 +1455,10 @@ def run( self._invoke_callbacks("on_run_start", self) # Execute based on max_workers - if max_workers == 1: + if self.max_workers == 1: self._run_sequential(queue, agent_data_lookup) else: - self._run_parallel(queue, agent_data_lookup, max_workers) + self._run_parallel(queue, agent_data_lookup, self.max_workers) # Callbacks at the end of the run self._invoke_callbacks("on_run_end", self, self.reports) @@ -1471,11 +1469,14 @@ def run( return self.reports - def _build_agent_data_lookup(self, tasks: BaseTaskQueue) -> Dict[str, Dict[str, Any]]: + def _build_agent_data_lookup( + self, tasks: BaseTaskQueue, agent_data: Dict[str, Any] | Iterable[Dict[str, Any]] + ) -> Dict[str, Dict[str, Any]]: """Build a mapping from task_id to agent_data configuration. Args: tasks: The task queue containing all tasks. + agent_data: Agent configuration(s) to map to tasks. Returns: Dict mapping task_id (string) to agent_data configuration. @@ -1483,12 +1484,12 @@ def _build_agent_data_lookup(self, tasks: BaseTaskQueue) -> Dict[str, Dict[str, Raises: ValueError: If agent_data is a list but doesn't match the number of tasks. """ - if isinstance(self.agent_data, dict): + if isinstance(agent_data, dict): # Single config - replicate for all tasks - return {str(task.id): cast(Dict[str, Any], self.agent_data) for task in tasks} + return {str(task.id): cast(Dict[str, Any], agent_data) for task in tasks} # List of configs - pair by position - agent_data_list = list(self.agent_data) + agent_data_list = list(agent_data) if len(agent_data_list) != len(tasks): raise ValueError( f"`agent_data` must either be a single dict or an iterable matching the number of tasks. " @@ -1527,8 +1528,8 @@ def get_failed_tasks( How to use: ```python # Run benchmark - benchmark = MyBenchmark(agent_data=config) - reports = benchmark.run(tasks=tasks) + benchmark = MyBenchmark() + reports = benchmark.run(tasks=tasks, agent_data=config) # Get all failed tasks (from internal state) failed = benchmark.get_failed_tasks() diff --git a/tests/conftest.py b/tests/conftest.py index 090a675..8fdb137 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -319,13 +319,14 @@ def dummy_task_queue(): @pytest.fixture def simple_benchmark(dummy_task_queue): - """Create a simple benchmark instance with tasks. + """Create a simple benchmark instance with tasks and agent_data. Returns: - tuple: (benchmark, tasks) - Call as benchmark.run(tasks) + tuple: (benchmark, tasks, agent_data) - Call as benchmark.run(tasks, agent_data=agent_data) """ - benchmark = DummyBenchmark(agent_data={"model": "test"}) - return benchmark, dummy_task_queue + benchmark = DummyBenchmark() + agent_data = {"model": "test"} + return benchmark, dummy_task_queue, agent_data @pytest.fixture diff --git a/tests/test_benchmarks/test_macs/conftest.py b/tests/test_benchmarks/test_macs/conftest.py index 8a061ad..081706b 100644 --- a/tests/test_benchmarks/test_macs/conftest.py +++ b/tests/test_benchmarks/test_macs/conftest.py @@ -97,14 +97,12 @@ class ConcreteMACSBenchmark(MACSBenchmark): def __init__( self, - agent_data: Dict[str, Any], model_factory: Optional[Any] = None, **kwargs: Any, ): """Initialize with optional model factory. Args: - agent_data: Agent configuration model_factory: Either a callable that takes a model name and returns a ModelAdapter, or a single ModelAdapter instance (for convenience in simple tests). If not provided, creates DummyModelAdapter instances. @@ -121,7 +119,7 @@ def __init__( else: # Single model instance - create a factory that always returns it self._model_factory = lambda model_name: model_factory - super().__init__(agent_data, **kwargs) + super().__init__(**kwargs) def get_model_adapter(self, model_id: str, **kwargs): """Create a model adapter for the given component. @@ -589,5 +587,6 @@ def macs_benchmark(sample_agent_data, dummy_model): """Create a MACS benchmark with dummy model for testing. Uses dummy_model from parent conftest.py. + Returns tuple (benchmark, agent_data) for use with run(). """ - return ConcreteMACSBenchmark(sample_agent_data, dummy_model) + return ConcreteMACSBenchmark(dummy_model), sample_agent_data diff --git a/tests/test_benchmarks/test_macs/test_macs_benchmark.py b/tests/test_benchmarks/test_macs/test_macs_benchmark.py index eeceddc..56e3896 100644 --- a/tests/test_benchmarks/test_macs/test_macs_benchmark.py +++ b/tests/test_benchmarks/test_macs/test_macs_benchmark.py @@ -27,11 +27,11 @@ class TestMACSBenchmarkSetup: """Tests for MACSBenchmark initialization and setup methods.""" def test_init_configures_benchmark(self, macs_model, sample_agent_data): - """Benchmark initializes with agent_data and optional params.""" + """Benchmark initializes with optional params.""" callbacks = [MagicMock()] - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model, callbacks=callbacks, n_task_repeats=3) + benchmark = ConcreteMACSBenchmark(macs_model, callbacks=callbacks, n_task_repeats=3) - assert benchmark.agent_data == sample_agent_data + # agent_data is now passed to run(), not __init__ assert benchmark.callbacks == callbacks assert benchmark.n_task_repeats == 3 @@ -41,13 +41,13 @@ def test_macs_default_max_invocations_is_five(self, macs_model, sample_agent_dat This is a MACS-specific default that differs from the base class default of 1. The MACS paper specifies up to 5 agent-user interaction rounds. """ - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) assert benchmark.max_invocations == 5 def test_setup_environment_creates_macs_environment(self, macs_model, sample_agent_data, sample_task): """setup_environment returns MACSEnvironment with tools.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) @@ -56,7 +56,7 @@ def test_setup_environment_creates_macs_environment(self, macs_model, sample_age def test_setup_user_creates_macs_user(self, macs_model, sample_agent_data, sample_task): """setup_user returns MACSUser with scenario from task.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) user = benchmark.setup_user(sample_agent_data, env, sample_task) @@ -66,7 +66,7 @@ def test_setup_user_creates_macs_user(self, macs_model, sample_agent_data, sampl def test_setup_user_handles_no_scenario(self, macs_model, sample_agent_data, sample_task_no_scenario): """Handles missing scenario gracefully.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task_no_scenario) user = benchmark.setup_user(sample_agent_data, env, sample_task_no_scenario) @@ -75,7 +75,7 @@ def test_setup_user_handles_no_scenario(self, macs_model, sample_agent_data, sam def test_setup_evaluators_creates_user_and_system(self, macs_model, sample_agent_data, sample_task): """Creates both user and system evaluators.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents = [MACSAgentAdapter()] @@ -114,7 +114,7 @@ class TestRunAgents: def test_run_agents_executes_agents_with_query(self, macs_model, sample_agent_data, sample_task): """Agents are executed with the query parameter.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -134,7 +134,7 @@ def test_run_agents_uses_query_parameter_not_task_query(self, macs_model, sample This is critical for multi-turn interaction where the query changes between invocations (e.g., user's response becomes the next query). """ - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -150,7 +150,7 @@ def test_run_agents_uses_query_parameter_not_task_query(self, macs_model, sample def test_run_agents_returns_answer(self, macs_model, sample_agent_data, sample_task): """Returns final answer(s) as MessageHistory.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -164,7 +164,7 @@ def test_run_agents_returns_answer(self, macs_model, sample_agent_data, sample_t def test_run_agents_single_agent(self, macs_model, sample_agent_data, sample_task): """Single agent returns MessageHistory.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -176,9 +176,9 @@ def test_run_agents_multiple_agents(self, macs_model, sample_agent_data, sample_ """Multiple agents return list of answers.""" class MultiAgentBenchmark(MACSBenchmark): - def __init__(self, agent_data, model_factory, **kwargs): + def __init__(self, model_factory, **kwargs): self._model_factory = model_factory if callable(model_factory) else lambda _: model_factory - super().__init__(agent_data, **kwargs) + super().__init__(**kwargs) def get_model_adapter(self, model_id: str, **kwargs): return self._model_factory(model_id) @@ -194,7 +194,7 @@ def setup_agents( agent2: AgentAdapter = MACSAgentAdapter("agent2") return [agent1, agent2], {"agent1": agent1, "agent2": agent2} - benchmark = MultiAgentBenchmark(sample_agent_data, macs_model) + benchmark = MultiAgentBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -221,7 +221,7 @@ def test_evaluate_calls_both_evaluators(self, sample_agent_data, sample_task): '[{"assertion": "System assertion", "answer": "TRUE", "evidence": "OK"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -251,7 +251,7 @@ def test_evaluate_returns_aggregated_metrics(self, sample_agent_data, sample_tas '[{"assertion": "B", "answer": "TRUE", "evidence": "OK"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -281,7 +281,7 @@ def test_evaluate_overall_gsr(self, sample_agent_data, sample_task): '[{"assertion": "B", "answer": "FALSE", "evidence": "Fail"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -305,7 +305,7 @@ def test_evaluate_supervisor_gsr(self, sample_agent_data, sample_task): '[{"assertion": "B", "answer": "FALSE", "evidence": "Fail"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -327,7 +327,7 @@ def test_evaluate_combined_report(self, sample_agent_data, sample_task): '[{"assertion": "System B", "answer": "TRUE", "evidence": "OK"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -612,7 +612,7 @@ def test_full_task_execution(self, sample_agent_data, sample_task): '[{"assertion": "Database updated", "answer": "TRUE", "evidence": "Updated"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) # Setup phase env = benchmark.setup_environment(sample_agent_data, sample_task) @@ -646,7 +646,7 @@ def test_full_task_execution(self, sample_agent_data, sample_task): def test_benchmark_with_real_environment(self, sample_agent_data, sample_task): """Test with real MACSEnvironment tool creation.""" model = DummyModelAdapter(responses=['{"text": "Default response", "details": {}}']) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) diff --git a/tests/test_benchmarks/test_macs/test_macs_integration.py b/tests/test_benchmarks/test_macs/test_macs_integration.py index 3bb38a2..247eedc 100644 --- a/tests/test_benchmarks/test_macs/test_macs_integration.py +++ b/tests/test_benchmarks/test_macs/test_macs_integration.py @@ -44,7 +44,7 @@ def test_complete_task_lifecycle(self, sample_agent_data, travel_task): json.dumps([{"assertion": "Tool called", "answer": "TRUE", "evidence": "OK"}]), ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) # Setup phase env = benchmark.setup_environment(sample_agent_data, travel_task) @@ -111,7 +111,7 @@ def test_loaded_task_works_with_environment(self, macs_model, sample_agent_data) metadata={"scenario": "Travel booking scenario", "task_id": "task-000001"}, ) - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, task) assert "search" in env.tools @@ -199,8 +199,8 @@ def test_run_single_task_complete_pipeline(self, sample_agent_data, travel_task) ] ) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) - reports = benchmark.run([travel_task]) + benchmark = ConcreteMACSBenchmark(model) + reports = benchmark.run([travel_task], agent_data=sample_agent_data) # Verify complete report structure assert len(reports) == 1 @@ -222,8 +222,8 @@ def test_run_multiple_tasks(self, sample_agent_data, macs_task_queue): ] ) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) - reports = benchmark.run(macs_task_queue) + benchmark = ConcreteMACSBenchmark(model) + reports = benchmark.run(macs_task_queue, agent_data=sample_agent_data) assert len(reports) == len(macs_task_queue) for report in reports: @@ -241,8 +241,8 @@ def test_run_with_task_repeats(self, sample_agent_data, sample_task): ) n_repeats = 3 - benchmark = ConcreteMACSBenchmark(sample_agent_data, model, n_task_repeats=n_repeats) - reports = benchmark.run([sample_task]) + benchmark = ConcreteMACSBenchmark(model, n_task_repeats=n_repeats) + reports = benchmark.run([sample_task], agent_data=sample_agent_data) assert len(reports) == n_repeats for i, report in enumerate(reports): @@ -284,8 +284,8 @@ def on_run_end(self, benchmark, results): ) callback = TrackingCallback() - benchmark = ConcreteMACSBenchmark(sample_agent_data, model, callbacks=[callback]) - benchmark.run([sample_task]) + benchmark = ConcreteMACSBenchmark(model, callbacks=[callback]) + benchmark.run([sample_task], agent_data=sample_agent_data) # Verify callback sequence expected_order = ["run_start", "task_start", "repeat_start_0", "repeat_end_0", "task_end", "run_end"] diff --git a/tests/test_core/test_benchmark/test_automatic_registration.py b/tests/test_core/test_benchmark/test_automatic_registration.py index 486b973..4d51a32 100644 --- a/tests/test_core/test_benchmark/test_automatic_registration.py +++ b/tests/test_core/test_benchmark/test_automatic_registration.py @@ -23,7 +23,7 @@ def test_automatic_agent_registration(): tasks = TaskQueue.from_list([{"query": "test", "id": "1", "environment_data": {}}]) agent_data = {} - benchmark = DummyBenchmark(agent_data=agent_data) + benchmark = DummyBenchmark() # Before run, registry should be empty assert len(benchmark._registry._trace_registry) == 0 @@ -57,7 +57,7 @@ def test_duplicate_registration_detection(): Verifies that the ID-based tracking system detects when a component instance is registered multiple times with different names, preventing data confusion. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create a component model = DummyModelAdapter() @@ -84,7 +84,7 @@ def test_duplicate_registration_helpful_message(): Verifies that error message includes both the existing registration name and the attempted new name, plus mentions automatic registration. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create and register an agent agent = DummyAgent() @@ -108,7 +108,7 @@ def test_manual_registration_for_models(): Verifies that models are not automatically registered (unlike agents, environments, and users), requiring explicit register() calls. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create a model model = DummyModelAdapter() @@ -128,7 +128,7 @@ def test_component_id_tracking(): Verifies that benchmark maintains a Python id() to name mapping for detecting duplicate registrations of the same component instance. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create a component model = DummyModelAdapter() @@ -156,10 +156,10 @@ def test_registry_cleared_after_repetition(): ) agent_data = {} - benchmark = DummyBenchmark(agent_data=agent_data, n_task_repeats=2) + benchmark = DummyBenchmark(n_task_repeats=2) # Run the benchmark - benchmark.run(tasks) + benchmark.run(tasks, agent_data=agent_data) # After run completes, registry should be empty (cleared after last repetition) assert len(benchmark._registry._trace_registry) == 0 diff --git a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py index f7b6536..654d077 100644 --- a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py +++ b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py @@ -15,10 +15,10 @@ class TestBenchmarkLifecycle: def test_benchmark_complete_run_single_task(self, simple_benchmark): """Test that a benchmark completes successfully with a single task.""" - benchmark, tasks = simple_benchmark + benchmark, tasks, agent_data = simple_benchmark # Run benchmark - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data=agent_data) # Verify we got one report assert len(reports) == 3 # 3 tasks in dummy_task_queue @@ -50,9 +50,9 @@ def test_benchmark_complete_run_multiple_tasks(self): {"query": "Task 3", "environment_data": {}}, ] ) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 reports (one per task) assert len(reports) == 3 @@ -70,9 +70,9 @@ def test_benchmark_task_repetitions(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=3) + benchmark = DummyBenchmark(n_task_repeats=3) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 reports (one per repetition) assert len(reports) == 3 @@ -123,12 +123,11 @@ def on_run_end(self, benchmark, results): ] ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[OrderTrackingCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # Verify order expected = [ @@ -174,11 +173,10 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): registry_sizes.append(len(benchmark._registry._trace_registry)) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[RegistryTracker()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # After first repetition completes and second starts, registry should be cleared # Note: This test verifies cleanup happens between repeats @@ -188,13 +186,13 @@ def test_benchmark_registry_cleared_after_task(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=1) + benchmark = DummyBenchmark(n_task_repeats=1) # Before run, registry should be empty assert len(benchmark._registry._trace_registry) == 0 assert len(benchmark._registry._config_registry) == 0 - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # After run completes, registry should be cleared assert len(benchmark._registry._trace_registry) == 0 @@ -205,9 +203,9 @@ def test_benchmark_reports_structure(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -252,9 +250,9 @@ def test_benchmark_agent_data_per_task(self): {"model": "model-2", "temp": 0.9}, ] - benchmark = DummyBenchmark(agent_data=agent_data_list) + benchmark = DummyBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data=agent_data_list) # Verify each task received its specific agent_data assert len(benchmark.setup_agents_calls) == 2 @@ -281,15 +279,15 @@ def test_benchmark_invalid_agent_data_length(self): ValueError, match="must either be a single dict or an iterable matching the number of tasks", ): - benchmark = DummyBenchmark(agent_data=agent_data_list) - benchmark.run(tasks) + benchmark = DummyBenchmark() + benchmark.run(tasks, agent_data=agent_data_list) def test_benchmark_n_task_repeats_validation(self): """Test that n_task_repeats must be at least 1.""" from conftest import DummyBenchmark with pytest.raises(ValueError, match="n_task_repeats must be at least 1"): - DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=0) + DummyBenchmark(n_task_repeats=0) @pytest.mark.core @@ -323,11 +321,10 @@ def setup_agents(self, agent_data, environment, task, user): tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = TaskFailureBenchmark( - agent_data={"model": "test"}, fail_on_task_error=False, ) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -361,12 +358,11 @@ def setup_agents(self, agent_data, environment, task, user): tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = TaskFailureBenchmark( - agent_data={"model": "test"}, fail_on_task_error=True, ) with pytest.raises(RuntimeError, match="Agent execution failed!"): - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_evaluation_failure_graceful(self): """Test that evaluation failures are caught and recorded when fail_on_evaluation_error=False.""" @@ -386,11 +382,10 @@ def setup_evaluators(self, environment, task, agents, user): tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = EvaluationFailureBenchmark( - agent_data={"model": "test"}, fail_on_evaluation_error=False, ) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -418,12 +413,11 @@ def setup_evaluators(self, environment, task, agents, user): tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = EvaluationFailureBenchmark( - agent_data={"model": "test"}, fail_on_evaluation_error=True, ) with pytest.raises(ValueError, match="Evaluation failed!"): - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_setup_failure_graceful(self): """Test that setup failures are caught and recorded when fail_on_setup_error=False.""" @@ -436,11 +430,10 @@ def setup_environment(self, agent_data, task): tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = SetupFailureBenchmark( - agent_data={"model": "test"}, fail_on_setup_error=False, ) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -460,12 +453,11 @@ def setup_environment(self, agent_data, task): tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = SetupFailureBenchmark( - agent_data={"model": "test"}, fail_on_setup_error=True, ) with pytest.raises(RuntimeError, match="Environment setup failed!"): - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_get_failed_tasks(self): """Test get_failed_tasks() method.""" @@ -505,8 +497,8 @@ def setup_agents(self, agent_data, environment, task, user): {"query": "Task 3", "environment_data": {}}, ] ) - benchmark = MixedBenchmark(agent_data={"model": "test"}) - reports = benchmark.run(tasks) + benchmark = MixedBenchmark() + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Test using internal state failed = benchmark.get_failed_tasks() @@ -553,7 +545,7 @@ def test_get_failed_tasks_before_run_raises(self): """Test that get_failed_tasks() raises if called before run().""" from conftest import DummyBenchmark - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() with pytest.raises(RuntimeError, match="must be called after run"): benchmark.get_failed_tasks() @@ -564,9 +556,9 @@ def test_successful_task_has_success_status(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.SUCCESS.value @@ -577,7 +569,7 @@ def test_default_failure_flags(self): """Test that failure flags default to False (graceful handling).""" from conftest import DummyBenchmark - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() assert benchmark.fail_on_setup_error is False assert benchmark.fail_on_task_error is False @@ -594,7 +586,7 @@ def test_multiple_run_calls_no_side_effects(self): from conftest import DummyBenchmark # Create benchmark - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() # First run with 3 tasks tasks1 = TaskQueue.from_list( @@ -605,7 +597,7 @@ def test_multiple_run_calls_no_side_effects(self): ] ) - reports1 = benchmark.run(tasks=tasks1) + reports1 = benchmark.run(tasks=tasks1, agent_data={"model": "test"}) assert len(reports1) == 3 assert len(benchmark.reports) == 3 @@ -617,7 +609,7 @@ def test_multiple_run_calls_no_side_effects(self): ] ) - reports2 = benchmark.run(tasks=tasks2) + reports2 = benchmark.run(tasks=tasks2, agent_data={"model": "test"}) assert len(reports2) == 2 # Verify reports were cleared from first run assert len(benchmark.reports) == 2 @@ -630,12 +622,12 @@ def test_multiple_run_calls_no_side_effects(self): # Third run - retry pattern (simulating failed tasks) # Use one task from tasks1 retry_tasks = TaskQueue([list(tasks1)[0]]) - reports3 = benchmark.run(tasks=retry_tasks) + reports3 = benchmark.run(tasks=retry_tasks, agent_data={"model": "test"}) assert len(reports3) == 1 assert len(benchmark.reports) == 1 def test_retry_failed_tasks_pattern(self): - """Test the intended use case: benchmark.run(benchmark.get_failed_tasks()). + """Test the intended use case: benchmark.run(benchmark.get_failed_tasks(, agent_data={"model": "test"})). This verifies that failed tasks can be retried by passing them back to run(). This includes returning tasks that failed using the correct format that run() expects. @@ -679,10 +671,10 @@ def setup_agents(self, agent_data, environment, task, user): ] ) - benchmark = ConditionalFailureBenchmark(agent_data={"model": "test"}) + benchmark = ConditionalFailureBenchmark() # First run - one task will fail - reports = benchmark.run(tasks=tasks) + reports = benchmark.run(tasks=tasks, agent_data={"model": "test"}) assert len(reports) == 3 # Get failed tasks - should have 1 failure @@ -693,7 +685,7 @@ def setup_agents(self, agent_data, environment, task, user): # Retry the failed tasks (simulate fixing the issue) benchmark.fail_on_first_run = False benchmark.task_counter = 0 # Reset counter - retry_reports = benchmark.run(tasks=failed) + retry_reports = benchmark.run(tasks=failed, agent_data={"model": "test"}) # Should have 1 report for the retried task assert len(retry_reports) == 1 diff --git a/tests/test_core/test_benchmark/test_callback_error_handling.py b/tests/test_core/test_benchmark/test_callback_error_handling.py index 3fb0d68..8d90623 100644 --- a/tests/test_core/test_benchmark/test_callback_error_handling.py +++ b/tests/test_core/test_benchmark/test_callback_error_handling.py @@ -72,10 +72,12 @@ def on_run_end(self, benchmark, results): @pytest.fixture def simple_tasks(): """Create simple tasks for testing.""" - return TaskQueue.from_list([ - {"query": "Task 1", "environment_data": {}}, - {"query": "Task 2", "environment_data": {}}, - ]) + return TaskQueue.from_list( + [ + {"query": "Task 1", "environment_data": {}}, + {"query": "Task 2", "environment_data": {}}, + ] + ) # ==================== Error Suppression Tests ==================== @@ -91,12 +93,11 @@ def test_failing_callback_does_not_stop_execution(self, simple_tasks, caplog): failing_cb = FailingCallback(fail_on="on_task_start") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb], ) # Should complete despite callback failure - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) assert len(reports) == 2 assert all(r["status"] == "success" for r in reports) @@ -114,11 +115,10 @@ def test_multiple_callbacks_one_fails_others_continue(self, simple_tasks, caplog tracking_cb = TrackingCallback() benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb, tracking_cb], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) # Execution completes assert len(reports) == 2 @@ -138,11 +138,10 @@ def test_callback_fails_on_every_task(self, simple_tasks, caplog): failing_cb = FailingCallback(fail_on="on_task_repeat_end") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) # All tasks complete assert len(reports) == 2 @@ -160,11 +159,10 @@ def on_run_start(self, benchmark): raise RuntimeError("Failed at run start") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[RunStartFailer()], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) assert len(reports) == 2 assert "on_run_start" in caplog.text @@ -178,11 +176,10 @@ def on_run_end(self, benchmark, results): raise RuntimeError("Failed at run end") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[RunEndFailer()], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) # Reports are generated (run_end happens after report collection) assert len(reports) == 2 @@ -204,12 +201,11 @@ def test_callback_error_in_parallel_execution(self, simple_tasks, caplog): tracking_cb = TrackingCallback() benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb, tracking_cb], ) # Run in parallel - reports = benchmark.run(simple_tasks, max_workers=2) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) # All tasks complete assert len(reports) == 2 @@ -225,19 +221,15 @@ def test_multiple_parallel_tasks_with_failing_callback(self, caplog): """Callback failures should not interfere across parallel workers.""" caplog.set_level(logging.ERROR) - tasks = TaskQueue.from_list([ - {"query": f"Task {i}", "environment_data": {}} - for i in range(5) - ]) + tasks = TaskQueue.from_list([{"query": f"Task {i}", "environment_data": {}} for i in range(5)]) failing_cb = FailingCallback(fail_on="on_task_repeat_start") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb], ) - reports = benchmark.run(tasks, max_workers=3) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # All 5 tasks complete despite callback failures assert len(reports) == 5 @@ -259,7 +251,6 @@ def test_invoke_callbacks_returns_empty_list_on_success(self, simple_tasks): """_invoke_callbacks should return empty list when no errors occur.""" tracking_cb = TrackingCallback() benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[tracking_cb], ) @@ -276,7 +267,6 @@ def test_invoke_callbacks_returns_error_list_on_failure(self, simple_tasks): tracking_cb = TrackingCallback() benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb1, tracking_cb, failing_cb2], ) @@ -296,7 +286,6 @@ def test_invoke_callbacks_with_suppress_false_raises(self, simple_tasks): tracking_cb = TrackingCallback() benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb, tracking_cb], ) @@ -327,11 +316,10 @@ def test_value_error_in_callback(self, simple_tasks, caplog): failing_cb = FailingCallback(fail_on="on_task_repeat_start") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) assert len(reports) == 2 assert "ValueError" in caplog.text @@ -342,11 +330,10 @@ def test_type_error_in_callback(self, simple_tasks, caplog): failing_cb = FailingCallback(fail_on="on_task_repeat_end") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) assert len(reports) == 2 assert "TypeError" in caplog.text @@ -357,11 +344,10 @@ def test_key_error_in_callback(self, simple_tasks, caplog): failing_cb = FailingCallback(fail_on="on_task_end") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) assert len(reports) == 2 assert "KeyError" in caplog.text @@ -383,12 +369,11 @@ def test_failing_callback_with_repeats(self, caplog): failing_cb = FailingCallback(fail_on="on_task_repeat_end") benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=3, callbacks=[failing_cb], ) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # 3 reports (one per repeat) assert len(reports) == 3 @@ -410,11 +395,10 @@ def test_mixed_callbacks_some_fail_some_succeed(self, simple_tasks, caplog): tracking_cb2 = TrackingCallback() benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[failing_cb1, tracking_cb1, failing_cb2, tracking_cb2], ) - reports = benchmark.run(simple_tasks) + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) # Execution completes assert len(reports) == 2 diff --git a/tests/test_core/test_benchmark/test_callback_orchestration.py b/tests/test_core/test_benchmark/test_callback_orchestration.py index 8de8b17..f32652b 100644 --- a/tests/test_core/test_benchmark/test_callback_orchestration.py +++ b/tests/test_core/test_benchmark/test_callback_orchestration.py @@ -39,12 +39,11 @@ def on_run_end(self, benchmark, results): tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[OrderedCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) expected = [ "run_start", @@ -80,9 +79,9 @@ def on_run_end(self, benchmark, results): callback2_calls.append("end") tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[Callback1(), Callback2()]) + benchmark = DummyBenchmark(callbacks=[Callback1(), Callback2()]) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert callback1_calls == ["start", "end"] assert callback2_calls == ["start", "end"] @@ -106,14 +105,13 @@ def on_run_end(self, benchmark, results): tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[FailingCallback(), SuccessfulCallback()], ) # Note: Current implementation may not catch callback errors # This test documents the expected behavior try: - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # If callbacks are isolated, successful callback should work assert "start" in successful_calls or "end" in successful_calls except RuntimeError: @@ -143,12 +141,11 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): ] ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=3, callbacks=[CountingCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # 2 tasks, each called once assert task_count == 2 @@ -173,9 +170,9 @@ def on_run_end(self, benchmark, results): contexts["results_count"] = len(results) tasks = TaskQueue.from_list([{"query": "TestQuery", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[ContextCapturingCallback()]) + benchmark = DummyBenchmark(callbacks=[ContextCapturingCallback()]) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # Verify contexts captured correctly assert contexts["task_query"] == "TestQuery" @@ -202,12 +199,11 @@ def on_run_start(self, benchmark): ] ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[StateAccessingCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert captured_state["n_tasks"] == 2 assert captured_state["n_repeats"] == 2 diff --git a/tests/test_core/test_benchmark/test_config_collection.py b/tests/test_core/test_benchmark/test_config_collection.py index 775c39d..22bac9a 100644 --- a/tests/test_core/test_benchmark/test_config_collection.py +++ b/tests/test_core/test_benchmark/test_config_collection.py @@ -17,9 +17,9 @@ def test_config_collected_from_all_components(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Verify config has the expected structure @@ -49,9 +49,9 @@ def test_config_includes_benchmark_level_info(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Verify benchmark-level config exists @@ -68,9 +68,9 @@ def test_config_includes_system_info(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] system_info = config["benchmark"]["system"] @@ -82,9 +82,9 @@ def test_config_includes_git_info(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Git info may not be available in all environments @@ -99,9 +99,9 @@ def test_config_includes_package_versions(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Should capture package information @@ -115,9 +115,9 @@ def test_config_structure_matches_spec(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Top-level keys @@ -167,10 +167,10 @@ def setup_agents(self, agent_data, environment, task, user): return [agent_adapter], {"failing_agent": agent_adapter} # type: ignore[return-value] tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() # Should complete without raising, with error info in config - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Verify config collection handled the error @@ -185,9 +185,9 @@ def test_config_json_serializable(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Should be able to serialize to JSON @@ -206,9 +206,9 @@ def test_config_contains_timestamps(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Check metadata timestamp @@ -230,9 +230,9 @@ def test_config_includes_component_types(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Check agent type @@ -250,9 +250,9 @@ def test_config_different_per_repetition(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=3) + benchmark = DummyBenchmark(n_task_repeats=3) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 reports assert len(reports) == 3 diff --git a/tests/test_core/test_benchmark/test_execution_loop.py b/tests/test_core/test_benchmark/test_execution_loop.py index 10b895a..9df47e6 100644 --- a/tests/test_core/test_benchmark/test_execution_loop.py +++ b/tests/test_core/test_benchmark/test_execution_loop.py @@ -72,7 +72,7 @@ class TestExecutionLoopNoUser: def test_uses_task_query_without_user(self, dummy_model): """Uses task.query when no user present.""" task = Task(query="What is the weather?", environment_data={}) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=None) + benchmark = ExecutionLoopBenchmark(return_user=None) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, None) @@ -86,7 +86,7 @@ def test_uses_task_query_without_user(self, dummy_model): def test_single_invocation_without_user(self, dummy_model): """Single agent run without user (default max_invocations=1).""" task = Task(query="Query", environment_data={}) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=None) + benchmark = ExecutionLoopBenchmark(return_user=None) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, None) @@ -98,7 +98,7 @@ def test_single_invocation_without_user(self, dummy_model): def test_returns_final_answer(self, dummy_model): """Returns final answer from agent.""" task = Task(query="Test query", environment_data={}) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=None) + benchmark = ExecutionLoopBenchmark(return_user=None) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, None) @@ -129,7 +129,7 @@ def test_uses_user_initial_query(self, dummy_model): initial_query="User's initial message", max_turns=5, ) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, user) @@ -149,7 +149,7 @@ def test_uses_get_initial_query_if_no_initial_query(self, dummy_model): # No initial_query, so messages is empty user.simulator.return_value = "LLM generated initial query" - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, user) @@ -175,7 +175,6 @@ def test_multi_turn_interaction(self, dummy_model): user.simulator.side_effect = ["Turn 1 response", "Turn 2 response", "Turn 3 response"] benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=3, ) @@ -208,7 +207,6 @@ def test_stops_when_user_done_via_max_turns(self, dummy_model): user.simulator.side_effect = ["Response 1", "Response 2", "Response 3"] benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=5, # Would allow 5, but user stops at 3 turns ) @@ -239,7 +237,6 @@ def test_stops_when_user_done_via_stop_token(self, dummy_model): user.simulator.side_effect = ["Continue please", "Thanks! "] benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=5, ) @@ -265,7 +262,7 @@ def test_final_answer_in_user_messages(self, dummy_model): ) user.simulator.return_value = "Thanks" - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, user) @@ -294,7 +291,6 @@ def test_user_response_becomes_next_query(self, dummy_model): user.simulator.side_effect = ["User reply 1", "User reply 2", "User reply 3"] benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=3, # Will stop after 3 due to max_invocations ) @@ -331,26 +327,25 @@ class TestMaxInvocations: def test_default_max_invocations_is_one(self): """Default is single invocation.""" - benchmark = ExecutionLoopBenchmark(agent_data={}) + benchmark = ExecutionLoopBenchmark() assert benchmark.max_invocations == 1 def test_custom_max_invocations(self): """Custom max_invocations is stored.""" - benchmark = ExecutionLoopBenchmark(agent_data={}, max_invocations=5) + benchmark = ExecutionLoopBenchmark(max_invocations=5) assert benchmark.max_invocations == 5 def test_warning_max_invocations_without_user(self): """Warning issued when max_invocations > 1 but no user.""" task = Task(query="Test", environment_data={}) benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=None, max_invocations=5, ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - benchmark.run(TaskQueue([task])) + benchmark.run(TaskQueue([task]), agent_data={}) # Check for warning about max_invocations without user warning_messages = [str(warning.message) for warning in w] @@ -379,9 +374,9 @@ def test_run_with_user_uses_execution_loop(self, dummy_model): ) user.simulator.return_value = "Done" - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) - benchmark.run(TaskQueue([task])) + benchmark.run(TaskQueue([task]), agent_data={}) # Verify run_agents was called with user's initial prompt assert len(benchmark.run_agents_calls) == 1 @@ -402,12 +397,11 @@ def test_complete_traces_with_user(self, dummy_model): user.simulator.side_effect = ["Reply 1", "Reply 2"] benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=2, ) - reports = benchmark.run(TaskQueue([task])) + reports = benchmark.run(TaskQueue([task]), agent_data={}) # Check that user traces are in the report assert len(reports) == 1 diff --git a/tests/test_core/test_benchmark/test_parallel_execution.py b/tests/test_core/test_benchmark/test_parallel_execution.py index 0e72da3..416d1c5 100644 --- a/tests/test_core/test_benchmark/test_parallel_execution.py +++ b/tests/test_core/test_benchmark/test_parallel_execution.py @@ -105,27 +105,27 @@ class TestParallelExecutionBasics: def test_parallel_execution_completes(self, parallel_tasks): """Verify parallel execution completes all tasks.""" - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(parallel_tasks, max_workers=3) + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) assert len(reports) == 5 def test_parallel_produces_same_report_count(self, parallel_tasks): """Parallel and sequential should produce same number of reports.""" - benchmark_seq = DummyBenchmark(agent_data={"model": "test"}) - benchmark_par = DummyBenchmark(agent_data={"model": "test"}) + benchmark_seq = DummyBenchmark() + benchmark_par = DummyBenchmark() - reports_seq = benchmark_seq.run(parallel_tasks, max_workers=1) - reports_par = benchmark_par.run(parallel_tasks, max_workers=3) + reports_seq = benchmark_seq.run(parallel_tasks, agent_data={"model": "test"}) + reports_par = benchmark_par.run(parallel_tasks, agent_data={"model": "test"}) assert len(reports_seq) == len(reports_par) def test_parallel_reports_have_correct_structure(self, parallel_tasks): """Verify parallel reports have expected fields.""" - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark(max_workers=3) - reports = benchmark.run(parallel_tasks, max_workers=2) + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) for report in reports: assert "task_id" in report @@ -139,11 +139,10 @@ def test_single_worker_uses_sequential(self, parallel_tasks): """max_workers=1 should behave identically to sequential.""" callback = OrderTrackingCallback() benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[callback], ) - benchmark.run(parallel_tasks, max_workers=1) + benchmark.run(parallel_tasks, agent_data={"model": "test"}) # Verify ordering is strictly sequential (task_start before all repeat_starts) assert callback.invocations[0] == "run_start" @@ -159,11 +158,10 @@ def test_parallel_with_repetitions(self): ] ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=3, ) - reports = benchmark.run(tasks, max_workers=2) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 6 # 2 tasks × 3 repeats @@ -182,11 +180,10 @@ class TestParallelThreadSafety: def test_reports_all_collected(self, parallel_tasks): """All reports should be collected regardless of completion order.""" benchmark = SlowBenchmark( - agent_data={"model": "test"}, delay_seconds=0.02, ) - reports = benchmark.run(parallel_tasks, max_workers=4) + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) assert len(reports) == 5 task_ids = {r["task_id"] for r in reports} @@ -194,9 +191,9 @@ def test_reports_all_collected(self, parallel_tasks): def test_traces_not_cross_contaminated(self, parallel_tasks): """Traces from one task should not appear in another's report.""" - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark(max_workers=4) - reports = benchmark.run(parallel_tasks, max_workers=3) + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) for report in reports: # Each report should have its own traces @@ -221,11 +218,10 @@ def on_task_repeat_end(self, benchmark, report): ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[DataCapturingCallback()], ) - benchmark.run(tasks, max_workers=2) + benchmark.run(tasks, agent_data={"model": "test"}) assert len(received_data) == 3 statuses = {d["status"] for d in received_data} @@ -234,10 +230,7 @@ def on_task_repeat_end(self, benchmark, report): def test_callback_exceptions_suppressed_by_default(self): """Callback exceptions are suppressed by default to prevent disruption.""" # Create fresh tasks for this test - tasks = TaskQueue.from_list([ - {"query": f"Task {i}", "environment_data": {}} - for i in range(5) - ]) + tasks = TaskQueue.from_list([{"query": f"Task {i}", "environment_data": {}} for i in range(5)]) call_count = [0] @@ -248,13 +241,12 @@ def on_task_repeat_end(self, benchmark, report): raise RuntimeError("Intentional failure") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[FailingCallback()], ) # New behavior: callback exceptions are suppressed by default # This prevents one failing callback from disrupting parallel execution - reports = benchmark.run(tasks, max_workers=2) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Execution completes despite callback failure assert len(reports) == 5 @@ -275,15 +267,15 @@ def test_parallel_faster_than_sequential(self): delay = 0.05 # Sequential timing - benchmark_seq = SlowBenchmark(agent_data={"model": "test"}, delay_seconds=delay) + benchmark_seq = SlowBenchmark(delay_seconds=delay) start_seq = time.time() - benchmark_seq.run(tasks, max_workers=1) + benchmark_seq.run(tasks, agent_data={"model": "test"}) time_seq = time.time() - start_seq # Parallel timing - benchmark_par = SlowBenchmark(agent_data={"model": "test"}, delay_seconds=delay) + benchmark_par = SlowBenchmark(delay_seconds=delay, max_workers=4) start_par = time.time() - benchmark_par.run(tasks, max_workers=4) + benchmark_par.run(tasks, agent_data={"model": "test"}) time_par = time.time() - start_par # Parallel should be significantly faster (at least 2x) @@ -294,11 +286,11 @@ def test_execution_overlaps(self): tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(3)]) benchmark = SlowBenchmark( - agent_data={"model": "test"}, delay_seconds=0.05, + max_workers=3, ) - benchmark.run(tasks, max_workers=3) + benchmark.run(tasks, agent_data={"model": "test"}) # Check for overlapping execution times times = benchmark.execution_times @@ -341,8 +333,8 @@ def run_agents(self, agents, task, environment, query): ] ) - benchmark = FailingBenchmark(agent_data={"model": "test"}) - reports = benchmark.run(tasks, max_workers=2) + benchmark = FailingBenchmark() + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 3 @@ -370,8 +362,8 @@ def run_agents(self, agents, task, environment, query): tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) - benchmark = HalfFailingBenchmark(agent_data={"model": "test"}) - reports = benchmark.run(tasks, max_workers=2) + benchmark = HalfFailingBenchmark() + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 4 @@ -409,11 +401,10 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): execution_order.append(task.query) benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[OrderTracker()], ) # With max_workers=1, order should be strictly by priority - benchmark.run(queue, max_workers=1) + benchmark.run(queue, agent_data={"model": "test"}) assert execution_order == ["P5", "P4", "P3", "P2", "P1"] diff --git a/tests/test_core/test_benchmark/test_progress_bar_integration.py b/tests/test_core/test_benchmark/test_progress_bar_integration.py index 4814a14..ab6805e 100644 --- a/tests/test_core/test_benchmark/test_progress_bar_integration.py +++ b/tests/test_core/test_benchmark/test_progress_bar_integration.py @@ -28,14 +28,14 @@ def test_benchmark_with_default_progress_bar(): tasks = TaskQueue([Task(query="What is 2+2?")]) # Default should have progress bar - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() # Check that a progress bar callback was added progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, TqdmProgressBarCallback)] assert len(progress_bars) == 1 # Run and verify it works - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 assert reports[0]["status"] == "success" @@ -45,13 +45,13 @@ def test_benchmark_with_disabled_progress_bar(): """Test that progress bar can be disabled.""" tasks = TaskQueue([Task(query="What is 2+2?")]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, progress_bar=False) + benchmark = DummyBenchmark(progress_bar=False) # Should have no progress bar callbacks progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, (TqdmProgressBarCallback, RichProgressBarCallback))] assert len(progress_bars) == 0 - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 @@ -60,13 +60,13 @@ def test_benchmark_with_rich_progress_bar(): """Test that rich progress bar can be specified.""" tasks = TaskQueue([Task(query="What is 2+2?")]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, progress_bar="rich") + benchmark = DummyBenchmark(progress_bar="rich") # Should have a rich progress bar progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, RichProgressBarCallback)] assert len(progress_bars) == 1 - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 @@ -79,7 +79,6 @@ def test_benchmark_with_custom_progress_bar(): custom_pbar = TqdmProgressBarCallback(desc="Custom Progress") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[custom_pbar], progress_bar=True, # Should be ignored ) @@ -89,7 +88,7 @@ def test_benchmark_with_custom_progress_bar(): assert len(progress_bars) == 1 assert progress_bars[0].desc == "Custom Progress" - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 @@ -98,14 +97,14 @@ def test_benchmark_with_multiple_tasks_and_repeats(): """Test progress bar with multiple tasks and repeats.""" tasks = TaskQueue([Task(query=f"Task {i}") for i in range(3)]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=2, progress_bar=True) + benchmark = DummyBenchmark(n_task_repeats=2, progress_bar=True) # Get the progress bar callback progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, TqdmProgressBarCallback)] assert len(progress_bars) == 1 pbar = progress_bars[0] - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 tasks * 2 repeats = 6 reports assert len(reports) == 6 @@ -119,4 +118,4 @@ def test_benchmark_with_multiple_tasks_and_repeats(): def test_invalid_progress_bar_value(): """Test that invalid progress bar value raises error.""" with pytest.raises(ValueError, match="Invalid progress_bar value"): - DummyBenchmark(agent_data={"model": "test"}, progress_bar="invalid") + DummyBenchmark(progress_bar="invalid") diff --git a/tests/test_core/test_benchmark/test_trace_collection.py b/tests/test_core/test_benchmark/test_trace_collection.py index 0632e57..f1cbf3f 100644 --- a/tests/test_core/test_benchmark/test_trace_collection.py +++ b/tests/test_core/test_benchmark/test_trace_collection.py @@ -17,9 +17,9 @@ def test_traces_collected_from_all_components(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify traces have the expected structure @@ -50,9 +50,9 @@ def test_traces_include_message_histories(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Get agent trace @@ -90,10 +90,10 @@ def setup_agents(self, agent_data, environment, task, user): return [agent_adapter], {"failing_agent": agent_adapter} # type: ignore[return-value] tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() # Should complete without raising, with error info in traces - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify trace collection handled the error @@ -131,9 +131,9 @@ def setup_agents(self, agent_data, environment, task, user): return [agent_adapter], {"test_agent": agent_adapter} # type: ignore[return-value] tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify model traces @@ -152,9 +152,9 @@ def test_environment_traces_tool_invocations(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify environment traces include tools @@ -218,12 +218,12 @@ def gather_traces(self): callback = CustomCallback() tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[callback]) + benchmark = DummyBenchmark(callbacks=[callback]) # Register callback for tracing benchmark.register("callbacks", "custom_callback", callback) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify callback traces @@ -239,9 +239,9 @@ def test_traces_json_serializable(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Should be able to serialize to JSON @@ -260,9 +260,9 @@ def test_traces_contain_timestamps(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Check metadata timestamp @@ -281,9 +281,9 @@ def test_traces_include_component_types(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Check agent type diff --git a/tests/test_core/test_evaluator.py b/tests/test_core/test_evaluator.py index 31ff817..c81acc0 100644 --- a/tests/test_core/test_evaluator.py +++ b/tests/test_core/test_evaluator.py @@ -36,9 +36,9 @@ def setup_evaluators(self, environment, task, agents, user): return [TracingEvaluator(task, environment, user)] tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert len(received_traces) == 1 assert isinstance(received_traces[0], dict) @@ -58,9 +58,9 @@ def evaluate(self, evaluators, agents, final_answer, traces): return super().evaluate(evaluators, agents, final_answer, traces) tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_evaluator_receives_final_answer(self): """Test that evaluate() receives the final answer from agents.""" @@ -74,9 +74,9 @@ def evaluate(self, evaluators, agents, final_answer, traces): return super().evaluate(evaluators, agents, final_answer, traces) tasks = TaskQueue.from_list([{"query": "My test query", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert len(received_answers) == 1 assert "Response to: My test query" in received_answers[0] @@ -93,9 +93,9 @@ def evaluate(self, evaluators, agents, final_answer, traces): return super().evaluate(evaluators, agents, final_answer, traces) tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert len(received_traces) == 1 traces = received_traces[0] @@ -139,9 +139,9 @@ def setup_evaluators(self, environment, task, agents, user): ] tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert call_counts["eval1"] == 1 assert call_counts["eval2"] == 1 @@ -151,9 +151,9 @@ def test_evaluator_results_in_report(self): from conftest import DummyBenchmark tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] diff --git a/tests/test_core/test_exceptions.py b/tests/test_core/test_exceptions.py index 3354221..415ecd0 100644 --- a/tests/test_core/test_exceptions.py +++ b/tests/test_core/test_exceptions.py @@ -42,8 +42,8 @@ def setup_agents(self, agent_data, environment, task, user): return [adapter], {"agent": adapter} tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = AgentErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + benchmark = AgentErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.AGENT_ERROR.value @@ -69,8 +69,8 @@ def setup_agents(self, agent_data, environment, task, user): return [adapter], {"agent": adapter} tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = EnvironmentErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + benchmark = EnvironmentErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.ENVIRONMENT_ERROR.value @@ -96,8 +96,8 @@ def setup_agents(self, agent_data, environment, task, user): return [adapter], {"agent": adapter} tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = UserErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + benchmark = UserErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.USER_ERROR.value @@ -123,8 +123,8 @@ def setup_agents(self, agent_data, environment, task, user): return [adapter], {"agent": adapter} tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = GenericErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + benchmark = GenericErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value @@ -153,8 +153,8 @@ def setup_agents(self, agent_data, environment, task, user): return [adapter], {"agent": adapter} tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DetailedAgentErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + benchmark = DetailedAgentErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 error = reports[0]["error"] @@ -409,8 +409,8 @@ def run(self, query: str) -> str: ] ) - benchmark = MixedErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + benchmark = MixedErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) # Should have 1 success, 1 agent error, 1 env error statuses = [r["status"] for r in reports] From 898d975cd4be11c7632f61815a3c00a36e63c35e Mon Sep 17 00:00:00 2001 From: cemde Date: Sat, 6 Dec 2025 12:20:40 +0000 Subject: [PATCH 11/25] fixed linting error --- maseval/core/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maseval/core/callback.py b/maseval/core/callback.py index f352045..80c68dc 100644 --- a/maseval/core/callback.py +++ b/maseval/core/callback.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, List, TYPE_CHECKING from .tracing import TraceableMixin From 68851b4cb22cd759fe6de0725119968a33325e21 Mon Sep 17 00:00:00 2001 From: cemde Date: Sun, 21 Dec 2025 11:21:42 +0000 Subject: [PATCH 12/25] removed note files --- PLAN.md | 1295 -------------------------------------------------- SUMMARY.md | 103 ---- TEST_PLAN.md | 359 -------------- 3 files changed, 1757 deletions(-) delete mode 100644 PLAN.md delete mode 100644 SUMMARY.md delete mode 100644 TEST_PLAN.md diff --git a/PLAN.md b/PLAN.md deleted file mode 100644 index eae60fd..0000000 --- a/PLAN.md +++ /dev/null @@ -1,1295 +0,0 @@ -# Parallel Task Execution, Timeout Handling, and Task Queue Design - -This document proposes a unified design for three interconnected features that fundamentally improve MASEval's task execution architecture: - -1. **Parallel Processing** - Concurrent task execution via asyncio or threading -2. **Timeout Handling** - Per-task timeout with graceful failure recording -3. **Task Queue** - Callback-driven task scheduling for adaptive testing - -All three features directly impact the `Benchmark.run()` task loop and should be designed together. - ---- - -## Table of Contents - -1. [Current Architecture Analysis](#current-architecture-analysis) -2. [Key Architectural Changes](#key-architectural-changes) -3. [Feature 1: Parallel Processing](#feature-1-parallel-processing) -4. [Feature 2: Timeout Handling & TaskProtocol](#feature-2-timeout-handling--taskprotocol) -5. [Feature 3: TaskQueue](#feature-3-taskqueue) -6. [Unified Design Proposal](#unified-design-proposal) -7. [Implementation Phases](#implementation-phases) -8. [Risks and Mitigations](#risks-and-mitigations) - ---- - -## Current Architecture Analysis - -### The Run Loop (`benchmark.py` lines 990-1330) - -The current execution model is strictly sequential: - -```python -def run(self, tasks: ...): - for task_idx, (task, agent_data) in enumerate(zip(self.tasks, agent_data_list)): - for repeat_idx in range(self.n_task_repeats): - # Setup - environment = self.setup_environment(agent_data, task) - # ... more setup - - # Execute - final_answers = self.execution_loop(agents_to_run, task, environment, user) - - # Evaluate - eval_results = self.evaluate(...) - - # Store - self.reports.append(report) -``` - -### Key Observations - -1. **Sequential by Design**: No parallelism, no timeouts, no queue abstraction -2. **Callback System**: Already has lifecycle hooks (`on_task_start`, `on_task_repeat_end`, etc.) but callbacks cannot influence task ordering -3. **Component Registry**: Per-task-repetition component tracking with `register()` / `clear_registry()` -4. **Error Handling**: Comprehensive status enum (`TaskExecutionStatus`) with graceful failure paths -5. **Agent Adapters**: Framework-specific adapters (smolagents, langgraph) that may or may not be async-native -6. **Model Adapters**: API clients that are inherently I/O-bound - -### Critical Dependencies for Concurrency - -| Component | Thread-Safety | Async-Native | Notes | -| --------------------------- | -------------- | ------------ | -------------------------------- | -| `Benchmark.reports` | ❌ List append | N/A | Needs synchronization | -| `Benchmark._trace_registry` | ❌ Dict | N/A | Per-task, but needs isolation | -| `CallbackHandler` | ❌ | N/A | Callbacks may not be thread-safe | -| `SmolAgentAdapter` | ✅ (stateless) | ❌ | Uses sync `agent.run()` | -| `LangGraphAgentAdapter` | ✅ (stateless) | ⚠️ Partial | LangGraph has `ainvoke()` | -| `GoogleGenAIModelAdapter` | ✅ | ⚠️ Partial | Google client has async methods | - ---- - -## Key Architectural Changes - -This section summarizes the major architectural decisions made during planning. - -### 1. Extract `ComponentRegistry` from `Benchmark` - -**Problem**: The component registry logic (~150 lines) is mixed with benchmark orchestration. Adding thread-local handling will make it worse. - -**Solution**: Extract into a dedicated `ComponentRegistry` class in `maseval/core/registry.py`. - -```python -# maseval/core/registry.py - -import threading -from typing import Dict, Any, Optional -from datetime import datetime - -from .tracing import TraceableMixin -from .config import ConfigurableMixin - - -class ComponentRegistry: - """Thread-safe registry for tracking components during task execution. - - Each thread gets its own isolated registry state, enabling parallel - task execution without cross-contamination. The registry tracks both - Traceable and Configurable components for comprehensive data collection. - - Usage: - registry = ComponentRegistry() - - # Register components (thread-local) - registry.register("agents", "orchestrator", agent_adapter) - registry.register("environment", "env", environment) - - # Collect data - traces = registry.collect_traces() - configs = registry.collect_configs() - - # Clear for next task - registry.clear() - """ - - def __init__(self, benchmark_config: Optional[Dict[str, Any]] = None): - """Initialize the registry. - - Args: - benchmark_config: Benchmark-level configuration to include in - collect_configs() output. This is shared (not thread-local). - """ - self._local = threading.local() - self._benchmark_config = benchmark_config or {} - - # --- Thread-local state properties --- - - @property - def _trace_registry(self) -> Dict[str, TraceableMixin]: - if not hasattr(self._local, 'trace_registry'): - self._local.trace_registry = {} - return self._local.trace_registry - - @property - def _component_id_map(self) -> Dict[int, str]: - if not hasattr(self._local, 'component_id_map'): - self._local.component_id_map = {} - return self._local.component_id_map - - @property - def _config_registry(self) -> Dict[str, ConfigurableMixin]: - if not hasattr(self._local, 'config_registry'): - self._local.config_registry = {} - return self._local.config_registry - - @property - def _config_component_id_map(self) -> Dict[int, str]: - if not hasattr(self._local, 'config_component_id_map'): - self._local.config_component_id_map = {} - return self._local.config_component_id_map - - # --- Public API --- - - def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: - """Register a component for trace and config collection. - - Args: - category: Component category (e.g., "agents", "models", "environment") - name: Unique identifier within the category - component: Component instance (must be TraceableMixin) - - Returns: - The component (for chaining) - - Raises: - ValueError: If component already registered under a different key - """ - component_id = id(component) - key = f"{category}:{name}" - - # Check for duplicate registration under different key - if component_id in self._component_id_map: - existing_key = self._component_id_map[component_id] - if existing_key != key: - raise ValueError( - f"Component already registered as '{existing_key}', " - f"cannot re-register as '{key}'." - ) - return component # Idempotent - - # Register for tracing - self._trace_registry[key] = component - self._component_id_map[component_id] = key - - # Also register for config if supported - if isinstance(component, ConfigurableMixin): - self._config_registry[key] = component - self._config_component_id_map[component_id] = key - - return component - - def clear(self) -> None: - """Clear all registrations for the current thread.""" - self._trace_registry.clear() - self._component_id_map.clear() - self._config_registry.clear() - self._config_component_id_map.clear() - - def collect_traces(self) -> Dict[str, Any]: - """Collect execution traces from all registered components.""" - traces: Dict[str, Any] = { - "metadata": { - "timestamp": datetime.now().isoformat(), - "thread_id": threading.current_thread().ident, - "total_components": len(self._trace_registry), - }, - "agents": {}, - "models": {}, - "tools": {}, - "simulators": {}, - "callbacks": {}, - "environment": None, - "user": None, - "other": {}, - } - - for key, component in self._trace_registry.items(): - category, comp_name = key.split(":", 1) - try: - component_traces = component.gather_traces() - if "name" not in component_traces: - component_traces["name"] = comp_name - - if category == "environment": - traces["environment"] = component_traces - elif category == "user": - traces["user"] = component_traces - else: - if category not in traces: - traces[category] = {} - traces[category][comp_name] = component_traces - except Exception as e: - error_info = {"error": str(e), "error_type": type(e).__name__} - if category in ("environment", "user"): - traces[category] = error_info - else: - if category not in traces: - traces[category] = {} - traces[category][comp_name] = error_info - - return traces - - def collect_configs(self) -> Dict[str, Any]: - """Collect configuration from all registered components.""" - configs: Dict[str, Any] = { - "metadata": { - "timestamp": datetime.now().isoformat(), - "thread_id": threading.current_thread().ident, - "total_components": len(self._config_registry), - }, - "agents": {}, - "models": {}, - "tools": {}, - "simulators": {}, - "callbacks": {}, - "environment": None, - "user": None, - "other": {}, - "benchmark": self._benchmark_config, - } - - for key, component in self._config_registry.items(): - category, comp_name = key.split(":", 1) - try: - component_config = component.gather_config() - if "name" not in component_config: - component_config["name"] = comp_name - - if category == "environment": - configs["environment"] = component_config - elif category == "user": - configs["user"] = component_config - else: - if category not in configs: - configs[category] = {} - configs[category][comp_name] = component_config - except Exception as e: - error_info = {"error": str(e), "error_type": type(e).__name__} - if category in ("environment", "user"): - configs[category] = error_info - else: - if category not in configs: - configs[category] = {} - configs[category][comp_name] = error_info - - return configs -``` - -**Benchmark integration** (delegation pattern): - -```python -class Benchmark: - def __init__(self, ...): - # ... - self._registry = ComponentRegistry( - benchmark_config=gather_benchmark_config() - ) - - def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: - """Register a component. Delegates to internal registry.""" - return self._registry.register(category, name, component) - - def clear_registry(self) -> None: - """Clear registry after task repetition.""" - self._registry.clear() - - def collect_all_traces(self) -> Dict[str, Any]: - """Collect traces. Delegates to internal registry.""" - return self._registry.collect_traces() - - def collect_all_configs(self) -> Dict[str, Any]: - """Collect configs. Delegates to internal registry.""" - return self._registry.collect_configs() -``` - -**Benefits**: - -- Single Responsibility: Benchmark orchestrates, Registry tracks components -- Testability: Registry can be unit tested in isolation -- Clarity: Thread-local complexity encapsulated in one place -- Zero API changes: Users still call `benchmark.register(...)` - -### 2. Threading over asyncio - -**Decision**: Use `ThreadPoolExecutor` for parallel task execution. - -**Rationale**: - -- No user code changes required (async would require rewriting `run_agents()`) -- Works with all agent frameworks (smolagents is sync-only) -- Same I/O concurrency benefits for LLM API calls -- Future-proof: Python's GIL removal will make threading even more powerful - -### 3. MASEval-Managed Callback Thread Safety - -**Decision**: MASEval serializes all callback invocations with an internal lock. - -**Rationale**: - -- Users don't need to think about thread safety -- Negligible performance cost (callbacks are fast) -- Prevents subtle race condition bugs - -```python -class Benchmark: - def __init__(self, ...): - self._callback_lock = threading.Lock() - - def _invoke_callbacks(self, method_name: str, *args, **kwargs): - with self._callback_lock: - for cb in self.callbacks: - getattr(cb, method_name)(*args, **kwargs) -``` - -### 4. Cooperative Timeout with Hard Backstop - -**Decision**: Use cooperative checkpoint-based timeout with a hard timeout fallback. - -**Rationale**: - -- Cross-platform (signal-based only works on Unix) -- Works in threads (signals only work in main thread) -- Clean interruption at defined checkpoints -- Hard timeout as last resort for misbehaving code - -**Limitation**: Python threads cannot be forcibly killed. Timeout is "best effort." - ---- - -## Feature 1: Parallel Processing - -### Decision: Threading with `ThreadPoolExecutor` - -We use **threading** (not asyncio) for parallel task execution. - -#### Why Threading Over asyncio - -| Consideration | Threading | asyncio | -| ----------------------- | ----------------------------- | ------------------------------------ | -| User code changes | None | Must rewrite `run_agents()` as async | -| Agent framework support | All (smolagents is sync-only) | Only async-native frameworks | -| API signature | Unchanged | Breaking (`async def run()`) | -| Mental model | Familiar to most developers | Requires async expertise | -| Future GIL removal | Benefits automatically | No additional benefit | - -**asyncio would require**: - -- All user-implemented methods become `async def` -- Wrapper code for sync agent frameworks (smolagents) -- Breaking API changes throughout - -**Threading provides**: - -- Zero user code changes -- Works with all agent frameworks today -- Same I/O concurrency benefits (LLM API calls) -- Future-proof: when Python removes the GIL, threading will gain true parallelism - -#### Implementation: `ThreadPoolExecutor` - -```python -from concurrent.futures import ThreadPoolExecutor, as_completed - -def run(self, tasks, max_workers: int = 1): # max_workers=1 = sequential (default) - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {} - for task, agent_data in zip(self.tasks, agent_data_list): - future = executor.submit(self._run_single_task, task, agent_data) - futures[future] = task - - for future in as_completed(futures): - report = future.result() - self._append_report_safe(report) # Thread-safe -``` - -**Key design points**: - -1. **Backwards Compatible**: `max_workers=1` maintains current sequential behavior -2. **Framework Agnostic**: Works with sync agent frameworks (smolagents) -3. **I/O Parallelism**: Multiple LLM API calls can happen concurrently -4. **Opt-in**: Users explicitly enable parallelism - -#### Thread-Local Component Registry - -The component registry is already cleared after each task repetition. For parallel execution, we make it thread-local so concurrent tasks don't share registries: - -```python -import threading - -class Benchmark: - def __init__(self, ...): - self._local = threading.local() - - @property - def _trace_registry(self): - if not hasattr(self._local, 'trace_registry'): - self._local.trace_registry = {} - return self._local.trace_registry - - @property - def _component_id_map(self): - if not hasattr(self._local, 'component_id_map'): - self._local.component_id_map = {} - return self._local.component_id_map - - # Same pattern for _config_registry, _config_component_id_map -``` - -This is the correct design because: - -- Each task repetition runs in one thread -- Registries are ephemeral (cleared after each repetition via `clear_registry()`) -- No cross-task state sharing is intended - -#### Thread-Safe Report Collection - -```python -import threading - -class Benchmark: - def __init__(self, ...): - self._reports_lock = threading.Lock() - - def _append_report_safe(self, report): - with self._reports_lock: - self.reports.append(report) -``` - -#### Thread-Safe Callback Invocation - -MASEval serializes all callback invocations internally, so **users don't need to implement thread-safe callbacks**: - -```python -class Benchmark: - def __init__(self, ...): - self._callback_lock = threading.Lock() - - def _invoke_callbacks(self, method_name: str, *args, **kwargs): - """Invoke a callback method on all registered callbacks (thread-safe).""" - with self._callback_lock: - for cb in self.callbacks: - getattr(cb, method_name)(*args, **kwargs) -``` - -**User impact**: None. Users write callbacks exactly as they do today: - -```python -class MyCallback(BenchmarkCallback): - def __init__(self): - self.count = 0 # No lock needed! - - def on_task_repeat_end(self, benchmark, report): - self.count += 1 # Safe because MASEval serializes calls -``` - -This approach: - -- Eliminates thread-safety burden on users -- Has negligible performance cost (callbacks are fast) -- Prevents an entire class of subtle bugs - -```` - ---- - -## Feature 2: Timeout Handling & TaskProtocol - -### Design Goal - -Enable per-task timeout configuration, capturing partial results on timeout. - -### The `TaskProtocol` Concept - -A `TaskProtocol` dataclass defines task-level execution parameters. It's attached to `Task` but describes how MASEval should run the task, not task content. - -```python -from dataclasses import dataclass, field -from typing import Optional -from enum import Enum - - -class TimeoutAction(Enum): - """What to do when a timeout occurs.""" - SKIP = "skip" # Mark as timed out, continue to next task - RETRY = "retry" # Retry once with same timeout - EXTEND = "extend" # Double timeout and retry - - -@dataclass -class TaskProtocol: - """Configuration for how MASEval executes a task. - - This is a data container for execution parameters, separate from - task content (query, environment_data, etc.). It controls the - interface between the task and MASEval's execution engine. - - Attributes: - timeout_seconds: Maximum execution time for this task. None means no timeout. - timeout_action: Action to take when timeout occurs. - max_retries: Maximum retry attempts for transient failures (not timeouts). - priority: Execution priority (higher = sooner). Used by adaptive task queues. - tags: Arbitrary tags for filtering or grouping tasks. - """ - timeout_seconds: Optional[float] = None - timeout_action: TimeoutAction = TimeoutAction.SKIP - max_retries: int = 0 - priority: int = 0 - tags: dict = field(default_factory=dict) -```` - -### Attaching Protocol to Task - -```python -@dataclass -class Task: - query: str - id: UUID = field(default_factory=uuid4) - environment_data: Dict[str, Any] = field(default_factory=dict) - evaluation_data: Dict[str, Any] = field(default_factory=dict) - metadata: Dict[str, Any] = field(default_factory=dict) - - # New: execution protocol - protocol: TaskProtocol = field(default_factory=TaskProtocol) -``` - -### Timeout Implementation Strategies - -#### Strategy A: `concurrent.futures` with Timeout - -```python -from concurrent.futures import ThreadPoolExecutor, TimeoutError - -def _run_task_with_timeout(self, task, agent_data, timeout: Optional[float]): - """Run a single task with optional timeout.""" - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self._run_single_task_inner, task, agent_data) - try: - return future.result(timeout=timeout) - except TimeoutError: - # Attempt to cancel (may not stop running code) - future.cancel() - return self._create_timeout_report(task) -``` - -**Problem**: `future.cancel()` doesn't actually stop running Python code. The task continues executing in the background. - -#### Strategy B: `multiprocessing` with Termination - -```python -from multiprocessing import Process, Queue - -def _run_task_with_timeout(self, task, agent_data, timeout): - result_queue = Queue() - process = Process(target=self._run_in_process, args=(task, agent_data, result_queue)) - process.start() - process.join(timeout=timeout) - - if process.is_alive(): - process.terminate() # Actually kills the task - return self._create_timeout_report(task) - - return result_queue.get() -``` - -**Problem**: Process isolation means no shared state. Components can't be registered, traces can't be collected incrementally. - -#### Strategy C: Signal-Based Timeout (Unix only) - -```python -import signal - -class TimeoutException(Exception): - pass - -def _run_task_with_timeout(self, task, agent_data, timeout): - def handler(signum, frame): - raise TimeoutException() - - old_handler = signal.signal(signal.SIGALRM, handler) - signal.alarm(int(timeout)) - try: - return self._run_single_task_inner(task, agent_data) - except TimeoutException: - return self._create_timeout_report(task) - finally: - signal.alarm(0) - signal.signal(signal.SIGALRM, old_handler) -``` - -**Problem**: Only works on Unix. Doesn't work in threads (signal only works in main thread). - -#### Strategy D: Cooperative Timeout with Checkpoints (Recommended) - -The cleanest approach that works cross-platform and with threads is **cooperative timeout checking**: - -```python -import time -import threading - -class TaskContext: - """Execution context passed to user code for timeout checking.""" - - def __init__(self, deadline: Optional[float] = None): - self._deadline = deadline - self._start_time = time.monotonic() - - @property - def elapsed(self) -> float: - return time.monotonic() - self._start_time - - @property - def remaining(self) -> Optional[float]: - if self._deadline is None: - return None - return max(0, self._deadline - self.elapsed) - - @property - def is_expired(self) -> bool: - return self._deadline is not None and self.elapsed >= self._deadline - - def check_timeout(self): - """Raise TimeoutError if deadline exceeded. Call at checkpoints.""" - if self.is_expired: - raise TaskTimeoutError(f"Task exceeded {self._deadline}s deadline") -``` - -**Usage in `run_agents()`**: - -```python -def run_agents(self, agents, task, environment, query, context: TaskContext) -> Any: - for step in range(self.max_steps): - context.check_timeout() # Cooperative checkpoint - result = agents[0].run(query) - # ... -``` - -**Hybrid with Hard Timeout**: Combine cooperative checking with a hard timeout fallback: - -```python -def _run_task_with_timeout(self, task, agent_data, timeout): - context = TaskContext(deadline=timeout) - - with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(self._run_single_task_inner, task, agent_data, context) - try: - # Hard timeout as backstop - return future.result(timeout=timeout + 5) # Grace period - except TimeoutError: - return self._create_timeout_report(task, partial_traces=context.collected_traces) -``` - -### New Exception Type - -```python -class TaskTimeoutError(MASEvalError): - """Task execution exceeded configured timeout. - - This is classified as TASK_TIMEOUT in benchmark results, separate from - other error types. Timeout is neither agent's fault nor infrastructure's - fault—it's a resource constraint. - """ - - def __init__(self, message: str, elapsed: float, timeout: float, partial_traces: Optional[Dict] = None): - super().__init__(message, component="timeout") - self.elapsed = elapsed - self.timeout = timeout - self.partial_traces = partial_traces or {} -``` - -### New Status - -```python -class TaskExecutionStatus(Enum): - SUCCESS = "success" - AGENT_ERROR = "agent_error" - ENVIRONMENT_ERROR = "environment_error" - USER_ERROR = "user_error" - UNKNOWN_EXECUTION_ERROR = "unknown_execution_error" - EVALUATION_FAILED = "evaluation_failed" - SETUP_FAILED = "setup_failed" - TASK_TIMEOUT = "task_timeout" # NEW -``` - ---- - -## Feature 3: TaskQueue - -### Design Goal - -Replace the static `for task in tasks` loop with a queue abstraction that enables: - -1. Dynamic task ordering -2. Callback-driven scheduling (adaptive testing) -3. Priority-based execution -4. Conditional task skipping - -### The `TaskQueue` Interface - -```python -from abc import ABC, abstractmethod -from typing import Iterator, Optional, Tuple - -class TaskQueue(ABC): - """Abstract base for task scheduling strategies.""" - - @abstractmethod - def __iter__(self) -> Iterator[Tuple[Task, Dict]]: - """Yield (task, agent_data) pairs in execution order.""" - pass - - def on_task_complete(self, task: Task, report: Dict) -> None: - """Called after each task completes. Override for adaptive behavior.""" - pass - - def should_continue(self) -> bool: - """Whether to continue processing. Default: True until queue exhausted.""" - return True - - -class SequentialQueue(TaskQueue): - """Default: execute tasks in order (current behavior).""" - - def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict]): - self._tasks = list(zip(tasks, agent_data_list)) - self._index = 0 - - def __iter__(self): - for task, agent_data in self._tasks: - yield task, agent_data - - -class PriorityQueue(TaskQueue): - """Execute tasks by priority (from TaskProtocol.priority).""" - - def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict]): - paired = list(zip(tasks, agent_data_list)) - # Sort by priority descending - self._tasks = sorted(paired, key=lambda x: x[0].protocol.priority, reverse=True) - - def __iter__(self): - for task, agent_data in self._tasks: - yield task, agent_data - - -class AdaptiveQueue(TaskQueue): - """Adaptive testing: adjust task order based on results. - - Example: Item Response Theory (IRT) based testing that estimates - agent difficulty and selects optimally informative tasks. - """ - - def __init__(self, tasks: TaskCollection, agent_data_list: List[Dict]): - self._pending = list(zip(tasks, agent_data_list)) - self._completed = [] - self._agent_ability_estimate = 0.0 - - def __iter__(self): - while self._pending and self.should_continue(): - # Select next task based on estimated ability - next_task = self._select_next_task() - if next_task: - yield next_task - - def _select_next_task(self) -> Optional[Tuple[Task, Dict]]: - """Select task that maximizes information gain.""" - if not self._pending: - return None - - # IRT-based selection (simplified) - best_idx = 0 - best_info = 0 - - for idx, (task, _) in enumerate(self._pending): - difficulty = task.metadata.get("difficulty", 0.5) - # Fisher information at current ability estimate - info = self._fisher_information(difficulty, self._agent_ability_estimate) - if info > best_info: - best_info = info - best_idx = idx - - return self._pending.pop(best_idx) - - def on_task_complete(self, task: Task, report: Dict) -> None: - """Update ability estimate based on task result.""" - self._completed.append((task, report)) - self._update_ability_estimate() - - def _update_ability_estimate(self): - """Bayesian update of ability estimate.""" - # Implementation depends on IRT model - pass - - def should_continue(self) -> bool: - """Stop when estimate is precise enough.""" - return len(self._completed) < 50 # Example stopping rule -``` - -### Integration with `Benchmark.run()` - -```python -def run( - self, - tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]], - queue: Optional[TaskQueue] = None, - max_workers: int = 1, -) -> List[Dict[str, Any]]: - # Normalize tasks - task_collection = self._normalize_tasks(tasks) - agent_data_list = self._normalize_agent_data(task_collection) - - # Create queue (default: sequential) - if queue is None: - queue = SequentialQueue(task_collection, agent_data_list) - - # Callbacks - for cb in self.callbacks: - cb.on_run_start(self) - - # Execute via queue - if max_workers == 1: - self._run_sequential(queue) - else: - self._run_parallel(queue, max_workers) - - # Callbacks - for cb in self.callbacks: - cb.on_run_end(self, self.reports) - - return self.reports - -def _run_sequential(self, queue: TaskQueue): - for task, agent_data in queue: - for repeat_idx in range(self.n_task_repeats): - report = self._execute_single_repetition(task, agent_data, repeat_idx) - self.reports.append(report) - queue.on_task_complete(task, report) - - if not queue.should_continue(): - return -``` - -### Callback Integration for Adaptive Testing - -The existing `BenchmarkCallback` can be extended: - -```python -class BenchmarkCallback(ABC, TraceableMixin): - # ... existing methods ... - - def on_task_selected(self, benchmark: "Benchmark", task: "Task", queue: "TaskQueue"): - """Called when TaskQueue selects the next task to run.""" - pass - - def on_queue_decision(self, benchmark: "Benchmark", queue: "TaskQueue", should_continue: bool): - """Called when TaskQueue makes a continue/stop decision.""" - pass -``` - ---- - -## Unified Design Proposal - -### Architecture Overview - -``` -┌─────────────────────────────────────────────────────────────────┐ -│ Benchmark.run() │ -├─────────────────────────────────────────────────────────────────┤ -│ ┌───────────────────┐ │ -│ │ TaskQueue │ ← Adaptive/Priority/Sequential │ -│ │ (iterator) │ │ -│ └────────┬──────────┘ │ -│ │ yields (Task, agent_data) │ -│ ▼ │ -│ ┌────────────────────────────────────────────────────────────┐│ -│ │ ThreadPoolExecutor (max_workers) ││ -│ │ ┌──────────────────────────────────────────────────────┐ ││ -│ │ │ Worker Thread 1 │ ││ -│ │ │ ┌─────────────────────────────────────────────────┐ │ ││ -│ │ │ │ TaskContext (deadline, checkpoints) │ │ ││ -│ │ │ │ ┌─────────────────────────────────────────────┐ │ │ ││ -│ │ │ │ │ setup → execution_loop → evaluate │ │ │ ││ -│ │ │ │ │ (Task.protocol.timeout_seconds) │ │ │ ││ -│ │ │ │ └─────────────────────────────────────────────┘ │ │ ││ -│ │ │ └─────────────────────────────────────────────────┘ │ ││ -│ │ └──────────────────────────────────────────────────────┘ ││ -│ │ ┌──────────────────────────────────────────────────────┐ ││ -│ │ │ Worker Thread 2 ... │ ││ -│ │ └──────────────────────────────────────────────────────┘ ││ -│ └────────────────────────────────────────────────────────────┘│ -│ │ │ -│ ▼ reports │ -│ ┌───────────────────┐ │ -│ │ Thread-Safe │ │ -│ │ Report Collector │ │ -│ └───────────────────┘ │ -└─────────────────────────────────────────────────────────────────┘ -``` - -### Complete `run()` Implementation - -```python -def run( - self, - tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]], - queue: Optional[TaskQueue] = None, - max_workers: int = 1, -) -> List[Dict[str, Any]]: - """Run benchmark with parallel processing, timeouts, and adaptive scheduling. - - Args: - tasks: Tasks to execute. - queue: Task scheduling strategy. Default: SequentialQueue. - max_workers: Maximum parallel task executions. Default: 1 (sequential). - - Returns: - List of report dictionaries. - """ - # Normalize inputs - self.tasks = self._normalize_tasks(tasks) - agent_data_list = self._normalize_agent_data() - - # Create queue - if queue is None: - queue = SequentialQueue(self.tasks, agent_data_list) - - # Clear reports - self.reports = [] - self._reports_lock = threading.Lock() - - # Run start callbacks - for cb in self.callbacks: - cb.on_run_start(self) - - # Execute - if max_workers == 1: - self._run_sequential(queue) - else: - self._run_parallel(queue, max_workers) - - # Run end callbacks - for cb in self.callbacks: - cb.on_run_end(self, self.reports) - - return self.reports - -def _run_sequential(self, queue: TaskQueue): - """Sequential execution with timeout support.""" - for task, agent_data in queue: - for cb in self.callbacks: - cb.on_task_start(self, task) - - for repeat_idx in range(self.n_task_repeats): - report = self._execute_task_repetition(task, agent_data, repeat_idx) - self._append_report_safe(report) - queue.on_task_complete(task, report) - - for cb in self.callbacks: - cb.on_task_end(self, task, self._last_report_for_task(task)) - - if not queue.should_continue(): - break - -def _run_parallel(self, queue: TaskQueue, max_workers: int): - """Parallel execution with timeout support.""" - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {} - - # Submit initial batch - for task, agent_data in queue: - for repeat_idx in range(self.n_task_repeats): - future = executor.submit( - self._execute_task_repetition, - task, agent_data, repeat_idx - ) - futures[future] = (task, repeat_idx) - - if len(futures) >= max_workers * 2: - break # Don't over-submit - - # Process completions and submit more - while futures: - done, _ = wait(futures, return_when=FIRST_COMPLETED) - - for future in done: - task, repeat_idx = futures.pop(future) - try: - report = future.result() - except Exception as e: - report = self._create_error_report(task, repeat_idx, e) - - self._append_report_safe(report) - queue.on_task_complete(task, report) - - # Callbacks serialized internally (thread-safe for users) - self._invoke_callbacks('on_task_repeat_end', self, report) - - if not queue.should_continue(): - # Cancel remaining futures - for f in futures: - f.cancel() - return - - # Submit more work - try: - task, agent_data = next(iter(queue)) - for repeat_idx in range(self.n_task_repeats): - future = executor.submit( - self._execute_task_repetition, - task, agent_data, repeat_idx - ) - futures[future] = (task, repeat_idx) - except StopIteration: - pass - -def _execute_task_repetition( - self, - task: Task, - agent_data: Dict[str, Any], - repeat_idx: int, -) -> Dict[str, Any]: - """Execute a single task repetition with timeout handling.""" - timeout = task.protocol.timeout_seconds - context = TaskContext(deadline=timeout) - - # Thread-local registry for this execution - local_registry = {} - - try: - # Setup - environment = self.setup_environment(agent_data, task) - user = self.setup_user(agent_data, environment, task) - agents_to_run, agents_dict = self.setup_agents(agent_data, environment, task, user) - evaluators = self.setup_evaluators(environment, task, agents_to_run, user) - - # Register components (thread-local) - local_registry.update(self._register_components(environment, user, agents_dict)) - - # Execute with timeout checking - context.check_timeout() - final_answer = self.execution_loop(agents_to_run, task, environment, user, context) - - # Collect traces - traces = self._collect_traces(local_registry) - configs = self._collect_configs(local_registry) - - # Evaluate - context.check_timeout() - eval_results = self.evaluate(evaluators, agents_dict, final_answer, traces) - - return { - "task_id": str(task.id), - "repeat_idx": repeat_idx, - "status": TaskExecutionStatus.SUCCESS.value, - "traces": traces, - "config": configs, - "eval": eval_results, - } - - except TaskTimeoutError as e: - return { - "task_id": str(task.id), - "repeat_idx": repeat_idx, - "status": TaskExecutionStatus.TASK_TIMEOUT.value, - "traces": e.partial_traces, - "config": {}, - "eval": None, - "error": { - "error_type": "TaskTimeoutError", - "error_message": str(e), - "elapsed": e.elapsed, - "timeout": e.timeout, - }, - } - except AgentError as e: - # ... existing error handling - pass -``` - ---- - -## Implementation Phases - -### Phase 0: Extract ComponentRegistry (Low Risk, Do First) - -**Scope**: Extract registry logic from `Benchmark` into dedicated `ComponentRegistry` class. - -**Files Modified**: - -- `maseval/core/registry.py` (new): `ComponentRegistry` with thread-local storage -- `maseval/core/benchmark.py`: Replace inline registry with delegation to `ComponentRegistry` -- `maseval/core/__init__.py`: Export `ComponentRegistry` - -**Effort**: ~1-2 days - -**Breaking Changes**: None (public API unchanged, internal refactoring only) - -**Why first**: This refactoring is needed for clean parallel execution. Doing it first: - -- Isolates the thread-local complexity -- Makes subsequent phases simpler -- Can be tested and merged independently - -### Phase 1: TaskProtocol & Timeout (Low Risk) - -**Scope**: Add `TaskProtocol` dataclass, integrate cooperative timeout. - -**Files Modified**: - -- `maseval/core/task.py`: Add `TaskProtocol`, attach to `Task` -- `maseval/core/exceptions.py`: Add `TaskTimeoutError` -- `maseval/core/benchmark.py`: Add `TaskContext`, timeout checking in execution - -**Effort**: ~2-3 days - -**Breaking Changes**: None (new optional field with defaults) - -### Phase 2: TaskQueue Abstraction (Medium Risk) - -**Scope**: Extract task iteration into `TaskQueue`, maintain sequential default. - -**Files Modified**: - -- `maseval/core/queue.py` (new): `TaskQueue`, `SequentialQueue`, `PriorityQueue` -- `maseval/core/benchmark.py`: Refactor `run()` to use queue - -**Effort**: ~3-4 days - -**Breaking Changes**: None (signature changes are additive) - -### Phase 3: Parallel Execution (Higher Risk) - -**Scope**: Add `max_workers` parameter, thread-safe report collection, callback locking. - -**Files Modified**: - -- `maseval/core/benchmark.py`: Add `_run_parallel()`, `_invoke_callbacks()`, `_append_report_safe()` - -**Effort**: ~4-5 days - -**Breaking Changes**: None. MASEval handles all thread safety internally. - -**Note**: Requires Phase 0 (ComponentRegistry) to be complete. - -### Phase 4: AdaptiveQueue (Collaborator-Driven) - -**Scope**: Implement `AdaptiveQueue` for IRT-based adaptive testing. - -**Files Modified**: - -- `maseval/core/queue.py`: Add `AdaptiveQueue` base or concrete implementation -- `maseval/core/callback.py`: Add `on_task_selected`, `on_queue_decision` (if needed) - -**Effort**: ~3-4 days (depends on algorithm complexity) - -**Breaking Changes**: None - -**Note**: This phase will be driven by collaborator implementing their adaptive sampling paper. MASEval provides the `TaskQueue` interface; they implement the selection algorithm. - ---- - -## Risks and Mitigations - -### Risk 1: Thread Safety Bugs - -**Mitigation**: - -- Thread-local storage for per-task registries (already ephemeral per-repetition) -- Lock for shared report list -- Lock for callback invocations (users don't need to think about this) -- Default to `max_workers=1` for backwards compatibility -- Comprehensive tests with race condition detection - -### Risk 2: Framework Incompatibility - -**Mitigation**: - -- Test with all supported frameworks (smolagents, langgraph, llamaindex) -- Document that user's `run_agents()` should not rely on shared mutable benchmark state -- All current adapters are stateless per-invocation (already safe) - -### Risk 3: Timeout Incomplete Cleanup - -**Mitigation**: - -- Cooperative timeout with checkpoints (clean interruption points) -- Hard timeout as backstop—logs warning but continues gracefully -- Document that timed-out tasks may leave external resources (API connections) in undefined state -- Timeout is "best effort"—we cannot forcibly kill Python threads - -### Risk 4: Callback Ordering in Parallel Mode - -**Mitigation**: - -- In parallel mode, `on_task_repeat_end` order is non-deterministic -- Document this behavior clearly -- Callbacks are still serialized (never concurrent), just out-of-order - -### Risk 5: Memory Pressure with Many Workers - -**Mitigation**: - -- Default `max_workers=1` -- Document memory implications -- Consider `max_workers="auto"` that uses `os.cpu_count()` - ---- - -## Summary - -### Implementation Order - -| Phase | Feature | Risk | Effort | Dependencies | -| ----- | ---------------------------- | ------ | -------- | ------------ | -| 0 | ComponentRegistry extraction | Low | 1-2 days | None | -| 1 | TaskProtocol & Timeout | Low | 2-3 days | None | -| 2 | TaskQueue abstraction | Medium | 3-4 days | None | -| 3 | Parallel Execution | Higher | 4-5 days | Phase 0 | -| 4 | AdaptiveQueue | Medium | 3-4 days | Phase 2 | - -### Feature Summary - -| Feature | Approach | Breaking Changes | -| ------------------- | --------------------------------------------- | ---------------- | -| ComponentRegistry | Extracted class with thread-local state | None | -| Parallel Processing | `ThreadPoolExecutor` with `max_workers` param | None | -| Timeout Handling | Cooperative checkpoints + hard backstop | None | -| TaskQueue | Iterator abstraction with `on_task_complete` | None | -| Callback Safety | MASEval serializes with internal lock | None | - -### Key Design Decisions - -1. **Extract ComponentRegistry**: Separate concerns. Registry manages thread-local component tracking. Benchmark orchestrates execution. Enables clean parallel implementation. - -2. **Threading over asyncio**: No user code changes required. Works with all agent frameworks (including sync-only smolagents). Future-proof for Python's GIL removal. - -3. **MASEval-managed callback safety**: All callback invocations are serialized with a lock. Users never need to think about thread safety in their callbacks. - -4. **Cooperative timeout**: Cross-platform, works in threads, clean interruption at defined checkpoints. Hard timeout as backstop for misbehaving code (best-effort only—Python threads cannot be killed). - -5. **AdaptiveQueue for collaborator**: The `TaskQueue` interface enables a collaborator to implement their adaptive sampling paper. MASEval provides the hooks; they implement the algorithm. - -### What's NOT Changing - -- **Public API**: All existing methods work unchanged -- **User-implemented methods**: `run_agents()`, `setup_environment()`, etc. stay sync -- **Callback interface**: Users write callbacks exactly as today -- **Default behavior**: `max_workers=1` maintains sequential execution - -The unified design maintains **full backwards compatibility** while enabling: - -- **Faster benchmarks** through parallelism -- **Resource-bounded execution** through timeouts -- **Intelligent task selection** through adaptive queues - -All features share the same execution model refactor, making them natural to implement together. diff --git a/SUMMARY.md b/SUMMARY.md deleted file mode 100644 index 8799a7a..0000000 --- a/SUMMARY.md +++ /dev/null @@ -1,103 +0,0 @@ -# Implementation Summary: Parallel Task Execution Engine - -This document summarizes the implementation of parallel task execution, timeout handling, and task queue abstraction for MASEval, as specified in PLAN.md. - -## Phase 0: ComponentRegistry Extraction - -**New File:** `maseval/core/registry.py` - -- Created `ComponentRegistry` class that manages component registration for tracing and configuration collection -- Uses `threading.local()` for thread-local storage, enabling parallel task execution without cross-contamination between threads -- Each thread gets isolated registry state: `_trace_registry`, `_component_id_map`, `_config_registry`, `_config_component_id_map` -- Methods: `register()`, `clear()`, `collect_traces()`, `collect_configs()` -- Refactored `Benchmark` class to delegate all registry operations to `self._registry: ComponentRegistry` - -## Phase 1: TaskProtocol & Timeout Infrastructure - -**Modified:** `maseval/core/task.py` - -- Added `TimeoutAction` enum with values `SKIP`, `RETRY`, `RAISE` for configurable timeout behavior -- Added `TaskProtocol` dataclass with fields: - - `timeout_seconds: Optional[float]` - per-task timeout limit - - `timeout_action: TimeoutAction` - what to do on timeout (default: SKIP) - - `max_retries: int` - retry count for failed tasks (default: 0) - - `priority: int` - scheduling priority (default: 0, higher = more important) - - `tags: Dict[str, Any]` - arbitrary metadata for filtering/grouping -- Added `protocol: TaskProtocol` field to `Task` dataclass - -**New File:** `maseval/core/context.py` - -- Created `TaskContext` class for cooperative timeout checking -- Properties: `elapsed` (time since start), `remaining` (time until deadline), `is_expired` (bool) -- Method: `check_timeout()` raises `TaskTimeoutError` if deadline exceeded -- Designed for checkpoint-based timeout checking in user code - -**Modified:** `maseval/core/exceptions.py` - -- Added `TaskTimeoutError(MASEvalError)` with attributes: - - `elapsed: float` - how long the task ran - - `timeout: float` - the configured timeout limit - - `partial_traces: Optional[Dict]` - any traces collected before timeout - -**Modified:** `maseval/core/benchmark.py` - -- Added `TASK_TIMEOUT` to `TaskExecutionStatus` enum - -## Phase 2: TaskQueue Abstraction - -**New File:** `maseval/core/queue.py` - -- Created `TaskQueue` abstract base class with iterator interface (`__iter__`, `__next__`) -- Supports both `Task` and `TaskCollection` inputs, with automatic expansion -- Handles `n_task_repeats` by yielding `(task, repeat_idx)` tuples - -**Implementations:** - -1. `SequentialQueue` - Simple FIFO ordering, iterates tasks in input order -2. `PriorityQueue` - Uses `TaskProtocol.priority` for scheduling (higher priority first) -3. `AdaptiveQueue` - Placeholder for future feedback-based scheduling (currently falls back to sequential) - -## Phase 3: Parallel Execution - -**Modified:** `maseval/core/benchmark.py` - -- Added `max_workers: int = 1` parameter to `Benchmark.run()` for controlling parallelism -- Added `queue: Optional[TaskQueue] = None` parameter for custom scheduling (defaults to `SequentialQueue`) -- Added thread-safety mechanisms: - - `self._reports_lock: threading.Lock` for safe report collection from multiple threads - - `self._callback_lock: threading.Lock` for serialized callback invocation -- New methods: - - `_invoke_callbacks(method_name, *args, **kwargs)` - thread-safe callback invocation - - `_append_report_safe(report)` - thread-safe report collection - - `_execute_task_repetition(task, repeat_idx, context)` - single task execution with timeout support - - `_run_sequential(queue)` - sequential execution (backward compatible) - - `_run_parallel(queue, max_workers)` - parallel execution using `ThreadPoolExecutor` - -**Backward Compatibility:** - -- `max_workers=1` (default) uses `_run_sequential()`, preserving existing behavior -- `max_workers>1` uses `_run_parallel()` with thread pool - -## Phase 4: AdaptiveQueue (Placeholder) - -- `AdaptiveQueue` class created as placeholder for collaborator implementation -- Intended for feedback-based scheduling that reorders remaining tasks based on execution results -- Currently falls back to sequential iteration - -## Updated Exports - -**Modified:** `maseval/__init__.py` - -New public exports: - -- `TaskProtocol`, `TimeoutAction` - task execution configuration -- `ComponentRegistry` - thread-safe component registration -- `TaskContext` - timeout checking context -- `TaskQueue`, `SequentialQueue`, `PriorityQueue`, `AdaptiveQueue` - scheduling abstractions -- `TaskTimeoutError` - timeout exception - -## Test Updates - -- Updated 2 test files that accessed internal registry attributes (`_trace_registry`, `_component_id_map`, `_config_registry`) -- Changed to access through `benchmark._registry._trace_registry` pattern -- All 666 tests pass diff --git a/TEST_PLAN.md b/TEST_PLAN.md deleted file mode 100644 index f4a1ddd..0000000 --- a/TEST_PLAN.md +++ /dev/null @@ -1,359 +0,0 @@ -# Test Plan: Parallel Task Execution Engine - -This document outlines the testing strategy for the parallel execution implementation. It covers new tests to add, existing tests to adapt, and tests that can be removed. - ---- - -## 1. New Tests to Add - -### 1.1 ComponentRegistry Tests (`tests/test_core/test_registry.py`) - -**Thread Safety Tests:** - -- `test_registry_thread_isolation` - Verify that registrations in one thread don't appear in another thread -- `test_registry_concurrent_registration` - Multiple threads registering components simultaneously without data races -- `test_registry_concurrent_collect_traces` - Multiple threads calling `collect_traces()` simultaneously -- `test_registry_concurrent_collect_configs` - Multiple threads calling `collect_configs()` simultaneously -- `test_registry_clear_only_affects_current_thread` - Calling `clear()` in one thread doesn't affect other threads - -**Basic Functionality Tests:** - -- `test_registry_register_traceable_component` - Component registered for tracing -- `test_registry_register_configurable_component` - Component also registered in config registry -- `test_registry_duplicate_key_idempotent` - Same component, same key is idempotent -- `test_registry_duplicate_component_different_key_raises` - Same component, different key raises ValueError -- `test_registry_collect_traces_structure` - Verify trace output structure -- `test_registry_collect_configs_structure` - Verify config output structure -- `test_registry_benchmark_config_included` - benchmark_config passed to constructor appears in configs - -### 1.2 TaskContext Tests (`tests/test_core/test_context.py`) - -**Timeout Behavior Tests:** - -- `test_context_no_timeout` - Context without deadline never expires -- `test_context_with_timeout_not_expired` - Context before deadline shows remaining time -- `test_context_with_timeout_expired` - Context after deadline shows is_expired=True -- `test_context_check_timeout_raises_on_expiry` - `check_timeout()` raises TaskTimeoutError when expired -- `test_context_check_timeout_with_partial_traces` - TaskTimeoutError includes partial traces -- `test_context_elapsed_increases` - `elapsed` property increases over time -- `test_context_remaining_decreases` - `remaining` property decreases over time - -### 1.3 TaskQueue Tests (`tests/test_core/test_queue.py`) - -**SequentialQueue Tests:** - -- `test_sequential_queue_order_preserved` - Tasks yielded in original order -- `test_sequential_queue_iteration_complete` - All tasks yielded exactly once -- `test_sequential_queue_empty_collection` - Empty collection yields nothing -- `test_sequential_queue_single_task` - Single task handled correctly - -**PriorityQueue Tests:** - -- `test_priority_queue_high_priority_first` - Higher priority tasks come first -- `test_priority_queue_stable_sort` - Equal priority maintains original order -- `test_priority_queue_default_priority_zero` - Tasks without explicit priority treated as 0 -- `test_priority_queue_negative_priority` - Negative priorities handled correctly - -**AdaptiveQueue Tests:** - -- `test_adaptive_queue_on_task_complete_updates_state` - Completed tasks moved to completed list -- `test_adaptive_queue_stop_terminates_iteration` - Calling `stop()` ends iteration early -- `test_adaptive_queue_should_continue_false_after_stop` - `should_continue()` returns False after stop - -### 1.4 TaskProtocol Tests (`tests/test_core/test_task_protocol.py`) - -- `test_task_protocol_defaults` - Default values: timeout=None, action=SKIP, retries=0, priority=0, tags={} -- `test_task_has_protocol_field` - Task dataclass has protocol field -- `test_task_protocol_custom_values` - Custom protocol values preserved -- `test_timeout_action_enum_values` - TimeoutAction has SKIP, RETRY, RAISE - -### 1.5 Parallel Execution Tests (`tests/test_core/test_benchmark/test_parallel_execution.py`) - -**Basic Parallel Execution:** - -- `test_parallel_execution_basic` - `max_workers>1` runs tasks in parallel -- `test_parallel_execution_same_results_as_sequential` - Parallel produces same reports as sequential -- `test_parallel_execution_max_workers_respected` - No more than max_workers concurrent threads -- `test_parallel_execution_single_worker_uses_sequential` - `max_workers=1` uses `_run_sequential` - -**Thread Safety - Report Collection:** - -- `test_parallel_reports_thread_safe` - Reports from parallel tasks all collected correctly -- `test_parallel_report_count_matches_task_count` - Number of reports equals tasks × repeats -- `test_parallel_report_order_independent` - Report content correct regardless of completion order - -**Thread Safety - Callbacks:** - -- `test_parallel_callbacks_serialized` - Callbacks invoked with lock (no concurrent callback execution) -- `test_parallel_callback_data_integrity` - Callback receives correct task/report data -- `test_parallel_callbacks_all_events_fire` - All lifecycle callbacks fire for each task -- `test_parallel_callback_exception_isolated` - Exception in one callback doesn't affect other tasks - -**Thread Safety - Registry:** - -- `test_parallel_registry_isolation` - Each task gets isolated registry state -- `test_parallel_traces_not_cross_contaminated` - Traces from task A don't appear in task B's report -- `test_parallel_configs_not_cross_contaminated` - Configs from task A don't appear in task B's report - -**Race Condition Tests:** - -- `test_parallel_concurrent_setup` - Multiple tasks calling setup methods simultaneously -- `test_parallel_concurrent_evaluation` - Multiple tasks being evaluated simultaneously -- `test_parallel_slow_fast_task_ordering` - Slow task in worker doesn't block fast task reports -- `test_parallel_error_in_one_task_doesnt_affect_others` - One task failing doesn't corrupt other tasks - -### 1.6 Timeout Handling Tests (`tests/test_core/test_benchmark/test_timeout_handling.py`) - -- `test_timeout_task_marked_as_timeout_status` - Timed out task has `TASK_TIMEOUT` status -- `test_timeout_partial_traces_collected` - Traces collected up to timeout point included in report -- `test_timeout_action_skip_continues_to_next` - SKIP action moves to next task -- `test_timeout_action_retry_retries_task` - RETRY action re-executes (up to max_retries) -- `test_timeout_action_raise_propagates` - RAISE action raises TaskTimeoutError -- `test_timeout_cooperative_checkpoint` - Tasks checking `context.check_timeout()` respect timeout - -### 1.7 Queue Integration Tests (`tests/test_core/test_benchmark/test_queue_integration.py`) - -- `test_run_with_custom_queue` - `benchmark.run(tasks, queue=custom_queue)` uses provided queue -- `test_run_default_queue_is_sequential` - No queue specified uses SequentialQueue -- `test_priority_queue_integration` - PriorityQueue orders execution correctly in real benchmark -- `test_queue_on_task_complete_called` - Queue's `on_task_complete` called after each task -- `test_queue_should_continue_checked` - Queue's `should_continue` checked after each task - -### 1.8 TaskTimeoutError Tests (`tests/test_core/test_exceptions.py` - extend existing) - -- `test_task_timeout_error_attributes` - Has elapsed, timeout, partial_traces attributes -- `test_task_timeout_error_message` - Message includes timeout and elapsed time -- `test_task_timeout_error_is_maseval_error` - Inherits from MASEvalError - ---- - -## 2. Existing Tests to Adapt - -### 2.1 Tests Already Adapted (completed) - -- `tests/test_core/test_benchmark/test_automatic_registration.py` - - Changed `benchmark._trace_registry` → `benchmark._registry._trace_registry` - - Changed `benchmark._component_id_map` → `benchmark._registry._component_id_map` -- `tests/test_core/test_benchmark/test_benchmark_lifecycle.py` - - Changed `benchmark._trace_registry` → `benchmark._registry._trace_registry` - - Changed `benchmark._config_registry` → `benchmark._registry._config_registry` - -### 2.2 Tests That May Need Adaptation - -**Callback Tests (`test_callback_orchestration.py`):** - -- Review `test_callback_errors_dont_break_execution` - Ensure behavior consistent with parallel mode -- Consider adding parallel variant of each callback order test - -**Lifecycle Tests (`test_benchmark_lifecycle.py`):** - -- `test_benchmark_lifecycle_hooks_order` - Verify order still guaranteed in sequential mode -- Add note/variant about callback order in parallel mode (order within task preserved, between tasks not) - -**Exception Tests (`test_exceptions.py`):** - -- Extend classification tests to include `TaskTimeoutError` → `TASK_TIMEOUT` mapping - -**Config Collection Tests (`test_config_collection.py`):** - -- Verify config collection works correctly in parallel mode -- `test_config_different_per_repetition` - May need thread-awareness verification - ---- - -## 3. Tests That Can Be Removed - -### 3.1 No Tests to Remove - -The implementation maintains backward compatibility (`max_workers=1` default), so all existing tests remain valid. No tests are obsoleted by this change. - -### 3.2 Tests That Could Be Consolidated (Optional Cleanup) - -- Some registry-related tests in `test_automatic_registration.py` and `test_benchmark_lifecycle.py` overlap in testing registry clearing. Consider consolidating into a single registry test file. - ---- - -## 4. Test Categories and Markers - -### New Pytest Markers to Consider - -```python -# conftest.py additions -pytest.mark.parallel # Tests specific to parallel execution -pytest.mark.thread_safety # Tests for race conditions and thread safety -pytest.mark.timeout # Tests for timeout handling -pytest.mark.queue # Tests for task queue abstraction -``` - -### Marker Usage - -```python -@pytest.mark.core -@pytest.mark.parallel -def test_parallel_execution_basic(): - ... - -@pytest.mark.core -@pytest.mark.thread_safety -def test_parallel_registry_isolation(): - ... -``` - ---- - -## 5. Test Infrastructure Needs - -### 5.1 New Test Fixtures - -```python -# conftest.py additions - -@pytest.fixture -def slow_benchmark(): - """Benchmark that takes configurable time per task (for parallel testing).""" - class SlowBenchmark(DummyBenchmark): - def __init__(self, delay_seconds=0.1, **kwargs): - super().__init__(**kwargs) - self.delay = delay_seconds - - def run_agents(self, agents, task, environment, query): - import time - time.sleep(self.delay) - return super().run_agents(agents, task, environment, query) - - return SlowBenchmark - -@pytest.fixture -def thread_tracking_callback(): - """Callback that records which thread each event fires on.""" - import threading - - class ThreadTracker(BenchmarkCallback): - def __init__(self): - self.thread_ids = [] - - def on_task_repeat_start(self, benchmark, task, repeat_idx): - self.thread_ids.append(threading.current_thread().ident) - - return ThreadTracker -``` - -### 5.2 Helper Functions - -```python -def run_parallel_and_sequential(benchmark, tasks): - """Run same benchmark both ways and compare reports.""" - import copy - - seq_benchmark = copy.deepcopy(benchmark) - par_benchmark = copy.deepcopy(benchmark) - - seq_reports = seq_benchmark.run(tasks, max_workers=1) - par_reports = par_benchmark.run(tasks, max_workers=4) - - return seq_reports, par_reports - -def verify_no_cross_contamination(reports): - """Check that traces in each report only contain that task's data.""" - for report in reports: - task_id = report['task_id'] - for key, trace in report['traces'].get('agents', {}).items(): - # Verify trace belongs to this task - assert task_id in str(trace) or 'task_id' not in trace -``` - ---- - -## 6. Priority Order for Implementation - -### High Priority (Core Functionality) - -1. `test_parallel_execution.py` - Basic parallel execution verification -2. `test_registry.py` - Thread isolation is critical for correctness -3. `test_timeout_handling.py` - Timeout is a key new feature - -### Medium Priority (Integration) - -4. `test_queue.py` - Queue abstraction tests -5. `test_queue_integration.py` - Queue + Benchmark integration -6. `test_context.py` - TaskContext functionality - -### Lower Priority (Edge Cases) - -7. `test_task_protocol.py` - Simple dataclass tests -8. Extended race condition tests -9. Performance/stress tests - ---- - -## 7. Notes for Test Implementation - -### Thread Safety Testing Patterns - -```python -import threading -import time -from concurrent.futures import ThreadPoolExecutor - -def test_concurrent_operation(): - """Pattern for testing concurrent operations.""" - results = [] - errors = [] - barrier = threading.Barrier(4) # Synchronize thread start - - def worker(worker_id): - try: - barrier.wait() # All threads start together - # Perform operation - result = do_something() - results.append((worker_id, result)) - except Exception as e: - errors.append((worker_id, e)) - - with ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(worker, i) for i in range(4)] - for f in futures: - f.result() # Wait for completion - - assert len(errors) == 0, f"Errors occurred: {errors}" - assert len(results) == 4 -``` - -### Timing Considerations - -- Use `time.sleep()` sparingly in tests -- Consider mocking time for deterministic timeout tests -- Use threading barriers for synchronization points -- Allow tolerance in timing assertions (e.g., ±10ms) - -### Isolation Verification - -```python -def test_registry_isolation(): - """Verify thread-local storage works correctly.""" - registry = ComponentRegistry() - results = {} - - def worker(worker_id): - # Each thread should see empty registry initially - assert len(registry._trace_registry) == 0 - - # Register unique component - registry.register("test", f"comp_{worker_id}", MockComponent()) - - # Only our component should be visible - assert len(registry._trace_registry) == 1 - assert f"test:comp_{worker_id}" in registry._trace_registry - - results[worker_id] = list(registry._trace_registry.keys()) - - # Run in parallel - with ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(worker, i) for i in range(4)] - for f in futures: - f.result() - - # Verify isolation - for worker_id, keys in results.items(): - assert keys == [f"test:comp_{worker_id}"] -``` From c45c75e9a2d9d2786b4163d7b3327cb5525b9d0a Mon Sep 17 00:00:00 2001 From: cemde Date: Sun, 21 Dec 2025 11:25:49 +0000 Subject: [PATCH 13/25] [skip ci] updated uv lock --- uv.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/uv.lock b/uv.lock index c085e95..085ed4b 100644 --- a/uv.lock +++ b/uv.lock @@ -2517,12 +2517,12 @@ dev = [ { name = "ty", specifier = ">=0.0.5" }, ] docs = [ - { name = "mkdocs", specifier = ">=1.5" }, + { name = "mkdocs", specifier = ">=1.6" }, { name = "mkdocs-git-revision-date-localized-plugin", specifier = ">=1.5.0" }, { name = "mkdocs-jupyter", specifier = ">=0.24.0" }, - { name = "mkdocs-material", specifier = ">=9.0.0" }, - { name = "mkdocstrings", extras = ["python"], specifier = ">=0.17.0" }, - { name = "pymdown-extensions", specifier = ">=9.0.0" }, + { name = "mkdocs-material", specifier = ">=9.7.0" }, + { name = "mkdocstrings", extras = ["python"], specifier = ">=1.0.0" }, + { name = "pymdown-extensions", specifier = ">=10.0.0" }, ] [[package]] From b5c0978072bcccd6e09c72e93755db303f63bc43 Mon Sep 17 00:00:00 2001 From: cemde Date: Sun, 21 Dec 2025 12:04:18 +0000 Subject: [PATCH 14/25] fixed type checking issues. --- maseval/benchmark/macs/data_loader.py | 66 +++++++++++-------- maseval/benchmark/macs/macs.py | 20 +++--- maseval/core/benchmark.py | 10 +-- maseval/core/simulator.py | 2 +- maseval/core/task.py | 7 +- tests/test_benchmarks/test_macs/conftest.py | 2 +- .../test_macs/test_data_loader.py | 17 ++--- .../test_macs/test_macs_benchmark.py | 2 +- .../test_macs/test_macs_integration.py | 2 +- .../test_agent_adapter_contract.py | 2 +- .../test_contract/test_collection_contract.py | 2 +- .../test_benchmark/test_execution_loop.py | 16 ++--- tests/test_core/test_llm_simulator.py | 5 +- tests/test_core/test_queue.py | 2 +- tests/test_core/test_user.py | 28 ++++---- .../test_langgraph_integration.py | 8 +-- 16 files changed, 100 insertions(+), 91 deletions(-) diff --git a/maseval/benchmark/macs/data_loader.py b/maseval/benchmark/macs/data_loader.py index 94ec1bb..1c80a31 100644 --- a/maseval/benchmark/macs/data_loader.py +++ b/maseval/benchmark/macs/data_loader.py @@ -27,24 +27,23 @@ # AWS Multi-Agent Collaboration Scenarios benchmark data # Source: https://github.com/aws-samples/multiagent-collab-scenario-benchmark -URLS = { - "data": { - "software": { - "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/agents.json", - "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/scenarios_30.json", - }, - "travel": { - "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/agents.json", - "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/scenarios_30.json", - }, - "mortgage": { - "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/agents.json", - "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/scenarios_30.json", - }, +DATA_URLS: Dict[str, Dict[str, str]] = { + "software": { + "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/agents.json", + "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/scenarios_30.json", }, - "evaluation": { - "prompt_templates": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/src/prompt_templates.py", + "travel": { + "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/agents.json", + "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/scenarios_30.json", }, + "mortgage": { + "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/agents.json", + "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/scenarios_30.json", + }, +} + +EVALUATION_URLS: Dict[str, str] = { + "prompt_templates": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/src/prompt_templates.py", } @@ -90,16 +89,16 @@ def download_original_data( data_dir = Path(data_dir) if data_dir else DEFAULT_DATA_DIR original_dir = data_dir / "original" - domains = [domain] if domain else list(URLS["data"].keys()) + domains = [domain] if domain else list(DATA_URLS.keys()) for d in domains: - if d not in URLS["data"]: + if d not in DATA_URLS: raise ValueError(f"Unknown domain: {d}") domain_dir = original_dir / d domain_dir.mkdir(parents=True, exist_ok=True) - for name, url in URLS["data"][d].items(): + for name, url in DATA_URLS[d].items(): content = download_json(url) out_path = domain_dir / f"{name}.json" with out_path.open("w") as f: @@ -132,7 +131,7 @@ def download_prompt_templates( templates_dir = data_dir.parent / "prompt_templates" templates_dir.mkdir(parents=True, exist_ok=True) - url = URLS["evaluation"]["prompt_templates"] + url = EVALUATION_URLS["prompt_templates"] text = download_file(url) # Parse Python file to extract prompt constants @@ -239,12 +238,15 @@ def _dedupe_tools_by_name(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return deduped -def _create_tools_list(agents_obj: object) -> List[Dict[str, Any]]: +def _create_tools_list(agents_obj: Union[Dict[str, Any], List[Any]]) -> List[Dict[str, Any]]: """Extract and deduplicate tools from agents data.""" tools: List[Dict[str, Any]] = [] - if isinstance(agents_obj, dict) and isinstance(agents_obj.get("agents"), list): - agents_list = agents_obj["agents"] + agents_list: List[Any] + if isinstance(agents_obj, dict): + agents_list = agents_obj.get("agents", []) + if not isinstance(agents_list, list): + return tools elif isinstance(agents_obj, list): agents_list = agents_obj else: @@ -260,7 +262,7 @@ def _create_tools_list(agents_obj: object) -> List[Dict[str, Any]]: return _dedupe_tools_by_name(tools) -def _create_agents_list(agents_obj: object) -> Dict[str, Any]: +def _create_agents_list(agents_obj: Union[Dict[str, Any], List[Any]]) -> Dict[str, Any]: """Create agents config with tool names only (not full tool dicts).""" def _process_agent(agent: Dict[str, Any]) -> Dict[str, Any]: @@ -269,8 +271,11 @@ def _process_agent(agent: Dict[str, Any]) -> Dict[str, Any]: a_copy["tools"] = tool_names return a_copy - if isinstance(agents_obj, dict) and isinstance(agents_obj.get("agents"), list): - processed = [_process_agent(a) for a in agents_obj["agents"] if isinstance(a, dict)] + if isinstance(agents_obj, dict): + agents_list = agents_obj.get("agents") + if not isinstance(agents_list, list): + return {} + processed = [_process_agent(a) for a in agents_list if isinstance(a, dict)] out: Dict[str, Any] = {"agents": processed} if "primary_agent_id" in agents_obj: out["primary_agent_id"] = agents_obj["primary_agent_id"] @@ -281,12 +286,15 @@ def _process_agent(agent: Dict[str, Any]) -> Dict[str, Any]: return {} -def _create_tasks_list(scenarios_obj: object, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def _create_tasks_list(scenarios_obj: Union[Dict[str, Any], List[Any]], tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert scenarios to task format with sequential IDs.""" tasks: List[Dict[str, Any]] = [] - if isinstance(scenarios_obj, dict) and isinstance(scenarios_obj.get("scenarios"), list): - scenarios_list = scenarios_obj["scenarios"] + scenarios_list: List[Any] + if isinstance(scenarios_obj, dict): + scenarios_list = scenarios_obj.get("scenarios", []) + if not isinstance(scenarios_list, list): + return tasks elif isinstance(scenarios_obj, list): scenarios_list = scenarios_obj else: diff --git a/maseval/benchmark/macs/macs.py b/maseval/benchmark/macs/macs.py index 5ab834d..ab8c053 100644 --- a/maseval/benchmark/macs/macs.py +++ b/maseval/benchmark/macs/macs.py @@ -798,10 +798,10 @@ def tool_model_factory(tool_name: str) -> ModelAdapter: model_factory=tool_model_factory, ) - def setup_user( + def setup_user( # ty: ignore[invalid-method-override] self, agent_data: Dict[str, Any], - environment: MACSEnvironment, # type: ignore[override] + environment: MACSEnvironment, task: Task, ) -> MACSUser: """Create MACS user simulator. @@ -834,10 +834,10 @@ def setup_user( ) @abstractmethod - def setup_agents( + def setup_agents( # ty: ignore[invalid-method-override] self, agent_data: Dict[str, Any], - environment: MACSEnvironment, # type: ignore[override] + environment: MACSEnvironment, task: Task, user: Optional[User], ) -> Tuple[List[AgentAdapter], Dict[str, AgentAdapter]]: @@ -854,9 +854,9 @@ def setup_agents( """ pass - def setup_evaluators( + def setup_evaluators( # ty: ignore[invalid-method-override] self, - environment: MACSEnvironment, # type: ignore[override] + environment: MACSEnvironment, task: Task, agents: Sequence[AgentAdapter], user: Optional[User], @@ -886,11 +886,11 @@ def setup_evaluators( ), ] - def run_agents( + def run_agents( # ty: ignore[invalid-method-override] self, agents: Sequence[AgentAdapter], task: Task, - environment: MACSEnvironment, # type: ignore[override] + environment: MACSEnvironment, query: str = "", ) -> Any: """Execute agents and return final answer.""" @@ -926,7 +926,9 @@ def evaluate( user_result = results[0] if results else {"gsr": 0.0, "partial_gsr": 0.0, "report": []} system_result = results[1] if len(results) > 1 else {"gsr": 0.0, "partial_gsr": 0.0, "report": []} - combined_report = user_result.get("report", []) + system_result.get("report", []) + user_report: List[Dict[str, Any]] = user_result.get("report", []) # type: ignore[assignment] + system_report: List[Dict[str, Any]] = system_result.get("report", []) # type: ignore[assignment] + combined_report = user_report + system_report # Compute overall metrics per AWS paper overall_gsr = 1.0 if (user_result.get("gsr", 0.0) == 1.0 and system_result.get("gsr", 0.0) == 1.0) else 0.0 diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index ea21be8..8ef2063 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -1445,10 +1445,10 @@ def run( self.reports = [] # Auto-register queue as callback if it's a BenchmarkCallback (e.g., AdaptiveTaskQueue) - queue_was_added_as_callback = False + queue_as_callback: Optional[BenchmarkCallback] = None if isinstance(queue, BenchmarkCallback) and queue not in self.callbacks: - self.callbacks.append(queue) - queue_was_added_as_callback = True + queue_as_callback = queue + self.callbacks.append(queue_as_callback) try: # Callbacks at the start of the run @@ -1464,8 +1464,8 @@ def run( self._invoke_callbacks("on_run_end", self, self.reports) finally: # Remove queue from callbacks if we added it - if queue_was_added_as_callback: - self.callbacks.remove(queue) + if queue_as_callback is not None: + self.callbacks.remove(queue_as_callback) return self.reports diff --git a/maseval/core/simulator.py b/maseval/core/simulator.py index edeed8b..1124707 100644 --- a/maseval/core/simulator.py +++ b/maseval/core/simulator.py @@ -491,7 +491,7 @@ def _create_error( component="user_simulator", ) - def __call__( + def __call__( # ty: ignore[invalid-method-override] self, conversation_history: List[Dict[str, str]], generation_params: Optional[Dict[str, Any]] = None, diff --git a/maseval/core/task.py b/maseval/core/task.py index 83c6166..a9c0aaa 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -115,7 +115,9 @@ def __getitem__(self, idx: int) -> Task: ... @overload def __getitem__(self, idx: slice) -> "BaseTaskQueue": ... - def __getitem__(self, idx: Union[int, slice]) -> Union[Task, "BaseTaskQueue"]: + def __getitem__( # ty: ignore[invalid-method-override] + self, idx: Union[int, slice] + ) -> Union[Task, "BaseTaskQueue"]: """Get a task by index or a slice of tasks. Args: @@ -191,7 +193,8 @@ def from_list(cls, data: Iterable[Union[Task, dict]]) -> "BaseTaskQueue": ) ) else: - query = item.get("question") or item.get("prompt") or item.get("query") or "" + query_val = item.get("question") or item.get("prompt") or item.get("query") or "" + query = str(query_val) if query_val else "" environment_data = ( item.get("environment_data") or {"text_content": item.get("text")} if item.get("text") diff --git a/tests/test_benchmarks/test_macs/conftest.py b/tests/test_benchmarks/test_macs/conftest.py index 081706b..f19fbca 100644 --- a/tests/test_benchmarks/test_macs/conftest.py +++ b/tests/test_benchmarks/test_macs/conftest.py @@ -144,7 +144,7 @@ def get_model_adapter(self, model_id: str, **kwargs): return adapter - def setup_agents( + def setup_agents( # ty: ignore[invalid-method-override] self, agent_data: Dict[str, Any], environment: MACSEnvironment, diff --git a/tests/test_benchmarks/test_macs/test_data_loader.py b/tests/test_benchmarks/test_macs/test_data_loader.py index fdbfbbe..96c0f1b 100644 --- a/tests/test_benchmarks/test_macs/test_data_loader.py +++ b/tests/test_benchmarks/test_macs/test_data_loader.py @@ -12,7 +12,8 @@ from maseval.benchmark.macs.data_loader import ( DEFAULT_DATA_DIR, VALID_DOMAINS, - URLS, + DATA_URLS, + EVALUATION_URLS, download_file, download_json, download_original_data, @@ -175,7 +176,6 @@ def test_empty_input(self): """Empty or invalid input returns empty list.""" assert _create_tools_list({}) == [] assert _create_tools_list([]) == [] - assert _create_tools_list(None) == [] @pytest.mark.benchmark @@ -569,16 +569,13 @@ def mock_download_file(url: str, timeout=15): assert len(config["agents"]) == 2 def test_urls_structure(self): - """Verify URLS constant has expected structure.""" - assert "data" in URLS - assert "evaluation" in URLS - + """Verify URL constants have expected structure.""" for domain in VALID_DOMAINS: - assert domain in URLS["data"] - assert "agents" in URLS["data"][domain] - assert "scenarios" in URLS["data"][domain] + assert domain in DATA_URLS + assert "agents" in DATA_URLS[domain] + assert "scenarios" in DATA_URLS[domain] - assert "prompt_templates" in URLS["evaluation"] + assert "prompt_templates" in EVALUATION_URLS # ============================================================================= diff --git a/tests/test_benchmarks/test_macs/test_macs_benchmark.py b/tests/test_benchmarks/test_macs/test_macs_benchmark.py index 56e3896..f8ba1d6 100644 --- a/tests/test_benchmarks/test_macs/test_macs_benchmark.py +++ b/tests/test_benchmarks/test_macs/test_macs_benchmark.py @@ -183,7 +183,7 @@ def __init__(self, model_factory, **kwargs): def get_model_adapter(self, model_id: str, **kwargs): return self._model_factory(model_id) - def setup_agents( + def setup_agents( # ty: ignore[invalid-method-override] self, agent_data: Dict[str, Any], environment: MACSEnvironment, diff --git a/tests/test_benchmarks/test_macs/test_macs_integration.py b/tests/test_benchmarks/test_macs/test_macs_integration.py index 247eedc..dac538c 100644 --- a/tests/test_benchmarks/test_macs/test_macs_integration.py +++ b/tests/test_benchmarks/test_macs/test_macs_integration.py @@ -141,7 +141,7 @@ def test_loaded_agent_config_works_with_environment(self, macs_model_factory): # Get tools for agent from config agent_spec = agent_config["agents"][0] - agent_tools = env.get_tools_for_agent(agent_spec) + agent_tools = env.get_tools_for_agent(agent_spec) # type: ignore[arg-type] assert "action1" in agent_tools diff --git a/tests/test_contract/test_agent_adapter_contract.py b/tests/test_contract/test_agent_adapter_contract.py index 810dc56..6521977 100644 --- a/tests/test_contract/test_agent_adapter_contract.py +++ b/tests/test_contract/test_agent_adapter_contract.py @@ -119,7 +119,7 @@ def agent_node(state: State) -> State: return {"messages": messages + [AIMessage(content=response)]} # Build graph - graph = StateGraph(State) + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) diff --git a/tests/test_contract/test_collection_contract.py b/tests/test_contract/test_collection_contract.py index 6e77b8b..497a689 100644 --- a/tests/test_contract/test_collection_contract.py +++ b/tests/test_contract/test_collection_contract.py @@ -151,7 +151,7 @@ class State(TypedDict): def agent_node(state: State) -> State: return {"messages": state["messages"] + [AIMessage(content="Response")]} - graph = StateGraph(State) + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) diff --git a/tests/test_core/test_benchmark/test_execution_loop.py b/tests/test_core/test_benchmark/test_execution_loop.py index 9df47e6..040efc7 100644 --- a/tests/test_core/test_benchmark/test_execution_loop.py +++ b/tests/test_core/test_benchmark/test_execution_loop.py @@ -147,7 +147,7 @@ def test_uses_get_initial_query_if_no_initial_query(self, dummy_model): task = Task(query="Task query", environment_data={}) user = DummyUser(name="test", model=dummy_model, max_turns=5) # No initial_query, so messages is empty - user.simulator.return_value = "LLM generated initial query" + user.simulator.return_value = "LLM generated initial query" # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark(return_user=user) @@ -172,7 +172,7 @@ def test_multi_turn_interaction(self, dummy_model): max_turns=5, ) # User responds with different messages each turn - user.simulator.side_effect = ["Turn 1 response", "Turn 2 response", "Turn 3 response"] + user.simulator.side_effect = ["Turn 1 response", "Turn 2 response", "Turn 3 response"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( return_user=user, @@ -204,7 +204,7 @@ def test_stops_when_user_done_via_max_turns(self, dummy_model): initial_query="Start", # Counts as turn 1 max_turns=3, # User done after 3 user messages ) - user.simulator.side_effect = ["Response 1", "Response 2", "Response 3"] + user.simulator.side_effect = ["Response 1", "Response 2", "Response 3"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( return_user=user, @@ -234,7 +234,7 @@ def test_stops_when_user_done_via_stop_token(self, dummy_model): early_stopping_condition="goals are met", ) # User stops on second response - user.simulator.side_effect = ["Continue please", "Thanks! "] + user.simulator.side_effect = ["Continue please", "Thanks! "] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( return_user=user, @@ -260,7 +260,7 @@ def test_final_answer_in_user_messages(self, dummy_model): initial_query="Help me", max_turns=2, # Allow initial + one response ) - user.simulator.return_value = "Thanks" + user.simulator.return_value = "Thanks" # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark(return_user=user) @@ -288,7 +288,7 @@ def test_user_response_becomes_next_query(self, dummy_model): max_turns=4, # Allow 4 turns total ) # User responses for turn 2, 3, 4 - user.simulator.side_effect = ["User reply 1", "User reply 2", "User reply 3"] + user.simulator.side_effect = ["User reply 1", "User reply 2", "User reply 3"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( return_user=user, @@ -372,7 +372,7 @@ def test_run_with_user_uses_execution_loop(self, dummy_model): initial_query="User query", max_turns=1, ) - user.simulator.return_value = "Done" + user.simulator.return_value = "Done" # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark(return_user=user) @@ -394,7 +394,7 @@ def test_complete_traces_with_user(self, dummy_model): initial_query="Hello", # Turn 1 max_turns=3, # Allow 3 user messages total ) - user.simulator.side_effect = ["Reply 1", "Reply 2"] + user.simulator.side_effect = ["Reply 1", "Reply 2"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( return_user=user, diff --git a/tests/test_core/test_llm_simulator.py b/tests/test_core/test_llm_simulator.py index 6def723..e775ecf 100644 --- a/tests/test_core/test_llm_simulator.py +++ b/tests/test_core/test_llm_simulator.py @@ -3,7 +3,6 @@ These tests verify that LLMSimulator retry logic and tracing work correctly. """ -from typing import cast import pytest from maseval.core.simulator import ( @@ -69,7 +68,7 @@ def test_llm_simulator_parsing_error_retry(self, dummy_model): simulator(actual_inputs={"param": "test"}) # Verify exception details - err = cast(ToolSimulatorError, exc_info.value) + err = exc_info.value assert err.attempts == 3 assert err.last_error is not None assert len(err.logs) == 3 # All 3 attempts in exception logs @@ -94,7 +93,7 @@ def test_llm_simulator_max_attempts_respected(self, dummy_model): simulator(actual_inputs={"param": "test"}) # Should stop after 2 attempts, not continue to 10 - err = cast(ToolSimulatorError, exc_info.value) + err = exc_info.value assert len(simulator.logs) == 2 assert err.attempts == 2 diff --git a/tests/test_core/test_queue.py b/tests/test_core/test_queue.py index cf0d9f0..5c6011b 100644 --- a/tests/test_core/test_queue.py +++ b/tests/test_core/test_queue.py @@ -115,7 +115,7 @@ def test_from_list_with_dicts(self): def test_from_list_type_error(self): """from_list should raise TypeError for invalid items.""" with pytest.raises(TypeError, match="expects Task or dict"): - SequentialTaskQueue.from_list(["not a task"]) + SequentialTaskQueue.from_list(["not a task"]) # type: ignore[arg-type] # intentional def test_from_json_file(self, tmp_path): """from_json_file should load tasks from JSON file.""" diff --git a/tests/test_core/test_user.py b/tests/test_core/test_user.py index 1db0ab1..f9d1aab 100644 --- a/tests/test_core/test_user.py +++ b/tests/test_core/test_user.py @@ -217,7 +217,7 @@ def test_stop_token_detection_sets_stopped(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Thanks! " + user.simulator.return_value = "Thanks! " # type: ignore[union-attr] # mock user.simulate_response("Here's your answer") @@ -234,7 +234,7 @@ def test_stop_token_removed_from_response(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Perfect, thanks! " + user.simulator.return_value = "Perfect, thanks! " # type: ignore[union-attr] # mock response = user.simulate_response("Booking confirmed!") @@ -252,7 +252,7 @@ def test_is_done_true_after_stop_token(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Done " + user.simulator.return_value = "Done " # type: ignore[union-attr] # mock user.simulate_response("Result") @@ -269,7 +269,7 @@ def test_stop_token_case_insensitive(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Thanks! " # lowercase + user.simulator.return_value = "Thanks! " # type: ignore[union-attr] # mock (lowercase) user.simulate_response("Answer") @@ -286,7 +286,7 @@ def test_fallback_message_when_only_stop_token(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "" + user.simulator.return_value = "" # type: ignore[union-attr] # mock response = user.simulate_response("Done!") @@ -304,7 +304,7 @@ def test_stop_token_response_counts_as_turn(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Thank you, all is clear " + user.simulator.return_value = "Thank you, all is clear " # type: ignore[union-attr] # mock initial_turn_count = user._turn_count user.simulate_response("Here is your result") @@ -352,19 +352,19 @@ def test_get_initial_query_generates_message(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model) - user.simulator.return_value = "I want to book a hotel" + user.simulator.return_value = "I want to book a hotel" # type: ignore[union-attr] # mock query = user.get_initial_query() assert query == "I want to book a hotel" - user.simulator.assert_called_once() + user.simulator.assert_called_once() # type: ignore[union-attr] # mock def test_get_initial_query_adds_to_messages(self, dummy_model): """Generated query is added to message history.""" from conftest import DummyUser user = DummyUser(name="test", model=dummy_model) - user.simulator.return_value = "Help me please" + user.simulator.return_value = "Help me please" # type: ignore[union-attr] # mock user.get_initial_query() @@ -390,7 +390,7 @@ def test_get_initial_query_counts_as_turn(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model, max_turns=3) - user.simulator.return_value = "Initial query" + user.simulator.return_value = "Initial query" # type: ignore[union-attr] # mock user.get_initial_query() @@ -424,7 +424,7 @@ def test_assistant_message_recorded(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model, max_turns=3) - user.simulator.return_value = "User reply" + user.simulator.return_value = "User reply" # type: ignore[union-attr] # mock user.simulate_response("Agent says hello") @@ -438,7 +438,7 @@ def test_user_response_recorded(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model, max_turns=3) - user.simulator.return_value = "Thanks for the help" + user.simulator.return_value = "Thanks for the help" # type: ignore[union-attr] # mock user.simulate_response("Here's your answer") @@ -455,7 +455,7 @@ def test_full_conversation_tracked(self, dummy_model): initial_query="I need a flight", max_turns=3, ) - user.simulator.side_effect = ["Monday works", "Yes, book it"] + user.simulator.side_effect = ["Monday works", "Yes, book it"] # type: ignore[union-attr] # mock # Two agent-user exchanges user.simulate_response("When do you want to travel?") @@ -481,7 +481,7 @@ def test_gather_traces_includes_all_messages(self, dummy_model): initial_query="Hello", max_turns=2, ) - user.simulator.return_value = "Got it" + user.simulator.return_value = "Got it" # type: ignore[union-attr] # mock user.simulate_response("Agent response") diff --git a/tests/test_interface/test_agent_integration/test_langgraph_integration.py b/tests/test_interface/test_agent_integration/test_langgraph_integration.py index 5d5d646..1e1b3fa 100644 --- a/tests/test_interface/test_agent_integration/test_langgraph_integration.py +++ b/tests/test_interface/test_agent_integration/test_langgraph_integration.py @@ -74,7 +74,7 @@ def agent_node(state: State) -> State: ) return {"messages": messages + [response]} - graph = StateGraph(State) + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) @@ -143,7 +143,7 @@ def agent_node(state: State) -> State: response = AIMessage(content="Response") return {"messages": messages + [response]} - graph = StateGraph(State) + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) @@ -177,7 +177,7 @@ class State(TypedDict): def failing_node(state: State) -> State: raise ValueError("Intentional test error") - graph = StateGraph(State) + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", failing_node) graph.set_entry_point("agent") graph.add_edge("agent", END) @@ -222,7 +222,7 @@ def agent_node(state: State) -> State: response = AIMessage(content="Test response") return {"messages": messages + [response]} - graph = StateGraph(State) + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) From 854e4eb272418ddfc4623ed5a02d20965a9ddb80 Mon Sep 17 00:00:00 2001 From: cemde Date: Sun, 21 Dec 2025 12:29:30 +0000 Subject: [PATCH 15/25] fixed examples --- .../five_a_day_benchmark.ipynb | 14 ++-------- .../five_a_day_benchmark.py | 3 +- examples/introduction/tutorial.ipynb | 28 ++++--------------- examples/macs_benchmark/macs_benchmark.py | 3 +- 4 files changed, 10 insertions(+), 38 deletions(-) diff --git a/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb b/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb index a67e8f2..96e3451 100644 --- a/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb +++ b/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb @@ -744,17 +744,7 @@ "id": "3764c0be", "metadata": {}, "outputs": [], - "source": [ - "# Create and run benchmark (will take approx. 2 min)\n", - "benchmark = FiveADayBenchmark(\n", - " agent_data=agent_configs,\n", - " fail_on_setup_error=True,\n", - " fail_on_task_error=True,\n", - " fail_on_evaluation_error=True,\n", - ")\n", - "\n", - "results = benchmark.run(tasks=tasks)" - ] + "source": "# Create and run benchmark (will take approx. 2 min)\nbenchmark = FiveADayBenchmark(\n fail_on_setup_error=True,\n fail_on_task_error=True,\n fail_on_evaluation_error=True,\n)\n\nresults = benchmark.run(tasks=tasks, agent_data=agent_configs)" }, { "cell_type": "markdown", @@ -898,4 +888,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/five_a_day_benchmark/five_a_day_benchmark.py b/examples/five_a_day_benchmark/five_a_day_benchmark.py index bc0a986..c4fa440 100644 --- a/examples/five_a_day_benchmark/five_a_day_benchmark.py +++ b/examples/five_a_day_benchmark/five_a_day_benchmark.py @@ -934,13 +934,12 @@ def load_benchmark_data( ) benchmark = FiveADayBenchmark( - agent_data=agent_configs, callbacks=[logger], fail_on_setup_error=True, fail_on_task_error=True, fail_on_evaluation_error=True, ) - results = benchmark.run(tasks=tasks) + results = benchmark.run(tasks=tasks, agent_data=agent_configs) print("\n--- Benchmark Complete ---") print(f"Total tasks: {len(tasks)}") diff --git a/examples/introduction/tutorial.ipynb b/examples/introduction/tutorial.ipynb index 6d51244..1fa9561 100644 --- a/examples/introduction/tutorial.ipynb +++ b/examples/introduction/tutorial.ipynb @@ -9,13 +9,13 @@ "\n", "[![Open Notebook on GitHub](https://img.shields.io/badge/Open%20Notebook%20on-GitHub-blue?logo=github)](https://github.com/parameterlab/MASEval/blob/main/examples/introduction/tutorial.ipynb)\n", "\n", - "This notebook is available as a Jupyter notebook \u2014 clone the repo and run it yourself!\n", + "This notebook is available as a Jupyter notebook — clone the repo and run it yourself!\n", "\n", "## What You'll Learn\n", "\n", - "- **Build your first agent** \u2014 Create tools and agents with smolagents\n", - "- **Run a minimal benchmark** \u2014 One task, one agent, end-to-end\n", - "- **Understand the core abstractions** \u2014 Tasks, Environments, Evaluators working together\n", + "- **Build your first agent** — Create tools and agents with smolagents\n", + "- **Run a minimal benchmark** — One task, one agent, end-to-end\n", + "- **Understand the core abstractions** — Tasks, Environments, Evaluators working together\n", "\n", "\n", "This tutorial first introduces [`smolagents`](https://huggingface.co/docs/smolagents/en/index) as introduction to agents. Then it provides a super small single task benchmark." @@ -633,23 +633,7 @@ "id": "b3ee60a7", "metadata": {}, "outputs": [], - "source": [ - "# Create benchmark instance with agent configuration\n", - "agent_data = {\"model_id\": \"gemini/gemini-2.5-flash\", \"temperature\": 0.7}\n", - "\n", - "benchmark = SimpleBenchmark(agent_data=agent_data, progress_bar=False)\n", - "\n", - "# Create task queue\n", - "tasks = TaskQueue([task])\n", - "\n", - "# Run the benchmark\n", - "print(\"Running benchmark...\\n\")\n", - "reports = benchmark.run(tasks=tasks)\n", - "\n", - "print(\"\\n\" + \"=\" * 60)\n", - "print(\"BENCHMARK COMPLETE\")\n", - "print(\"=\" * 60)" - ] + "source": "# Create benchmark instance\nagent_data = {\"model_id\": \"gemini/gemini-2.5-flash\", \"temperature\": 0.7}\n\nbenchmark = SimpleBenchmark(progress_bar=False)\n\n# Create task queue\ntasks = TaskQueue([task])\n\n# Run the benchmark\nprint(\"Running benchmark...\\n\")\nreports = benchmark.run(tasks=tasks, agent_data=agent_data)\n\nprint(\"\\n\" + \"=\" * 60)\nprint(\"BENCHMARK COMPLETE\")\nprint(\"=\" * 60)" }, { "cell_type": "markdown", @@ -715,7 +699,7 @@ "\n", "## Next Steps\n", "\n", - "1. **Try the Five-A-Day Benchmark notebook** \u2014 A production-ready example with multi-agent systems and diverse evaluators\n", + "1. **Try the Five-A-Day Benchmark notebook** — A production-ready example with multi-agent systems and diverse evaluators\n", "2. Create your own custom evaluators for your specific use case\n", "3. Experiment with different agent frameworks (LangGraph, LlamaIndex)\n", "4. Add callbacks for logging and tracing\n", diff --git a/examples/macs_benchmark/macs_benchmark.py b/examples/macs_benchmark/macs_benchmark.py index d06ce50..7ea7903 100644 --- a/examples/macs_benchmark/macs_benchmark.py +++ b/examples/macs_benchmark/macs_benchmark.py @@ -737,7 +737,6 @@ def run_benchmark( # Get benchmark class and instantiate BenchmarkClass = get_benchmark_class(framework) benchmark = BenchmarkClass( - agent_data=agent_config, callbacks=[logger], n_task_repeats=n_task_repeats, fail_on_setup_error=True, @@ -747,7 +746,7 @@ def run_benchmark( # Run benchmark print(f"\nRunning {framework} benchmark on {domain} domain...") - results = benchmark.run(tasks=tasks) + results = benchmark.run(tasks=tasks, agent_data=agent_config) # Compute summary metrics summary = compute_benchmark_metrics(results) From 223e19cf0b70dce052f71851f7b27179330fa27f Mon Sep 17 00:00:00 2001 From: cemde Date: Sun, 21 Dec 2025 12:30:54 +0000 Subject: [PATCH 16/25] fixed formatting --- tests/test_core/test_llm_simulator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_core/test_llm_simulator.py b/tests/test_core/test_llm_simulator.py index e775ecf..0454e9c 100644 --- a/tests/test_core/test_llm_simulator.py +++ b/tests/test_core/test_llm_simulator.py @@ -3,7 +3,6 @@ These tests verify that LLMSimulator retry logic and tracing work correctly. """ - import pytest from maseval.core.simulator import ( ToolLLMSimulator, From bffa951508c5122d47e774ac24fef2ae5b020d62 Mon Sep 17 00:00:00 2001 From: cemde Date: Sun, 21 Dec 2025 12:52:10 +0000 Subject: [PATCH 17/25] [skip ci] improved docstring --- maseval/core/task.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/maseval/core/task.py b/maseval/core/task.py index a9c0aaa..8a0d8b2 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -56,7 +56,11 @@ class Task: environment_data: A dictionary of data needed to set up the environment for the task. evaluation_data: A dictionary of data needed to evaluate the agent's performance on the task. metadata: A dictionary for any additional metadata about the task. - protocol: Execution protocol controlling timeout, retries, priority, etc. + protocol: Execution protocol controlling timeout, retries, priority, and other runtime + parameters. It provides fine-grained control over how MASEval runs the task. The + protocol serves purely as a communication channel between the task instance and + MASEval's execution engine; it does not impose any intrinsic semantics on the task + content itself. """ query: str From 742961a98d6422638e433b7f9de78610f9e73648 Mon Sep 17 00:00:00 2001 From: cemde Date: Tue, 30 Dec 2025 22:40:28 +0100 Subject: [PATCH 18/25] fixed bugs from merging --- maseval/benchmark/tau2/data_loader.py | 14 ++++---- .../test_tau2/test_default_agent.py | 33 +++++++++---------- .../test_tau2/test_integration.py | 6 ++-- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/maseval/benchmark/tau2/data_loader.py b/maseval/benchmark/tau2/data_loader.py index 6e1b3fd..0d45bb0 100644 --- a/maseval/benchmark/tau2/data_loader.py +++ b/maseval/benchmark/tau2/data_loader.py @@ -18,7 +18,7 @@ from urllib.error import HTTPError, URLError from urllib.request import urlopen -from maseval import Task, TaskCollection +from maseval import Task, TaskQueue # ============================================================================= @@ -257,7 +257,7 @@ def load_tasks( split: str = "base", data_dir: Optional[Path] = None, limit: Optional[int] = None, -) -> TaskCollection: +) -> TaskQueue: """Load tasks for a tau2 domain. Args: @@ -267,7 +267,7 @@ def load_tasks( limit: Maximum number of tasks to load Returns: - TaskCollection containing Task objects with: + TaskQueue containing Task objects with: - id: Task identifier from tau2 data - query: Initial user message (from user_scenario) - environment_data: Domain tools, database state, policies @@ -316,7 +316,7 @@ def load_tasks( task = _convert_tau2_task_to_maseval(raw_task, domain, split, domain_config) tasks.append(task) - return TaskCollection(tasks) + return TaskQueue(tasks) def _convert_tau2_task_to_maseval( @@ -440,18 +440,18 @@ def load_domain_config( def configure_model_ids( - tasks: Union[TaskCollection, List[Task]], + tasks: Union[TaskQueue, List[Task]], *, user_model_id: Optional[str] = None, evaluator_model_id: Optional[str] = None, -) -> Union[TaskCollection, List[Task]]: +) -> Union[TaskQueue, List[Task]]: """Configure model IDs for benchmark components in task data. Tau2 tools execute real business logic and don't need a tool_model_id. Only user simulation and evaluation use LLMs. Args: - tasks: TaskCollection or list of Tasks to configure + tasks: TaskQueue or list of Tasks to configure user_model_id: Model ID for user simulator (stored in user_data) evaluator_model_id: Model ID for evaluators (stored in evaluation_data) diff --git a/tests/test_benchmarks/test_tau2/test_default_agent.py b/tests/test_benchmarks/test_tau2/test_default_agent.py index a8f8921..11818d8 100644 --- a/tests/test_benchmarks/test_tau2/test_default_agent.py +++ b/tests/test_benchmarks/test_tau2/test_default_agent.py @@ -501,16 +501,13 @@ class TestDefaultAgentTau2BenchmarkInit: def test_init_basic(self): """Test basic initialization.""" - benchmark = DummyDefaultAgentBenchmark( - agent_data={"model_id": "gpt-4o"}, - ) + benchmark = DummyDefaultAgentBenchmark() assert benchmark._model_cache == {} def test_init_with_all_options(self): """Test initialization with all options.""" benchmark = DummyDefaultAgentBenchmark( - agent_data={"model_id": "gpt-4o", "llm_args": {"temperature": 0.5}}, n_task_repeats=3, max_invocations=5, ) @@ -525,7 +522,7 @@ class TestDefaultAgentTau2BenchmarkSetupAgents: def test_setup_agents_basic(self, sample_task): """Test basic agent setup.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() with patch.object(Tau2Environment, "__init__", return_value=None): mock_env = MagicMock(spec=Tau2Environment) @@ -545,7 +542,7 @@ def test_setup_agents_basic(self, sample_task): def test_setup_agents_missing_model_id(self, sample_task): """Test that missing model_id raises ValueError.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={}) + benchmark = DummyDefaultAgentBenchmark() mock_env = MagicMock(spec=Tau2Environment) mock_env.create_tools.return_value = {} @@ -556,7 +553,7 @@ def test_setup_agents_missing_model_id(self, sample_task): def test_setup_agents_with_llm_args(self, sample_task): """Test agent setup with custom llm_args.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o", "llm_args": {"temperature": 0.5}}) + benchmark = DummyDefaultAgentBenchmark() mock_env = MagicMock(spec=Tau2Environment) mock_env.create_tools.return_value = {} @@ -574,7 +571,7 @@ def test_setup_agents_with_llm_args(self, sample_task): def test_setup_agents_with_max_tool_calls(self, sample_task): """Test agent setup with custom max_tool_calls.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o", "max_tool_calls": 10}) + benchmark = DummyDefaultAgentBenchmark() mock_env = MagicMock(spec=Tau2Environment) mock_env.create_tools.return_value = {} @@ -600,7 +597,7 @@ def test_get_model_adapter_is_abstract(self): # DefaultAgentTau2Benchmark itself is still abstract # because get_model_adapter is abstract with pytest.raises(TypeError, match="abstract"): - DefaultAgentTau2Benchmark(agent_data={"model_id": "gpt-4o"}) + DefaultAgentTau2Benchmark() # ============================================================================= @@ -723,7 +720,7 @@ class TestTau2BenchmarkMethods: def test_setup_environment(self, sample_task): """Test setup_environment creates Tau2Environment.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() with patch("maseval.benchmark.tau2.tau2.Tau2Environment") as mock_env_cls: mock_env_cls.return_value = MagicMock() @@ -736,7 +733,7 @@ def test_setup_environment(self, sample_task): def test_setup_user_with_dict_instructions(self): """Test setup_user with dict instructions.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() task = Task( query="Hello", @@ -767,7 +764,7 @@ def test_setup_user_with_dict_instructions(self): def test_setup_user_with_string_instructions(self): """Test setup_user with string instructions.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() task = Task( query="Hello", @@ -790,7 +787,7 @@ def test_setup_user_with_string_instructions(self): def test_setup_user_empty_instructions(self): """Test setup_user with no instructions.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() task = Task( query="Hello", @@ -810,7 +807,7 @@ def test_setup_user_empty_instructions(self): def test_setup_evaluators(self, sample_task): """Test setup_evaluators creates Tau2Evaluator.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() with patch("maseval.benchmark.tau2.tau2.Tau2Environment") as mock_env_cls: mock_env = MagicMock(spec=Tau2Environment) @@ -823,7 +820,7 @@ def test_setup_evaluators(self, sample_task): def test_run_agents_single_agent(self, sample_task): """Test run_agents with a single agent.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() mock_agent = MagicMock(spec=AgentAdapter) mock_agent.run.return_value = "Response" @@ -837,7 +834,7 @@ def test_run_agents_single_agent(self, sample_task): def test_run_agents_multiple_agents(self, sample_task): """Test run_agents with multiple agents.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() mock_agent1 = MagicMock(spec=AgentAdapter) mock_agent1.run.return_value = "Response 1" @@ -852,7 +849,7 @@ def test_run_agents_multiple_agents(self, sample_task): def test_evaluate(self, sample_task): """Test evaluate method.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() mock_evaluator = MagicMock() mock_evaluator.filter_traces.return_value = {"filtered": True} @@ -955,7 +952,7 @@ class TestDummyBenchmarkModelAdapter: def test_get_model_adapter_returns_mock(self): """Test that DummyDefaultAgentBenchmark returns working mock adapter.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() adapter = benchmark.get_model_adapter("gpt-4o") diff --git a/tests/test_benchmarks/test_tau2/test_integration.py b/tests/test_benchmarks/test_tau2/test_integration.py index e37996e..36b2340 100644 --- a/tests/test_benchmarks/test_tau2/test_integration.py +++ b/tests/test_benchmarks/test_tau2/test_integration.py @@ -59,9 +59,11 @@ def test_tau2_dry_run(): task.user_data = {"model_id": "mock-user", "instructions": "Test scenario"} task.evaluation_data = {"reward_basis": ["DB"], "actions": []} task.query = "Help me." + task.protocol = MagicMock() + task.protocol.timeout_seconds = None # Setup benchmark - benchmark = IntegrationTau2Benchmark(agent_data={}, n_task_repeats=1) + benchmark = IntegrationTau2Benchmark(n_task_repeats=1) # Mock Environment env_mock = MagicMock() @@ -88,7 +90,7 @@ def test_tau2_dry_run(): # Patch setup_evaluators to return our mock benchmark.setup_evaluators = MagicMock(return_value=[mock_evaluator]) # type: ignore[assignment] - results = benchmark.run([task]) + results = benchmark.run([task], agent_data={}) # Debug info if failed if results[0]["status"] != "success": From b9d09a24ca8106009b67bd11eefeac455b2c9852 Mon Sep 17 00:00:00 2001 From: cemde Date: Tue, 30 Dec 2025 22:46:08 +0100 Subject: [PATCH 19/25] changed naming of worker argument --- CHANGELOG.md | 2 +- maseval/core/benchmark.py | 28 +++++++++---------- .../test_benchmark/test_parallel_execution.py | 14 +++++----- 3 files changed, 22 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6babcd4..56aa73b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Parallel Execution** -- Added parallel task execution with `max_workers` parameter in `Benchmark.run()` using `ThreadPoolExecutor` (PR: #14) +- Added parallel task execution with `num_workers` parameter in `Benchmark.run()` using `ThreadPoolExecutor` (PR: #14) - Added `ComponentRegistry` class for thread-safe component registration with thread-local storage (PR: #14) - Added `TaskContext` for cooperative timeout checking with `check_timeout()`, `elapsed`, `remaining`, and `is_expired` properties (PR: #14) - Added `TaskProtocol` dataclass with `timeout_seconds`, `timeout_action`, `max_retries`, `priority`, and `tags` fields for task-level execution control (PR: #14) diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index 8ef2063..c75e160 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -109,7 +109,7 @@ def run_agents(self, agents, task, environment, query): retry_reports = benchmark.run(tasks=failed_tasks, agent_data=config) # Parallel execution for I/O-bound workloads - benchmark = MyBenchmark(max_workers=4) + benchmark = MyBenchmark(num_workers=4) reports = benchmark.run(tasks=my_tasks, agent_data=config) # Or use strict mode for debugging (fail fast) @@ -131,7 +131,7 @@ def __init__( callbacks: Optional[List[BenchmarkCallback]] = None, n_task_repeats: int = 1, max_invocations: int = 1, - max_workers: int = 1, + num_workers: int = 1, fail_on_setup_error: bool = False, fail_on_task_error: bool = False, fail_on_evaluation_error: bool = False, @@ -148,7 +148,7 @@ def __init__( For simple benchmarks, the default (1) means agents run once per task. For interactive benchmarks with user feedback loops, set higher (e.g., 5 for MACS) to allow multiple agent-user interaction rounds. - max_workers: Maximum number of parallel task executions. Default 1 (sequential). + num_workers: Number of parallel task executions. Default 1 (sequential). Set higher for I/O-bound workloads (e.g., LLM API calls). This controls the ThreadPoolExecutor worker count for concurrent task processing. fail_on_setup_error: If True, raise exceptions when setup fails (environment, agents, evaluators). @@ -180,7 +180,7 @@ def __init__( benchmark = MyBenchmark() # Parallel execution for faster I/O-bound workloads - benchmark = MyBenchmark(max_workers=4) + benchmark = MyBenchmark(num_workers=4) # Strict mode - fail fast on any error (useful for debugging) benchmark = MyBenchmark( @@ -225,7 +225,7 @@ def __init__( # Execution configuration self.max_invocations = max_invocations - self.max_workers = max_workers + self.num_workers = num_workers # Failure handling configuration self.fail_on_task_error = fail_on_task_error @@ -1193,16 +1193,16 @@ def _run_parallel( self, queue: BaseTaskQueue, agent_data_lookup: Dict[str, Dict[str, Any]], - max_workers: int, + num_workers: int, ) -> None: """Execute tasks in parallel with thread pool. Args: queue: Task queue providing task ordering. agent_data_lookup: Mapping from task_id to agent_data configuration. - max_workers: Maximum number of concurrent workers. + num_workers: Number of concurrent workers. """ - with ThreadPoolExecutor(max_workers=max_workers) as executor: + with ThreadPoolExecutor(max_workers=num_workers) as executor: futures: Dict[Any, Tuple[Task, int]] = {} task_repeat_counts: Dict[str, int] = {} # Track submitted repeats per task @@ -1233,7 +1233,7 @@ def submit_task_repeats(task: Task) -> None: # Submit initial batch try: - while len(futures) < max_workers * 2: + while len(futures) < num_workers * 2: task = next(queue_iter) submit_task_repeats(task) submitted_tasks.append(task) @@ -1291,7 +1291,7 @@ def submit_task_repeats(task: Task) -> None: self._invoke_callbacks("on_task_end", self, task, last_report) # Submit more work if queue not exhausted - if not queue_exhausted and len(futures) < max_workers: + if not queue_exhausted and len(futures) < num_workers: try: task = next(queue_iter) submit_task_repeats(task) @@ -1396,7 +1396,7 @@ def run( print(f"Traces: {report['traces']}") # Parallel execution with 4 workers - benchmark = MyBenchmark(max_workers=4) + benchmark = MyBenchmark(num_workers=4) reports = benchmark.run(tasks=tasks, agent_data=config) # Single agent config for all tasks @@ -1454,11 +1454,11 @@ def run( # Callbacks at the start of the run self._invoke_callbacks("on_run_start", self) - # Execute based on max_workers - if self.max_workers == 1: + # Execute based on num_workers + if self.num_workers == 1: self._run_sequential(queue, agent_data_lookup) else: - self._run_parallel(queue, agent_data_lookup, self.max_workers) + self._run_parallel(queue, agent_data_lookup, self.num_workers) # Callbacks at the end of the run self._invoke_callbacks("on_run_end", self, self.reports) diff --git a/tests/test_core/test_benchmark/test_parallel_execution.py b/tests/test_core/test_benchmark/test_parallel_execution.py index 416d1c5..a9bfeae 100644 --- a/tests/test_core/test_benchmark/test_parallel_execution.py +++ b/tests/test_core/test_benchmark/test_parallel_execution.py @@ -1,6 +1,6 @@ """Tests for parallel task execution in Benchmark. -These tests verify that parallel execution with max_workers > 1 works correctly, +These tests verify that parallel execution with num_workers > 1 works correctly, including thread safety, report collection, and callback serialization. """ @@ -123,7 +123,7 @@ def test_parallel_produces_same_report_count(self, parallel_tasks): def test_parallel_reports_have_correct_structure(self, parallel_tasks): """Verify parallel reports have expected fields.""" - benchmark = DummyBenchmark(max_workers=3) + benchmark = DummyBenchmark(num_workers=3) reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) @@ -136,7 +136,7 @@ def test_parallel_reports_have_correct_structure(self, parallel_tasks): assert "eval" in report def test_single_worker_uses_sequential(self, parallel_tasks): - """max_workers=1 should behave identically to sequential.""" + """num_workers=1 should behave identically to sequential.""" callback = OrderTrackingCallback() benchmark = DummyBenchmark( callbacks=[callback], @@ -191,7 +191,7 @@ def test_reports_all_collected(self, parallel_tasks): def test_traces_not_cross_contaminated(self, parallel_tasks): """Traces from one task should not appear in another's report.""" - benchmark = DummyBenchmark(max_workers=4) + benchmark = DummyBenchmark(num_workers=4) reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) @@ -273,7 +273,7 @@ def test_parallel_faster_than_sequential(self): time_seq = time.time() - start_seq # Parallel timing - benchmark_par = SlowBenchmark(delay_seconds=delay, max_workers=4) + benchmark_par = SlowBenchmark(delay_seconds=delay, num_workers=4) start_par = time.time() benchmark_par.run(tasks, agent_data={"model": "test"}) time_par = time.time() - start_par @@ -287,7 +287,7 @@ def test_execution_overlaps(self): benchmark = SlowBenchmark( delay_seconds=0.05, - max_workers=3, + num_workers=3, ) benchmark.run(tasks, agent_data={"model": "test"}) @@ -404,7 +404,7 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): callbacks=[OrderTracker()], ) - # With max_workers=1, order should be strictly by priority + # With num_workers=1, order should be strictly by priority benchmark.run(queue, agent_data={"model": "test"}) assert execution_order == ["P5", "P4", "P3", "P2", "P1"] From 44f3f8baa69c07f2322e109de74c6159c13211cb Mon Sep 17 00:00:00 2001 From: cemde Date: Tue, 30 Dec 2025 23:11:13 +0100 Subject: [PATCH 20/25] replaced ascii art with proper diagram --- docs/guides/exception-handling.md | 37 +++++++++++++------------------ mkdocs.yml | 6 ++++- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/docs/guides/exception-handling.md b/docs/guides/exception-handling.md index 819db1e..3fbbdac 100644 --- a/docs/guides/exception-handling.md +++ b/docs/guides/exception-handling.md @@ -90,27 +90,22 @@ class SimulatedUser: One approach to exception handling places the boundary between agent responsibility and infrastructure responsibility at input validation: -``` -┌─────────────────────────────────────────────────────────────┐ -│ TOOL EXECUTION │ -├─────────────────────────────────────────────────────────────┤ -│ │ -│ ┌─────────────────┐ │ -│ │ INPUT │ Agent passes arguments │ -│ │ VALIDATION │ │ -│ │ │ ❌ Fails → AgentError │ -│ │ │ ✓ Passes ↓ │ -│ └─────────────────┘ │ -│ │ │ -│ ▼ │ -│ ┌─────────────────┐ │ -│ │ EXECUTION │ Tool runs its logic │ -│ │ │ │ -│ │ │ ❌ Fails → EnvironmentError │ -│ │ │ ✓ Passes → Result │ -│ └─────────────────┘ │ -│ │ -└─────────────────────────────────────────────────────────────┘ +```mermaid +flowchart TD + subgraph TOOL_EXECUTION[" "] + A[Agent passes arguments] --> B{INPUT VALIDATION} + B -->|Fails| C[AgentError] + B -->|Passes| D{EXECUTION} + D -->|Fails| E[EnvironmentError] + D -->|Passes| F[Result] + end + + style TOOL_EXECUTION fill:none,stroke:#888 + style B fill:#f5f5f5,stroke:#333 + style D fill:#f5f5f5,stroke:#333 + style C fill:#ffebee,stroke:#c62828 + style E fill:#ffebee,stroke:#c62828 + style F fill:#e8f5e9,stroke:#2e7d32 ``` With this pattern: diff --git a/mkdocs.yml b/mkdocs.yml index a5b09b2..8617174 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,11 @@ extra_css: markdown_extensions: - admonition - pymdownx.details - - pymdownx.superfences + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format - attr_list - pymdownx.emoji: emoji_index: !!python/name:material.extensions.emoji.twemoji From b1ea2487eea82e38aecde49b8cb0a79a7b16de77 Mon Sep 17 00:00:00 2001 From: cemde Date: Wed, 31 Dec 2025 14:21:46 +0100 Subject: [PATCH 21/25] small fixes --- maseval/benchmark/tau2/data_loader.py | 30 ++++++++++++++++++++++++++- maseval/benchmark/tau2/tau2.py | 28 ++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/maseval/benchmark/tau2/data_loader.py b/maseval/benchmark/tau2/data_loader.py index 0d45bb0..f2e93dc 100644 --- a/maseval/benchmark/tau2/data_loader.py +++ b/maseval/benchmark/tau2/data_loader.py @@ -19,6 +19,7 @@ from urllib.request import urlopen from maseval import Task, TaskQueue +from maseval.core.task import TaskProtocol # ============================================================================= @@ -36,6 +37,10 @@ "telecom": 114, } +# Default execution protocol settings +DEFAULT_TIMEOUT_SECONDS = 600.0 # 10 minutes per task +DEFAULT_MAX_RETRIES = 1 # Skip on first failure + # GitHub raw content URLs for v0.2.0 tag GITHUB_BASE = "https://raw.githubusercontent.com/sierra-research/tau2-bench" DEFAULT_VERSION = "v0.2.0" @@ -257,6 +262,8 @@ def load_tasks( split: str = "base", data_dir: Optional[Path] = None, limit: Optional[int] = None, + timeout_seconds: Optional[float] = DEFAULT_TIMEOUT_SECONDS, + max_retries: int = DEFAULT_MAX_RETRIES, ) -> TaskQueue: """Load tasks for a tau2 domain. @@ -265,6 +272,9 @@ def load_tasks( split: One of "base", "hard", "all" (base recommended for reproducibility) data_dir: Base data directory (default: module's data/) limit: Maximum number of tasks to load + timeout_seconds: Maximum execution time per task in seconds. Default 600 (10 minutes). + Set to None to disable timeout. + max_retries: Maximum retry attempts for transient failures. Default 1 (skip on failure). Returns: TaskQueue containing Task objects with: @@ -274,6 +284,7 @@ def load_tasks( - evaluation_data: Assertions, expected outcomes - user_data: User profile, instructions - metadata: domain, split, description + - protocol: Execution settings (timeout, retries, tags) Raises: ValueError: If domain or split is invalid @@ -283,6 +294,9 @@ def load_tasks( >>> tasks = load_tasks("retail", split="base", limit=5) >>> len(tasks) 5 + + >>> # Custom timeout and retries + >>> tasks = load_tasks("retail", timeout_seconds=300, max_retries=2) """ if domain not in VALID_DOMAINS: raise ValueError(f"Invalid domain '{domain}'. Must be one of {VALID_DOMAINS}") @@ -313,7 +327,9 @@ def load_tasks( # Convert to MASEval Task objects tasks = [] for raw_task in raw_tasks: - task = _convert_tau2_task_to_maseval(raw_task, domain, split, domain_config) + task = _convert_tau2_task_to_maseval( + raw_task, domain, split, domain_config, timeout_seconds, max_retries + ) tasks.append(task) return TaskQueue(tasks) @@ -324,6 +340,8 @@ def _convert_tau2_task_to_maseval( domain: str, split: str, domain_config: Dict[str, Any], + timeout_seconds: Optional[float], + max_retries: int, ) -> Task: """Convert a tau2-bench task dict to MASEval Task. @@ -332,6 +350,8 @@ def _convert_tau2_task_to_maseval( domain: Domain name split: Split name domain_config: Domain configuration with policy and db_path + timeout_seconds: Maximum execution time per task in seconds + max_retries: Maximum retry attempts for transient failures Returns: MASEval Task object @@ -380,6 +400,13 @@ def _convert_tau2_task_to_maseval( "ticket": raw_task.get("ticket"), # For solo mode (not used) } + # Build execution protocol with timeout, retries, and tags + protocol = TaskProtocol( + timeout_seconds=timeout_seconds, + max_retries=max_retries, + tags={"domain": domain, "split": split}, + ) + # Build task kwargs, only include id if provided in raw task task_kwargs: Dict[str, Any] = { "query": query, @@ -387,6 +414,7 @@ def _convert_tau2_task_to_maseval( "evaluation_data": evaluation_data, "user_data": user_data, "metadata": metadata, + "protocol": protocol, } if raw_task.get("id"): task_kwargs["id"] = str(raw_task["id"]) diff --git a/maseval/benchmark/tau2/tau2.py b/maseval/benchmark/tau2/tau2.py index 5b02522..b3126b8 100644 --- a/maseval/benchmark/tau2/tau2.py +++ b/maseval/benchmark/tau2/tau2.py @@ -237,7 +237,33 @@ def __init__(self, *args: Any, max_invocations: int = 50, **kwargs: Any): Args: max_invocations: Maximum agent-user interaction rounds (default: 50). tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps. - *args, **kwargs: Passed to parent Benchmark class + + Inherited from Benchmark (pass via kwargs): + num_workers: Number of parallel task executions. Default 1 (sequential). + Set higher for I/O-bound workloads (e.g., LLM API calls). + n_task_repeats: Number of times to repeat each task. Default 1. + Useful for measuring variance or computing pass@k metrics. + callbacks: List of callback handlers for monitoring execution. + progress_bar: Progress display. True (default) for tqdm, "rich" for Rich, + or False to disable. + fail_on_task_error: If True, raise on task execution errors. Default False. + fail_on_evaluation_error: If True, raise on evaluation errors. Default False. + fail_on_setup_error: If True, raise on setup errors. Default False. + + Example: + ```python + # Parallel execution for faster evaluation + benchmark = MyTau2Benchmark(num_workers=4) + + # Multiple repeats for pass@k metrics + benchmark = MyTau2Benchmark(n_task_repeats=4) + + # Debug mode - fail fast on errors + benchmark = MyTau2Benchmark( + fail_on_task_error=True, + fail_on_evaluation_error=True, + ) + ``` """ super().__init__(*args, max_invocations=max_invocations, **kwargs) # type: ignore[parameter-already-assigned] From d49f1bbe49bcc9bb6544faecf0153d8fee5423ba Mon Sep 17 00:00:00 2001 From: cemde Date: Wed, 31 Dec 2025 14:32:48 +0100 Subject: [PATCH 22/25] fixed tau documentation --- maseval/benchmark/tau2/data_loader.py | 4 +- maseval/benchmark/tau2/tau2.py | 68 +++++++++---------- .../test_tau2/test_default_agent.py | 10 ++- 3 files changed, 44 insertions(+), 38 deletions(-) diff --git a/maseval/benchmark/tau2/data_loader.py b/maseval/benchmark/tau2/data_loader.py index f2e93dc..e79ec8a 100644 --- a/maseval/benchmark/tau2/data_loader.py +++ b/maseval/benchmark/tau2/data_loader.py @@ -327,9 +327,7 @@ def load_tasks( # Convert to MASEval Task objects tasks = [] for raw_task in raw_tasks: - task = _convert_tau2_task_to_maseval( - raw_task, domain, split, domain_config, timeout_seconds, max_retries - ) + task = _convert_tau2_task_to_maseval(raw_task, domain, split, domain_config, timeout_seconds, max_retries) tasks.append(task) return TaskQueue(tasks) diff --git a/maseval/benchmark/tau2/tau2.py b/maseval/benchmark/tau2/tau2.py index b3126b8..e917c60 100644 --- a/maseval/benchmark/tau2/tau2.py +++ b/maseval/benchmark/tau2/tau2.py @@ -62,6 +62,7 @@ def get_model_adapter(self, model_id, **kwargs): from maseval import AgentAdapter, Benchmark, Evaluator, ModelAdapter, Task, User from maseval.core.user import AgenticUser +from maseval.core.callback import BenchmarkCallback from maseval.benchmark.tau2.environment import Tau2Environment from maseval.benchmark.tau2.evaluator import Tau2Evaluator @@ -231,41 +232,44 @@ def get_model_adapter(self, model_id, **kwargs): benchmark.run(tasks) """ - def __init__(self, *args: Any, max_invocations: int = 50, **kwargs: Any): + # Maximum agent-user interaction rounds (tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps) + MAX_INVOCATIONS = 50 + + def __init__( + self, + callbacks: Optional[List[BenchmarkCallback]] = None, + n_task_repeats: int = 1, + max_invocations: int = MAX_INVOCATIONS, + num_workers: int = 1, + fail_on_setup_error: bool = False, + fail_on_task_error: bool = False, + fail_on_evaluation_error: bool = False, + progress_bar: bool | str = True, + ): """Initialize benchmark with tau2-specific defaults. Args: + callbacks: Optional list of callback handlers for monitoring execution. + n_task_repeats: Number of times to repeat each task. Default 1. max_invocations: Maximum agent-user interaction rounds (default: 50). tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps. - - Inherited from Benchmark (pass via kwargs): num_workers: Number of parallel task executions. Default 1 (sequential). - Set higher for I/O-bound workloads (e.g., LLM API calls). - n_task_repeats: Number of times to repeat each task. Default 1. - Useful for measuring variance or computing pass@k metrics. - callbacks: List of callback handlers for monitoring execution. - progress_bar: Progress display. True (default) for tqdm, "rich" for Rich, - or False to disable. + fail_on_setup_error: If True, raise on setup errors. Default False. fail_on_task_error: If True, raise on task execution errors. Default False. fail_on_evaluation_error: If True, raise on evaluation errors. Default False. - fail_on_setup_error: If True, raise on setup errors. Default False. - - Example: - ```python - # Parallel execution for faster evaluation - benchmark = MyTau2Benchmark(num_workers=4) - - # Multiple repeats for pass@k metrics - benchmark = MyTau2Benchmark(n_task_repeats=4) - - # Debug mode - fail fast on errors - benchmark = MyTau2Benchmark( - fail_on_task_error=True, - fail_on_evaluation_error=True, - ) - ``` + progress_bar: Progress display. True (default) for tqdm, "rich" for Rich, + or False to disable. """ - super().__init__(*args, max_invocations=max_invocations, **kwargs) # type: ignore[parameter-already-assigned] + super().__init__( + callbacks=callbacks, + n_task_repeats=n_task_repeats, + max_invocations=max_invocations, + num_workers=num_workers, + fail_on_setup_error=fail_on_setup_error, + fail_on_task_error=fail_on_task_error, + fail_on_evaluation_error=fail_on_evaluation_error, + progress_bar=progress_bar, + ) def _get_user_model_id(self, task: Task) -> str: """Get user simulator model ID from task.user_data. @@ -875,14 +879,6 @@ class DefaultAgentTau2Benchmark(Tau2Benchmark): results = benchmark.run(tasks) """ - # Cache for model adapters - _model_cache: Dict[str, ModelAdapter] - - def __init__(self, *args: Any, **kwargs: Any): - """Initialize the default agent benchmark. See Tau2Benchmark for args.""" - super().__init__(*args, **kwargs) - self._model_cache = {} - def _get_agent_model_id(self, agent_data: Dict[str, Any]) -> str: """Get agent model ID from agent_data. @@ -965,5 +961,9 @@ def get_model_adapter(self, model_id: str, **kwargs: Any) -> ModelAdapter: Returns: ModelAdapter instance + + Note: + DefaultAgentTau2Benchmark uses lazy initialization for model caching. + Access via `getattr(self, '_model_cache', {})` in subclass implementations. """ pass diff --git a/tests/test_benchmarks/test_tau2/test_default_agent.py b/tests/test_benchmarks/test_tau2/test_default_agent.py index 11818d8..1d8551a 100644 --- a/tests/test_benchmarks/test_tau2/test_default_agent.py +++ b/tests/test_benchmarks/test_tau2/test_default_agent.py @@ -503,7 +503,8 @@ def test_init_basic(self): """Test basic initialization.""" benchmark = DummyDefaultAgentBenchmark() - assert benchmark._model_cache == {} + # Benchmark should be initialized successfully + assert benchmark.max_invocations == 50 # Tau2 default def test_init_with_all_options(self): """Test initialization with all options.""" @@ -515,6 +516,13 @@ def test_init_with_all_options(self): assert benchmark.n_task_repeats == 3 assert benchmark.max_invocations == 5 + def test_default_max_invocations(self): + """Test that default max_invocations is 50 from class attribute.""" + benchmark = DummyDefaultAgentBenchmark() + + assert benchmark.max_invocations == 50 + assert benchmark.MAX_INVOCATIONS == 50 + @pytest.mark.benchmark class TestDefaultAgentTau2BenchmarkSetupAgents: From e522a9a1ad45dedd6df33ce10a5eeee71e9c3270 Mon Sep 17 00:00:00 2001 From: cemde Date: Wed, 31 Dec 2025 15:33:26 +0100 Subject: [PATCH 23/25] MACS now uses TaskExecutionStatus --- maseval/benchmark/macs/macs.py | 35 +++++----- maseval/benchmark/tau2/evaluator.py | 30 ++++---- .../test_macs/test_macs_benchmark.py | 70 +++++++++---------- .../test_tau2/test_evaluator.py | 23 +++--- 4 files changed, 77 insertions(+), 81 deletions(-) diff --git a/maseval/benchmark/macs/macs.py b/maseval/benchmark/macs/macs.py index 186ec3b..6e4ee04 100644 --- a/maseval/benchmark/macs/macs.py +++ b/maseval/benchmark/macs/macs.py @@ -53,6 +53,7 @@ def get_model_adapter(self, model_id, **kwargs): MessageHistory, ModelAdapter, Task, + TaskExecutionStatus, ToolInvocationHistory, ToolLLMSimulator, User, @@ -64,6 +65,17 @@ def get_model_adapter(self, model_id, **kwargs): from maseval.core.tracing import TraceableMixin +# Statuses where agent is accountable (included in scoring) +# Note: task_timeout is included - timeouts count as failures in MACS +SCOREABLE_STATUSES = frozenset( + { + TaskExecutionStatus.SUCCESS.value, + TaskExecutionStatus.AGENT_ERROR.value, + TaskExecutionStatus.TASK_TIMEOUT.value, + } +) + + # ============================================================================= # Tool # ============================================================================= @@ -987,15 +999,6 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: - excluded: Dict with counts of excluded tasks by category - status_counts: Dict with counts of each status type """ - # Status values that indicate infrastructure failures (not agent's fault) - INFRASTRUCTURE_STATUSES = { - "environment_error", - "user_error", - "unknown_execution_error", - "evaluation_failed", - "setup_failed", - } - if not results: return { "total_tasks": 0, @@ -1003,13 +1006,7 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: "successful_tasks": 0, "success_rate": 0.0, "mean_metrics": {}, - "excluded": { - "environment_error": 0, - "user_error": 0, - "unknown_execution_error": 0, - "evaluation_failed": 0, - "setup_failed": 0, - }, + "excluded": {}, "status_counts": {}, } @@ -1019,14 +1016,14 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: successful_tasks = 0 scored_tasks = 0 status_counts: Dict[str, int] = {} - excluded_counts: Dict[str, int] = {s: 0 for s in INFRASTRUCTURE_STATUSES} + excluded_counts: Dict[str, int] = {} for res in results: status = res.get("status", "unknown") status_counts[status] = status_counts.get(status, 0) + 1 - # Skip infrastructure failures from scoring - if status in INFRASTRUCTURE_STATUSES: + # Skip infrastructure failures from scoring (use module-level SCOREABLE_STATUSES) + if status not in SCOREABLE_STATUSES: excluded_counts[status] = excluded_counts.get(status, 0) + 1 continue diff --git a/maseval/benchmark/tau2/evaluator.py b/maseval/benchmark/tau2/evaluator.py index e5e6b07..3023be5 100644 --- a/maseval/benchmark/tau2/evaluator.py +++ b/maseval/benchmark/tau2/evaluator.py @@ -23,12 +23,23 @@ from enum import Enum from typing import Any, Dict, List, Optional -from maseval import Evaluator, Task +from maseval import Evaluator, Task, TaskExecutionStatus from maseval.benchmark.tau2.environment import Tau2Environment, get_environment_constructor from maseval.benchmark.tau2.utils import compare_tool_calls +# Statuses where agent is accountable (included in scoring) +# Note: task_timeout is included - timeouts count as failures in tau2 +SCOREABLE_STATUSES = frozenset( + { + TaskExecutionStatus.SUCCESS.value, + TaskExecutionStatus.AGENT_ERROR.value, + TaskExecutionStatus.TASK_TIMEOUT.value, + } +) + + class RewardType(str, Enum): """Types of rewards that can be computed. @@ -447,6 +458,7 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: """Compute summary metrics across all benchmark results. Infrastructure errors are excluded from scoring metrics. + Uses SCOREABLE_STATUSES to determine which results count toward agent score. Args: results: List of result dicts from benchmark.run() @@ -454,14 +466,6 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: Returns: Dict with success_rate, mean_reward, pass_at_k, status_counts """ - INFRASTRUCTURE_STATUSES = { - "environment_error", - "user_error", - "unknown_execution_error", - "evaluation_failed", - "setup_failed", - } - if not results: return { "total_tasks": 0, @@ -482,8 +486,8 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: status = res.get("status", "unknown") status_counts[status] = status_counts.get(status, 0) + 1 - if status in INFRASTRUCTURE_STATUSES: - continue + if status not in SCOREABLE_STATUSES: + continue # Skip infrastructure errors scored_tasks += 1 evals = res.get("eval") or [] @@ -529,7 +533,7 @@ def compute_pass_at_k( task_results: Dict[str, List[bool]] = {} for res in results: task_id = res.get("task_id", "") - if res.get("status") not in {"success", "agent_error"}: + if res.get("status") not in SCOREABLE_STATUSES: continue # Skip infrastructure errors evals = res.get("eval") or [] @@ -615,7 +619,7 @@ def compute_pass_hat_k( task_results: Dict[str, List[bool]] = {} for res in results: task_id = res.get("task_id", "") - if res.get("status") not in {"success", "agent_error"}: + if res.get("status") not in SCOREABLE_STATUSES: continue # Skip infrastructure errors evals = res.get("eval") or [] diff --git a/tests/test_benchmarks/test_macs/test_macs_benchmark.py b/tests/test_benchmarks/test_macs/test_macs_benchmark.py index 954019e..0244abd 100644 --- a/tests/test_benchmarks/test_macs/test_macs_benchmark.py +++ b/tests/test_benchmarks/test_macs/test_macs_benchmark.py @@ -365,18 +365,12 @@ def test_empty_results(self): assert result["successful_tasks"] == 0 assert result["success_rate"] == 0.0 assert result["mean_metrics"] == {} - assert result["excluded"] == { - "environment_error": 0, - "user_error": 0, - "unknown_execution_error": 0, - "evaluation_failed": 0, - "setup_failed": 0, - } + assert result["excluded"] == {} assert result["status_counts"] == {} def test_single_successful_result(self): """Single successful result counted.""" - results = [{"status": "completed", "eval": [{"overall_gsr": 1.0, "user_gsr": 1.0, "system_gsr": 1.0}]}] + results = [{"status": "success", "eval": [{"overall_gsr": 1.0, "user_gsr": 1.0, "system_gsr": 1.0}]}] metrics = compute_benchmark_metrics(results) @@ -387,7 +381,7 @@ def test_single_successful_result(self): def test_single_failed_result(self): """Single failed result counted.""" - results = [{"status": "completed", "eval": [{"overall_gsr": 0.0, "user_gsr": 0.0, "system_gsr": 0.0}]}] + results = [{"status": "success", "eval": [{"overall_gsr": 0.0, "user_gsr": 0.0, "system_gsr": 0.0}]}] metrics = compute_benchmark_metrics(results) @@ -399,9 +393,9 @@ def test_single_failed_result(self): def test_multiple_results(self): """Multiple results aggregated correctly.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, # Success - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, # Fail - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, # Success + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, # Success + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, # Fail + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, # Success ] metrics = compute_benchmark_metrics(results) @@ -414,10 +408,10 @@ def test_multiple_results(self): def test_success_rate_calculation(self): """success_rate = successful/scored (not total).""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -427,8 +421,8 @@ def test_success_rate_calculation(self): def test_mean_metrics_calculation(self): """Mean of numeric metrics computed.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0, "partial_gsr": 0.8}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0, "partial_gsr": 0.4}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0, "partial_gsr": 0.8}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0, "partial_gsr": 0.4}]}, ] metrics = compute_benchmark_metrics(results) @@ -439,9 +433,9 @@ def test_mean_metrics_calculation(self): def test_handles_missing_eval(self): """Handles results with no eval key.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "no_eval_key": True}, # Missing eval - {"status": "completed", "eval": None}, # None eval + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "no_eval_key": True}, # Missing eval + {"status": "success", "eval": None}, # None eval ] metrics = compute_benchmark_metrics(results) @@ -454,7 +448,7 @@ def test_handles_non_numeric_values(self): """Non-numeric values in eval are ignored for mean.""" results = [ { - "status": "completed", + "status": "success", "eval": [ { "overall_gsr": 1.0, @@ -475,15 +469,15 @@ def test_handles_non_numeric_values(self): def test_excludes_environment_errors_from_scoring(self): """Environment errors are excluded from scoring.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, {"status": "environment_error", "eval": None}, # Should be excluded - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, ] metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 3 - assert metrics["scored_tasks"] == 2 # Only completed tasks + assert metrics["scored_tasks"] == 2 # Only success tasks assert metrics["successful_tasks"] == 1 assert metrics["success_rate"] == 0.5 # 1/2, not 1/3 assert metrics["excluded"]["environment_error"] == 1 @@ -491,7 +485,7 @@ def test_excludes_environment_errors_from_scoring(self): def test_excludes_user_errors_from_scoring(self): """User simulator errors are excluded from scoring.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, {"status": "user_error", "eval": None}, ] @@ -500,14 +494,14 @@ def test_excludes_user_errors_from_scoring(self): assert metrics["total_tasks"] == 2 assert metrics["scored_tasks"] == 1 assert metrics["successful_tasks"] == 1 - assert metrics["success_rate"] == 1.0 # Only the completed one + assert metrics["success_rate"] == 1.0 # Only the success one assert metrics["excluded"]["user_error"] == 1 def test_excludes_unknown_errors_from_scoring(self): """Unknown execution errors are excluded from scoring.""" results = [ {"status": "unknown_execution_error", "eval": None}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -521,7 +515,7 @@ def test_excludes_setup_failed_from_scoring(self): """Setup failures are excluded from scoring.""" results = [ {"status": "setup_failed", "eval": None}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -534,21 +528,21 @@ def test_excludes_evaluation_failed_from_scoring(self): """Evaluation failures are excluded from scoring.""" results = [ {"status": "evaluation_failed", "eval": None}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, ] metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 2 assert metrics["scored_tasks"] == 1 - assert metrics["success_rate"] == 1.0 # Only the completed one + assert metrics["success_rate"] == 1.0 # Only the success one assert metrics["excluded"]["evaluation_failed"] == 1 def test_includes_agent_errors_in_scoring(self): """Agent errors ARE included in scoring (agent's fault).""" results = [ {"status": "agent_error", "eval": [{"overall_gsr": 0.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -561,23 +555,23 @@ def test_includes_agent_errors_in_scoring(self): def test_status_counts_tracked(self): """Status counts are tracked for all tasks.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, {"status": "agent_error", "eval": None}, {"status": "environment_error", "eval": None}, ] metrics = compute_benchmark_metrics(results) - assert metrics["status_counts"]["completed"] == 2 + assert metrics["status_counts"]["success"] == 2 assert metrics["status_counts"]["agent_error"] == 1 assert metrics["status_counts"]["environment_error"] == 1 def test_mixed_errors_comprehensive(self): """Comprehensive test with various error types.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0, "accuracy": 0.9}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0, "accuracy": 0.3}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0, "accuracy": 0.9}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0, "accuracy": 0.3}]}, {"status": "agent_error", "eval": [{"overall_gsr": 0.0, "accuracy": 0.0}]}, {"status": "environment_error", "eval": None}, # Excluded {"status": "user_error", "eval": None}, # Excluded @@ -588,7 +582,7 @@ def test_mixed_errors_comprehensive(self): metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 7 - assert metrics["scored_tasks"] == 3 # completed(2) + agent_error(1) + assert metrics["scored_tasks"] == 3 # success(2) + agent_error(1) assert metrics["successful_tasks"] == 1 assert metrics["success_rate"] == pytest.approx(1 / 3) assert metrics["mean_metrics"]["accuracy"] == pytest.approx((0.9 + 0.3 + 0.0) / 3) diff --git a/tests/test_benchmarks/test_tau2/test_evaluator.py b/tests/test_benchmarks/test_tau2/test_evaluator.py index a4f25bd..c7383bb 100644 --- a/tests/test_benchmarks/test_tau2/test_evaluator.py +++ b/tests/test_benchmarks/test_tau2/test_evaluator.py @@ -276,7 +276,7 @@ def test_single_success(self): """Single successful result counted.""" from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics - results = [{"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}] + results = [{"status": "success", "eval": [{"reward": 1.0, "passed": True}]}] metrics = compute_benchmark_metrics(results) @@ -287,10 +287,10 @@ def test_single_success(self): assert metrics["mean_reward"] == 1.0 def test_single_failure(self): - """Single failed result counted.""" + """Single failed result counted (agent_error is scoreable).""" from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics - results = [{"status": "completed", "eval": [{"reward": 0.0, "passed": False}]}] + results = [{"status": "agent_error", "eval": [{"reward": 0.0, "passed": False}]}] metrics = compute_benchmark_metrics(results) @@ -304,9 +304,9 @@ def test_mixed_results(self): from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics results = [ - {"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}, - {"status": "completed", "eval": [{"reward": 0.0, "passed": False}]}, - {"status": "completed", "eval": [{"reward": 0.5, "passed": False}]}, + {"status": "success", "eval": [{"reward": 1.0, "passed": True}]}, + {"status": "agent_error", "eval": [{"reward": 0.0, "passed": False}]}, + {"status": "task_timeout", "eval": [{"reward": 0.5, "passed": False}]}, ] metrics = compute_benchmark_metrics(results) @@ -322,7 +322,7 @@ def test_excludes_infrastructure_errors(self): from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics results = [ - {"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}, + {"status": "success", "eval": [{"reward": 1.0, "passed": True}]}, {"status": "environment_error", "eval": None}, {"status": "user_error", "eval": None}, {"status": "setup_failed", "eval": None}, @@ -331,7 +331,7 @@ def test_excludes_infrastructure_errors(self): metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 4 - assert metrics["scored_tasks"] == 1 # Only completed + assert metrics["scored_tasks"] == 1 # Only success assert metrics["successful_tasks"] == 1 assert metrics["success_rate"] == 1.0 @@ -340,14 +340,15 @@ def test_status_counts(self): from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics results = [ - {"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}, - {"status": "completed", "eval": [{"reward": 0.0, "passed": False}]}, + {"status": "success", "eval": [{"reward": 1.0, "passed": True}]}, + {"status": "agent_error", "eval": [{"reward": 0.0, "passed": False}]}, {"status": "environment_error", "eval": None}, ] metrics = compute_benchmark_metrics(results) - assert metrics["status_counts"]["completed"] == 2 + assert metrics["status_counts"]["success"] == 1 + assert metrics["status_counts"]["agent_error"] == 1 assert metrics["status_counts"]["environment_error"] == 1 From 175f4d2ebef6ee8f617408d9428e235e76e78d6a Mon Sep 17 00:00:00 2001 From: cemde Date: Wed, 31 Dec 2025 15:35:34 +0100 Subject: [PATCH 24/25] added issue template --- .github/issue_template.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .github/issue_template.md diff --git a/.github/issue_template.md b/.github/issue_template.md new file mode 100644 index 0000000..8f60fc3 --- /dev/null +++ b/.github/issue_template.md @@ -0,0 +1,22 @@ +## Type + +- [ ] Bug +- [ ] Feature request +- [ ] Question + +## Summary + + + +## Details + + + +## Environment (if applicable) + +- maseval version: +- Python version: From d2925e2cd60fd605fc7237d53c216b513a0c17ab Mon Sep 17 00:00:00 2001 From: cemde Date: Fri, 2 Jan 2026 20:22:53 +0100 Subject: [PATCH 25/25] small fixes --- IDEAS.md | 1 - tests/test_benchmarks/test_macs/test_macs_benchmark.py | 1 - 2 files changed, 2 deletions(-) delete mode 100644 IDEAS.md diff --git a/IDEAS.md b/IDEAS.md deleted file mode 100644 index 57d143f..0000000 --- a/IDEAS.md +++ /dev/null @@ -1 +0,0 @@ -- Guide explaining that Dataset aprpox equal to Queue + Task Collection diff --git a/tests/test_benchmarks/test_macs/test_macs_benchmark.py b/tests/test_benchmarks/test_macs/test_macs_benchmark.py index 0244abd..b4c2aa9 100644 --- a/tests/test_benchmarks/test_macs/test_macs_benchmark.py +++ b/tests/test_benchmarks/test_macs/test_macs_benchmark.py @@ -31,7 +31,6 @@ def test_init_configures_benchmark(self, macs_model, sample_agent_data): callbacks = [MagicMock()] benchmark = ConcreteMACSBenchmark(macs_model, callbacks=callbacks, n_task_repeats=3) - # agent_data is now passed to run(), not __init__ assert benchmark.callbacks == callbacks assert benchmark.n_task_repeats == 3