diff --git a/.github/issue_template.md b/.github/issue_template.md new file mode 100644 index 0000000..8f60fc3 --- /dev/null +++ b/.github/issue_template.md @@ -0,0 +1,22 @@ +## Type + +- [ ] Bug +- [ ] Feature request +- [ ] Question + +## Summary + + + +## Details + + + +## Environment (if applicable) + +- maseval version: +- Python version: diff --git a/CHANGELOG.md b/CHANGELOG.md index 4da1364..56aa73b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,23 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +**Parallel Execution** + +- Added parallel task execution with `num_workers` parameter in `Benchmark.run()` using `ThreadPoolExecutor` (PR: #14) +- Added `ComponentRegistry` class for thread-safe component registration with thread-local storage (PR: #14) +- Added `TaskContext` for cooperative timeout checking with `check_timeout()`, `elapsed`, `remaining`, and `is_expired` properties (PR: #14) +- Added `TaskProtocol` dataclass with `timeout_seconds`, `timeout_action`, `max_retries`, `priority`, and `tags` fields for task-level execution control (PR: #14) +- Added `TimeoutAction` enum (`SKIP`, `RETRY`, `RAISE`) for configurable timeout behavior (PR: #14) +- Added `TaskTimeoutError` exception with `elapsed`, `timeout`, and `partial_traces` attributes (PR: #14) +- Added `TASK_TIMEOUT` to `TaskExecutionStatus` enum for timeout classification (PR: #14) + +**Task Queue Abstraction** + +- Added `TaskQueue` abstract base class with iterator interface for flexible task scheduling (PR: #14) +- Added `SequentialQueue` for simple FIFO task ordering (PR: #14) +- Added `PriorityQueue` for priority-based task scheduling using `TaskProtocol.priority` (PR: #14) +- Added `AdaptiveQueue` placeholder for future feedback-based scheduling (PR: #14) + **ModelAdapter Chat Interface** - Added `chat()` method to `ModelAdapter` as the primary interface for LLM inference, accepting a list of messages in OpenAI format and returning a `ChatResponse` object and accepting tools @@ -48,6 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Benchmark** - `Benchmark.agent_data` parameter is now optional (defaults to empty dict) (PR: #16) +- Refactored `Benchmark` to delegate registry operations to `ComponentRegistry` class (PR: #) +- `Benchmark.run()` now accepts optional `queue` parameter for custom task scheduling (PR: #14) **Task** diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index d5ec921..b9c7669 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -117,7 +117,7 @@ Once implemented, run your benchmark: ```python # Define your tasks -tasks = TaskCollection([Task(query="...", expected="..."), ...]) +tasks = TaskQueue([Task(query="..."), ...]) # Configure your agents (e.g., model parameters, tool settings) agent_config = {"model": "gpt-4", "temperature": 0.7} diff --git a/docs/guides/exception-handling.md b/docs/guides/exception-handling.md index 819db1e..3fbbdac 100644 --- a/docs/guides/exception-handling.md +++ b/docs/guides/exception-handling.md @@ -90,27 +90,22 @@ class SimulatedUser: One approach to exception handling places the boundary between agent responsibility and infrastructure responsibility at input validation: -``` -┌─────────────────────────────────────────────────────────────┐ -│ TOOL EXECUTION │ -├─────────────────────────────────────────────────────────────┤ -│ │ -│ ┌─────────────────┐ │ -│ │ INPUT │ Agent passes arguments │ -│ │ VALIDATION │ │ -│ │ │ ❌ Fails → AgentError │ -│ │ │ ✓ Passes ↓ │ -│ └─────────────────┘ │ -│ │ │ -│ ▼ │ -│ ┌─────────────────┐ │ -│ │ EXECUTION │ Tool runs its logic │ -│ │ │ │ -│ │ │ ❌ Fails → EnvironmentError │ -│ │ │ ✓ Passes → Result │ -│ └─────────────────┘ │ -│ │ -└─────────────────────────────────────────────────────────────┘ +```mermaid +flowchart TD + subgraph TOOL_EXECUTION[" "] + A[Agent passes arguments] --> B{INPUT VALIDATION} + B -->|Fails| C[AgentError] + B -->|Passes| D{EXECUTION} + D -->|Fails| E[EnvironmentError] + D -->|Passes| F[Result] + end + + style TOOL_EXECUTION fill:none,stroke:#888 + style B fill:#f5f5f5,stroke:#333 + style D fill:#f5f5f5,stroke:#333 + style C fill:#ffebee,stroke:#c62828 + style E fill:#ffebee,stroke:#c62828 + style F fill:#e8f5e9,stroke:#2e7d32 ``` With this pattern: diff --git a/docs/reference/task.md b/docs/reference/task.md index ee2ccd8..0997396 100644 --- a/docs/reference/task.md +++ b/docs/reference/task.md @@ -1,9 +1,9 @@ # Task -Tasks define individual benchmark scenarios including inputs, expected outputs, and any metadata needed for evaluation. TaskCollections group related tasks together. +Tasks define individual benchmark scenarios including inputs, expected outputs, and any metadata needed for evaluation. TaskQueues group related tasks together. [:material-github: View source](https://github.com/parameterlab/maseval/blob/main/maseval/core/task.py){ .md-source-file } ::: maseval.core.task.Task -::: maseval.core.task.TaskCollection +::: maseval.core.task.TaskQueue diff --git a/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb b/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb index 914947a..aa192ba 100644 --- a/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb +++ b/examples/five_a_day_benchmark/five_a_day_benchmark.ipynb @@ -124,7 +124,7 @@ "from smolagents import ToolCallingAgent, LiteLLMModel, FinalAnswerTool\n", "\n", "# MASEval core components\n", - "from maseval import Benchmark, Environment, Task, TaskCollection, AgentAdapter, Evaluator, ModelAdapter\n", + "from maseval import Benchmark, Environment, Task, TaskQueue, AgentAdapter, Evaluator, ModelAdapter\n", "from maseval.interface.agents.smolagents import SmolAgentAdapter\n", "\n", "# Import evaluators module (dynamically loaded later)\n", @@ -139,7 +139,7 @@ " limit: int | None = None,\n", " seed: int | None = None,\n", " task_indices: list[int] | None = None,\n", - ") -> tuple[TaskCollection, list[Dict[str, Any]]]:\n", + ") -> tuple[TaskQueue, list[Dict[str, Any]]]:\n", " \"\"\"Load tasks and agent configurations.\n", "\n", " Args:\n", @@ -152,7 +152,7 @@ " task_indices: Optional list of task indices to load (e.g., [0, 2, 4])\n", "\n", " Returns:\n", - " Tuple of (TaskCollection, list of agent configs)\n", + " Tuple of (TaskQueue, list of agent configs)\n", " \"\"\"\n", " data_dir = Path(\"examples/five_a_day_benchmark/data\")\n", "\n", @@ -200,7 +200,7 @@ "\n", " configs_data.append(config)\n", "\n", - " return TaskCollection(tasks_data), configs_data" + " return TaskQueue(tasks_data), configs_data" ] }, { @@ -745,17 +745,7 @@ "id": "3764c0be", "metadata": {}, "outputs": [], - "source": [ - "# Create and run benchmark (will take approx. 2 min)\n", - "benchmark = FiveADayBenchmark(\n", - " agent_data=agent_configs,\n", - " fail_on_setup_error=True,\n", - " fail_on_task_error=True,\n", - " fail_on_evaluation_error=True,\n", - ")\n", - "\n", - "results = benchmark.run(tasks=tasks)" - ] + "source": "# Create and run benchmark (will take approx. 2 min)\nbenchmark = FiveADayBenchmark(\n fail_on_setup_error=True,\n fail_on_task_error=True,\n fail_on_evaluation_error=True,\n)\n\nresults = benchmark.run(tasks=tasks, agent_data=agent_configs)" }, { "cell_type": "markdown", @@ -899,4 +889,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/five_a_day_benchmark/five_a_day_benchmark.py b/examples/five_a_day_benchmark/five_a_day_benchmark.py index 1c34c3f..1a867b8 100644 --- a/examples/five_a_day_benchmark/five_a_day_benchmark.py +++ b/examples/five_a_day_benchmark/five_a_day_benchmark.py @@ -26,7 +26,7 @@ from utils import derive_seed, sanitize_name # type: ignore[unresolved-import] -from maseval import Benchmark, Environment, Evaluator, Task, TaskCollection, AgentAdapter, ModelAdapter +from maseval import Benchmark, Environment, Evaluator, Task, TaskQueue, AgentAdapter, ModelAdapter from maseval.core.callbacks.result_logger import FileResultLogger # Import tool implementations @@ -825,7 +825,7 @@ def load_benchmark_data( limit: Optional[int] = None, specific_task: Optional[int] = None, seed: Optional[int] = None, -) -> tuple[TaskCollection, List[Dict[str, Any]]]: +) -> tuple[TaskQueue, List[Dict[str, Any]]]: """Load tasks and agent configurations with validation. Args: @@ -838,7 +838,7 @@ def load_benchmark_data( seed: Base random seed for reproducibility (None for non-deterministic) Returns: - Tuple of (TaskCollection, agent_configs_list) + Tuple of (TaskQueue, agent_configs_list) """ if limit is not None and specific_task is not None: raise ValueError("Cannot specify both limit and specific_task") @@ -897,7 +897,7 @@ def load_benchmark_data( print(f"Loaded {len(tasks_data)} tasks and {len(configs_data)} agent configs\n") - return TaskCollection(tasks_data), configs_data + return TaskQueue(tasks_data), configs_data # ============================================================================ @@ -935,13 +935,12 @@ def load_benchmark_data( ) benchmark = FiveADayBenchmark( - agent_data=agent_configs, callbacks=[logger], fail_on_setup_error=True, fail_on_task_error=True, fail_on_evaluation_error=True, ) - results = benchmark.run(tasks=tasks) + results = benchmark.run(tasks=tasks, agent_data=agent_configs) print("\n--- Benchmark Complete ---") print(f"Total tasks: {len(tasks)}") diff --git a/examples/introduction/tutorial.ipynb b/examples/introduction/tutorial.ipynb index 7c8c7de..898b97c 100644 --- a/examples/introduction/tutorial.ipynb +++ b/examples/introduction/tutorial.ipynb @@ -330,7 +330,7 @@ "metadata": {}, "outputs": [], "source": [ - "from maseval import Benchmark, Environment, Evaluator, Task, TaskCollection\n", + "from maseval import Benchmark, Environment, Evaluator, Task, TaskQueue\n", "from maseval.interface.agents.smolagents import SmolAgentAdapter\n", "\n", "print(\"MASEval components imported successfully!\")" @@ -634,23 +634,7 @@ "id": "b3ee60a7", "metadata": {}, "outputs": [], - "source": [ - "# Create benchmark instance with agent configuration\n", - "agent_data = {\"model_id\": \"gemini/gemini-2.5-flash\", \"temperature\": 0.7}\n", - "\n", - "benchmark = SimpleBenchmark(agent_data=agent_data, progress_bar=False)\n", - "\n", - "# Create task collection\n", - "tasks = TaskCollection([task])\n", - "\n", - "# Run the benchmark\n", - "print(\"Running benchmark...\\n\")\n", - "reports = benchmark.run(tasks=tasks)\n", - "\n", - "print(\"\\n\" + \"=\" * 60)\n", - "print(\"BENCHMARK COMPLETE\")\n", - "print(\"=\" * 60)" - ] + "source": "# Create benchmark instance\nagent_data = {\"model_id\": \"gemini/gemini-2.5-flash\", \"temperature\": 0.7}\n\nbenchmark = SimpleBenchmark(progress_bar=False)\n\n# Create task queue\ntasks = TaskQueue([task])\n\n# Run the benchmark\nprint(\"Running benchmark...\\n\")\nreports = benchmark.run(tasks=tasks, agent_data=agent_data)\n\nprint(\"\\n\" + \"=\" * 60)\nprint(\"BENCHMARK COMPLETE\")\nprint(\"=\" * 60)" }, { "cell_type": "markdown", @@ -746,4 +730,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/macs_benchmark/macs_benchmark.py b/examples/macs_benchmark/macs_benchmark.py index d06ce50..7ea7903 100644 --- a/examples/macs_benchmark/macs_benchmark.py +++ b/examples/macs_benchmark/macs_benchmark.py @@ -737,7 +737,6 @@ def run_benchmark( # Get benchmark class and instantiate BenchmarkClass = get_benchmark_class(framework) benchmark = BenchmarkClass( - agent_data=agent_config, callbacks=[logger], n_task_repeats=n_task_repeats, fail_on_setup_error=True, @@ -747,7 +746,7 @@ def run_benchmark( # Run benchmark print(f"\nRunning {framework} benchmark on {domain} domain...") - results = benchmark.run(tasks=tasks) + results = benchmark.run(tasks=tasks, agent_data=agent_config) # Compute summary metrics summary = compute_benchmark_metrics(results) diff --git a/maseval/__init__.py b/maseval/__init__.py index e74feda..6917ecc 100644 --- a/maseval/__init__.py +++ b/maseval/__init__.py @@ -8,7 +8,17 @@ Benchmarks sit in the `maseval.benchmark` submodule. """ -from .core.task import Task, TaskCollection +from .core.task import ( + Task, + TaskProtocol, + TimeoutAction, + # Task queue classes + BaseTaskQueue, + TaskQueue, + SequentialTaskQueue, + PriorityTaskQueue, + AdaptiveTaskQueue, +) from .core.environment import Environment from .core.agent import AgentAdapter from .core.benchmark import Benchmark, TaskExecutionStatus @@ -27,11 +37,14 @@ from .core.evaluator import Evaluator from .core.history import MessageHistory, ToolInvocationHistory from .core.tracing import TraceableMixin +from .core.registry import ComponentRegistry +from .core.context import TaskContext from .core.exceptions import ( MASEvalError, AgentError, EnvironmentError, UserError, + TaskTimeoutError, validate_argument_type, validate_required_arguments, validate_no_extra_arguments, @@ -41,7 +54,8 @@ __all__ = [ # Tasks "Task", - "TaskCollection", + "TaskProtocol", + "TimeoutAction", # Core abstractions "Environment", "AgentAdapter", @@ -68,6 +82,15 @@ "MessageHistory", "ToolInvocationHistory", "TraceableMixin", + # Registry and execution context + "ComponentRegistry", + "TaskContext", + # Task queues + "BaseTaskQueue", + "TaskQueue", + "SequentialTaskQueue", + "PriorityTaskQueue", + "AdaptiveTaskQueue", # Model adapters "ModelAdapter", "ChatResponse", @@ -76,6 +99,7 @@ "AgentError", "EnvironmentError", "UserError", + "TaskTimeoutError", "validate_argument_type", "validate_required_arguments", "validate_no_extra_arguments", diff --git a/maseval/benchmark/macs/data_loader.py b/maseval/benchmark/macs/data_loader.py index f3e41b7..b627c7f 100644 --- a/maseval/benchmark/macs/data_loader.py +++ b/maseval/benchmark/macs/data_loader.py @@ -15,7 +15,7 @@ from urllib.error import HTTPError, URLError from urllib.request import urlopen -from maseval import Task, TaskCollection +from maseval import Task, TaskQueue # ============================================================================= @@ -27,24 +27,23 @@ # AWS Multi-Agent Collaboration Scenarios benchmark data # Source: https://github.com/aws-samples/multiagent-collab-scenario-benchmark -URLS = { - "data": { - "software": { - "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/agents.json", - "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/scenarios_30.json", - }, - "travel": { - "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/agents.json", - "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/scenarios_30.json", - }, - "mortgage": { - "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/agents.json", - "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/scenarios_30.json", - }, +DATA_URLS: Dict[str, Dict[str, str]] = { + "software": { + "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/agents.json", + "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/software/scenarios_30.json", }, - "evaluation": { - "prompt_templates": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/src/prompt_templates.py", + "travel": { + "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/agents.json", + "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/travel/scenarios_30.json", }, + "mortgage": { + "agents": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/agents.json", + "scenarios": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/datasets/mortgage/scenarios_30.json", + }, +} + +EVALUATION_URLS: Dict[str, str] = { + "prompt_templates": "https://raw.githubusercontent.com/aws-samples/multiagent-collab-scenario-benchmark/refs/heads/main/src/prompt_templates.py", } @@ -90,16 +89,16 @@ def download_original_data( data_dir = Path(data_dir) if data_dir else DEFAULT_DATA_DIR original_dir = data_dir / "original" - domains = [domain] if domain else list(URLS["data"].keys()) + domains = [domain] if domain else list(DATA_URLS.keys()) for d in domains: - if d not in URLS["data"]: + if d not in DATA_URLS: raise ValueError(f"Unknown domain: {d}") domain_dir = original_dir / d domain_dir.mkdir(parents=True, exist_ok=True) - for name, url in URLS["data"][d].items(): # type: ignore[union-attr] + for name, url in DATA_URLS[d].items(): content = download_json(url) out_path = domain_dir / f"{name}.json" with out_path.open("w") as f: @@ -132,8 +131,8 @@ def download_prompt_templates( templates_dir = data_dir.parent / "prompt_templates" templates_dir.mkdir(parents=True, exist_ok=True) - url = URLS["evaluation"]["prompt_templates"] - text = download_file(url) # type: ignore[arg-type] + url = EVALUATION_URLS["prompt_templates"] + text = download_file(url) # Parse Python file to extract prompt constants tree = ast.parse(text) @@ -239,12 +238,15 @@ def _dedupe_tools_by_name(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: return deduped -def _create_tools_list(agents_obj: object) -> List[Dict[str, Any]]: +def _create_tools_list(agents_obj: Union[Dict[str, Any], List[Any]]) -> List[Dict[str, Any]]: """Extract and deduplicate tools from agents data.""" tools: List[Dict[str, Any]] = [] - if isinstance(agents_obj, dict) and isinstance(agents_obj.get("agents"), list): # type: ignore[arg-type,union-attr] - agents_list = agents_obj["agents"] # type: ignore[index] + agents_list: List[Any] + if isinstance(agents_obj, dict): + agents_list = agents_obj.get("agents", []) + if not isinstance(agents_list, list): + return tools elif isinstance(agents_obj, list): agents_list = agents_obj else: @@ -260,7 +262,7 @@ def _create_tools_list(agents_obj: object) -> List[Dict[str, Any]]: return _dedupe_tools_by_name(tools) -def _create_agents_list(agents_obj: object) -> Dict[str, Any]: +def _create_agents_list(agents_obj: Union[Dict[str, Any], List[Any]]) -> Dict[str, Any]: """Create agents config with tool names only (not full tool dicts).""" def _process_agent(agent: Dict[str, Any]) -> Dict[str, Any]: @@ -269,8 +271,11 @@ def _process_agent(agent: Dict[str, Any]) -> Dict[str, Any]: a_copy["tools"] = tool_names return a_copy - if isinstance(agents_obj, dict) and isinstance(agents_obj.get("agents"), list): # type: ignore[arg-type,union-attr] - processed = [_process_agent(a) for a in agents_obj["agents"] if isinstance(a, dict)] # type: ignore[index,union-attr] + if isinstance(agents_obj, dict): + agents_list = agents_obj.get("agents") + if not isinstance(agents_list, list): + return {} + processed = [_process_agent(a) for a in agents_list if isinstance(a, dict)] out: Dict[str, Any] = {"agents": processed} if "primary_agent_id" in agents_obj: out["primary_agent_id"] = agents_obj["primary_agent_id"] # type: ignore[index] @@ -281,12 +286,15 @@ def _process_agent(agent: Dict[str, Any]) -> Dict[str, Any]: return {} -def _create_tasks_list(scenarios_obj: object, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def _create_tasks_list(scenarios_obj: Union[Dict[str, Any], List[Any]], tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Convert scenarios to task format with sequential IDs.""" tasks: List[Dict[str, Any]] = [] - if isinstance(scenarios_obj, dict) and isinstance(scenarios_obj.get("scenarios"), list): # type: ignore[arg-type,union-attr] - scenarios_list = scenarios_obj["scenarios"] # type: ignore[index] + scenarios_list: List[Any] + if isinstance(scenarios_obj, dict): + scenarios_list = scenarios_obj.get("scenarios", []) + if not isinstance(scenarios_list, list): + return tasks elif isinstance(scenarios_obj, list): scenarios_list = scenarios_obj else: @@ -422,7 +430,7 @@ def load_tasks( domain: str, data_dir: Optional[Path] = None, limit: Optional[int] = None, -) -> TaskCollection: +) -> TaskQueue: """Load tasks for a MACS domain. Args: @@ -432,7 +440,7 @@ def load_tasks( limit: Maximum number of tasks to load Returns: - TaskCollection containing Task objects + TaskQueue containing Task objects Raises: ValueError: If domain is not valid @@ -464,7 +472,7 @@ def load_tasks( task_kwargs["id"] = str(t["id"]) tasks.append(Task(**task_kwargs)) - return TaskCollection(tasks) + return TaskQueue(tasks) def load_agent_config( @@ -502,12 +510,12 @@ def load_agent_config( def configure_model_ids( - tasks: Union[TaskCollection, List[Task]], + tasks: Union[TaskQueue, List[Task]], *, tool_model_id: Optional[str] = None, user_model_id: Optional[str] = None, evaluator_model_id: Optional[str] = None, -) -> Union[TaskCollection, List[Task]]: +) -> Union[TaskQueue, List[Task]]: """Configure model IDs for benchmark components in task data. This helper merges runtime model configuration into task data structures, @@ -518,13 +526,13 @@ def configure_model_ids( task-specific overrides in the original data to take precedence. Args: - tasks: TaskCollection or list of Tasks to configure + tasks: TaskQueue or list of Tasks to configure tool_model_id: Model ID for tool simulators (stored in environment_data) user_model_id: Model ID for user simulator (stored in user_data) evaluator_model_id: Model ID for evaluators (stored in evaluation_data) Returns: - The same collection (mutated in place for convenience) + The same queue or list (mutated in place for convenience) Example: ```python diff --git a/maseval/benchmark/macs/macs.py b/maseval/benchmark/macs/macs.py index c4aaa9b..6e4ee04 100644 --- a/maseval/benchmark/macs/macs.py +++ b/maseval/benchmark/macs/macs.py @@ -36,8 +36,8 @@ def get_model_adapter(self, model_id, **kwargs): return adapter # Run - benchmark = MyMACSBenchmark(agent_data=agent_config) - results = benchmark.run(tasks) + benchmark = MyMACSBenchmark() + results = benchmark.run(tasks, agent_data=agent_config) """ import json @@ -53,6 +53,7 @@ def get_model_adapter(self, model_id, **kwargs): MessageHistory, ModelAdapter, Task, + TaskExecutionStatus, ToolInvocationHistory, ToolLLMSimulator, User, @@ -64,6 +65,17 @@ def get_model_adapter(self, model_id, **kwargs): from maseval.core.tracing import TraceableMixin +# Statuses where agent is accountable (included in scoring) +# Note: task_timeout is included - timeouts count as failures in MACS +SCOREABLE_STATUSES = frozenset( + { + TaskExecutionStatus.SUCCESS.value, + TaskExecutionStatus.AGENT_ERROR.value, + TaskExecutionStatus.TASK_TIMEOUT.value, + } +) + + # ============================================================================= # Tool # ============================================================================= @@ -695,7 +707,6 @@ class MACSBenchmark(Benchmark): def __init__( self, - agent_data: Dict[str, Any], callbacks: Optional[List[Any]] = None, n_task_repeats: int = 1, max_invocations: int = 5, @@ -704,12 +715,11 @@ def __init__( """Initialize benchmark. Args: - agent_data: Agent configuration from load_agent_config(). callbacks: Benchmark callbacks n_task_repeats: Repetitions per task max_invocations: Maximum agent-user interaction rounds (default: 5 per MACS paper) """ - super().__init__(agent_data, callbacks, n_task_repeats, max_invocations, **kwargs) + super().__init__(callbacks=callbacks, n_task_repeats=n_task_repeats, max_invocations=max_invocations, **kwargs) def _get_tool_model_id(self, task: Task) -> str: """Get tool simulator model ID from task.environment_data. @@ -804,7 +814,7 @@ def tool_model_factory(tool_name: str) -> ModelAdapter: model_factory=tool_model_factory, ) - def setup_user( # type: ignore[override] + def setup_user( # type: ignore[invalid-method-override] self, agent_data: Dict[str, Any], environment: MACSEnvironment, @@ -840,7 +850,7 @@ def setup_user( # type: ignore[override] ) @abstractmethod - def setup_agents( # type: ignore[override] + def setup_agents( # type: ignore[invalid-method-override] self, agent_data: Dict[str, Any], environment: MACSEnvironment, @@ -860,7 +870,7 @@ def setup_agents( # type: ignore[override] """ pass - def setup_evaluators( # type: ignore[override] + def setup_evaluators( # type: ignore[invalid-method-override] self, environment: MACSEnvironment, task: Task, @@ -892,7 +902,7 @@ def setup_evaluators( # type: ignore[override] ), ] - def run_agents( # type: ignore[override] + def run_agents( # type: ignore[invalid-method-override] self, agents: Sequence[AgentAdapter], task: Task, @@ -932,7 +942,9 @@ def evaluate( user_result = results[0] if results else {"gsr": 0.0, "partial_gsr": 0.0, "report": []} system_result = results[1] if len(results) > 1 else {"gsr": 0.0, "partial_gsr": 0.0, "report": []} - combined_report = user_result.get("report", []) + system_result.get("report", []) # type: ignore[operator] + user_report: List[Dict[str, Any]] = user_result.get("report", []) # type: ignore[assignment] + system_report: List[Dict[str, Any]] = system_result.get("report", []) # type: ignore[assignment] + combined_report = user_report + system_report # Compute overall metrics per AWS paper overall_gsr = 1.0 if (user_result.get("gsr", 0.0) == 1.0 and system_result.get("gsr", 0.0) == 1.0) else 0.0 @@ -987,15 +999,6 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: - excluded: Dict with counts of excluded tasks by category - status_counts: Dict with counts of each status type """ - # Status values that indicate infrastructure failures (not agent's fault) - INFRASTRUCTURE_STATUSES = { - "environment_error", - "user_error", - "unknown_execution_error", - "evaluation_failed", - "setup_failed", - } - if not results: return { "total_tasks": 0, @@ -1003,13 +1006,7 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: "successful_tasks": 0, "success_rate": 0.0, "mean_metrics": {}, - "excluded": { - "environment_error": 0, - "user_error": 0, - "unknown_execution_error": 0, - "evaluation_failed": 0, - "setup_failed": 0, - }, + "excluded": {}, "status_counts": {}, } @@ -1019,14 +1016,14 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: successful_tasks = 0 scored_tasks = 0 status_counts: Dict[str, int] = {} - excluded_counts: Dict[str, int] = {s: 0 for s in INFRASTRUCTURE_STATUSES} + excluded_counts: Dict[str, int] = {} for res in results: status = res.get("status", "unknown") status_counts[status] = status_counts.get(status, 0) + 1 - # Skip infrastructure failures from scoring - if status in INFRASTRUCTURE_STATUSES: + # Skip infrastructure failures from scoring (use module-level SCOREABLE_STATUSES) + if status not in SCOREABLE_STATUSES: excluded_counts[status] = excluded_counts.get(status, 0) + 1 continue diff --git a/maseval/benchmark/tau2/data_loader.py b/maseval/benchmark/tau2/data_loader.py index 6e1b3fd..e79ec8a 100644 --- a/maseval/benchmark/tau2/data_loader.py +++ b/maseval/benchmark/tau2/data_loader.py @@ -18,7 +18,8 @@ from urllib.error import HTTPError, URLError from urllib.request import urlopen -from maseval import Task, TaskCollection +from maseval import Task, TaskQueue +from maseval.core.task import TaskProtocol # ============================================================================= @@ -36,6 +37,10 @@ "telecom": 114, } +# Default execution protocol settings +DEFAULT_TIMEOUT_SECONDS = 600.0 # 10 minutes per task +DEFAULT_MAX_RETRIES = 1 # Skip on first failure + # GitHub raw content URLs for v0.2.0 tag GITHUB_BASE = "https://raw.githubusercontent.com/sierra-research/tau2-bench" DEFAULT_VERSION = "v0.2.0" @@ -257,7 +262,9 @@ def load_tasks( split: str = "base", data_dir: Optional[Path] = None, limit: Optional[int] = None, -) -> TaskCollection: + timeout_seconds: Optional[float] = DEFAULT_TIMEOUT_SECONDS, + max_retries: int = DEFAULT_MAX_RETRIES, +) -> TaskQueue: """Load tasks for a tau2 domain. Args: @@ -265,15 +272,19 @@ def load_tasks( split: One of "base", "hard", "all" (base recommended for reproducibility) data_dir: Base data directory (default: module's data/) limit: Maximum number of tasks to load + timeout_seconds: Maximum execution time per task in seconds. Default 600 (10 minutes). + Set to None to disable timeout. + max_retries: Maximum retry attempts for transient failures. Default 1 (skip on failure). Returns: - TaskCollection containing Task objects with: + TaskQueue containing Task objects with: - id: Task identifier from tau2 data - query: Initial user message (from user_scenario) - environment_data: Domain tools, database state, policies - evaluation_data: Assertions, expected outcomes - user_data: User profile, instructions - metadata: domain, split, description + - protocol: Execution settings (timeout, retries, tags) Raises: ValueError: If domain or split is invalid @@ -283,6 +294,9 @@ def load_tasks( >>> tasks = load_tasks("retail", split="base", limit=5) >>> len(tasks) 5 + + >>> # Custom timeout and retries + >>> tasks = load_tasks("retail", timeout_seconds=300, max_retries=2) """ if domain not in VALID_DOMAINS: raise ValueError(f"Invalid domain '{domain}'. Must be one of {VALID_DOMAINS}") @@ -313,10 +327,10 @@ def load_tasks( # Convert to MASEval Task objects tasks = [] for raw_task in raw_tasks: - task = _convert_tau2_task_to_maseval(raw_task, domain, split, domain_config) + task = _convert_tau2_task_to_maseval(raw_task, domain, split, domain_config, timeout_seconds, max_retries) tasks.append(task) - return TaskCollection(tasks) + return TaskQueue(tasks) def _convert_tau2_task_to_maseval( @@ -324,6 +338,8 @@ def _convert_tau2_task_to_maseval( domain: str, split: str, domain_config: Dict[str, Any], + timeout_seconds: Optional[float], + max_retries: int, ) -> Task: """Convert a tau2-bench task dict to MASEval Task. @@ -332,6 +348,8 @@ def _convert_tau2_task_to_maseval( domain: Domain name split: Split name domain_config: Domain configuration with policy and db_path + timeout_seconds: Maximum execution time per task in seconds + max_retries: Maximum retry attempts for transient failures Returns: MASEval Task object @@ -380,6 +398,13 @@ def _convert_tau2_task_to_maseval( "ticket": raw_task.get("ticket"), # For solo mode (not used) } + # Build execution protocol with timeout, retries, and tags + protocol = TaskProtocol( + timeout_seconds=timeout_seconds, + max_retries=max_retries, + tags={"domain": domain, "split": split}, + ) + # Build task kwargs, only include id if provided in raw task task_kwargs: Dict[str, Any] = { "query": query, @@ -387,6 +412,7 @@ def _convert_tau2_task_to_maseval( "evaluation_data": evaluation_data, "user_data": user_data, "metadata": metadata, + "protocol": protocol, } if raw_task.get("id"): task_kwargs["id"] = str(raw_task["id"]) @@ -440,18 +466,18 @@ def load_domain_config( def configure_model_ids( - tasks: Union[TaskCollection, List[Task]], + tasks: Union[TaskQueue, List[Task]], *, user_model_id: Optional[str] = None, evaluator_model_id: Optional[str] = None, -) -> Union[TaskCollection, List[Task]]: +) -> Union[TaskQueue, List[Task]]: """Configure model IDs for benchmark components in task data. Tau2 tools execute real business logic and don't need a tool_model_id. Only user simulation and evaluation use LLMs. Args: - tasks: TaskCollection or list of Tasks to configure + tasks: TaskQueue or list of Tasks to configure user_model_id: Model ID for user simulator (stored in user_data) evaluator_model_id: Model ID for evaluators (stored in evaluation_data) diff --git a/maseval/benchmark/tau2/evaluator.py b/maseval/benchmark/tau2/evaluator.py index e5e6b07..3023be5 100644 --- a/maseval/benchmark/tau2/evaluator.py +++ b/maseval/benchmark/tau2/evaluator.py @@ -23,12 +23,23 @@ from enum import Enum from typing import Any, Dict, List, Optional -from maseval import Evaluator, Task +from maseval import Evaluator, Task, TaskExecutionStatus from maseval.benchmark.tau2.environment import Tau2Environment, get_environment_constructor from maseval.benchmark.tau2.utils import compare_tool_calls +# Statuses where agent is accountable (included in scoring) +# Note: task_timeout is included - timeouts count as failures in tau2 +SCOREABLE_STATUSES = frozenset( + { + TaskExecutionStatus.SUCCESS.value, + TaskExecutionStatus.AGENT_ERROR.value, + TaskExecutionStatus.TASK_TIMEOUT.value, + } +) + + class RewardType(str, Enum): """Types of rewards that can be computed. @@ -447,6 +458,7 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: """Compute summary metrics across all benchmark results. Infrastructure errors are excluded from scoring metrics. + Uses SCOREABLE_STATUSES to determine which results count toward agent score. Args: results: List of result dicts from benchmark.run() @@ -454,14 +466,6 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: Returns: Dict with success_rate, mean_reward, pass_at_k, status_counts """ - INFRASTRUCTURE_STATUSES = { - "environment_error", - "user_error", - "unknown_execution_error", - "evaluation_failed", - "setup_failed", - } - if not results: return { "total_tasks": 0, @@ -482,8 +486,8 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: status = res.get("status", "unknown") status_counts[status] = status_counts.get(status, 0) + 1 - if status in INFRASTRUCTURE_STATUSES: - continue + if status not in SCOREABLE_STATUSES: + continue # Skip infrastructure errors scored_tasks += 1 evals = res.get("eval") or [] @@ -529,7 +533,7 @@ def compute_pass_at_k( task_results: Dict[str, List[bool]] = {} for res in results: task_id = res.get("task_id", "") - if res.get("status") not in {"success", "agent_error"}: + if res.get("status") not in SCOREABLE_STATUSES: continue # Skip infrastructure errors evals = res.get("eval") or [] @@ -615,7 +619,7 @@ def compute_pass_hat_k( task_results: Dict[str, List[bool]] = {} for res in results: task_id = res.get("task_id", "") - if res.get("status") not in {"success", "agent_error"}: + if res.get("status") not in SCOREABLE_STATUSES: continue # Skip infrastructure errors evals = res.get("eval") or [] diff --git a/maseval/benchmark/tau2/tau2.py b/maseval/benchmark/tau2/tau2.py index 5b02522..e917c60 100644 --- a/maseval/benchmark/tau2/tau2.py +++ b/maseval/benchmark/tau2/tau2.py @@ -62,6 +62,7 @@ def get_model_adapter(self, model_id, **kwargs): from maseval import AgentAdapter, Benchmark, Evaluator, ModelAdapter, Task, User from maseval.core.user import AgenticUser +from maseval.core.callback import BenchmarkCallback from maseval.benchmark.tau2.environment import Tau2Environment from maseval.benchmark.tau2.evaluator import Tau2Evaluator @@ -231,15 +232,44 @@ def get_model_adapter(self, model_id, **kwargs): benchmark.run(tasks) """ - def __init__(self, *args: Any, max_invocations: int = 50, **kwargs: Any): + # Maximum agent-user interaction rounds (tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps) + MAX_INVOCATIONS = 50 + + def __init__( + self, + callbacks: Optional[List[BenchmarkCallback]] = None, + n_task_repeats: int = 1, + max_invocations: int = MAX_INVOCATIONS, + num_workers: int = 1, + fail_on_setup_error: bool = False, + fail_on_task_error: bool = False, + fail_on_evaluation_error: bool = False, + progress_bar: bool | str = True, + ): """Initialize benchmark with tau2-specific defaults. Args: + callbacks: Optional list of callback handlers for monitoring execution. + n_task_repeats: Number of times to repeat each task. Default 1. max_invocations: Maximum agent-user interaction rounds (default: 50). tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps. - *args, **kwargs: Passed to parent Benchmark class + num_workers: Number of parallel task executions. Default 1 (sequential). + fail_on_setup_error: If True, raise on setup errors. Default False. + fail_on_task_error: If True, raise on task execution errors. Default False. + fail_on_evaluation_error: If True, raise on evaluation errors. Default False. + progress_bar: Progress display. True (default) for tqdm, "rich" for Rich, + or False to disable. """ - super().__init__(*args, max_invocations=max_invocations, **kwargs) # type: ignore[parameter-already-assigned] + super().__init__( + callbacks=callbacks, + n_task_repeats=n_task_repeats, + max_invocations=max_invocations, + num_workers=num_workers, + fail_on_setup_error=fail_on_setup_error, + fail_on_task_error=fail_on_task_error, + fail_on_evaluation_error=fail_on_evaluation_error, + progress_bar=progress_bar, + ) def _get_user_model_id(self, task: Task) -> str: """Get user simulator model ID from task.user_data. @@ -849,14 +879,6 @@ class DefaultAgentTau2Benchmark(Tau2Benchmark): results = benchmark.run(tasks) """ - # Cache for model adapters - _model_cache: Dict[str, ModelAdapter] - - def __init__(self, *args: Any, **kwargs: Any): - """Initialize the default agent benchmark. See Tau2Benchmark for args.""" - super().__init__(*args, **kwargs) - self._model_cache = {} - def _get_agent_model_id(self, agent_data: Dict[str, Any]) -> str: """Get agent model ID from agent_data. @@ -939,5 +961,9 @@ def get_model_adapter(self, model_id: str, **kwargs: Any) -> ModelAdapter: Returns: ModelAdapter instance + + Note: + DefaultAgentTau2Benchmark uses lazy initialization for model caching. + Access via `getattr(self, '_model_cache', {})` in subclass implementations. """ pass diff --git a/maseval/core/benchmark.py b/maseval/core/benchmark.py index bfec7cc..c75e160 100644 --- a/maseval/core/benchmark.py +++ b/maseval/core/benchmark.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Iterable, Optional, Sequence, Tuple, Union, cast -from datetime import datetime import threading +from concurrent.futures import ThreadPoolExecutor from enum import Enum import warnings +import traceback +import logging from .evaluator import Evaluator -from .task import Task, TaskCollection +from .task import Task, BaseTaskQueue, SequentialTaskQueue from .environment import Environment from .agent import AgentAdapter from .model import ModelAdapter @@ -14,14 +16,15 @@ from .callback import BenchmarkCallback from .user import User from .tracing import TraceableMixin -from .config import ConfigurableMixin +from .registry import ComponentRegistry +from .context import TaskContext from .utils.system_info import gather_benchmark_config from .callbacks.progress_bar import ( ProgressBarCallback, TqdmProgressBarCallback, RichProgressBarCallback, ) -from .exceptions import AgentError, EnvironmentError, UserError +from .exceptions import AgentError, EnvironmentError, UserError, TaskTimeoutError class TaskExecutionStatus(Enum): @@ -39,13 +42,14 @@ class TaskExecutionStatus(Enum): AGENT_ERROR: Agent violated contract at a boundary (agent's fault, counts against score). ENVIRONMENT_ERROR: Environment/tool infrastructure failed (not agent's fault, exclude from scoring). USER_ERROR: User simulator failed (not agent's fault, exclude from scoring). + TASK_TIMEOUT: Task execution exceeded configured timeout (resource constraint). UNKNOWN_EXECUTION_ERROR: Unclassified execution error (e.g., agent framework internal failure). EVALUATION_FAILED: Task executed but evaluation raised an exception. SETUP_FAILED: Setup phase (environment, agents, evaluators) raised an exception. Scoring Guidance: - Include in agent score: SUCCESS, AGENT_ERROR - - Exclude from agent score: ENVIRONMENT_ERROR, USER_ERROR, UNKNOWN_EXECUTION_ERROR + - Exclude from agent score: ENVIRONMENT_ERROR, USER_ERROR, TASK_TIMEOUT, UNKNOWN_EXECUTION_ERROR - Handle separately: EVALUATION_FAILED, SETUP_FAILED """ @@ -53,6 +57,7 @@ class TaskExecutionStatus(Enum): AGENT_ERROR = "agent_error" ENVIRONMENT_ERROR = "environment_error" USER_ERROR = "user_error" + TASK_TIMEOUT = "task_timeout" UNKNOWN_EXECUTION_ERROR = "unknown_execution_error" EVALUATION_FAILED = "evaluation_failed" SETUP_FAILED = "setup_failed" @@ -94,21 +99,26 @@ def run_agents(self, agents, task, environment, query): # ... implement other abstract methods # Run the benchmark - benchmark = MyBenchmark(agent_data=config) - reports = benchmark.run(tasks=my_tasks) + config = {"model": "gpt-4", "temperature": 0.7} + benchmark = MyBenchmark() + reports = benchmark.run(tasks=my_tasks, agent_data=config) # Retry failed tasks elegantly (graceful task failure handling by default) failed_tasks = benchmark.get_failed_tasks() if len(failed_tasks) > 0: - retry_reports = benchmark.run(tasks=failed_tasks) + retry_reports = benchmark.run(tasks=failed_tasks, agent_data=config) + + # Parallel execution for I/O-bound workloads + benchmark = MyBenchmark(num_workers=4) + reports = benchmark.run(tasks=my_tasks, agent_data=config) # Or use strict mode for debugging (fail fast) benchmark = MyBenchmark( - agent_data=config, fail_on_task_error=True, fail_on_evaluation_error=True, fail_on_setup_error=True ) + reports = benchmark.run(tasks=my_tasks, agent_data=config) ``` The framework handles task iteration, repetitions for statistical robustness, callback @@ -118,21 +128,18 @@ def run_agents(self, agents, task, environment, query): def __init__( self, - agent_data: Optional[Dict[str, Any] | Iterable[Dict[str, Any]]] = None, callbacks: Optional[List[BenchmarkCallback]] = None, n_task_repeats: int = 1, max_invocations: int = 1, + num_workers: int = 1, fail_on_setup_error: bool = False, fail_on_task_error: bool = False, fail_on_evaluation_error: bool = False, progress_bar: bool | str = True, ): - """Initialize a benchmark with agent configurations. + """Initialize a benchmark with execution configuration. Args: - agent_data: Configuration for agents. Either a single dict applied to all tasks, or - an iterable of dicts with one configuration per task. If None, defaults to empty dict. - Agent data typically includes model parameters, agent architecture details, and tool specifications. callbacks: Optional list of callback handlers for monitoring execution, tracing messages, or collecting custom metrics during the benchmark run. n_task_repeats: Number of times to repeat each task. Useful for measuring variance in @@ -141,6 +148,9 @@ def __init__( For simple benchmarks, the default (1) means agents run once per task. For interactive benchmarks with user feedback loops, set higher (e.g., 5 for MACS) to allow multiple agent-user interaction rounds. + num_workers: Number of parallel task executions. Default 1 (sequential). + Set higher for I/O-bound workloads (e.g., LLM API calls). This controls the + ThreadPoolExecutor worker count for concurrent task processing. fail_on_setup_error: If True, raise exceptions when setup fails (environment, agents, evaluators). If False (default), catch exceptions during setup and record them in the report with status SETUP_FAILED. This allows the benchmark to continue running remaining tasks even if setup fails. @@ -163,54 +173,34 @@ def __init__( ValueError: If n_task_repeats is less than 1. How to use: - Provide either a single agent configuration for all tasks, or task-specific configurations: + Configure execution settings at initialization: ```python - # Single config for all tasks - benchmark = MyBenchmark(agent_data={"model": "gpt-4", "temperature": 0.7}) - - # Task-specific configs (will be validated in run() based on task count) - benchmark = MyBenchmark( - agent_data=[ - {"model": "gpt-4", "config": "easy"}, - {"model": "gpt-4", "config": "hard"} - ] - ) + # Sequential execution (default) + benchmark = MyBenchmark() - # Enable failure-safe execution (default behavior) - benchmark = MyBenchmark( - agent_data=config, - fail_on_task_error=False, # Continue on task failures - fail_on_evaluation_error=False # Continue on evaluation failures - ) + # Parallel execution for faster I/O-bound workloads + benchmark = MyBenchmark(num_workers=4) # Strict mode - fail fast on any error (useful for debugging) benchmark = MyBenchmark( - agent_data=config, fail_on_task_error=True, fail_on_evaluation_error=True, fail_on_setup_error=True ) - # Progress bar configuration (automatically adds a callback) - benchmark = MyBenchmark(agent_data=config) # Default: adds TqdmProgressBarCallback - benchmark = MyBenchmark(agent_data=config, progress_bar=True) # Explicit: TqdmProgressBarCallback - benchmark = MyBenchmark(agent_data=config, progress_bar="tqdm") # Same as True - benchmark = MyBenchmark(agent_data=config, progress_bar="rich") # Uses RichProgressBarCallback - benchmark = MyBenchmark(agent_data=config, progress_bar=False) # No automatic callback + # Progress bar configuration + benchmark = MyBenchmark() # Default: adds TqdmProgressBarCallback + benchmark = MyBenchmark(progress_bar=True) # Explicit: TqdmProgressBarCallback + benchmark = MyBenchmark(progress_bar="rich") # Uses RichProgressBarCallback + benchmark = MyBenchmark(progress_bar=False) # No automatic callback - # Progress bar configuration (manually add a callback) - benchmark = MyBenchmark( - agent_data=config, - callbacks=[MyCustomProgressBarCallback()] # User-defined progress bar - ) + # Custom callbacks + benchmark = MyBenchmark(callbacks=[MyCustomProgressBarCallback()]) ``` """ - # Store agent_data, defaulting to empty dict if None - self.agent_data = agent_data if agent_data is not None else {} - - # Initialize tasks to empty collection (will be set in run()) - self.tasks = TaskCollection([]) + # Initialize tasks to empty queue (will be set in run()) + self.tasks: BaseTaskQueue = SequentialTaskQueue([]) self.callback_handler = CallbackHandler() self.callbacks = callbacks or [] @@ -233,29 +223,29 @@ def __init__( if self.n_task_repeats < 1: raise ValueError("n_task_repeats must be at least 1") - # Execution loop configuration + # Execution configuration self.max_invocations = max_invocations + self.num_workers = num_workers # Failure handling configuration self.fail_on_task_error = fail_on_task_error self.fail_on_evaluation_error = fail_on_evaluation_error self.fail_on_setup_error = fail_on_setup_error - # Registry for Traceable components (cleared after each task repetition) - self._trace_registry: Dict[str, TraceableMixin] = {} - self._component_id_map: Dict[int, str] = {} # Maps id(component) -> registry key + # Gather benchmark-level configuration (system, git, packages, etc.) + self.benchmark_level_config = gather_benchmark_config() + + # Thread-safe component registry (replaces inline registry dicts) + self._registry = ComponentRegistry(benchmark_config=self.benchmark_level_config) - # Registry for Configurable components (cleared after each task repetition) - self._config_registry: Dict[str, ConfigurableMixin] = {} - self._config_component_id_map: Dict[int, str] = {} # Maps id(component) -> registry key + # Thread safety locks for parallel execution + self._reports_lock = threading.Lock() + self._callback_lock = threading.Lock() # Persistent benchmark-level reports (stored across all task repetitions) # Each report contains both traces and configs for a single task repetition self.reports: List[Dict[str, Any]] = [] - # Gather benchmark-level configuration (system, git, packages, etc.) - self.benchmark_level_config = gather_benchmark_config() - def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: """Register a component for comprehensive trace and configuration collection. @@ -304,35 +294,7 @@ def setup_agents(self, agent_data, environment, task, user): `collect_all_traces()` and `collect_all_configs()` which are called internally by the `run()` method. """ - # Check if this component is already registered for traces - component_id = id(component) - if component_id in self._component_id_map: - existing_key = self._component_id_map[component_id] - existing_category, existing_name = existing_key.split(":", 1) - new_key = f"{category}:{name}" - - if existing_key == new_key: - # Same component, same name - silently accept (idempotent) - return component - else: - raise ValueError( - f"Component is already registered as '{existing_key}' and cannot be " - f"re-registered as '{new_key}'. Note: Environments, users, and agents " - f"returned from setup methods are automatically registered." - ) - - key = f"{category}:{name}" - - # Register for trace collection - self._trace_registry[key] = component - self._component_id_map[component_id] = key - - # Also register for configuration collection if component supports it - if isinstance(component, ConfigurableMixin): - self._config_registry[key] = component - self._config_component_id_map[component_id] = key - - return component + return self._registry.register(category, name, component) def clear_registry(self) -> None: """Clear the component registry after a task repetition completes. @@ -341,10 +303,7 @@ def clear_registry(self) -> None: to ensure components are not carried over between repetitions. The reports list persists across all repetitions for aggregated analysis. """ - self._trace_registry.clear() - self._component_id_map.clear() - self._config_registry.clear() - self._config_component_id_map.clear() + self._registry.clear() def collect_all_traces(self) -> Dict[str, Any]: """Collect execution traces from all registered components for the current task repetition. @@ -389,61 +348,7 @@ def collect_all_traces(self) -> Dict[str, Any]: The collected traces are passed to the evaluator's `evaluate()` method and stored in `benchmark.reports` for later analysis. """ - traces: Dict[str, Any] = { - "metadata": { - "timestamp": datetime.now().isoformat(), - "thread_id": threading.current_thread().ident, - "total_components": len(self._trace_registry), - }, - "agents": {}, - "models": {}, - "tools": {}, - "simulators": {}, - "callbacks": {}, - "environment": None, - "user": None, - "other": {}, - } - - for key, component in self._trace_registry.items(): - category, comp_name = key.split(":", 1) - - try: - component_traces = component.gather_traces() - - # Inject name from registry if component doesn't have it - # Is this intervention obfuscating the mechnaisms too much? - if "name" not in component_traces: - component_traces["name"] = comp_name - - # Handle environment and user as direct values (not nested in dict) - if category == "environment": - traces["environment"] = component_traces - elif category == "user": - traces["user"] = component_traces - else: - # Ensure category exists in traces - if category not in traces: - traces[category] = {} - traces[category][comp_name] = component_traces - except Exception as e: - # Gracefully handle tracing errors - error_info = { - "error": f"Failed to gather traces: {e}", - "error_type": type(e).__name__, - "component_type": type(component).__name__, - } - - if category == "environment": - traces["environment"] = error_info - elif category == "user": - traces["user"] = error_info - else: - if category not in traces: - traces[category] = {} - traces[category][comp_name] = error_info - - return traces + return self._registry.collect_traces() def collect_all_configs(self) -> Dict[str, Any]: """Collect configuration from all registered components for the current task repetition. @@ -490,62 +395,62 @@ def collect_all_configs(self) -> Dict[str, Any]: The collected configs are available in the results for reproducibility analysis. """ - configs: Dict[str, Any] = { - "metadata": { - "timestamp": datetime.now().isoformat(), - "thread_id": threading.current_thread().ident, - "total_components": len(self._config_registry), - }, - "agents": {}, - "models": {}, - "tools": {}, - "simulators": {}, - "callbacks": {}, - "environment": None, - "user": None, - "other": {}, - "benchmark": self.benchmark_level_config, # Include benchmark-level config - } + return self._registry.collect_configs() - for key, component in self._config_registry.items(): - category, comp_name = key.split(":", 1) + def _invoke_callbacks(self, method_name: str, *args, suppress_errors: bool = True, **kwargs) -> List[Exception]: + """Invoke a callback method on all registered callbacks (thread-safe). - try: - component_config = component.gather_config() - - # Inject name from registry if component doesn't have it - # Is this intervention obfuscating the mechnaisms too much? - if "name" not in component_config: - component_config["name"] = comp_name - - # Handle environment and user as direct values (not nested in dict) - if category == "environment": - configs["environment"] = component_config - elif category == "user": - configs["user"] = component_config - else: - # Ensure category exists in configs - if category not in configs: - configs[category] = {} - configs[category][comp_name] = component_config - except Exception as e: - # Gracefully handle config gathering errors - error_info = { - "error": f"Failed to gather config: {e}", - "error_type": type(e).__name__, - "component_type": type(component).__name__, - } + This method serializes all callback invocations using an internal lock, + so users don't need to implement thread-safe callbacks. - if category == "environment": - configs["environment"] = error_info - elif category == "user": - configs["user"] = error_info - else: - if category not in configs: - configs[category] = {} - configs[category][comp_name] = error_info + Callback errors are caught and logged by default to prevent one failing + callback from disrupting the entire benchmark run. This is especially + important in parallel execution where callback failures could otherwise + cause difficult-to-debug issues. + + Args: + method_name: Name of the callback method to invoke (e.g., "on_task_start"). + *args: Positional arguments to pass to the callback method. + suppress_errors: If True (default), catch and log callback errors instead + of propagating them. If False, first callback error will be raised. + **kwargs: Keyword arguments to pass to the callback method. - return configs + Returns: + List of exceptions that occurred during callback invocation (empty if none). + + Raises: + Exception: First callback exception if suppress_errors=False. + """ + errors: List[Exception] = [] + logger = logging.getLogger(__name__) + + with self._callback_lock: + for cb in self.callbacks: + method = getattr(cb, method_name, None) + if method is not None: + try: + method(*args, **kwargs) + except Exception as e: + if not suppress_errors: + raise + + # Log error with full context + logger.error( + f"Callback {cb.__class__.__name__}.{method_name}() failed: {e}", + exc_info=True, + ) + errors.append(e) + + return errors + + def _append_report_safe(self, report: Dict[str, Any]) -> None: + """Append a report to the reports list (thread-safe). + + Args: + report: The report dictionary to append. + """ + with self._reports_lock: + self.reports.append(report) def add_callback(self, callback: BenchmarkCallback) -> None: """Register a callback handler to monitor benchmark execution. @@ -992,12 +897,427 @@ def __init__(self, ...): return final_answer - def run(self, tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]]) -> List[Dict[str, Any]]: + def _execute_task_repetition( + self, + task: Task, + agent_data: Dict[str, Any], + repeat_idx: int, + ) -> Dict[str, Any]: + """Execute a single task repetition with timeout handling. + + This method encapsulates the complete execution of one task repetition, + including setup, execution, trace collection, and evaluation. It is + designed to be called from both sequential and parallel execution paths. + + Args: + task: The task to execute. + agent_data: Agent configuration for this task. + repeat_idx: Repetition index (0 to n_task_repeats-1). + + Returns: + Report dictionary containing execution results. + """ + # Initialize status and error tracking + execution_status = TaskExecutionStatus.SUCCESS + error_info: Optional[Dict[str, Any]] = None + final_answers: Any = None + eval_results: Any = None + execution_traces: Dict[str, Any] = {} + execution_configs: Dict[str, Any] = {} + evaluators: Sequence[Evaluator] = [] + agents_dict: Dict[str, AgentAdapter] = {} + + # Create execution context with optional timeout + timeout = task.protocol.timeout_seconds + context = TaskContext(deadline=timeout) + + try: + # 1. Setup + environment = self.setup_environment(agent_data, task) + user = self.setup_user(agent_data, environment, task) + if user is None and self.max_invocations > 1: + # Warn if multi-turn is enabled but no user to drive interaction + warnings.warn( + f"max_invocations={self.max_invocations} > 1 but no user simulator provided. " + f"Falling back to single-turn execution for task {task.id}." + ) + agents_to_run, agents_dict = self.setup_agents(agent_data, environment, task, user) + evaluators = self.setup_evaluators(environment, task, agents_to_run, user) + + # Auto-register components returned from setup methods + # Environment + if environment is not None and isinstance(environment, TraceableMixin): + self.register("environment", "env", environment) + + # User + if user is not None and isinstance(user, TraceableMixin): + self.register("user", "user", user) + + # Agents (use their names from agents_dict) + for agent_name, agent in agents_dict.items(): + if isinstance(agent, TraceableMixin): + self.register("agents", agent_name, agent) + + # Check timeout after setup + context.check_timeout() + + except TaskTimeoutError as e: + # Timeout during setup + execution_status = TaskExecutionStatus.TASK_TIMEOUT + error_info = { + "error_type": "TaskTimeoutError", + "error_message": str(e), + "elapsed": e.elapsed, + "timeout": e.timeout, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + # Create a minimal report for this timeout + report = { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": execution_status.value, + "error": error_info, + "traces": e.partial_traces, + "config": {}, + "eval": None, + } + self.clear_registry() + return report + + except Exception as e: + # Setup failed - record error + execution_status = TaskExecutionStatus.SETUP_FAILED + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + # Create a minimal report for this failed setup + report = { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": execution_status.value, + "error": error_info, + "traces": {}, + "config": {}, + "eval": None, + } + self.clear_registry() + + if self.fail_on_setup_error: + raise + + return report + + # 2. Execute agent system with optional user interaction loop + try: + # Check timeout before execution + context.check_timeout() + final_answers = self.execution_loop(agents_to_run, task, environment, user) + except TaskTimeoutError as e: + # Task timed out during execution + execution_status = TaskExecutionStatus.TASK_TIMEOUT + error_info = { + "error_type": "TaskTimeoutError", + "error_message": str(e), + "elapsed": e.elapsed, + "timeout": e.timeout, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + final_answers = None + except AgentError as e: + # Agent violated contract at boundary (agent's fault) + execution_status = TaskExecutionStatus.AGENT_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "component": e.component, + "details": e.details, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + except EnvironmentError as e: + # Environment/tool infrastructure failed (not agent's fault) + execution_status = TaskExecutionStatus.ENVIRONMENT_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "component": e.component, + "details": e.details, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + except UserError as e: + # User simulator failed (not agent's fault) + execution_status = TaskExecutionStatus.USER_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "component": e.component, + "details": e.details, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + except Exception as e: + # Unclassified error (e.g., agent framework internal failure) + execution_status = TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_task_error: + self.clear_registry() + raise + + final_answers = None + + # 3. Collect traces and configs (always attempt this) + try: + execution_configs = self.collect_all_configs() + execution_traces = self.collect_all_traces() + # Store in context for potential timeout errors + context.set_collected_traces(execution_traces) + except Exception as e: + # If trace/config collection fails, record it but continue + execution_configs = { + "error": f"Failed to collect configs: {e}", + "error_type": type(e).__name__, + } + execution_traces = { + "error": f"Failed to collect traces: {e}", + "error_type": type(e).__name__, + } + + # 4. Evaluate (skip if task execution failed) + if execution_status == TaskExecutionStatus.SUCCESS: + try: + # Check timeout before evaluation + context.check_timeout() + eval_results = self.evaluate(evaluators, agents_dict, final_answers, execution_traces) + except TaskTimeoutError as e: + execution_status = TaskExecutionStatus.TASK_TIMEOUT + error_info = { + "error_type": "TaskTimeoutError", + "error_message": str(e), + "elapsed": e.elapsed, + "timeout": e.timeout, + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + eval_results = None + except Exception as e: + execution_status = TaskExecutionStatus.EVALUATION_FAILED + error_info = { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if self.fail_on_evaluation_error: + self.clear_registry() + raise + + eval_results = None + else: + # Task execution failed, so skip evaluation + eval_results = None + + # 5. Build report + report: Dict[str, Any] = { + "task_id": str(task.id), + "repeat_idx": repeat_idx, + "status": execution_status.value, + "traces": execution_traces, + "config": execution_configs, + "eval": eval_results, + } + + # Add error info if present + if error_info is not None: + report["error"] = error_info + + # Clear registry after task repetition completes + self.clear_registry() + + return report + + def _run_sequential( + self, + queue: BaseTaskQueue, + agent_data_lookup: Dict[str, Dict[str, Any]], + ) -> None: + """Execute tasks sequentially with optional timeout support. + + Args: + queue: Task queue providing task ordering. + agent_data_lookup: Mapping from task_id to agent_data configuration. + """ + for task in queue: + agent_data = agent_data_lookup[str(task.id)] + + # Callbacks at the start of each task + self._invoke_callbacks("on_task_start", self, task) + + for repeat_idx in range(self.n_task_repeats): + self._invoke_callbacks("on_task_repeat_start", self, task, repeat_idx) + + report = self._execute_task_repetition(task, agent_data, repeat_idx) + self._append_report_safe(report) + + self._invoke_callbacks("on_task_repeat_end", self, report) + + # Callbacks at the end of each task + task_reports = [r for r in self.reports if r["task_id"] == str(task.id)] + last_report = task_reports[-1] if task_reports else {} + self._invoke_callbacks("on_task_end", self, task, last_report) + + def _run_parallel( + self, + queue: BaseTaskQueue, + agent_data_lookup: Dict[str, Dict[str, Any]], + num_workers: int, + ) -> None: + """Execute tasks in parallel with thread pool. + + Args: + queue: Task queue providing task ordering. + agent_data_lookup: Mapping from task_id to agent_data configuration. + num_workers: Number of concurrent workers. + """ + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures: Dict[Any, Tuple[Task, int]] = {} + task_repeat_counts: Dict[str, int] = {} # Track submitted repeats per task + + def submit_task_repeats(task: Task) -> None: + """Submit all repeats for a task.""" + task_id = str(task.id) + task_repeat_counts[task_id] = 0 + agent_data = agent_data_lookup[task_id] + + self._invoke_callbacks("on_task_start", self, task) + + for repeat_idx in range(self.n_task_repeats): + self._invoke_callbacks("on_task_repeat_start", self, task, repeat_idx) + + future = executor.submit( + self._execute_task_repetition, + task, + agent_data, + repeat_idx, + ) + futures[future] = (task, repeat_idx) + task_repeat_counts[task_id] += 1 + + # Submit initial batch from queue + submitted_tasks: List[Task] = [] + queue_iter = iter(queue) # Create iterator once + queue_exhausted = False + + # Submit initial batch + try: + while len(futures) < num_workers * 2: + task = next(queue_iter) + submit_task_repeats(task) + submitted_tasks.append(task) + except StopIteration: + queue_exhausted = True + + # Process completions + completed_task_ids: set = set() + + while futures: + # Wait for at least one completion + done_futures = [] + for future in list(futures.keys()): + if future.done(): + done_futures.append(future) + + if not done_futures: + # No futures done yet, wait a bit + import time + + time.sleep(0.01) + continue + + for future in done_futures: + task, repeat_idx = futures.pop(future) + task_id = str(task.id) + + try: + report = future.result() + except Exception as e: + # Create error report for unexpected failures + report = { + "task_id": task_id, + "repeat_idx": repeat_idx, + "status": TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value, + "error": { + "error_type": type(e).__name__, + "error_message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + }, + "traces": {}, + "config": {}, + "eval": None, + } + + self._append_report_safe(report) + + self._invoke_callbacks("on_task_repeat_end", self, report) + + # Check if all repeats for this task are done + task_reports = [r for r in self.reports if r["task_id"] == task_id] + if len(task_reports) >= self.n_task_repeats and task_id not in completed_task_ids: + completed_task_ids.add(task_id) + last_report = task_reports[-1] if task_reports else {} + self._invoke_callbacks("on_task_end", self, task, last_report) + + # Submit more work if queue not exhausted + if not queue_exhausted and len(futures) < num_workers: + try: + task = next(queue_iter) + submit_task_repeats(task) + submitted_tasks.append(task) + except StopIteration: + queue_exhausted = True + + def run( + self, + tasks: Union[Task, BaseTaskQueue, Iterable[Union[Task, dict]]], + agent_data: Dict[str, Any] | Iterable[Dict[str, Any]], + ) -> List[Dict[str, Any]]: """Initialize and execute the complete benchmark loop across all tasks. Args: - tasks: Collection of tasks to execute. Can be a single Task, TaskCollection, - list of Task objects, or list of dicts that will be converted to Tasks. + tasks: Task source for execution. Can be: + - A single Task object + - A BaseTaskQueue (SequentialTaskQueue, PriorityTaskQueue, or custom AdaptiveTaskQueue) + - An iterable of Task objects or dicts that will be converted to Tasks + + When a BaseTaskQueue is provided, it controls the task ordering. AdaptiveTaskQueue + subclasses are automatically registered as callbacks to receive task completion + notifications. + agent_data: Configuration for agents. Either a single dict applied to all tasks, or + an iterable of dicts with one configuration per task. Agent data typically includes + model parameters, agent architecture details, and tool specifications. Returns: List of report dictionaries, one per task repetition. Each report contains: @@ -1066,279 +1386,123 @@ def run(self, tasks: Union[Task, TaskCollection, Iterable[Union[Task, dict]]]) - ```python # Typical usage - benchmark = MyBenchmark(agent_data=config) - reports = benchmark.run(tasks=tasks) + benchmark = MyBenchmark() + reports = benchmark.run(tasks=tasks, agent_data=config) # Analyze results for report in reports: print(f"Task {report['task_id']}, Repeat {report['repeat_idx']}: {report['eval']}") print(f"Config: {report['config']}") print(f"Traces: {report['traces']}") + + # Parallel execution with 4 workers + benchmark = MyBenchmark(num_workers=4) + reports = benchmark.run(tasks=tasks, agent_data=config) + + # Single agent config for all tasks + reports = benchmark.run(tasks=tasks, agent_data={"model": "gpt-4"}) + + # Task-specific agent configs (must match task count) + reports = benchmark.run( + tasks=tasks, + agent_data=[ + {"model": "gpt-4", "difficulty": "easy"}, + {"model": "gpt-4", "difficulty": "hard"}, + ] + ) + + # Priority-based execution + from maseval.core.task import PriorityTaskQueue + for task in tasks: + task.protocol.priority = compute_priority(task) + queue = PriorityTaskQueue(tasks) + reports = benchmark.run(tasks=queue, agent_data=config) + + # Adaptive queue (auto-registered as callback) + queue = MyAdaptiveTaskQueue(tasks) + reports = benchmark.run(tasks=queue) # queue receives on_task_complete callbacks ``` """ - # Normalize tasks into a TaskCollection + # Normalize tasks into a queue + queue: BaseTaskQueue if isinstance(tasks, Task): - # Single task - self.tasks = TaskCollection([tasks]) - elif isinstance(tasks, TaskCollection): - self.tasks = tasks + # Single task - wrap in SequentialTaskQueue + queue = SequentialTaskQueue([tasks]) + elif isinstance(tasks, BaseTaskQueue): + # Already a queue - use directly + queue = tasks else: - # Iterable of tasks or dicts - self.tasks = TaskCollection.from_list(list(tasks)) + # Iterable of tasks or dicts - wrap in SequentialTaskQueue + queue = SequentialTaskQueue.from_list(list(tasks)) - # Normalize agent_data into a list matching the number of tasks - if isinstance(self.agent_data, dict): - # Single config for all tasks - agent_data_list: List[Dict[str, Any]] = [cast(Dict[str, Any], self.agent_data) for _ in range(len(self.tasks))] - else: - # Task-specific configs - agent_data_list = list(self.agent_data) + # Store tasks reference for get_failed_tasks() compatibility + self.tasks = queue - if len(agent_data_list) != len(self.tasks): - raise ValueError( - f"`agent_data` must either be a single dict or an iterable matching the number of tasks. " - f"Got {len(agent_data_list)} agent configs for {len(self.tasks)} tasks." - ) + # Build agent_data lookup (task_id -> agent_data) + agent_data_lookup = self._build_agent_data_lookup(queue, agent_data) # Clear reports from previous run() calls to prevent accumulation self.reports = [] - # Callbacks at the start of the run - for cb in self.callbacks: - cb.on_run_start(self) - - for task_idx, (task, agent_data) in enumerate(zip(self.tasks, agent_data_list)): - # Callbacks at the start of each task - for cb in self.callbacks: - cb.on_task_start(self, task) - - for repeat_idx in range(self.n_task_repeats): - for cb in self.callbacks: - cb.on_task_repeat_start(self, task, repeat_idx) - - # Initialize status and error tracking - execution_status = TaskExecutionStatus.SUCCESS - error_info: Optional[Dict[str, Any]] = None - final_answers: Any = None - eval_results: Any = None - execution_traces: Dict[str, Any] = {} - execution_configs: Dict[str, Any] = {} - - try: - # 1. Setup - environment = self.setup_environment(agent_data, task) - user = self.setup_user(agent_data, environment, task) - if user is None and self.max_invocations > 1: - # Warn if multi-turn is enabled but no user to drive interaction - warnings.warn( - f"max_invocations={self.max_invocations} > 1 but no user simulator provided. " - f"Falling back to single-turn execution for task {task.id}." - ) - agents_to_run, agents_dict = self.setup_agents(agent_data, environment, task, user) - evaluators = self.setup_evaluators(environment, task, agents_to_run, user) - - # Auto-register components returned from setup methods - # Environment - if environment is not None and isinstance(environment, TraceableMixin): - self.register("environment", "env", environment) - - # User - if user is not None and isinstance(user, TraceableMixin): - self.register("user", "user", user) - - # Agents (use their names from agents_dict) - for agent_name, agent in agents_dict.items(): - if isinstance(agent, TraceableMixin): - self.register("agents", agent_name, agent) - - except Exception as e: - # Setup failed - record error and optionally re-raise - execution_status = TaskExecutionStatus.SETUP_FAILED - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - # Create a minimal report for this failed setup - report = { - "task_id": task.id, - "repeat_idx": repeat_idx, - "status": execution_status.value, - "error": error_info, - "traces": {}, - "config": {}, - "eval": None, - } - self.reports.append(report) - - for cb in self.callbacks: - cb.on_task_repeat_end(self, report) - - # Clear registry before potentially re-raising - self.clear_registry() + # Auto-register queue as callback if it's a BenchmarkCallback (e.g., AdaptiveTaskQueue) + queue_as_callback: Optional[BenchmarkCallback] = None + if isinstance(queue, BenchmarkCallback) and queue not in self.callbacks: + queue_as_callback = queue + self.callbacks.append(queue_as_callback) - if self.fail_on_setup_error: - raise + try: + # Callbacks at the start of the run + self._invoke_callbacks("on_run_start", self) - # Continue to next task repetition - continue - - # 2. Execute agent system with optional user interaction loop - try: - final_answers = self.execution_loop(agents_to_run, task, environment, user) - except AgentError as e: - # Agent violated contract at boundary (agent's fault) - execution_status = TaskExecutionStatus.AGENT_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "component": e.component, - "details": e.details, - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - except EnvironmentError as e: - # Environment/tool infrastructure failed (not agent's fault) - execution_status = TaskExecutionStatus.ENVIRONMENT_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "component": e.component, - "details": e.details, - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - except UserError as e: - # User simulator failed (not agent's fault) - execution_status = TaskExecutionStatus.USER_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "component": e.component, - "details": e.details, - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - except Exception as e: - # Unclassified error (e.g., agent framework internal failure) - execution_status = TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } - - if self.fail_on_task_error: - # Clear registry before re-raising - self.clear_registry() - raise - - # Continue with trace collection even if task failed - final_answers = None - - # # Callbacks before evaluation - # for cb in self.callbacks: - # cb.on_before_evaluation(self, task, agent_output) - - # 3. Collect traces and configs (always attempt this) - try: - execution_configs = self.collect_all_configs() - execution_traces = self.collect_all_traces() - except Exception as e: - # If trace/config collection fails, record it but continue - execution_configs = { - "error": f"Failed to collect configs: {e}", - "error_type": type(e).__name__, - } - execution_traces = { - "error": f"Failed to collect traces: {e}", - "error_type": type(e).__name__, - } - - # 4. Evaluate (skip if task execution failed, unless we want partial evaluation) - if execution_status == TaskExecutionStatus.SUCCESS: - try: - eval_results = self.evaluate(evaluators, agents_dict, final_answers, execution_traces) - except Exception as e: - execution_status = TaskExecutionStatus.EVALUATION_FAILED - error_info = { - "error_type": type(e).__name__, - "error_message": str(e), - "traceback": "".join(__import__("traceback").format_exception(type(e), e, e.__traceback__)), - } + # Execute based on num_workers + if self.num_workers == 1: + self._run_sequential(queue, agent_data_lookup) + else: + self._run_parallel(queue, agent_data_lookup, self.num_workers) - if self.fail_on_evaluation_error: - # Clear registry before re-raising - self.clear_registry() - raise + # Callbacks at the end of the run + self._invoke_callbacks("on_run_end", self, self.reports) + finally: + # Remove queue from callbacks if we added it + if queue_as_callback is not None: + self.callbacks.remove(queue_as_callback) - # Set eval_results to None on failure - eval_results = None - else: - # Task execution failed, so skip evaluation - eval_results = None - - # 5. Store results with status and error info - report = { - "task_id": task.id, - "repeat_idx": repeat_idx, - "status": execution_status.value, - "traces": execution_traces, - "config": execution_configs, - "eval": eval_results, - } + return self.reports - # Add error info if present - if error_info is not None: - report["error"] = error_info + def _build_agent_data_lookup( + self, tasks: BaseTaskQueue, agent_data: Dict[str, Any] | Iterable[Dict[str, Any]] + ) -> Dict[str, Dict[str, Any]]: + """Build a mapping from task_id to agent_data configuration. - self.reports.append(report) + Args: + tasks: The task queue containing all tasks. + agent_data: Agent configuration(s) to map to tasks. - for cb in self.callbacks: - cb.on_task_repeat_end(self, report) + Returns: + Dict mapping task_id (string) to agent_data configuration. - # Clear registry after task repetition completes - self.clear_registry() + Raises: + ValueError: If agent_data is a list but doesn't match the number of tasks. + """ + if isinstance(agent_data, dict): + # Single config - replicate for all tasks + return {str(task.id): cast(Dict[str, Any], agent_data) for task in tasks} - # Callbacks at the end of each task - # Pass the last report for this task to the callback - task_reports = [r for r in self.reports if r["task_id"] == task.id] - last_report = task_reports[-1] if task_reports else {} - for cb in self.callbacks: - cb.on_task_end(self, task, last_report) + # List of configs - pair by position + agent_data_list = list(agent_data) + if len(agent_data_list) != len(tasks): + raise ValueError( + f"`agent_data` must either be a single dict or an iterable matching the number of tasks. " + f"Got {len(agent_data_list)} agent configs for {len(tasks)} tasks." + ) - # Callbacks at the end of the run - for cb in self.callbacks: - cb.on_run_end(self, self.reports) - return self.reports + return {str(task.id): agent_data_list[i] for i, task in enumerate(tasks)} def get_failed_tasks( self, status_filter: Optional[Union[TaskExecutionStatus, List[TaskExecutionStatus]]] = None, reports: Optional[List[Dict[str, Any]]] = None, - ) -> TaskCollection: + ) -> SequentialTaskQueue: """Get tasks that failed during benchmark execution. This method retrieves failed tasks based on their execution status, useful for @@ -1356,7 +1520,7 @@ def get_failed_tasks( run() call. This allows analyzing externally stored or modified reports. Returns: - TaskCollection containing the failed tasks. Empty if no failures match the filter. + SequentialTaskQueue containing the failed tasks. Empty if no failures match the filter. Raises: RuntimeError: If reports is None and run() has not been executed yet. @@ -1364,8 +1528,8 @@ def get_failed_tasks( How to use: ```python # Run benchmark - benchmark = MyBenchmark(agent_data=config) - reports = benchmark.run(tasks=tasks) + benchmark = MyBenchmark() + reports = benchmark.run(tasks=tasks, agent_data=config) # Get all failed tasks (from internal state) failed = benchmark.get_failed_tasks() @@ -1411,6 +1575,7 @@ def get_failed_tasks( TaskExecutionStatus.AGENT_ERROR.value, TaskExecutionStatus.ENVIRONMENT_ERROR.value, TaskExecutionStatus.USER_ERROR.value, + TaskExecutionStatus.TASK_TIMEOUT.value, TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value, TaskExecutionStatus.EVALUATION_FAILED.value, TaskExecutionStatus.SETUP_FAILED.value, @@ -1428,6 +1593,6 @@ def get_failed_tasks( if report["status"] in filter_values: failed_task_ids.add(report["task_id"]) - # Build TaskCollection from original tasks that failed - failed_tasks = [task for task in self.tasks if task.id in failed_task_ids] - return TaskCollection(failed_tasks) + # Build queue from original tasks that failed + failed_tasks = [task for task in self.tasks if str(task.id) in failed_task_ids] + return SequentialTaskQueue(failed_tasks) diff --git a/maseval/core/context.py b/maseval/core/context.py new file mode 100644 index 0000000..5265d40 --- /dev/null +++ b/maseval/core/context.py @@ -0,0 +1,122 @@ +"""Execution context for task timeout handling. + +This module provides the TaskContext class for cooperative timeout checking +during task execution. The context tracks elapsed time and enables checkpoints +where tasks can gracefully exit if the deadline has passed. +""" + +import time +from typing import Any, Dict, Optional + +from .exceptions import TaskTimeoutError + + +class TaskContext: + """Execution context for cooperative timeout checking. + + TaskContext provides a mechanism for tasks to voluntarily check for timeout + conditions at defined checkpoints. This enables clean interruption without + forcibly killing threads (which Python doesn't support well). + + The context is created with an optional deadline and tracks elapsed time. + Tasks should call `check_timeout()` at natural checkpoints (e.g., between + agent steps, after LLM calls) to allow graceful termination. + + Attributes: + deadline: Maximum execution time in seconds, or None for no timeout. + collected_traces: Traces collected before timeout (for partial results). + + Usage: + context = TaskContext(deadline=60.0) + + for step in range(max_steps): + context.check_timeout() # Raises TaskTimeoutError if expired + result = agent.run(query) + # ... process result + + # Access timing info + print(f"Elapsed: {context.elapsed}s") + print(f"Remaining: {context.remaining}s") + + Thread Safety: + TaskContext instances are not thread-safe. Each thread/task should + have its own context instance. + """ + + def __init__(self, deadline: Optional[float] = None): + """Initialize execution context. + + Args: + deadline: Maximum execution time in seconds. If None, no timeout + checking is performed and check_timeout() is a no-op. + """ + self._deadline = deadline + self._start_time = time.monotonic() + self.collected_traces: Dict[str, Any] = {} + + @property + def deadline(self) -> Optional[float]: + """Maximum execution time in seconds, or None if no timeout.""" + return self._deadline + + @property + def elapsed(self) -> float: + """Time elapsed since context creation in seconds.""" + return time.monotonic() - self._start_time + + @property + def remaining(self) -> Optional[float]: + """Time remaining before deadline in seconds, or None if no deadline. + + Returns 0 if deadline has passed. + """ + if self._deadline is None: + return None + return max(0.0, self._deadline - self.elapsed) + + @property + def is_expired(self) -> bool: + """Whether the deadline has passed. + + Returns False if no deadline is set. + """ + if self._deadline is None: + return False + return self.elapsed >= self._deadline + + def check_timeout(self) -> None: + """Check if deadline has passed and raise TaskTimeoutError if so. + + This method should be called at natural checkpoints during task + execution (e.g., between agent steps, after LLM calls). If the + deadline has passed, it raises TaskTimeoutError with timing info. + + If no deadline is set, this is a no-op. + + Raises: + TaskTimeoutError: If the deadline has passed. + + Usage: + context = TaskContext(deadline=30.0) + + for step in range(max_steps): + context.check_timeout() # May raise + result = agent.run(query) + """ + if self.is_expired: + raise TaskTimeoutError( + f"Task exceeded {self._deadline}s deadline after {self.elapsed:.2f}s", + component="timeout_check", + elapsed=self.elapsed, + timeout=self._deadline or 0.0, + partial_traces=self.collected_traces, + ) + + def set_collected_traces(self, traces: Dict[str, Any]) -> None: + """Store traces collected during execution for inclusion in timeout errors. + + Args: + traces: Traces to store. These will be included in TaskTimeoutError + if a timeout occurs. + """ + self.collected_traces = traces diff --git a/maseval/core/exceptions.py b/maseval/core/exceptions.py index 77fc8c0..d4f6ba1 100644 --- a/maseval/core/exceptions.py +++ b/maseval/core/exceptions.py @@ -216,6 +216,70 @@ class UserError(MASEvalError): pass +class TaskTimeoutError(MASEvalError): + """Task execution exceeded configured timeout. + + This is classified as TASK_TIMEOUT in benchmark results, separate from + other error types. Timeout is neither agent's fault nor infrastructure's + fault—it's a resource constraint. + + When to raise: + - Task execution time exceeds TaskProtocol.timeout_seconds + - Cooperative timeout check detects deadline has passed + - Hard timeout backstop triggers + + Attributes: + elapsed: Time elapsed before timeout was detected. + timeout: The configured timeout value in seconds. + partial_traces: Any traces collected before timeout occurred. + + Examples: + ```python + # Cooperative timeout at checkpoint + raise TaskTimeoutError( + "Task exceeded 60s deadline", + component="execution_loop", + elapsed=62.5, + timeout=60.0 + ) + + # Hard timeout with partial traces + raise TaskTimeoutError( + "Task exceeded 120s hard deadline", + component="timeout_backstop", + elapsed=125.0, + timeout=120.0, + partial_traces={"agents": {"main": {"messages": [...]}}} + ) + ``` + """ + + def __init__( + self, + message: str, + *, + component: Optional[str] = None, + details: Optional[Dict[str, Any]] = None, + elapsed: float = 0.0, + timeout: float = 0.0, + partial_traces: Optional[Dict[str, Any]] = None, + ): + """Initialize TaskTimeoutError. + + Args: + message: Human-readable error description. + component: Name of the component that raised the error. + details: Additional structured information about the error. + elapsed: Time elapsed before timeout was detected. + timeout: The configured timeout value in seconds. + partial_traces: Any traces collected before timeout occurred. + """ + super().__init__(message, component=component, details=details) + self.elapsed = elapsed + self.timeout = timeout + self.partial_traces = partial_traces or {} + + # ============================================================================= # Convenience functions for tool implementers # ============================================================================= diff --git a/maseval/core/registry.py b/maseval/core/registry.py new file mode 100644 index 0000000..ec4e38a --- /dev/null +++ b/maseval/core/registry.py @@ -0,0 +1,243 @@ +"""Thread-safe component registry for task execution. + +This module provides the ComponentRegistry class that tracks components +(agents, models, tools, etc.) during task execution. It uses thread-local +storage to enable parallel task execution without cross-contamination. +""" + +import threading +from typing import Any, Dict, Optional +from datetime import datetime + +from .tracing import TraceableMixin +from .config import ConfigurableMixin + + +class ComponentRegistry: + """Thread-safe registry for tracking components during task execution. + + Each thread gets its own isolated registry state, enabling parallel + task execution without cross-contamination. The registry tracks both + Traceable and Configurable components for comprehensive data collection. + + Usage: + registry = ComponentRegistry() + + # Register components (thread-local) + registry.register("agents", "orchestrator", agent_adapter) + registry.register("environment", "env", environment) + + # Collect data + traces = registry.collect_traces() + configs = registry.collect_configs() + + # Clear for next task + registry.clear() + """ + + def __init__(self, benchmark_config: Optional[Dict[str, Any]] = None): + """Initialize the registry. + + Args: + benchmark_config: Benchmark-level configuration to include in + collect_configs() output. This is shared (not thread-local). + """ + self._local = threading.local() + self._benchmark_config = benchmark_config or {} + + # --- Thread-local state properties --- + + @property + def _trace_registry(self) -> Dict[str, TraceableMixin]: + if not hasattr(self._local, "trace_registry"): + self._local.trace_registry = {} + return self._local.trace_registry + + @property + def _component_id_map(self) -> Dict[int, str]: + if not hasattr(self._local, "component_id_map"): + self._local.component_id_map = {} + return self._local.component_id_map + + @property + def _config_registry(self) -> Dict[str, ConfigurableMixin]: + if not hasattr(self._local, "config_registry"): + self._local.config_registry = {} + return self._local.config_registry + + @property + def _config_component_id_map(self) -> Dict[int, str]: + if not hasattr(self._local, "config_component_id_map"): + self._local.config_component_id_map = {} + return self._local.config_component_id_map + + # --- Public API --- + + def register(self, category: str, name: str, component: TraceableMixin) -> TraceableMixin: + """Register a component for trace and config collection. + + Args: + category: Component category (e.g., "agents", "models", "environment") + name: Unique identifier within the category + component: Component instance (must be TraceableMixin) + + Returns: + The component (for chaining) + + Raises: + ValueError: If component already registered under a different key + """ + component_id = id(component) + key = f"{category}:{name}" + + # Check for duplicate registration under different key + if component_id in self._component_id_map: + existing_key = self._component_id_map[component_id] + if existing_key != key: + raise ValueError( + f"Component is already registered as '{existing_key}' and cannot be " + f"re-registered as '{key}'. Note: Environments, users, and agents " + f"returned from setup methods are automatically registered." + ) + return component # Idempotent + + # Register for tracing + self._trace_registry[key] = component + self._component_id_map[component_id] = key + + # Also register for config if supported + if isinstance(component, ConfigurableMixin): + self._config_registry[key] = component + self._config_component_id_map[component_id] = key + + return component + + def clear(self) -> None: + """Clear all registrations for the current thread.""" + self._trace_registry.clear() + self._component_id_map.clear() + self._config_registry.clear() + self._config_component_id_map.clear() + + def collect_traces(self) -> Dict[str, Any]: + """Collect execution traces from all registered components.""" + traces: Dict[str, Any] = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "thread_id": threading.current_thread().ident, + "total_components": len(self._trace_registry), + }, + "agents": {}, + "models": {}, + "tools": {}, + "simulators": {}, + "callbacks": {}, + "environment": None, + "user": None, + "other": {}, + } + + for key, component in self._trace_registry.items(): + category, comp_name = key.split(":", 1) + + try: + component_traces = component.gather_traces() + + # Inject name from registry if component doesn't have it + if "name" not in component_traces: + component_traces["name"] = comp_name + + # Handle environment and user as direct values (not nested in dict) + if category == "environment": + traces["environment"] = component_traces + elif category == "user": + traces["user"] = component_traces + else: + # Ensure category exists in traces + if category not in traces: + traces[category] = {} + traces[category][comp_name] = component_traces + except Exception as e: + # Gracefully handle tracing errors + error_info = { + "error": f"Failed to gather traces: {e}", + "error_type": type(e).__name__, + "component_type": type(component).__name__, + } + + if category == "environment": + traces["environment"] = error_info + elif category == "user": + traces["user"] = error_info + else: + if category not in traces: + traces[category] = {} + traces[category][comp_name] = error_info + + return traces + + def collect_configs(self) -> Dict[str, Any]: + """Collect configuration from all registered components.""" + configs: Dict[str, Any] = { + "metadata": { + "timestamp": datetime.now().isoformat(), + "thread_id": threading.current_thread().ident, + "total_components": len(self._config_registry), + }, + "agents": {}, + "models": {}, + "tools": {}, + "simulators": {}, + "callbacks": {}, + "environment": None, + "user": None, + "other": {}, + "benchmark": self._benchmark_config, + } + + for key, component in self._config_registry.items(): + category, comp_name = key.split(":", 1) + + try: + component_config = component.gather_config() + + # Inject name from registry if component doesn't have it + if "name" not in component_config: + component_config["name"] = comp_name + + # Handle environment and user as direct values (not nested in dict) + if category == "environment": + configs["environment"] = component_config + elif category == "user": + configs["user"] = component_config + else: + # Ensure category exists in configs + if category not in configs: + configs[category] = {} + configs[category][comp_name] = component_config + except Exception as e: + # Gracefully handle config gathering errors + error_info = { + "error": f"Failed to gather config: {e}", + "error_type": type(e).__name__, + "component_type": type(component).__name__, + } + + if category == "environment": + configs["environment"] = error_info + elif category == "user": + configs["user"] = error_info + else: + if category not in configs: + configs[category] = {} + configs[category][comp_name] = error_info + + return configs + + def update_benchmark_config(self, benchmark_config: Dict[str, Any]) -> None: + """Update the benchmark-level configuration. + + Args: + benchmark_config: New benchmark configuration dict. + """ + self._benchmark_config = benchmark_config diff --git a/maseval/core/simulator.py b/maseval/core/simulator.py index 7f9669b..301022b 100644 --- a/maseval/core/simulator.py +++ b/maseval/core/simulator.py @@ -491,7 +491,7 @@ def _create_error( component="user_simulator", ) - def __call__( + def __call__( # ty: ignore[invalid-method-override] self, conversation_history: List[Dict[str, str]], generation_params: Optional[Dict[str, Any]] = None, diff --git a/maseval/core/task.py b/maseval/core/task.py index a4fc099..d0548af 100644 --- a/maseval/core/task.py +++ b/maseval/core/task.py @@ -1,10 +1,49 @@ +from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Any, Dict, overload +from typing import Any, Dict, Tuple, overload, TYPE_CHECKING from uuid import uuid4 from collections.abc import Sequence from typing import Iterable, List, Union, Iterator, Optional import json from pathlib import Path +from enum import Enum + +if TYPE_CHECKING: + from .benchmark import Benchmark + +# Import BenchmarkCallback at runtime to enable inheritance +from .callback import BenchmarkCallback + + +class TimeoutAction(Enum): + """Action to take when a task timeout occurs.""" + + SKIP = "skip" # Mark as timed out, continue to next task + RETRY = "retry" # Retry once with same timeout + EXTEND = "extend" # Double timeout and retry + + +@dataclass +class TaskProtocol: + """Configuration for how MASEval executes a task. + + This is a data container for execution parameters, separate from + task content (query, environment_data, etc.). It controls the + interface between the task and MASEval's execution engine. + + Attributes: + timeout_seconds: Maximum execution time for this task. None means no timeout. + timeout_action: Action to take when timeout occurs. + max_retries: Maximum retry attempts for transient failures (not timeouts). + priority: Execution priority (higher = sooner). Used by adaptive task queues. + tags: Arbitrary tags for filtering or grouping tasks. + """ + + timeout_seconds: Optional[float] = None + timeout_action: TimeoutAction = TimeoutAction.SKIP + max_retries: int = 0 + priority: int = 0 + tags: Dict[str, Any] = field(default_factory=dict) @dataclass @@ -18,6 +57,11 @@ class Task: environment_data: A dictionary of data needed to set up the environment for the task. evaluation_data: A dictionary of data needed to evaluate the agent's performance on the task. metadata: A dictionary for any additional metadata about the task. + protocol: Execution protocol controlling timeout, retries, priority, and other runtime + parameters. It provides fine-grained control over how MASEval runs the task. The + protocol serves purely as a communication channel between the task instance and + MASEval's execution engine; it does not impose any intrinsic semantics on the task + content itself. """ query: str @@ -26,69 +70,123 @@ class Task: user_data: Dict[str, Any] = field(default_factory=dict) evaluation_data: Dict[str, Any] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) + protocol: TaskProtocol = field(default_factory=TaskProtocol) + +# ============================================================================= +# Task Queue Classes +# ============================================================================= -class TaskCollection(Sequence): - """A lightweight, sequence-like container for `Task` objects. - Usage: - - Construct from an iterable of `Task` or dicts: `TaskCollection.from_list(data)` - - Load from an examples-style JSON: `TaskCollection.from_json_file("examples/data.json")` +class BaseTaskQueue(ABC, Sequence): + """Abstract base class for task scheduling strategies. - The collection is immutable from the Sequence API perspective (supports indexing and slicing), - but provides `append`/`extend` helpers for convenience when building programmatically. + BaseTaskQueue provides a sequence-like interface for task execution. + Concrete implementations can reorder tasks, skip tasks, or terminate + early based on execution outcomes. + + Subclasses must implement ``__iter__`` to define the iteration order. + For adaptive behavior based on task results, use ``AdaptiveTaskQueue`` + which integrates with the benchmark callback system. + + Attributes: + _tasks: Internal list of tasks. + + Example: + ```python + queue = SequentialTaskQueue(tasks) + + for task in queue: + report = execute_task(task) + # Iterator handles termination automatically + ``` """ - def __init__(self, tasks: Optional[Iterable[Task]] = None) -> None: - """Initialize the TaskCollection. + def __init__(self, tasks: Iterable[Task]) -> None: + """Initialize the task queue. Args: - tasks: An optional iterable of `Task` objects to initialize the collection. + tasks: An iterable of Task objects to schedule. """ - # TODO for any element in the iterable that is not a Task, convert it to a Task - self._tasks: List[Task] = list(tasks) if tasks is not None else [] + self._tasks: List[Task] = list(tasks) def __len__(self) -> int: + """Return the total number of tasks in the queue.""" return len(self._tasks) @overload - def __getitem__(self, idx: int) -> "Task": ... + def __getitem__(self, idx: int) -> Task: ... @overload - def __getitem__(self, idx: slice) -> "TaskCollection": ... + def __getitem__(self, idx: slice) -> "BaseTaskQueue": ... + + def __getitem__( # type: ignore[invalid-method-override] + self, idx: Union[int, slice] + ) -> Union[Task, "BaseTaskQueue"]: + """Get a task by index or a slice of tasks. + + Args: + idx: Integer index or slice object. - def __getitem__(self, idx: Union[int, slice]) -> Union["Task", "TaskCollection"]: # type: ignore[override] - # Return a Task for int index, or a new TaskCollection for slices (pythonic behaviour) + Returns: + A single Task for integer index, or a new queue instance for slices. + """ if isinstance(idx, slice): - return TaskCollection(self._tasks[idx]) + # Return a new instance of the same type with sliced tasks + return self.__class__(self._tasks[idx]) return self._tasks[idx] + @abstractmethod def __iter__(self) -> Iterator[Task]: - return iter(self._tasks) + """Yield tasks in the scheduled execution order. - def __repr__(self) -> str: # pragma: no cover - trivial - return f"TaskCollection({len(self._tasks)} tasks)" + Returns: + Iterator yielding Task objects. + """ + pass - # Convenience mutators def append(self, task: Task) -> None: + """Add a task to the end of the queue. + + Args: + task: The task to append. + """ self._tasks.append(task) def extend(self, tasks: Iterable[Task]) -> None: + """Add multiple tasks to the end of the queue. + + Args: + tasks: An iterable of tasks to append. + """ self._tasks.extend(tasks) def to_list(self) -> List[Task]: + """Return a copy of the internal task list. + + Returns: + List of all tasks in the queue. + """ return list(self._tasks) - # Factories @classmethod - def from_list(cls, data: Iterable[Union[Task, dict]]) -> "TaskCollection": + def from_list(cls, data: Iterable[Union[Task, dict]]) -> "BaseTaskQueue": + """Create a queue from an iterable of Tasks or dicts. + + Args: + data: An iterable of Task objects or dicts that can be converted to Tasks. + + Returns: + A new queue instance containing the tasks. + + Raises: + TypeError: If an item is neither a Task nor a dict. + """ tasks: List[Task] = [] for item in data: if isinstance(item, Task): tasks.append(item) elif isinstance(item, dict): - # Expect a dict that can be turned into a Task - # Accept both full Task kwargs or the lightweight example format if "query" in item: query = item["query"] tasks.append( @@ -100,8 +198,8 @@ def from_list(cls, data: Iterable[Union[Task, dict]]) -> "TaskCollection": ) ) else: - # Attempt to map common example keys - query = item.get("question") or item.get("prompt") or item.get("query") or "" + query_val = item.get("question") or item.get("prompt") or item.get("query") or "" + query = str(query_val) if query_val else "" environment_data = ( item.get("environment_data") or {"text_content": item.get("text")} if item.get("text") @@ -121,29 +219,260 @@ def from_list(cls, data: Iterable[Union[Task, dict]]) -> "TaskCollection": ) ) else: - raise TypeError("TaskCollection.from_list expects Task or dict entries") + raise TypeError(f"{cls.__name__}.from_list expects Task or dict entries") return cls(tasks) @classmethod - def from_json_file(cls, path: Union[str, Path], *, limit: Optional[int] = None) -> "TaskCollection": + def from_json_file(cls, path: Union[str, Path], *, limit: Optional[int] = None) -> "BaseTaskQueue": """Load tasks from a JSON file. - This helper understands the example file format used in `examples/data.json` where the - top-level object has a `data` list and optional `metadata`. + This helper understands the example file format used in ``examples/data.json`` + where the top-level object has a ``data`` list and optional ``metadata``. Args: path: Path to the JSON file. limit: Optional limit to the number of tasks to load. + + Returns: + A new queue instance containing the loaded tasks. """ p = Path(path) with p.open("r", encoding="utf-8") as fh: payload = json.load(fh) - # Support both the wrapped `{ "data": [...] }` format and a plain list items = payload.get("data") if isinstance(payload, dict) and "data" in payload else payload if limit is not None: items = items[:limit] - # Convert each item to a Task using the same heuristics as from_list return cls.from_list(items) + + +class SequentialTaskQueue(BaseTaskQueue): + """Execute tasks in their original order. + + This queue maintains the current sequential execution model, processing + tasks in the order they appear in the input iterable. It's the default + queue used when no explicit queue is provided. + + Example: + ```python + queue = SequentialTaskQueue(tasks) + for task in queue: + result = execute(task) + ``` + """ + + def __iter__(self) -> Iterator[Task]: + """Yield tasks in original order.""" + return iter(self._tasks) + + +class PriorityTaskQueue(BaseTaskQueue): + """Execute tasks ordered by priority. + + Tasks are sorted by ``task.protocol.priority`` at construction time. + Higher priority values are executed first by default. Tasks with equal + priority maintain their relative order from the original input (stable sort). + + This queue uses ``task.protocol.priority`` as the sole source of priority. + Pre-compute priority values and assign them to tasks before creating the queue. + + Args: + tasks: An iterable of Task objects to schedule. + reverse: If True (default), higher priority values execute first. + If False, lower priority values execute first. + + Example: + ```python + # Assign priorities based on your criteria + for task in tasks: + task.protocol.priority = compute_priority(task) + + # Create queue (higher priority first) + queue = PriorityTaskQueue(tasks) + + # Or lower priority first + queue = PriorityTaskQueue(tasks, reverse=False) + ``` + """ + + def __init__(self, tasks: Iterable[Task], reverse: bool = True) -> None: + """Initialize priority queue with sorted tasks. + + Args: + tasks: An iterable of Task objects to schedule. + reverse: If True (default), higher priority values execute first. + """ + task_list = list(tasks) + # Stable sort by priority + sorted_tasks = sorted(task_list, key=lambda t: t.protocol.priority, reverse=reverse) + super().__init__(sorted_tasks) + + def __iter__(self) -> Iterator[Task]: + """Yield tasks in priority order.""" + return iter(self._tasks) + + +class AdaptiveTaskQueue(BaseTaskQueue, BenchmarkCallback, ABC): + """Abstract base class for adaptive task scheduling. + + AdaptiveTaskQueue enables dynamic task ordering based on execution results. + It inherits from BenchmarkCallback to integrate with the benchmark's callback + system, creating a clean bidirectional communication model: + + - **Benchmark → Queue**: Via iterator protocol (``for task in queue``) + - **Queue → Benchmark**: Via callback (``on_task_repeat_end()``) + + The queue automatically moves completed tasks from ``_remaining`` to + ``_completed`` and calls ``_update_state()`` to let subclasses adapt their + scheduling strategy based on task results. + + Subclasses must implement: + - ``_select_next_task()``: Choose the next task to execute + - ``_update_state()``: Update internal model after task completion + + Internal state: + - ``_remaining``: Tasks not yet executed + - ``_completed``: Completed tasks paired with their reports + - ``_stop_flag``: Flag to signal early termination + + When used with ``Benchmark.run()``, the queue is automatically registered + as a callback and receives ``on_task_repeat_end()`` notifications. + + Example: + ```python + class IRTTaskQueue(AdaptiveTaskQueue): + '''Item Response Theory-based adaptive testing.''' + + def __init__(self, tasks: Iterable[Task]): + super().__init__(tasks) + self._ability_estimate = 0.0 + + def _select_next_task(self) -> Optional[Task]: + # Select task with difficulty closest to current ability estimate + # (No need to check if _remaining is empty - guaranteed by base class) + return min( + self._remaining, + key=lambda t: abs(t.protocol.priority - self._ability_estimate) + ) + + def _update_state(self, task: Task, report: Dict[str, Any]) -> None: + # Update ability estimate based on task result + correct = report.get("eval", [{}])[0].get("correct", False) + difficulty = task.protocol.priority + self._ability_estimate += 0.5 if correct else -0.5 + + queue = IRTTaskQueue(tasks) + results = benchmark.run(queue) # Auto-registered as callback + ``` + """ + + def __init__(self, tasks: Iterable[Task]) -> None: + """Initialize adaptive queue. + + Args: + tasks: An iterable of Task objects to schedule. + """ + super().__init__(tasks) + self._remaining: List[Task] = list(self._tasks) + self._completed: List[Tuple[Task, Dict[str, Any]]] = [] + self._stop_flag: bool = False + + def __iter__(self) -> Iterator[Task]: + """Yield tasks selected by the adaptive algorithm. + + Continues until ``_select_next_task()`` returns None, ``_remaining`` + is empty, or ``_stop_flag`` is set. + + Note: ``_select_next_task()`` is only called when ``_remaining`` is + non-empty, so implementers don't need to check for empty list. + """ + while not self._stop_flag and self._remaining: + next_task = self._select_next_task() + if next_task is not None: + yield next_task + else: + break + + def on_task_repeat_end(self, benchmark: "Benchmark", report: Dict[str, Any]) -> None: + """BenchmarkCallback hook called after each task repetition completes. + + This method extracts the task from the report, moves it from + ``_remaining`` to ``_completed``, and calls ``_update_state()`` + to let the subclass update its adaptive model. + + Args: + benchmark: The benchmark instance (unused in this implementation). + report: The execution report containing task_id and results. + """ + # Extract task from report + task_id_str = report.get("task_id") + if task_id_str is None: + return + + # Find the task in remaining list + task = None + for i, t in enumerate(self._remaining): + if str(t.id) == task_id_str: + task = self._remaining.pop(i) + self._completed.append((task, report)) + break + + # If not found in remaining, check completed (for n_task_repeats > 1) + if task is None: + for t, _ in self._completed: + if str(t.id) == task_id_str: + task = t + break + + # Update subclass state + if task is not None: + self._update_state(task, report) + + def stop(self) -> None: + """Signal that no more tasks should be processed. + + Call this from ``_update_state()`` or ``_select_next_task()`` to + trigger early termination (e.g., when confidence threshold is reached). + + The ``_stop_flag`` is checked in ``__iter__``, which will stop yielding + tasks and naturally terminate the benchmark's iteration loop via Python's + iterator protocol. + """ + self._stop_flag = True + + @abstractmethod + def _select_next_task(self) -> Optional[Task]: + """Select the next task to execute. + + Implement this method to define your adaptive selection algorithm + (e.g., IRT-based selection, uncertainty sampling, bandit algorithms). + + **Guaranteed precondition**: This method is only called when + ``self._remaining`` is non-empty, so you don't need to check for + an empty list. You can safely assume at least one task is available. + + Returns: + The next Task to execute from ``self._remaining``, or None to + signal early termination (e.g., if no suitable task meets your + selection criteria). + """ + pass + + @abstractmethod + def _update_state(self, task: Task, report: Dict[str, Any]) -> None: + """Update internal state after task completion. + + Implement this method to update ability estimates, difficulty models, + or other state used by ``_select_next_task()``. + + Args: + task: The task that just completed. + report: The execution report containing status and eval results. + """ + pass + + +# Alias for the default queue type +TaskQueue = SequentialTaskQueue diff --git a/mkdocs.yml b/mkdocs.yml index a5b09b2..8617174 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,11 @@ extra_css: markdown_extensions: - admonition - pymdownx.details - - pymdownx.superfences + - pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: !!python/name:pymdownx.superfences.fence_code_format - attr_list - pymdownx.emoji: emoji_index: !!python/name:material.extensions.emoji.twemoji diff --git a/tests/conftest.py b/tests/conftest.py index 04d1a5e..3de2b14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,7 +8,7 @@ Environment, User, Task, - TaskCollection, + TaskQueue, Evaluator, MessageHistory, ) @@ -369,9 +369,9 @@ def dummy_task(): @pytest.fixture -def dummy_task_collection(): +def dummy_task_queue(): """Create a collection of dummy tasks.""" - return TaskCollection.from_list( + return TaskQueue.from_list( [ {"query": "Query 1", "environment_data": {"task": 1}}, {"query": "Query 2", "environment_data": {"task": 2}}, @@ -381,14 +381,15 @@ def dummy_task_collection(): @pytest.fixture -def simple_benchmark(dummy_task_collection): - """Create a simple benchmark instance with tasks. +def simple_benchmark(dummy_task_queue): + """Create a simple benchmark instance with tasks and agent_data. Returns: - tuple: (benchmark, tasks) - Call as benchmark.run(tasks) + tuple: (benchmark, tasks, agent_data) - Call as benchmark.run(tasks, agent_data=agent_data) """ - benchmark = DummyBenchmark(agent_data={"model": "test"}) - return benchmark, dummy_task_collection + benchmark = DummyBenchmark() + agent_data = {"model": "test"} + return benchmark, dummy_task_queue, agent_data @pytest.fixture diff --git a/tests/test_benchmarks/test_macs/conftest.py b/tests/test_benchmarks/test_macs/conftest.py index dca893e..22c4a16 100644 --- a/tests/test_benchmarks/test_macs/conftest.py +++ b/tests/test_benchmarks/test_macs/conftest.py @@ -23,7 +23,7 @@ from unittest.mock import MagicMock from conftest import DummyModelAdapter -from maseval import AgentAdapter, Task, User, MessageHistory, TaskCollection +from maseval import AgentAdapter, Task, User, MessageHistory, TaskQueue from maseval.benchmark.macs import MACSBenchmark, MACSEnvironment @@ -97,14 +97,12 @@ class ConcreteMACSBenchmark(MACSBenchmark): def __init__( self, - agent_data: Dict[str, Any], model_factory: Optional[Any] = None, **kwargs: Any, ): """Initialize with optional model factory. Args: - agent_data: Agent configuration model_factory: Either a callable that takes a model name and returns a ModelAdapter, or a single ModelAdapter instance (for convenience in simple tests). If not provided, creates DummyModelAdapter instances. @@ -121,7 +119,7 @@ def __init__( else: # Single model instance - create a factory that always returns it self._model_factory = lambda model_name: model_factory - super().__init__(agent_data, **kwargs) + super().__init__(**kwargs) def get_model_adapter(self, model_id: str, **kwargs): """Create a model adapter for the given component. @@ -146,7 +144,7 @@ def get_model_adapter(self, model_id: str, **kwargs): return adapter - def setup_agents( # type: ignore[override] + def setup_agents( # type: ignore[invalid-method-override] self, agent_data: Dict[str, Any], environment: MACSEnvironment, @@ -415,9 +413,9 @@ def travel_task(): @pytest.fixture -def macs_task_collection(sample_task, travel_task): +def macs_task_queue(sample_task, travel_task): """Collection of MACS tasks for benchmark.run() tests.""" - return TaskCollection.from_list([sample_task, travel_task]) + return TaskQueue.from_list([sample_task, travel_task]) # ============================================================================= @@ -589,5 +587,6 @@ def macs_benchmark(sample_agent_data, dummy_model): """Create a MACS benchmark with dummy model for testing. Uses dummy_model from parent conftest.py. + Returns tuple (benchmark, agent_data) for use with run(). """ - return ConcreteMACSBenchmark(sample_agent_data, dummy_model) + return ConcreteMACSBenchmark(dummy_model), sample_agent_data diff --git a/tests/test_benchmarks/test_macs/test_data_loader.py b/tests/test_benchmarks/test_macs/test_data_loader.py index fdbfbbe..96c0f1b 100644 --- a/tests/test_benchmarks/test_macs/test_data_loader.py +++ b/tests/test_benchmarks/test_macs/test_data_loader.py @@ -12,7 +12,8 @@ from maseval.benchmark.macs.data_loader import ( DEFAULT_DATA_DIR, VALID_DOMAINS, - URLS, + DATA_URLS, + EVALUATION_URLS, download_file, download_json, download_original_data, @@ -175,7 +176,6 @@ def test_empty_input(self): """Empty or invalid input returns empty list.""" assert _create_tools_list({}) == [] assert _create_tools_list([]) == [] - assert _create_tools_list(None) == [] @pytest.mark.benchmark @@ -569,16 +569,13 @@ def mock_download_file(url: str, timeout=15): assert len(config["agents"]) == 2 def test_urls_structure(self): - """Verify URLS constant has expected structure.""" - assert "data" in URLS - assert "evaluation" in URLS - + """Verify URL constants have expected structure.""" for domain in VALID_DOMAINS: - assert domain in URLS["data"] - assert "agents" in URLS["data"][domain] - assert "scenarios" in URLS["data"][domain] + assert domain in DATA_URLS + assert "agents" in DATA_URLS[domain] + assert "scenarios" in DATA_URLS[domain] - assert "prompt_templates" in URLS["evaluation"] + assert "prompt_templates" in EVALUATION_URLS # ============================================================================= diff --git a/tests/test_benchmarks/test_macs/test_macs_benchmark.py b/tests/test_benchmarks/test_macs/test_macs_benchmark.py index a3b7ef6..b4c2aa9 100644 --- a/tests/test_benchmarks/test_macs/test_macs_benchmark.py +++ b/tests/test_benchmarks/test_macs/test_macs_benchmark.py @@ -27,11 +27,10 @@ class TestMACSBenchmarkSetup: """Tests for MACSBenchmark initialization and setup methods.""" def test_init_configures_benchmark(self, macs_model, sample_agent_data): - """Benchmark initializes with agent_data and optional params.""" + """Benchmark initializes with optional params.""" callbacks = [MagicMock()] - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model, callbacks=callbacks, n_task_repeats=3) + benchmark = ConcreteMACSBenchmark(macs_model, callbacks=callbacks, n_task_repeats=3) - assert benchmark.agent_data == sample_agent_data assert benchmark.callbacks == callbacks assert benchmark.n_task_repeats == 3 @@ -41,13 +40,13 @@ def test_macs_default_max_invocations_is_five(self, macs_model, sample_agent_dat This is a MACS-specific default that differs from the base class default of 1. The MACS paper specifies up to 5 agent-user interaction rounds. """ - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) assert benchmark.max_invocations == 5 def test_setup_environment_creates_macs_environment(self, macs_model, sample_agent_data, sample_task): """setup_environment returns MACSEnvironment with tools.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) @@ -56,7 +55,7 @@ def test_setup_environment_creates_macs_environment(self, macs_model, sample_age def test_setup_user_creates_macs_user(self, macs_model, sample_agent_data, sample_task): """setup_user returns MACSUser with scenario from task.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) user = benchmark.setup_user(sample_agent_data, env, sample_task) @@ -66,7 +65,7 @@ def test_setup_user_creates_macs_user(self, macs_model, sample_agent_data, sampl def test_setup_user_handles_no_scenario(self, macs_model, sample_agent_data, sample_task_no_scenario): """Handles missing scenario gracefully.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task_no_scenario) user = benchmark.setup_user(sample_agent_data, env, sample_task_no_scenario) @@ -75,7 +74,7 @@ def test_setup_user_handles_no_scenario(self, macs_model, sample_agent_data, sam def test_setup_evaluators_creates_user_and_system(self, macs_model, sample_agent_data, sample_task): """Creates both user and system evaluators.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents = [MACSAgentAdapter()] @@ -114,7 +113,7 @@ class TestRunAgents: def test_run_agents_executes_agents_with_query(self, macs_model, sample_agent_data, sample_task): """Agents are executed with the query parameter.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -134,7 +133,7 @@ def test_run_agents_uses_query_parameter_not_task_query(self, macs_model, sample This is critical for multi-turn interaction where the query changes between invocations (e.g., user's response becomes the next query). """ - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -150,7 +149,7 @@ def test_run_agents_uses_query_parameter_not_task_query(self, macs_model, sample def test_run_agents_returns_answer(self, macs_model, sample_agent_data, sample_task): """Returns final answer(s) as MessageHistory.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -164,7 +163,7 @@ def test_run_agents_returns_answer(self, macs_model, sample_agent_data, sample_t def test_run_agents_single_agent(self, macs_model, sample_agent_data, sample_task): """Single agent returns MessageHistory.""" - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -176,14 +175,14 @@ def test_run_agents_multiple_agents(self, macs_model, sample_agent_data, sample_ """Multiple agents return list of answers.""" class MultiAgentBenchmark(MACSBenchmark): - def __init__(self, agent_data, model_factory, **kwargs): + def __init__(self, model_factory, **kwargs): self._model_factory = model_factory if callable(model_factory) else lambda _: model_factory - super().__init__(agent_data, **kwargs) + super().__init__(**kwargs) def get_model_adapter(self, model_id: str, **kwargs): return self._model_factory(model_id) - def setup_agents( # type: ignore[override] + def setup_agents( # type: ignore[invalid-method-override] self, agent_data: Dict[str, Any], environment: MACSEnvironment, @@ -194,7 +193,7 @@ def setup_agents( # type: ignore[override] agent2: AgentAdapter = MACSAgentAdapter("agent2") return [agent1, agent2], {"agent1": agent1, "agent2": agent2} - benchmark = MultiAgentBenchmark(sample_agent_data, macs_model) + benchmark = MultiAgentBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, sample_task) agents_list, _ = benchmark.setup_agents(sample_agent_data, env, sample_task, None) @@ -221,7 +220,7 @@ def test_evaluate_calls_both_evaluators(self, sample_agent_data, sample_task): '[{"assertion": "System assertion", "answer": "TRUE", "evidence": "OK"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -251,7 +250,7 @@ def test_evaluate_returns_aggregated_metrics(self, sample_agent_data, sample_tas '[{"assertion": "B", "answer": "TRUE", "evidence": "OK"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -281,7 +280,7 @@ def test_evaluate_overall_gsr(self, sample_agent_data, sample_task): '[{"assertion": "B", "answer": "FALSE", "evidence": "Fail"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -305,7 +304,7 @@ def test_evaluate_supervisor_gsr(self, sample_agent_data, sample_task): '[{"assertion": "B", "answer": "FALSE", "evidence": "Fail"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -327,7 +326,7 @@ def test_evaluate_combined_report(self, sample_agent_data, sample_task): '[{"assertion": "System B", "answer": "TRUE", "evidence": "OK"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) _, agents_dict = benchmark.setup_agents(sample_agent_data, env, sample_task, None) evaluators = benchmark.setup_evaluators(env, sample_task, list(agents_dict.values()), None) @@ -365,18 +364,12 @@ def test_empty_results(self): assert result["successful_tasks"] == 0 assert result["success_rate"] == 0.0 assert result["mean_metrics"] == {} - assert result["excluded"] == { - "environment_error": 0, - "user_error": 0, - "unknown_execution_error": 0, - "evaluation_failed": 0, - "setup_failed": 0, - } + assert result["excluded"] == {} assert result["status_counts"] == {} def test_single_successful_result(self): """Single successful result counted.""" - results = [{"status": "completed", "eval": [{"overall_gsr": 1.0, "user_gsr": 1.0, "system_gsr": 1.0}]}] + results = [{"status": "success", "eval": [{"overall_gsr": 1.0, "user_gsr": 1.0, "system_gsr": 1.0}]}] metrics = compute_benchmark_metrics(results) @@ -387,7 +380,7 @@ def test_single_successful_result(self): def test_single_failed_result(self): """Single failed result counted.""" - results = [{"status": "completed", "eval": [{"overall_gsr": 0.0, "user_gsr": 0.0, "system_gsr": 0.0}]}] + results = [{"status": "success", "eval": [{"overall_gsr": 0.0, "user_gsr": 0.0, "system_gsr": 0.0}]}] metrics = compute_benchmark_metrics(results) @@ -399,9 +392,9 @@ def test_single_failed_result(self): def test_multiple_results(self): """Multiple results aggregated correctly.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, # Success - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, # Fail - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, # Success + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, # Success + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, # Fail + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, # Success ] metrics = compute_benchmark_metrics(results) @@ -414,10 +407,10 @@ def test_multiple_results(self): def test_success_rate_calculation(self): """success_rate = successful/scored (not total).""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -427,8 +420,8 @@ def test_success_rate_calculation(self): def test_mean_metrics_calculation(self): """Mean of numeric metrics computed.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0, "partial_gsr": 0.8}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0, "partial_gsr": 0.4}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0, "partial_gsr": 0.8}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0, "partial_gsr": 0.4}]}, ] metrics = compute_benchmark_metrics(results) @@ -439,9 +432,9 @@ def test_mean_metrics_calculation(self): def test_handles_missing_eval(self): """Handles results with no eval key.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "no_eval_key": True}, # Missing eval - {"status": "completed", "eval": None}, # None eval + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "no_eval_key": True}, # Missing eval + {"status": "success", "eval": None}, # None eval ] metrics = compute_benchmark_metrics(results) @@ -454,7 +447,7 @@ def test_handles_non_numeric_values(self): """Non-numeric values in eval are ignored for mean.""" results = [ { - "status": "completed", + "status": "success", "eval": [ { "overall_gsr": 1.0, @@ -475,15 +468,15 @@ def test_handles_non_numeric_values(self): def test_excludes_environment_errors_from_scoring(self): """Environment errors are excluded from scoring.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, {"status": "environment_error", "eval": None}, # Should be excluded - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, ] metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 3 - assert metrics["scored_tasks"] == 2 # Only completed tasks + assert metrics["scored_tasks"] == 2 # Only success tasks assert metrics["successful_tasks"] == 1 assert metrics["success_rate"] == 0.5 # 1/2, not 1/3 assert metrics["excluded"]["environment_error"] == 1 @@ -491,7 +484,7 @@ def test_excludes_environment_errors_from_scoring(self): def test_excludes_user_errors_from_scoring(self): """User simulator errors are excluded from scoring.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, {"status": "user_error", "eval": None}, ] @@ -500,14 +493,14 @@ def test_excludes_user_errors_from_scoring(self): assert metrics["total_tasks"] == 2 assert metrics["scored_tasks"] == 1 assert metrics["successful_tasks"] == 1 - assert metrics["success_rate"] == 1.0 # Only the completed one + assert metrics["success_rate"] == 1.0 # Only the success one assert metrics["excluded"]["user_error"] == 1 def test_excludes_unknown_errors_from_scoring(self): """Unknown execution errors are excluded from scoring.""" results = [ {"status": "unknown_execution_error", "eval": None}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -521,7 +514,7 @@ def test_excludes_setup_failed_from_scoring(self): """Setup failures are excluded from scoring.""" results = [ {"status": "setup_failed", "eval": None}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -534,21 +527,21 @@ def test_excludes_evaluation_failed_from_scoring(self): """Evaluation failures are excluded from scoring.""" results = [ {"status": "evaluation_failed", "eval": None}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, ] metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 2 assert metrics["scored_tasks"] == 1 - assert metrics["success_rate"] == 1.0 # Only the completed one + assert metrics["success_rate"] == 1.0 # Only the success one assert metrics["excluded"]["evaluation_failed"] == 1 def test_includes_agent_errors_in_scoring(self): """Agent errors ARE included in scoring (agent's fault).""" results = [ {"status": "agent_error", "eval": [{"overall_gsr": 0.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, ] metrics = compute_benchmark_metrics(results) @@ -561,23 +554,23 @@ def test_includes_agent_errors_in_scoring(self): def test_status_counts_tracked(self): """Status counts are tracked for all tasks.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0}]}, {"status": "agent_error", "eval": None}, {"status": "environment_error", "eval": None}, ] metrics = compute_benchmark_metrics(results) - assert metrics["status_counts"]["completed"] == 2 + assert metrics["status_counts"]["success"] == 2 assert metrics["status_counts"]["agent_error"] == 1 assert metrics["status_counts"]["environment_error"] == 1 def test_mixed_errors_comprehensive(self): """Comprehensive test with various error types.""" results = [ - {"status": "completed", "eval": [{"overall_gsr": 1.0, "accuracy": 0.9}]}, - {"status": "completed", "eval": [{"overall_gsr": 0.0, "accuracy": 0.3}]}, + {"status": "success", "eval": [{"overall_gsr": 1.0, "accuracy": 0.9}]}, + {"status": "success", "eval": [{"overall_gsr": 0.0, "accuracy": 0.3}]}, {"status": "agent_error", "eval": [{"overall_gsr": 0.0, "accuracy": 0.0}]}, {"status": "environment_error", "eval": None}, # Excluded {"status": "user_error", "eval": None}, # Excluded @@ -588,7 +581,7 @@ def test_mixed_errors_comprehensive(self): metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 7 - assert metrics["scored_tasks"] == 3 # completed(2) + agent_error(1) + assert metrics["scored_tasks"] == 3 # success(2) + agent_error(1) assert metrics["successful_tasks"] == 1 assert metrics["success_rate"] == pytest.approx(1 / 3) assert metrics["mean_metrics"]["accuracy"] == pytest.approx((0.9 + 0.3 + 0.0) / 3) @@ -612,7 +605,7 @@ def test_full_task_execution(self, sample_agent_data, sample_task): '[{"assertion": "Database updated", "answer": "TRUE", "evidence": "Updated"}]', ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) # Setup phase env = benchmark.setup_environment(sample_agent_data, sample_task) @@ -646,7 +639,7 @@ def test_full_task_execution(self, sample_agent_data, sample_task): def test_benchmark_with_real_environment(self, sample_agent_data, sample_task): """Test with real MACSEnvironment tool creation.""" model = DummyModelAdapter(responses=['{"text": "Default response", "details": {}}']) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) env = benchmark.setup_environment(sample_agent_data, sample_task) diff --git a/tests/test_benchmarks/test_macs/test_macs_integration.py b/tests/test_benchmarks/test_macs/test_macs_integration.py index 24a69fd..f8db989 100644 --- a/tests/test_benchmarks/test_macs/test_macs_integration.py +++ b/tests/test_benchmarks/test_macs/test_macs_integration.py @@ -44,7 +44,7 @@ def test_complete_task_lifecycle(self, sample_agent_data, travel_task): json.dumps([{"assertion": "Tool called", "answer": "TRUE", "evidence": "OK"}]), ] model = DummyModelAdapter(responses=responses) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) + benchmark = ConcreteMACSBenchmark(model) # Setup phase env = benchmark.setup_environment(sample_agent_data, travel_task) @@ -112,7 +112,7 @@ def test_loaded_task_works_with_environment(self, macs_model, sample_agent_data) metadata={"scenario": "Travel booking scenario"}, ) - benchmark = ConcreteMACSBenchmark(sample_agent_data, macs_model) + benchmark = ConcreteMACSBenchmark(macs_model) env = benchmark.setup_environment(sample_agent_data, task) assert "search" in env.tools @@ -183,7 +183,7 @@ def test_environment_handles_empty_tool_specs(self, macs_model_factory): @pytest.mark.benchmark class TestEndToEndPipeline: - """End-to-end tests that call benchmark.run() with TaskCollection. + """End-to-end tests that call benchmark.run() with TaskQueue. These tests verify the complete MACS benchmark pipeline by actually calling benchmark.run(). More granular integration tests are in @@ -200,8 +200,8 @@ def test_run_single_task_complete_pipeline(self, sample_agent_data, travel_task) ] ) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) - reports = benchmark.run([travel_task]) + benchmark = ConcreteMACSBenchmark(model) + reports = benchmark.run([travel_task], agent_data=sample_agent_data) # Verify complete report structure assert len(reports) == 1 @@ -213,8 +213,8 @@ def test_run_single_task_complete_pipeline(self, sample_agent_data, travel_task) assert "config" in report assert "eval" in report - def test_run_multiple_tasks(self, sample_agent_data, macs_task_collection): - """Run benchmark with multiple tasks via TaskCollection.""" + def test_run_multiple_tasks(self, sample_agent_data, macs_task_queue): + """Run benchmark with multiple tasks via TaskQueue.""" model = DummyModelAdapter( responses=[ '{"text": "User response", "details": {}}', @@ -223,10 +223,10 @@ def test_run_multiple_tasks(self, sample_agent_data, macs_task_collection): ] ) - benchmark = ConcreteMACSBenchmark(sample_agent_data, model) - reports = benchmark.run(macs_task_collection) + benchmark = ConcreteMACSBenchmark(model) + reports = benchmark.run(macs_task_queue, agent_data=sample_agent_data) - assert len(reports) == len(macs_task_collection) + assert len(reports) == len(macs_task_queue) for report in reports: assert report["status"] == "success" assert "eval" in report @@ -242,8 +242,8 @@ def test_run_with_task_repeats(self, sample_agent_data, sample_task): ) n_repeats = 3 - benchmark = ConcreteMACSBenchmark(sample_agent_data, model, n_task_repeats=n_repeats) - reports = benchmark.run([sample_task]) + benchmark = ConcreteMACSBenchmark(model, n_task_repeats=n_repeats) + reports = benchmark.run([sample_task], agent_data=sample_agent_data) assert len(reports) == n_repeats for i, report in enumerate(reports): @@ -285,8 +285,8 @@ def on_run_end(self, benchmark, results): ) callback = TrackingCallback() - benchmark = ConcreteMACSBenchmark(sample_agent_data, model, callbacks=[callback]) - benchmark.run([sample_task]) + benchmark = ConcreteMACSBenchmark(model, callbacks=[callback]) + benchmark.run([sample_task], agent_data=sample_agent_data) # Verify callback sequence expected_order = ["run_start", "task_start", "repeat_start_0", "repeat_end_0", "task_end", "run_end"] diff --git a/tests/test_benchmarks/test_tau2/test_default_agent.py b/tests/test_benchmarks/test_tau2/test_default_agent.py index a8f8921..1d8551a 100644 --- a/tests/test_benchmarks/test_tau2/test_default_agent.py +++ b/tests/test_benchmarks/test_tau2/test_default_agent.py @@ -501,16 +501,14 @@ class TestDefaultAgentTau2BenchmarkInit: def test_init_basic(self): """Test basic initialization.""" - benchmark = DummyDefaultAgentBenchmark( - agent_data={"model_id": "gpt-4o"}, - ) + benchmark = DummyDefaultAgentBenchmark() - assert benchmark._model_cache == {} + # Benchmark should be initialized successfully + assert benchmark.max_invocations == 50 # Tau2 default def test_init_with_all_options(self): """Test initialization with all options.""" benchmark = DummyDefaultAgentBenchmark( - agent_data={"model_id": "gpt-4o", "llm_args": {"temperature": 0.5}}, n_task_repeats=3, max_invocations=5, ) @@ -518,6 +516,13 @@ def test_init_with_all_options(self): assert benchmark.n_task_repeats == 3 assert benchmark.max_invocations == 5 + def test_default_max_invocations(self): + """Test that default max_invocations is 50 from class attribute.""" + benchmark = DummyDefaultAgentBenchmark() + + assert benchmark.max_invocations == 50 + assert benchmark.MAX_INVOCATIONS == 50 + @pytest.mark.benchmark class TestDefaultAgentTau2BenchmarkSetupAgents: @@ -525,7 +530,7 @@ class TestDefaultAgentTau2BenchmarkSetupAgents: def test_setup_agents_basic(self, sample_task): """Test basic agent setup.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() with patch.object(Tau2Environment, "__init__", return_value=None): mock_env = MagicMock(spec=Tau2Environment) @@ -545,7 +550,7 @@ def test_setup_agents_basic(self, sample_task): def test_setup_agents_missing_model_id(self, sample_task): """Test that missing model_id raises ValueError.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={}) + benchmark = DummyDefaultAgentBenchmark() mock_env = MagicMock(spec=Tau2Environment) mock_env.create_tools.return_value = {} @@ -556,7 +561,7 @@ def test_setup_agents_missing_model_id(self, sample_task): def test_setup_agents_with_llm_args(self, sample_task): """Test agent setup with custom llm_args.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o", "llm_args": {"temperature": 0.5}}) + benchmark = DummyDefaultAgentBenchmark() mock_env = MagicMock(spec=Tau2Environment) mock_env.create_tools.return_value = {} @@ -574,7 +579,7 @@ def test_setup_agents_with_llm_args(self, sample_task): def test_setup_agents_with_max_tool_calls(self, sample_task): """Test agent setup with custom max_tool_calls.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o", "max_tool_calls": 10}) + benchmark = DummyDefaultAgentBenchmark() mock_env = MagicMock(spec=Tau2Environment) mock_env.create_tools.return_value = {} @@ -600,7 +605,7 @@ def test_get_model_adapter_is_abstract(self): # DefaultAgentTau2Benchmark itself is still abstract # because get_model_adapter is abstract with pytest.raises(TypeError, match="abstract"): - DefaultAgentTau2Benchmark(agent_data={"model_id": "gpt-4o"}) + DefaultAgentTau2Benchmark() # ============================================================================= @@ -723,7 +728,7 @@ class TestTau2BenchmarkMethods: def test_setup_environment(self, sample_task): """Test setup_environment creates Tau2Environment.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() with patch("maseval.benchmark.tau2.tau2.Tau2Environment") as mock_env_cls: mock_env_cls.return_value = MagicMock() @@ -736,7 +741,7 @@ def test_setup_environment(self, sample_task): def test_setup_user_with_dict_instructions(self): """Test setup_user with dict instructions.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() task = Task( query="Hello", @@ -767,7 +772,7 @@ def test_setup_user_with_dict_instructions(self): def test_setup_user_with_string_instructions(self): """Test setup_user with string instructions.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() task = Task( query="Hello", @@ -790,7 +795,7 @@ def test_setup_user_with_string_instructions(self): def test_setup_user_empty_instructions(self): """Test setup_user with no instructions.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() task = Task( query="Hello", @@ -810,7 +815,7 @@ def test_setup_user_empty_instructions(self): def test_setup_evaluators(self, sample_task): """Test setup_evaluators creates Tau2Evaluator.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() with patch("maseval.benchmark.tau2.tau2.Tau2Environment") as mock_env_cls: mock_env = MagicMock(spec=Tau2Environment) @@ -823,7 +828,7 @@ def test_setup_evaluators(self, sample_task): def test_run_agents_single_agent(self, sample_task): """Test run_agents with a single agent.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() mock_agent = MagicMock(spec=AgentAdapter) mock_agent.run.return_value = "Response" @@ -837,7 +842,7 @@ def test_run_agents_single_agent(self, sample_task): def test_run_agents_multiple_agents(self, sample_task): """Test run_agents with multiple agents.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() mock_agent1 = MagicMock(spec=AgentAdapter) mock_agent1.run.return_value = "Response 1" @@ -852,7 +857,7 @@ def test_run_agents_multiple_agents(self, sample_task): def test_evaluate(self, sample_task): """Test evaluate method.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() mock_evaluator = MagicMock() mock_evaluator.filter_traces.return_value = {"filtered": True} @@ -955,7 +960,7 @@ class TestDummyBenchmarkModelAdapter: def test_get_model_adapter_returns_mock(self): """Test that DummyDefaultAgentBenchmark returns working mock adapter.""" - benchmark = DummyDefaultAgentBenchmark(agent_data={"model_id": "gpt-4o"}) + benchmark = DummyDefaultAgentBenchmark() adapter = benchmark.get_model_adapter("gpt-4o") diff --git a/tests/test_benchmarks/test_tau2/test_evaluator.py b/tests/test_benchmarks/test_tau2/test_evaluator.py index a4f25bd..c7383bb 100644 --- a/tests/test_benchmarks/test_tau2/test_evaluator.py +++ b/tests/test_benchmarks/test_tau2/test_evaluator.py @@ -276,7 +276,7 @@ def test_single_success(self): """Single successful result counted.""" from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics - results = [{"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}] + results = [{"status": "success", "eval": [{"reward": 1.0, "passed": True}]}] metrics = compute_benchmark_metrics(results) @@ -287,10 +287,10 @@ def test_single_success(self): assert metrics["mean_reward"] == 1.0 def test_single_failure(self): - """Single failed result counted.""" + """Single failed result counted (agent_error is scoreable).""" from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics - results = [{"status": "completed", "eval": [{"reward": 0.0, "passed": False}]}] + results = [{"status": "agent_error", "eval": [{"reward": 0.0, "passed": False}]}] metrics = compute_benchmark_metrics(results) @@ -304,9 +304,9 @@ def test_mixed_results(self): from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics results = [ - {"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}, - {"status": "completed", "eval": [{"reward": 0.0, "passed": False}]}, - {"status": "completed", "eval": [{"reward": 0.5, "passed": False}]}, + {"status": "success", "eval": [{"reward": 1.0, "passed": True}]}, + {"status": "agent_error", "eval": [{"reward": 0.0, "passed": False}]}, + {"status": "task_timeout", "eval": [{"reward": 0.5, "passed": False}]}, ] metrics = compute_benchmark_metrics(results) @@ -322,7 +322,7 @@ def test_excludes_infrastructure_errors(self): from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics results = [ - {"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}, + {"status": "success", "eval": [{"reward": 1.0, "passed": True}]}, {"status": "environment_error", "eval": None}, {"status": "user_error", "eval": None}, {"status": "setup_failed", "eval": None}, @@ -331,7 +331,7 @@ def test_excludes_infrastructure_errors(self): metrics = compute_benchmark_metrics(results) assert metrics["total_tasks"] == 4 - assert metrics["scored_tasks"] == 1 # Only completed + assert metrics["scored_tasks"] == 1 # Only success assert metrics["successful_tasks"] == 1 assert metrics["success_rate"] == 1.0 @@ -340,14 +340,15 @@ def test_status_counts(self): from maseval.benchmark.tau2.evaluator import compute_benchmark_metrics results = [ - {"status": "completed", "eval": [{"reward": 1.0, "passed": True}]}, - {"status": "completed", "eval": [{"reward": 0.0, "passed": False}]}, + {"status": "success", "eval": [{"reward": 1.0, "passed": True}]}, + {"status": "agent_error", "eval": [{"reward": 0.0, "passed": False}]}, {"status": "environment_error", "eval": None}, ] metrics = compute_benchmark_metrics(results) - assert metrics["status_counts"]["completed"] == 2 + assert metrics["status_counts"]["success"] == 1 + assert metrics["status_counts"]["agent_error"] == 1 assert metrics["status_counts"]["environment_error"] == 1 diff --git a/tests/test_benchmarks/test_tau2/test_integration.py b/tests/test_benchmarks/test_tau2/test_integration.py index e37996e..36b2340 100644 --- a/tests/test_benchmarks/test_tau2/test_integration.py +++ b/tests/test_benchmarks/test_tau2/test_integration.py @@ -59,9 +59,11 @@ def test_tau2_dry_run(): task.user_data = {"model_id": "mock-user", "instructions": "Test scenario"} task.evaluation_data = {"reward_basis": ["DB"], "actions": []} task.query = "Help me." + task.protocol = MagicMock() + task.protocol.timeout_seconds = None # Setup benchmark - benchmark = IntegrationTau2Benchmark(agent_data={}, n_task_repeats=1) + benchmark = IntegrationTau2Benchmark(n_task_repeats=1) # Mock Environment env_mock = MagicMock() @@ -88,7 +90,7 @@ def test_tau2_dry_run(): # Patch setup_evaluators to return our mock benchmark.setup_evaluators = MagicMock(return_value=[mock_evaluator]) # type: ignore[assignment] - results = benchmark.run([task]) + results = benchmark.run([task], agent_data={}) # Debug info if failed if results[0]["status"] != "success": diff --git a/tests/test_contract/test_agent_adapter_contract.py b/tests/test_contract/test_agent_adapter_contract.py index e2730b2..6521977 100644 --- a/tests/test_contract/test_agent_adapter_contract.py +++ b/tests/test_contract/test_agent_adapter_contract.py @@ -119,7 +119,7 @@ def agent_node(state: State) -> State: return {"messages": messages + [AIMessage(content=response)]} # Build graph - graph = StateGraph(State) # type: ignore[arg-type] + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) diff --git a/tests/test_contract/test_collection_contract.py b/tests/test_contract/test_collection_contract.py index d2621e4..497a689 100644 --- a/tests/test_contract/test_collection_contract.py +++ b/tests/test_contract/test_collection_contract.py @@ -151,7 +151,7 @@ class State(TypedDict): def agent_node(state: State) -> State: return {"messages": state["messages"] + [AIMessage(content="Response")]} - graph = StateGraph(State) # type: ignore[arg-type] + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) diff --git a/tests/test_core/test_benchmark/test_automatic_registration.py b/tests/test_core/test_benchmark/test_automatic_registration.py index e1c55aa..4d51a32 100644 --- a/tests/test_core/test_benchmark/test_automatic_registration.py +++ b/tests/test_core/test_benchmark/test_automatic_registration.py @@ -8,7 +8,7 @@ """ import pytest -from maseval import TaskCollection, TraceableMixin +from maseval import TaskQueue, TraceableMixin from conftest import DummyBenchmark, DummyModelAdapter, DummyAgentAdapter, DummyAgent @@ -20,13 +20,13 @@ def test_automatic_agent_registration(): returned from setup_agents() for trace and config collection without requiring manual register() calls. """ - tasks = TaskCollection.from_list([{"query": "test", "id": "1", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "test", "id": "1", "environment_data": {}}]) agent_data = {} - benchmark = DummyBenchmark(agent_data=agent_data) + benchmark = DummyBenchmark() # Before run, registry should be empty - assert len(benchmark._trace_registry) == 0 + assert len(benchmark._registry._trace_registry) == 0 # Run one step to trigger setup for task, agent_data in zip(tasks, [agent_data]): @@ -46,8 +46,8 @@ def test_automatic_agent_registration(): break # Only test first task # Check that components were registered - assert "environment:env" in benchmark._trace_registry - assert "agents:test_agent" in benchmark._trace_registry + assert "environment:env" in benchmark._registry._trace_registry + assert "agents:test_agent" in benchmark._registry._trace_registry @pytest.mark.core @@ -57,7 +57,7 @@ def test_duplicate_registration_detection(): Verifies that the ID-based tracking system detects when a component instance is registered multiple times with different names, preventing data confusion. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create a component model = DummyModelAdapter() @@ -84,7 +84,7 @@ def test_duplicate_registration_helpful_message(): Verifies that error message includes both the existing registration name and the attempted new name, plus mentions automatic registration. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create and register an agent agent = DummyAgent() @@ -108,7 +108,7 @@ def test_manual_registration_for_models(): Verifies that models are not automatically registered (unlike agents, environments, and users), requiring explicit register() calls. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create a model model = DummyModelAdapter() @@ -117,8 +117,8 @@ def test_manual_registration_for_models(): benchmark.register("models", "my_model", model) # Verify it was registered - assert "models:my_model" in benchmark._trace_registry - assert benchmark._trace_registry["models:my_model"] is model + assert "models:my_model" in benchmark._registry._trace_registry + assert benchmark._registry._trace_registry["models:my_model"] is model @pytest.mark.core @@ -128,7 +128,7 @@ def test_component_id_tracking(): Verifies that benchmark maintains a Python id() to name mapping for detecting duplicate registrations of the same component instance. """ - benchmark = DummyBenchmark(agent_data={}) + benchmark = DummyBenchmark() # Create a component model = DummyModelAdapter() @@ -137,8 +137,8 @@ def test_component_id_tracking(): benchmark.register("models", "test_model", model) # Verify ID tracking - assert id(model) in benchmark._component_id_map - assert benchmark._component_id_map[id(model)] == "models:test_model" + assert id(model) in benchmark._registry._component_id_map + assert benchmark._registry._component_id_map[id(model)] == "models:test_model" @pytest.mark.core @@ -148,7 +148,7 @@ def test_registry_cleared_after_repetition(): Verifies that after each task iteration completes, the registry is reset to allow fresh components for the next iteration while preserving reports. """ - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "test1", "id": "1", "environment_data": {}}, {"query": "test2", "id": "2", "environment_data": {}}, @@ -156,14 +156,14 @@ def test_registry_cleared_after_repetition(): ) agent_data = {} - benchmark = DummyBenchmark(agent_data=agent_data, n_task_repeats=2) + benchmark = DummyBenchmark(n_task_repeats=2) # Run the benchmark - benchmark.run(tasks) + benchmark.run(tasks, agent_data=agent_data) # After run completes, registry should be empty (cleared after last repetition) - assert len(benchmark._trace_registry) == 0 - assert len(benchmark._component_id_map) == 0 + assert len(benchmark._registry._trace_registry) == 0 + assert len(benchmark._registry._component_id_map) == 0 # But reports should contain entries for all task repetitions assert len(benchmark.reports) == 4 # 2 tasks * 2 repeats diff --git a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py index 8a15fca..654d077 100644 --- a/tests/test_core/test_benchmark/test_benchmark_lifecycle.py +++ b/tests/test_core/test_benchmark/test_benchmark_lifecycle.py @@ -6,7 +6,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -15,13 +15,13 @@ class TestBenchmarkLifecycle: def test_benchmark_complete_run_single_task(self, simple_benchmark): """Test that a benchmark completes successfully with a single task.""" - benchmark, tasks = simple_benchmark + benchmark, tasks, agent_data = simple_benchmark # Run benchmark - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data=agent_data) # Verify we got one report - assert len(reports) == 3 # 3 tasks in dummy_task_collection + assert len(reports) == 3 # 3 tasks in dummy_task_queue # Verify report structure report = reports[0] @@ -43,16 +43,16 @@ def test_benchmark_complete_run_multiple_tasks(self): """Test that a benchmark handles multiple tasks correctly.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, {"query": "Task 3", "environment_data": {}}, ] ) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 reports (one per task) assert len(reports) == 3 @@ -69,10 +69,10 @@ def test_benchmark_task_repetitions(self): """Test that task repetitions work correctly.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=3) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) + benchmark = DummyBenchmark(n_task_repeats=3) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 reports (one per repetition) assert len(reports) == 3 @@ -116,19 +116,18 @@ def on_task_end(self, benchmark, task, result): def on_run_end(self, benchmark, results): invocations.append("on_run_end") - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, ] ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[OrderTrackingCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # Verify order expected = [ @@ -156,7 +155,7 @@ def test_benchmark_component_cleanup_between_repeats(self): from conftest import DummyBenchmark from maseval import BenchmarkCallback - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) # Track registry size after each repetition registry_sizes = [] @@ -171,14 +170,13 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): # At start of new repeat, registry should be empty (except for callbacks) if repeat_idx > 0: # After first repeat, registry should have been cleared - registry_sizes.append(len(benchmark._trace_registry)) + registry_sizes.append(len(benchmark._registry._trace_registry)) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[RegistryTracker()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # After first repetition completes and second starts, registry should be cleared # Note: This test verifies cleanup happens between repeats @@ -187,27 +185,27 @@ def test_benchmark_registry_cleared_after_task(self): """Test that registry is properly cleared after each task repetition.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=1) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark(n_task_repeats=1) # Before run, registry should be empty - assert len(benchmark._trace_registry) == 0 - assert len(benchmark._config_registry) == 0 + assert len(benchmark._registry._trace_registry) == 0 + assert len(benchmark._registry._config_registry) == 0 - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # After run completes, registry should be cleared - assert len(benchmark._trace_registry) == 0 - assert len(benchmark._config_registry) == 0 + assert len(benchmark._registry._trace_registry) == 0 + assert len(benchmark._registry._config_registry) == 0 def test_benchmark_reports_structure(self): """Test that benchmark reports have the correct structure.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -239,7 +237,7 @@ def test_benchmark_agent_data_per_task(self): """Test that different agent_data can be provided per task.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, @@ -252,9 +250,9 @@ def test_benchmark_agent_data_per_task(self): {"model": "model-2", "temp": 0.9}, ] - benchmark = DummyBenchmark(agent_data=agent_data_list) + benchmark = DummyBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data=agent_data_list) # Verify each task received its specific agent_data assert len(benchmark.setup_agents_calls) == 2 @@ -267,7 +265,7 @@ def test_benchmark_invalid_agent_data_length(self): """Test that providing wrong number of agent_data items raises error.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, @@ -281,15 +279,15 @@ def test_benchmark_invalid_agent_data_length(self): ValueError, match="must either be a single dict or an iterable matching the number of tasks", ): - benchmark = DummyBenchmark(agent_data=agent_data_list) - benchmark.run(tasks) + benchmark = DummyBenchmark() + benchmark.run(tasks, agent_data=agent_data_list) def test_benchmark_n_task_repeats_validation(self): """Test that n_task_repeats must be at least 1.""" from conftest import DummyBenchmark with pytest.raises(ValueError, match="n_task_repeats must be at least 1"): - DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=0) + DummyBenchmark(n_task_repeats=0) @pytest.mark.core @@ -321,13 +319,12 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingAgentAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = TaskFailureBenchmark( - agent_data={"model": "test"}, fail_on_task_error=False, ) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -359,14 +356,13 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingAgentAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = TaskFailureBenchmark( - agent_data={"model": "test"}, fail_on_task_error=True, ) with pytest.raises(RuntimeError, match="Agent execution failed!"): - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_evaluation_failure_graceful(self): """Test that evaluation failures are caught and recorded when fail_on_evaluation_error=False.""" @@ -384,13 +380,12 @@ class EvaluationFailureBenchmark(DummyBenchmark): def setup_evaluators(self, environment, task, agents, user): return [FailingEvaluator(task, environment, user)] - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = EvaluationFailureBenchmark( - agent_data={"model": "test"}, fail_on_evaluation_error=False, ) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -416,14 +411,13 @@ class EvaluationFailureBenchmark(DummyBenchmark): def setup_evaluators(self, environment, task, agents, user): return [FailingEvaluator(task, environment, user)] - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = EvaluationFailureBenchmark( - agent_data={"model": "test"}, fail_on_evaluation_error=True, ) with pytest.raises(ValueError, match="Evaluation failed!"): - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_setup_failure_graceful(self): """Test that setup failures are caught and recorded when fail_on_setup_error=False.""" @@ -434,13 +428,12 @@ class SetupFailureBenchmark(DummyBenchmark): def setup_environment(self, agent_data, task): raise RuntimeError("Environment setup failed!") - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = SetupFailureBenchmark( - agent_data={"model": "test"}, fail_on_setup_error=False, ) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] @@ -458,14 +451,13 @@ class SetupFailureBenchmark(DummyBenchmark): def setup_environment(self, agent_data, task): raise RuntimeError("Environment setup failed!") - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) benchmark = SetupFailureBenchmark( - agent_data={"model": "test"}, fail_on_setup_error=True, ) with pytest.raises(RuntimeError, match="Environment setup failed!"): - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_get_failed_tasks(self): """Test get_failed_tasks() method.""" @@ -498,15 +490,15 @@ def setup_agents(self, agent_data, environment, task, user): self.task_counter += 1 return [agent_adapter], {agent_adapter.name: agent_adapter} - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, {"query": "Task 3", "environment_data": {}}, ] ) - benchmark = MixedBenchmark(agent_data={"model": "test"}) - reports = benchmark.run(tasks) + benchmark = MixedBenchmark() + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Test using internal state failed = benchmark.get_failed_tasks() @@ -553,7 +545,7 @@ def test_get_failed_tasks_before_run_raises(self): """Test that get_failed_tasks() raises if called before run().""" from conftest import DummyBenchmark - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() with pytest.raises(RuntimeError, match="must be called after run"): benchmark.get_failed_tasks() @@ -563,10 +555,10 @@ def test_successful_task_has_success_status(self): from maseval import TaskExecutionStatus from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.SUCCESS.value @@ -577,7 +569,7 @@ def test_default_failure_flags(self): """Test that failure flags default to False (graceful handling).""" from conftest import DummyBenchmark - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() assert benchmark.fail_on_setup_error is False assert benchmark.fail_on_task_error is False @@ -594,10 +586,10 @@ def test_multiple_run_calls_no_side_effects(self): from conftest import DummyBenchmark # Create benchmark - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() # First run with 3 tasks - tasks1 = TaskCollection.from_list( + tasks1 = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, @@ -605,19 +597,19 @@ def test_multiple_run_calls_no_side_effects(self): ] ) - reports1 = benchmark.run(tasks=tasks1) + reports1 = benchmark.run(tasks=tasks1, agent_data={"model": "test"}) assert len(reports1) == 3 assert len(benchmark.reports) == 3 # Second run with 2 different tasks - tasks2 = TaskCollection.from_list( + tasks2 = TaskQueue.from_list( [ {"query": "Task A", "environment_data": {}}, {"query": "Task B", "environment_data": {}}, ] ) - reports2 = benchmark.run(tasks=tasks2) + reports2 = benchmark.run(tasks=tasks2, agent_data={"model": "test"}) assert len(reports2) == 2 # Verify reports were cleared from first run assert len(benchmark.reports) == 2 @@ -629,13 +621,13 @@ def test_multiple_run_calls_no_side_effects(self): # Third run - retry pattern (simulating failed tasks) # Use one task from tasks1 - retry_tasks = TaskCollection([list(tasks1)[0]]) - reports3 = benchmark.run(tasks=retry_tasks) + retry_tasks = TaskQueue([list(tasks1)[0]]) + reports3 = benchmark.run(tasks=retry_tasks, agent_data={"model": "test"}) assert len(reports3) == 1 assert len(benchmark.reports) == 1 def test_retry_failed_tasks_pattern(self): - """Test the intended use case: benchmark.run(benchmark.get_failed_tasks()). + """Test the intended use case: benchmark.run(benchmark.get_failed_tasks(, agent_data={"model": "test"})). This verifies that failed tasks can be retried by passing them back to run(). This includes returning tasks that failed using the correct format that run() expects. @@ -671,7 +663,7 @@ def setup_agents(self, agent_data, environment, task, user): self.task_counter += 1 return [agent_adapter], {agent_adapter.name: agent_adapter} - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, @@ -679,10 +671,10 @@ def setup_agents(self, agent_data, environment, task, user): ] ) - benchmark = ConditionalFailureBenchmark(agent_data={"model": "test"}) + benchmark = ConditionalFailureBenchmark() # First run - one task will fail - reports = benchmark.run(tasks=tasks) + reports = benchmark.run(tasks=tasks, agent_data={"model": "test"}) assert len(reports) == 3 # Get failed tasks - should have 1 failure @@ -693,7 +685,7 @@ def setup_agents(self, agent_data, environment, task, user): # Retry the failed tasks (simulate fixing the issue) benchmark.fail_on_first_run = False benchmark.task_counter = 0 # Reset counter - retry_reports = benchmark.run(tasks=failed) + retry_reports = benchmark.run(tasks=failed, agent_data={"model": "test"}) # Should have 1 report for the retried task assert len(retry_reports) == 1 diff --git a/tests/test_core/test_benchmark/test_callback_error_handling.py b/tests/test_core/test_benchmark/test_callback_error_handling.py new file mode 100644 index 0000000..8d90623 --- /dev/null +++ b/tests/test_core/test_benchmark/test_callback_error_handling.py @@ -0,0 +1,413 @@ +"""Tests for callback error handling in benchmark execution. + +These tests verify that callback exceptions are properly isolated and logged, +preventing one failing callback from disrupting the entire benchmark run. +""" + +import pytest +import logging +from typing import List + +from maseval import ( + BenchmarkCallback, + Task, + TaskQueue, +) +from conftest import DummyBenchmark + + +# ==================== Test Fixtures ==================== + + +class FailingCallback(BenchmarkCallback): + """Callback that raises an exception on specific methods.""" + + def __init__(self, fail_on: str = "on_task_start"): + self.fail_on = fail_on + self.call_count = 0 + + def on_task_start(self, benchmark, task): + if self.fail_on == "on_task_start": + raise RuntimeError("Intentional failure in on_task_start") + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + if self.fail_on == "on_task_repeat_start": + raise ValueError("Intentional failure in on_task_repeat_start") + + def on_task_repeat_end(self, benchmark, report): + self.call_count += 1 + if self.fail_on == "on_task_repeat_end": + raise TypeError("Intentional failure in on_task_repeat_end") + + def on_task_end(self, benchmark, task, result): + if self.fail_on == "on_task_end": + raise KeyError("Intentional failure in on_task_end") + + +class TrackingCallback(BenchmarkCallback): + """Callback that tracks which methods were called.""" + + def __init__(self): + self.calls: List[str] = [] + + def on_run_start(self, benchmark): + self.calls.append("run_start") + + def on_task_start(self, benchmark, task): + self.calls.append(f"task_start:{task.query}") + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + self.calls.append(f"repeat_start:{task.query}:{repeat_idx}") + + def on_task_repeat_end(self, benchmark, report): + self.calls.append(f"repeat_end:{report['task_id'][:8]}") + + def on_task_end(self, benchmark, task, result): + self.calls.append(f"task_end:{task.query}") + + def on_run_end(self, benchmark, results): + self.calls.append("run_end") + + +@pytest.fixture +def simple_tasks(): + """Create simple tasks for testing.""" + return TaskQueue.from_list( + [ + {"query": "Task 1", "environment_data": {}}, + {"query": "Task 2", "environment_data": {}}, + ] + ) + + +# ==================== Error Suppression Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorSuppression: + """Tests for callback error suppression in sequential execution.""" + + def test_failing_callback_does_not_stop_execution(self, simple_tasks, caplog): + """A failing callback should not prevent benchmark from completing.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_start") + benchmark = DummyBenchmark( + callbacks=[failing_cb], + ) + + # Should complete despite callback failure + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + assert len(reports) == 2 + assert all(r["status"] == "success" for r in reports) + + # Error should be logged + assert "Callback" in caplog.text + assert "on_task_start" in caplog.text + assert "Intentional failure" in caplog.text + + def test_multiple_callbacks_one_fails_others_continue(self, simple_tasks, caplog): + """Other callbacks should continue even if one fails.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + callbacks=[failing_cb, tracking_cb], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + # Execution completes + assert len(reports) == 2 + + # Tracking callback still received all events + assert "run_start" in tracking_cb.calls + assert "task_start:Task 1" in tracking_cb.calls + assert "task_start:Task 2" in tracking_cb.calls + assert "run_end" in tracking_cb.calls + + # Error logged + assert "on_task_repeat_end" in caplog.text + + def test_callback_fails_on_every_task(self, simple_tasks, caplog): + """Execution continues even if callback fails on every task.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + benchmark = DummyBenchmark( + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + # All tasks complete + assert len(reports) == 2 + assert all(r["status"] == "success" for r in reports) + + # Callback was attempted for each task (even though it failed) + assert failing_cb.call_count == 2 + + def test_callback_error_in_on_run_start(self, simple_tasks, caplog): + """Benchmark continues if callback fails in on_run_start.""" + caplog.set_level(logging.ERROR) + + class RunStartFailer(BenchmarkCallback): + def on_run_start(self, benchmark): + raise RuntimeError("Failed at run start") + + benchmark = DummyBenchmark( + callbacks=[RunStartFailer()], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + assert len(reports) == 2 + assert "on_run_start" in caplog.text + + def test_callback_error_in_on_run_end(self, simple_tasks, caplog): + """Benchmark completes and logs error if on_run_end fails.""" + caplog.set_level(logging.ERROR) + + class RunEndFailer(BenchmarkCallback): + def on_run_end(self, benchmark, results): + raise RuntimeError("Failed at run end") + + benchmark = DummyBenchmark( + callbacks=[RunEndFailer()], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + # Reports are generated (run_end happens after report collection) + assert len(reports) == 2 + assert "on_run_end" in caplog.text + + +# ==================== Parallel Execution Error Handling Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorHandlingParallel: + """Tests for callback error handling in parallel execution.""" + + def test_callback_error_in_parallel_execution(self, simple_tasks, caplog): + """Callback errors in parallel execution should not crash workers.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + callbacks=[failing_cb, tracking_cb], + ) + + # Run in parallel + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + # All tasks complete + assert len(reports) == 2 + assert all(r["status"] == "success" for r in reports) + + # Tracking callback still received events + assert len(tracking_cb.calls) > 0 + + # Errors logged + assert "on_task_repeat_end" in caplog.text + + def test_multiple_parallel_tasks_with_failing_callback(self, caplog): + """Callback failures should not interfere across parallel workers.""" + caplog.set_level(logging.ERROR) + + tasks = TaskQueue.from_list([{"query": f"Task {i}", "environment_data": {}} for i in range(5)]) + + failing_cb = FailingCallback(fail_on="on_task_repeat_start") + + benchmark = DummyBenchmark( + callbacks=[failing_cb], + ) + + reports = benchmark.run(tasks, agent_data={"model": "test"}) + + # All 5 tasks complete despite callback failures + assert len(reports) == 5 + assert all(r["status"] == "success" for r in reports) + + # Multiple errors logged (one per task) + error_count = caplog.text.count("on_task_repeat_start") + assert error_count >= 5 + + +# ==================== Error Return Value Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorReturnValues: + """Tests for _invoke_callbacks error return values.""" + + def test_invoke_callbacks_returns_empty_list_on_success(self, simple_tasks): + """_invoke_callbacks should return empty list when no errors occur.""" + tracking_cb = TrackingCallback() + benchmark = DummyBenchmark( + callbacks=[tracking_cb], + ) + + # Manually invoke callbacks to test return value + errors = benchmark._invoke_callbacks("on_run_start", benchmark) + + assert errors == [] + assert "run_start" in tracking_cb.calls + + def test_invoke_callbacks_returns_error_list_on_failure(self, simple_tasks): + """_invoke_callbacks should return list of exceptions when callbacks fail.""" + failing_cb1 = FailingCallback(fail_on="on_task_start") + failing_cb2 = FailingCallback(fail_on="on_task_start") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + callbacks=[failing_cb1, tracking_cb, failing_cb2], + ) + + task = Task(query="Test", environment_data={}) + errors = benchmark._invoke_callbacks("on_task_start", benchmark, task) + + # Two errors returned (from two failing callbacks) + assert len(errors) == 2 + assert all(isinstance(e, RuntimeError) for e in errors) + + # Tracking callback still ran + assert "task_start:Test" in tracking_cb.calls + + def test_invoke_callbacks_with_suppress_false_raises(self, simple_tasks): + """With suppress_errors=False, first exception should be raised.""" + failing_cb = FailingCallback(fail_on="on_task_start") + tracking_cb = TrackingCallback() + + benchmark = DummyBenchmark( + callbacks=[failing_cb, tracking_cb], + ) + + task = Task(query="Test", environment_data={}) + + with pytest.raises(RuntimeError, match="Intentional failure"): + benchmark._invoke_callbacks( + "on_task_start", + benchmark, + task, + suppress_errors=False, + ) + + # Tracking callback was not called (execution stopped at first error) + assert len(tracking_cb.calls) == 0 + + +# ==================== Different Exception Types Tests ==================== + + +@pytest.mark.core +class TestCallbackExceptionTypes: + """Tests for handling different exception types from callbacks.""" + + def test_value_error_in_callback(self, simple_tasks, caplog): + """ValueError from callback should be caught and logged.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_start") + benchmark = DummyBenchmark( + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + assert len(reports) == 2 + assert "ValueError" in caplog.text + + def test_type_error_in_callback(self, simple_tasks, caplog): + """TypeError from callback should be caught and logged.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + benchmark = DummyBenchmark( + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + assert len(reports) == 2 + assert "TypeError" in caplog.text + + def test_key_error_in_callback(self, simple_tasks, caplog): + """KeyError from callback should be caught and logged.""" + caplog.set_level(logging.ERROR) + + failing_cb = FailingCallback(fail_on="on_task_end") + benchmark = DummyBenchmark( + callbacks=[failing_cb], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + assert len(reports) == 2 + assert "KeyError" in caplog.text + + +# ==================== Integration Tests ==================== + + +@pytest.mark.core +class TestCallbackErrorHandlingIntegration: + """Integration tests for callback error handling with real scenarios.""" + + def test_failing_callback_with_repeats(self, caplog): + """Callback errors should be handled correctly with n_task_repeats > 1.""" + caplog.set_level(logging.ERROR) + + tasks = TaskQueue.from_list([{"query": "Task", "environment_data": {}}]) + + failing_cb = FailingCallback(fail_on="on_task_repeat_end") + + benchmark = DummyBenchmark( + n_task_repeats=3, + callbacks=[failing_cb], + ) + + reports = benchmark.run(tasks, agent_data={"model": "test"}) + + # 3 reports (one per repeat) + assert len(reports) == 3 + + # Callback attempted 3 times + assert failing_cb.call_count == 3 + + # 3 errors logged + error_count = caplog.text.count("on_task_repeat_end") + assert error_count >= 3 + + def test_mixed_callbacks_some_fail_some_succeed(self, simple_tasks, caplog): + """Mixed scenario: some callbacks fail, others succeed.""" + caplog.set_level(logging.ERROR) + + failing_cb1 = FailingCallback(fail_on="on_task_start") + tracking_cb1 = TrackingCallback() + failing_cb2 = FailingCallback(fail_on="on_task_repeat_end") + tracking_cb2 = TrackingCallback() + + benchmark = DummyBenchmark( + callbacks=[failing_cb1, tracking_cb1, failing_cb2, tracking_cb2], + ) + + reports = benchmark.run(simple_tasks, agent_data={"model": "test"}) + + # Execution completes + assert len(reports) == 2 + + # Both tracking callbacks received all events + assert len(tracking_cb1.calls) > 0 + assert len(tracking_cb2.calls) > 0 + assert tracking_cb1.calls == tracking_cb2.calls + + # Both types of errors logged + assert "on_task_start" in caplog.text + assert "on_task_repeat_end" in caplog.text diff --git a/tests/test_core/test_benchmark/test_callback_orchestration.py b/tests/test_core/test_benchmark/test_callback_orchestration.py index 63a3028..f32652b 100644 --- a/tests/test_core/test_benchmark/test_callback_orchestration.py +++ b/tests/test_core/test_benchmark/test_callback_orchestration.py @@ -5,7 +5,7 @@ """ import pytest -from maseval import BenchmarkCallback, TaskCollection +from maseval import BenchmarkCallback, TaskQueue @pytest.mark.core @@ -37,14 +37,13 @@ def on_task_end(self, benchmark, task, result): def on_run_end(self, benchmark, results): order.append("run_end") - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[OrderedCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) expected = [ "run_start", @@ -79,10 +78,10 @@ def on_run_start(self, benchmark): def on_run_end(self, benchmark, results): callback2_calls.append("end") - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[Callback1(), Callback2()]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark(callbacks=[Callback1(), Callback2()]) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert callback1_calls == ["start", "end"] assert callback2_calls == ["start", "end"] @@ -104,16 +103,15 @@ def on_run_start(self, benchmark): def on_run_end(self, benchmark, results): successful_calls.append("end") - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[FailingCallback(), SuccessfulCallback()], ) # Note: Current implementation may not catch callback errors # This test documents the expected behavior try: - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # If callbacks are isolated, successful callback should work assert "start" in successful_calls or "end" in successful_calls except RuntimeError: @@ -136,19 +134,18 @@ def on_task_repeat_start(self, benchmark, task, repeat_idx): nonlocal repeat_count repeat_count += 1 - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task1", "environment_data": {}}, {"query": "Task2", "environment_data": {}}, ] ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=3, callbacks=[CountingCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # 2 tasks, each called once assert task_count == 2 @@ -172,10 +169,10 @@ def on_task_repeat_end(self, benchmark, report): def on_run_end(self, benchmark, results): contexts["results_count"] = len(results) - tasks = TaskCollection.from_list([{"query": "TestQuery", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[ContextCapturingCallback()]) + tasks = TaskQueue.from_list([{"query": "TestQuery", "environment_data": {}}]) + benchmark = DummyBenchmark(callbacks=[ContextCapturingCallback()]) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) # Verify contexts captured correctly assert contexts["task_query"] == "TestQuery" @@ -195,19 +192,18 @@ def on_run_start(self, benchmark): captured_state["n_tasks"] = len(benchmark.tasks) captured_state["n_repeats"] = benchmark.n_task_repeats - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Q1", "environment_data": {}}, {"query": "Q2", "environment_data": {}}, ] ) benchmark = DummyBenchmark( - agent_data={"model": "test"}, n_task_repeats=2, callbacks=[StateAccessingCallback()], ) - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert captured_state["n_tasks"] == 2 assert captured_state["n_repeats"] == 2 diff --git a/tests/test_core/test_benchmark/test_config_collection.py b/tests/test_core/test_benchmark/test_config_collection.py index e1a4ad3..22bac9a 100644 --- a/tests/test_core/test_benchmark/test_config_collection.py +++ b/tests/test_core/test_benchmark/test_config_collection.py @@ -5,7 +5,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -16,10 +16,10 @@ def test_config_collected_from_all_components(self): """Test that configs are collected from all registered components.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Verify config has the expected structure @@ -48,10 +48,10 @@ def test_config_includes_benchmark_level_info(self): """Test that benchmark-level configuration is included.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Verify benchmark-level config exists @@ -67,10 +67,10 @@ def test_config_includes_system_info(self): """Test that system information is captured.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] system_info = config["benchmark"]["system"] @@ -81,10 +81,10 @@ def test_config_includes_git_info(self): """Test that git information is captured when available.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Git info may not be available in all environments @@ -98,10 +98,10 @@ def test_config_includes_package_versions(self): """Test that installed package versions are captured.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Should capture package information @@ -114,10 +114,10 @@ def test_config_structure_matches_spec(self): """Test that config structure matches expected specification.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Top-level keys @@ -166,11 +166,11 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingConfigAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} # type: ignore[return-value] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = TestBenchmark() # Should complete without raising, with error info in config - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Verify config collection handled the error @@ -184,10 +184,10 @@ def test_config_json_serializable(self): import json from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Should be able to serialize to JSON @@ -205,10 +205,10 @@ def test_config_contains_timestamps(self): """Test that all config components include timestamps.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Check metadata timestamp @@ -229,10 +229,10 @@ def test_config_includes_component_types(self): """Test that all configs include component type information.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) config = reports[0]["config"] # Check agent type @@ -249,10 +249,10 @@ def test_config_different_per_repetition(self): """Test that each repetition has its own config snapshot.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=3) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark(n_task_repeats=3) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 reports assert len(reports) == 3 diff --git a/tests/test_core/test_benchmark/test_execution_loop.py b/tests/test_core/test_benchmark/test_execution_loop.py index bd06dfc..a126566 100644 --- a/tests/test_core/test_benchmark/test_execution_loop.py +++ b/tests/test_core/test_benchmark/test_execution_loop.py @@ -11,7 +11,7 @@ from typing import Any, List, Optional, Tuple import warnings -from maseval import Benchmark, Task, TaskCollection, User +from maseval import Benchmark, Task, TaskQueue, User # ============================================================================= @@ -72,7 +72,7 @@ class TestExecutionLoopNoUser: def test_uses_task_query_without_user(self, dummy_model): """Uses task.query when no user present.""" task = Task(query="What is the weather?", environment_data={}) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=None) + benchmark = ExecutionLoopBenchmark(return_user=None) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, None) @@ -86,7 +86,7 @@ def test_uses_task_query_without_user(self, dummy_model): def test_single_invocation_without_user(self, dummy_model): """Single agent run without user (default max_invocations=1).""" task = Task(query="Query", environment_data={}) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=None) + benchmark = ExecutionLoopBenchmark(return_user=None) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, None) @@ -98,7 +98,7 @@ def test_single_invocation_without_user(self, dummy_model): def test_returns_final_answer(self, dummy_model): """Returns final answer from agent.""" task = Task(query="Test query", environment_data={}) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=None) + benchmark = ExecutionLoopBenchmark(return_user=None) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, None) @@ -129,7 +129,7 @@ def test_uses_user_initial_query(self, dummy_model): initial_query="User's initial message", max_turns=5, ) - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, user) @@ -147,9 +147,9 @@ def test_uses_get_initial_query_if_no_initial_query(self, dummy_model): task = Task(query="Task query", environment_data={}) user = DummyUser(name="test", model=dummy_model, max_turns=5) # No initial_query, so messages is empty - user.simulator.return_value = "LLM generated initial query" # type: ignore[assignment] + user.simulator.return_value = "LLM generated initial query" # type: ignore[union-attr] # mock - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, user) @@ -172,10 +172,9 @@ def test_multi_turn_interaction(self, dummy_model): max_turns=5, ) # User responds with different messages each turn - user.simulator.side_effect = ["Turn 1 response", "Turn 2 response", "Turn 3 response"] # type: ignore[assignment] + user.simulator.side_effect = ["Turn 1 response", "Turn 2 response", "Turn 3 response"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=3, ) @@ -205,10 +204,9 @@ def test_stops_when_user_done_via_max_turns(self, dummy_model): initial_query="Start", # Counts as turn 1 max_turns=3, # User done after 3 user messages ) - user.simulator.side_effect = ["Response 1", "Response 2", "Response 3"] # type: ignore[assignment] + user.simulator.side_effect = ["Response 1", "Response 2", "Response 3"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=5, # Would allow 5, but user stops at 3 turns ) @@ -236,10 +234,9 @@ def test_stops_when_user_done_via_stop_token(self, dummy_model): early_stopping_condition="goals are met", ) # User stops on second response - user.simulator.side_effect = ["Continue please", "Thanks! "] # type: ignore[assignment] + user.simulator.side_effect = ["Continue please", "Thanks! "] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=5, ) @@ -263,9 +260,9 @@ def test_final_answer_in_user_messages(self, dummy_model): initial_query="Help me", max_turns=2, # Allow initial + one response ) - user.simulator.return_value = "Thanks" # type: ignore[assignment] + user.simulator.return_value = "Thanks" # type: ignore[union-attr] # mock - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) env = benchmark.setup_environment({}, task) agents, _ = benchmark.setup_agents({}, env, task, user) @@ -291,10 +288,9 @@ def test_user_response_becomes_next_query(self, dummy_model): max_turns=4, # Allow 4 turns total ) # User responses for turn 2, 3, 4 - user.simulator.side_effect = ["User reply 1", "User reply 2", "User reply 3"] # type: ignore[assignment] + user.simulator.side_effect = ["User reply 1", "User reply 2", "User reply 3"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=3, # Will stop after 3 due to max_invocations ) @@ -331,26 +327,25 @@ class TestMaxInvocations: def test_default_max_invocations_is_one(self): """Default is single invocation.""" - benchmark = ExecutionLoopBenchmark(agent_data={}) + benchmark = ExecutionLoopBenchmark() assert benchmark.max_invocations == 1 def test_custom_max_invocations(self): """Custom max_invocations is stored.""" - benchmark = ExecutionLoopBenchmark(agent_data={}, max_invocations=5) + benchmark = ExecutionLoopBenchmark(max_invocations=5) assert benchmark.max_invocations == 5 def test_warning_max_invocations_without_user(self): """Warning issued when max_invocations > 1 but no user.""" task = Task(query="Test", environment_data={}) benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=None, max_invocations=5, ) with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - benchmark.run(TaskCollection([task])) + benchmark.run(TaskQueue([task]), agent_data={}) # Check for warning about max_invocations without user warning_messages = [str(warning.message) for warning in w] @@ -377,11 +372,11 @@ def test_run_with_user_uses_execution_loop(self, dummy_model): initial_query="User query", max_turns=1, ) - user.simulator.return_value = "Done" # type: ignore[assignment] + user.simulator.return_value = "Done" # type: ignore[union-attr] # mock - benchmark = ExecutionLoopBenchmark(agent_data={}, return_user=user) + benchmark = ExecutionLoopBenchmark(return_user=user) - benchmark.run(TaskCollection([task])) + benchmark.run(TaskQueue([task]), agent_data={}) # Verify run_agents was called with user's initial prompt assert len(benchmark.run_agents_calls) == 1 @@ -399,15 +394,14 @@ def test_complete_traces_with_user(self, dummy_model): initial_query="Hello", # Turn 1 max_turns=3, # Allow 3 user messages total ) - user.simulator.side_effect = ["Reply 1", "Reply 2"] # type: ignore[assignment] + user.simulator.side_effect = ["Reply 1", "Reply 2"] # type: ignore[union-attr] # mock benchmark = ExecutionLoopBenchmark( - agent_data={}, return_user=user, max_invocations=2, ) - reports = benchmark.run(TaskCollection([task])) + reports = benchmark.run(TaskQueue([task]), agent_data={}) # Check that user traces are in the report assert len(reports) == 1 diff --git a/tests/test_core/test_benchmark/test_parallel_execution.py b/tests/test_core/test_benchmark/test_parallel_execution.py new file mode 100644 index 0000000..a9bfeae --- /dev/null +++ b/tests/test_core/test_benchmark/test_parallel_execution.py @@ -0,0 +1,410 @@ +"""Tests for parallel task execution in Benchmark. + +These tests verify that parallel execution with num_workers > 1 works correctly, +including thread safety, report collection, and callback serialization. +""" + +import pytest +import threading +import time +from typing import List, Tuple, Optional + +from maseval import ( + BenchmarkCallback, + Task, + TaskQueue, + TaskExecutionStatus, +) +from conftest import DummyBenchmark + + +# ==================== Test Fixtures ==================== + + +class SlowBenchmark(DummyBenchmark): + """Benchmark that introduces configurable delays per task.""" + + def __init__(self, delay_seconds: float = 0.05, **kwargs): + super().__init__(**kwargs) + self.delay = delay_seconds + self.execution_times: List[Tuple[str, float, float]] = [] # (task_id, start, end) + self._timing_lock = threading.Lock() + + def run_agents(self, agents, task, environment, query): + start = time.time() + time.sleep(self.delay) + result = super().run_agents(agents, task, environment, query) + end = time.time() + + with self._timing_lock: + self.execution_times.append((str(task.id), start, end)) + + return result + + +class ThreadTrackingCallback(BenchmarkCallback): + """Callback that records which thread each event fires on.""" + + def __init__(self): + self.thread_ids: List[Tuple[str, Optional[int]]] = [] + self._lock = threading.Lock() + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + with self._lock: + self.thread_ids.append(("repeat_start", threading.current_thread().ident)) + + def on_task_repeat_end(self, benchmark, report): + with self._lock: + self.thread_ids.append(("repeat_end", threading.current_thread().ident)) + + +class OrderTrackingCallback(BenchmarkCallback): + """Callback that records the order of callback invocations.""" + + def __init__(self): + self.invocations: List[str] = [] + self._lock = threading.Lock() + + def on_run_start(self, benchmark): + with self._lock: + self.invocations.append("run_start") + + def on_task_start(self, benchmark, task): + with self._lock: + self.invocations.append(f"task_start:{task.query}") + + def on_task_repeat_start(self, benchmark, task, repeat_idx): + with self._lock: + self.invocations.append(f"repeat_start:{task.query}:{repeat_idx}") + + def on_task_repeat_end(self, benchmark, report): + with self._lock: + self.invocations.append(f"repeat_end:{report['task_id'][:8]}") + + def on_task_end(self, benchmark, task, result): + with self._lock: + self.invocations.append(f"task_end:{task.query}") + + def on_run_end(self, benchmark, results): + with self._lock: + self.invocations.append("run_end") + + +@pytest.fixture +def parallel_tasks(): + """Create tasks for parallel execution testing.""" + return TaskQueue.from_list([{"query": f"Task {i}", "environment_data": {"index": i}} for i in range(5)]) + + +# ==================== Basic Parallel Execution Tests ==================== + + +@pytest.mark.core +class TestParallelExecutionBasics: + """Tests for basic parallel execution functionality.""" + + def test_parallel_execution_completes(self, parallel_tasks): + """Verify parallel execution completes all tasks.""" + benchmark = DummyBenchmark() + + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) + + assert len(reports) == 5 + + def test_parallel_produces_same_report_count(self, parallel_tasks): + """Parallel and sequential should produce same number of reports.""" + benchmark_seq = DummyBenchmark() + benchmark_par = DummyBenchmark() + + reports_seq = benchmark_seq.run(parallel_tasks, agent_data={"model": "test"}) + reports_par = benchmark_par.run(parallel_tasks, agent_data={"model": "test"}) + + assert len(reports_seq) == len(reports_par) + + def test_parallel_reports_have_correct_structure(self, parallel_tasks): + """Verify parallel reports have expected fields.""" + benchmark = DummyBenchmark(num_workers=3) + + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) + + for report in reports: + assert "task_id" in report + assert "repeat_idx" in report + assert "status" in report + assert "traces" in report + assert "config" in report + assert "eval" in report + + def test_single_worker_uses_sequential(self, parallel_tasks): + """num_workers=1 should behave identically to sequential.""" + callback = OrderTrackingCallback() + benchmark = DummyBenchmark( + callbacks=[callback], + ) + + benchmark.run(parallel_tasks, agent_data={"model": "test"}) + + # Verify ordering is strictly sequential (task_start before all repeat_starts) + assert callback.invocations[0] == "run_start" + assert callback.invocations[1] == "task_start:Task 0" + assert callback.invocations[-1] == "run_end" + + def test_parallel_with_repetitions(self): + """Verify parallel execution with n_task_repeats > 1.""" + tasks = TaskQueue.from_list( + [ + {"query": "T1", "environment_data": {}}, + {"query": "T2", "environment_data": {}}, + ] + ) + benchmark = DummyBenchmark( + n_task_repeats=3, + ) + + reports = benchmark.run(tasks, agent_data={"model": "test"}) + + assert len(reports) == 6 # 2 tasks × 3 repeats + + # Verify repeat indices + repeat_indices = [r["repeat_idx"] for r in reports] + assert set(repeat_indices) == {0, 1, 2} + + +# ==================== Thread Safety Tests ==================== + + +@pytest.mark.core +class TestParallelThreadSafety: + """Tests for thread safety in parallel execution.""" + + def test_reports_all_collected(self, parallel_tasks): + """All reports should be collected regardless of completion order.""" + benchmark = SlowBenchmark( + delay_seconds=0.02, + ) + + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) + + assert len(reports) == 5 + task_ids = {r["task_id"] for r in reports} + assert len(task_ids) == 5 + + def test_traces_not_cross_contaminated(self, parallel_tasks): + """Traces from one task should not appear in another's report.""" + benchmark = DummyBenchmark(num_workers=4) + + reports = benchmark.run(parallel_tasks, agent_data={"model": "test"}) + + for report in reports: + # Each report should have its own traces + assert report["traces"] is not None + assert "metadata" in report["traces"] + + def test_callbacks_receive_correct_data(self): + """Callbacks should receive correct task/report data in parallel.""" + tasks = TaskQueue.from_list([{"query": f"Query_{i}", "environment_data": {"idx": i}} for i in range(3)]) + + received_data = [] + lock = threading.Lock() + + class DataCapturingCallback(BenchmarkCallback): + def on_task_repeat_end(self, benchmark, report): + with lock: + received_data.append( + { + "task_id": report["task_id"], + "status": report["status"], + } + ) + + benchmark = DummyBenchmark( + callbacks=[DataCapturingCallback()], + ) + + benchmark.run(tasks, agent_data={"model": "test"}) + + assert len(received_data) == 3 + statuses = {d["status"] for d in received_data} + assert statuses == {"success"} + + def test_callback_exceptions_suppressed_by_default(self): + """Callback exceptions are suppressed by default to prevent disruption.""" + # Create fresh tasks for this test + tasks = TaskQueue.from_list([{"query": f"Task {i}", "environment_data": {}} for i in range(5)]) + + call_count = [0] + + class FailingCallback(BenchmarkCallback): + def on_task_repeat_end(self, benchmark, report): + call_count[0] += 1 + if call_count[0] == 2: + raise RuntimeError("Intentional failure") + + benchmark = DummyBenchmark( + callbacks=[FailingCallback()], + ) + + # New behavior: callback exceptions are suppressed by default + # This prevents one failing callback from disrupting parallel execution + reports = benchmark.run(tasks, agent_data={"model": "test"}) + + # Execution completes despite callback failure + assert len(reports) == 5 + # Callback was called multiple times (not stopped at failure) + assert call_count[0] >= 2 + + +# ==================== Concurrency Verification Tests ==================== + + +@pytest.mark.core +class TestParallelConcurrency: + """Tests verifying actual concurrent execution.""" + + def test_parallel_faster_than_sequential(self): + """Parallel execution should be faster for I/O-bound tasks.""" + tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) + delay = 0.05 + + # Sequential timing + benchmark_seq = SlowBenchmark(delay_seconds=delay) + start_seq = time.time() + benchmark_seq.run(tasks, agent_data={"model": "test"}) + time_seq = time.time() - start_seq + + # Parallel timing + benchmark_par = SlowBenchmark(delay_seconds=delay, num_workers=4) + start_par = time.time() + benchmark_par.run(tasks, agent_data={"model": "test"}) + time_par = time.time() - start_par + + # Parallel should be significantly faster (at least 2x) + assert time_par < time_seq * 0.7 + + def test_execution_overlaps(self): + """Task executions should overlap in parallel mode.""" + tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(3)]) + + benchmark = SlowBenchmark( + delay_seconds=0.05, + num_workers=3, + ) + + benchmark.run(tasks, agent_data={"model": "test"}) + + # Check for overlapping execution times + times = benchmark.execution_times + assert len(times) == 3 + + # At least one pair should overlap + overlaps = 0 + for i in range(len(times)): + for j in range(i + 1, len(times)): + _, start_i, end_i = times[i] + _, start_j, end_j = times[j] + # Check if intervals overlap + if start_i < end_j and start_j < end_i: + overlaps += 1 + + assert overlaps > 0, "Expected overlapping execution in parallel mode" + + +# ==================== Error Handling Tests ==================== + + +@pytest.mark.core +class TestParallelErrorHandling: + """Tests for error handling in parallel execution.""" + + def test_error_in_one_task_doesnt_stop_others(self): + """One task failing should not prevent other tasks from completing.""" + + class FailingBenchmark(DummyBenchmark): + def run_agents(self, agents, task, environment, query): + if "fail" in query.lower(): + raise RuntimeError("Intentional failure") + return super().run_agents(agents, task, environment, query) + + tasks = TaskQueue.from_list( + [ + {"query": "Normal 1", "environment_data": {}}, + {"query": "FAIL task", "environment_data": {}}, + {"query": "Normal 2", "environment_data": {}}, + ] + ) + + benchmark = FailingBenchmark() + reports = benchmark.run(tasks, agent_data={"model": "test"}) + + assert len(reports) == 3 + + statuses = {r["status"] for r in reports} + assert TaskExecutionStatus.SUCCESS.value in statuses + assert TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value in statuses + + def test_all_tasks_get_reports_even_with_failures(self): + """Every task should produce a report even if some fail.""" + + class HalfFailingBenchmark(DummyBenchmark): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._call_count = 0 + self._lock = threading.Lock() + + def run_agents(self, agents, task, environment, query): + with self._lock: + self._call_count += 1 + should_fail = self._call_count % 2 == 0 + + if should_fail: + raise ValueError("Every other task fails") + return super().run_agents(agents, task, environment, query) + + tasks = TaskQueue.from_list([{"query": f"T{i}", "environment_data": {}} for i in range(4)]) + + benchmark = HalfFailingBenchmark() + reports = benchmark.run(tasks, agent_data={"model": "test"}) + + assert len(reports) == 4 + + +# ==================== Queue Integration Tests ==================== + + +@pytest.mark.core +class TestParallelQueueIntegration: + """Tests for queue integration with parallel execution.""" + + def test_custom_queue_respected(self, parallel_tasks): + """Custom queue ordering should be respected.""" + from maseval.core.task import PriorityTaskQueue, TaskProtocol + + # Create tasks with priorities + prioritized_tasks = [ + Task( + query=f"P{p}", + environment_data={}, + protocol=TaskProtocol(priority=p), + ) + for p in [1, 5, 3, 2, 4] + ] + + queue = PriorityTaskQueue(prioritized_tasks) + + # Track execution order + execution_order = [] + lock = threading.Lock() + + class OrderTracker(BenchmarkCallback): + def on_task_repeat_start(self, benchmark, task, repeat_idx): + with lock: + execution_order.append(task.query) + + benchmark = DummyBenchmark( + callbacks=[OrderTracker()], + ) + + # With num_workers=1, order should be strictly by priority + benchmark.run(queue, agent_data={"model": "test"}) + + assert execution_order == ["P5", "P4", "P3", "P2", "P1"] diff --git a/tests/test_core/test_benchmark/test_progress_bar_integration.py b/tests/test_core/test_benchmark/test_progress_bar_integration.py index 138371b..ab6805e 100644 --- a/tests/test_core/test_benchmark/test_progress_bar_integration.py +++ b/tests/test_core/test_benchmark/test_progress_bar_integration.py @@ -15,7 +15,7 @@ from conftest import DummyBenchmark # noqa: E402 -from maseval.core.task import Task, TaskCollection # noqa: E402 +from maseval.core.task import Task, TaskQueue # noqa: E402 from maseval.core.callbacks.progress_bar import ( # noqa: E402 TqdmProgressBarCallback, RichProgressBarCallback, @@ -25,17 +25,17 @@ @pytest.mark.core def test_benchmark_with_default_progress_bar(): """Test that benchmark attaches tqdm progress bar by default.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) # Default should have progress bar - benchmark = DummyBenchmark(agent_data={"model": "test"}) + benchmark = DummyBenchmark() # Check that a progress bar callback was added progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, TqdmProgressBarCallback)] assert len(progress_bars) == 1 # Run and verify it works - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 assert reports[0]["status"] == "success" @@ -43,43 +43,42 @@ def test_benchmark_with_default_progress_bar(): @pytest.mark.core def test_benchmark_with_disabled_progress_bar(): """Test that progress bar can be disabled.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, progress_bar=False) + benchmark = DummyBenchmark(progress_bar=False) # Should have no progress bar callbacks progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, (TqdmProgressBarCallback, RichProgressBarCallback))] assert len(progress_bars) == 0 - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 @pytest.mark.core def test_benchmark_with_rich_progress_bar(): """Test that rich progress bar can be specified.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, progress_bar="rich") + benchmark = DummyBenchmark(progress_bar="rich") # Should have a rich progress bar progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, RichProgressBarCallback)] assert len(progress_bars) == 1 - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 @pytest.mark.core def test_benchmark_with_custom_progress_bar(): """Test that custom progress bar callback prevents default from being added.""" - tasks = TaskCollection([Task(query="What is 2+2?")]) + tasks = TaskQueue([Task(query="What is 2+2?")]) # User provides their own progress bar custom_pbar = TqdmProgressBarCallback(desc="Custom Progress") benchmark = DummyBenchmark( - agent_data={"model": "test"}, callbacks=[custom_pbar], progress_bar=True, # Should be ignored ) @@ -89,23 +88,23 @@ def test_benchmark_with_custom_progress_bar(): assert len(progress_bars) == 1 assert progress_bars[0].desc == "Custom Progress" - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 @pytest.mark.core def test_benchmark_with_multiple_tasks_and_repeats(): """Test progress bar with multiple tasks and repeats.""" - tasks = TaskCollection([Task(query=f"Task {i}") for i in range(3)]) + tasks = TaskQueue([Task(query=f"Task {i}") for i in range(3)]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, n_task_repeats=2, progress_bar=True) + benchmark = DummyBenchmark(n_task_repeats=2, progress_bar=True) # Get the progress bar callback progress_bars = [cb for cb in benchmark.callbacks if isinstance(cb, TqdmProgressBarCallback)] assert len(progress_bars) == 1 pbar = progress_bars[0] - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) # Should have 3 tasks * 2 repeats = 6 reports assert len(reports) == 6 @@ -119,4 +118,4 @@ def test_benchmark_with_multiple_tasks_and_repeats(): def test_invalid_progress_bar_value(): """Test that invalid progress bar value raises error.""" with pytest.raises(ValueError, match="Invalid progress_bar value"): - DummyBenchmark(agent_data={"model": "test"}, progress_bar="invalid") + DummyBenchmark(progress_bar="invalid") diff --git a/tests/test_core/test_benchmark/test_trace_collection.py b/tests/test_core/test_benchmark/test_trace_collection.py index 2b801a7..f1cbf3f 100644 --- a/tests/test_core/test_benchmark/test_trace_collection.py +++ b/tests/test_core/test_benchmark/test_trace_collection.py @@ -5,7 +5,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -16,10 +16,10 @@ def test_traces_collected_from_all_components(self): """Test that traces are collected from all registered components.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify traces have the expected structure @@ -49,10 +49,10 @@ def test_traces_include_message_histories(self): """Test that agent traces include complete message histories.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test query", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test query", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Get agent trace @@ -89,11 +89,11 @@ def setup_agents(self, agent_data, environment, task, user): agent_adapter = FailingAgentAdapter(agent, "failing_agent") return [agent_adapter], {"failing_agent": agent_adapter} # type: ignore[return-value] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = TestBenchmark() # Should complete without raising, with error info in traces - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify trace collection handled the error @@ -130,10 +130,10 @@ def setup_agents(self, agent_data, environment, task, user): self.register("models", "test_model", model) return [agent_adapter], {"test_agent": agent_adapter} # type: ignore[return-value] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = TestBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify model traces @@ -151,10 +151,10 @@ def test_environment_traces_tool_invocations(self): """Test that Environment traces include tool information.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify environment traces include tools @@ -217,13 +217,13 @@ def gather_traces(self): } callback = CustomCallback() - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}, callbacks=[callback]) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark(callbacks=[callback]) # Register callback for tracing benchmark.register("callbacks", "custom_callback", callback) - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Verify callback traces @@ -238,10 +238,10 @@ def test_traces_json_serializable(self): import json from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {"key": "value"}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Should be able to serialize to JSON @@ -259,10 +259,10 @@ def test_traces_contain_timestamps(self): """Test that all trace components include timestamps.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Check metadata timestamp @@ -280,10 +280,10 @@ def test_traces_include_component_types(self): """Test that all traces include component type information.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) traces = reports[0]["traces"] # Check agent type diff --git a/tests/test_core/test_context.py b/tests/test_core/test_context.py new file mode 100644 index 0000000..2f39170 --- /dev/null +++ b/tests/test_core/test_context.py @@ -0,0 +1,137 @@ +"""Tests for TaskContext timeout handling. + +These tests verify that TaskContext correctly tracks time, checks deadlines, +and raises TaskTimeoutError when appropriate. +""" + +import pytest +import time + +from maseval.core.context import TaskContext +from maseval.core.exceptions import TaskTimeoutError + + +@pytest.mark.core +class TestTaskContextBasics: + """Tests for basic TaskContext functionality.""" + + def test_context_no_timeout_never_expires(self): + """Context without deadline should never expire.""" + context = TaskContext(deadline=None) + + assert context.deadline is None + assert context.remaining is None + assert context.is_expired is False + + # check_timeout should be no-op + context.check_timeout() # Should not raise + + def test_context_with_timeout_not_expired(self): + """Context before deadline should show remaining time.""" + context = TaskContext(deadline=10.0) + + assert context.deadline == 10.0 + assert context.is_expired is False + assert context.remaining is not None and context.remaining > 9.0 + assert context.elapsed < 1.0 # Just started + + def test_context_elapsed_increases(self): + """Elapsed property should increase over time.""" + context = TaskContext(deadline=10.0) + + elapsed1 = context.elapsed + time.sleep(0.05) + elapsed2 = context.elapsed + + assert elapsed2 > elapsed1 + + def test_context_remaining_decreases(self): + """Remaining property should decrease over time.""" + context = TaskContext(deadline=10.0) + + remaining1 = context.remaining + time.sleep(0.05) + remaining2 = context.remaining + + assert remaining1 is not None and remaining2 is not None + assert remaining2 < remaining1 + + +@pytest.mark.core +class TestTaskContextTimeout: + """Tests for TaskContext timeout behavior.""" + + def test_context_is_expired_after_deadline(self): + """Context after deadline should show is_expired=True.""" + context = TaskContext(deadline=0.01) # Very short deadline + + time.sleep(0.02) # Wait past deadline + + assert context.is_expired is True + assert context.remaining == 0.0 + + def test_check_timeout_raises_on_expiry(self): + """check_timeout() should raise TaskTimeoutError when expired.""" + context = TaskContext(deadline=0.01) + + time.sleep(0.02) + + with pytest.raises(TaskTimeoutError) as exc_info: + context.check_timeout() + + error: TaskTimeoutError = exc_info.value # type: ignore[assignment] + assert error.timeout == 0.01 + assert error.elapsed >= 0.01 + + def test_check_timeout_includes_partial_traces(self): + """TaskTimeoutError should include partial traces if set.""" + context = TaskContext(deadline=0.01) + partial_traces = {"agents": {"agent1": {"steps": 3}}} + context.set_collected_traces(partial_traces) + + time.sleep(0.02) + + with pytest.raises(TaskTimeoutError) as exc_info: + context.check_timeout() + + error: TaskTimeoutError = exc_info.value # type: ignore[assignment] + assert error.partial_traces == partial_traces + + def test_check_timeout_no_traces_if_not_set(self): + """TaskTimeoutError should have empty traces if not set.""" + context = TaskContext(deadline=0.01) + + time.sleep(0.02) + + with pytest.raises(TaskTimeoutError) as exc_info: + context.check_timeout() + + error: TaskTimeoutError = exc_info.value # type: ignore[assignment] + assert error.partial_traces == {} + + def test_check_timeout_does_not_raise_before_deadline(self): + """check_timeout() should not raise before deadline.""" + context = TaskContext(deadline=10.0) + + # Should not raise + for _ in range(10): + context.check_timeout() + + def test_set_collected_traces_updates_context(self): + """set_collected_traces should store traces for later use.""" + context = TaskContext(deadline=10.0) + + traces = {"test": "data"} + context.set_collected_traces(traces) + + assert context.collected_traces == traces + + def test_context_timing_accuracy(self): + """Elapsed time should be reasonably accurate.""" + context = TaskContext(deadline=10.0) + + time.sleep(0.1) + elapsed = context.elapsed + + # Allow 50ms tolerance + assert 0.05 < elapsed < 0.2 diff --git a/tests/test_core/test_evaluator.py b/tests/test_core/test_evaluator.py index 5bf2d6a..c81acc0 100644 --- a/tests/test_core/test_evaluator.py +++ b/tests/test_core/test_evaluator.py @@ -4,7 +4,7 @@ """ import pytest -from maseval import TaskCollection +from maseval import TaskQueue @pytest.mark.core @@ -35,10 +35,10 @@ class TestBenchmark(DummyBenchmark): def setup_evaluators(self, environment, task, agents, user): return [TracingEvaluator(task, environment, user)] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert len(received_traces) == 1 assert isinstance(received_traces[0], dict) @@ -57,10 +57,10 @@ def evaluate(self, evaluators, agents, final_answer, traces): assert "test_agent" in agents return super().evaluate(evaluators, agents, final_answer, traces) - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) def test_evaluator_receives_final_answer(self): """Test that evaluate() receives the final answer from agents.""" @@ -73,10 +73,10 @@ def evaluate(self, evaluators, agents, final_answer, traces): received_answers.append(final_answer) return super().evaluate(evaluators, agents, final_answer, traces) - tasks = TaskCollection.from_list([{"query": "My test query", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "My test query", "environment_data": {}}]) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert len(received_answers) == 1 assert "Response to: My test query" in received_answers[0] @@ -92,10 +92,10 @@ def evaluate(self, evaluators, agents, final_answer, traces): received_traces.append(traces) return super().evaluate(evaluators, agents, final_answer, traces) - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert len(received_traces) == 1 traces = received_traces[0] @@ -138,10 +138,10 @@ def setup_evaluators(self, environment, task, agents, user): Evaluator2(task, environment, user), ] - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = TestBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = TestBenchmark() - benchmark.run(tasks) + benchmark.run(tasks, agent_data={"model": "test"}) assert call_counts["eval1"] == 1 assert call_counts["eval2"] == 1 @@ -150,10 +150,10 @@ def test_evaluator_results_in_report(self): """Test that evaluator results appear in the final report.""" from conftest import DummyBenchmark - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DummyBenchmark(agent_data={"model": "test"}) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DummyBenchmark() - reports = benchmark.run(tasks) + reports = benchmark.run(tasks, agent_data={"model": "test"}) assert len(reports) == 1 report = reports[0] diff --git a/tests/test_core/test_exceptions.py b/tests/test_core/test_exceptions.py index 0f66eea..415ecd0 100644 --- a/tests/test_core/test_exceptions.py +++ b/tests/test_core/test_exceptions.py @@ -8,7 +8,7 @@ import pytest from maseval import ( - TaskCollection, + TaskQueue, TaskExecutionStatus, AgentError, EnvironmentError, @@ -41,9 +41,9 @@ def setup_agents(self, agent_data, environment, task, user): adapter = AgentErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = AgentErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = AgentErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.AGENT_ERROR.value @@ -68,9 +68,9 @@ def setup_agents(self, agent_data, environment, task, user): adapter = EnvironmentErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = EnvironmentErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = EnvironmentErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.ENVIRONMENT_ERROR.value @@ -95,9 +95,9 @@ def setup_agents(self, agent_data, environment, task, user): adapter = UserErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = UserErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = UserErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.USER_ERROR.value @@ -122,9 +122,9 @@ def setup_agents(self, agent_data, environment, task, user): adapter = GenericErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = GenericErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = GenericErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 assert reports[0]["status"] == TaskExecutionStatus.UNKNOWN_EXECUTION_ERROR.value @@ -152,9 +152,9 @@ def setup_agents(self, agent_data, environment, task, user): adapter = DetailedAgentErrorAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list([{"query": "Test", "environment_data": {}}]) - benchmark = DetailedAgentErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + tasks = TaskQueue.from_list([{"query": "Test", "environment_data": {}}]) + benchmark = DetailedAgentErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) assert len(reports) == 1 error = reports[0]["error"] @@ -163,6 +163,61 @@ def setup_agents(self, agent_data, environment, task, user): assert error["details"]["actual"] == "str" +@pytest.mark.core +class TestTaskTimeoutError: + """Tests for TaskTimeoutError exception.""" + + def test_timeout_error_attributes(self): + """TaskTimeoutError should have elapsed, timeout, partial_traces attributes.""" + from maseval import TaskTimeoutError + + error = TaskTimeoutError( + "Task exceeded 60s deadline", + component="execution_loop", + elapsed=62.5, + timeout=60.0, + partial_traces={"agents": {"agent1": {"steps": 3}}}, + ) + + assert error.elapsed == 62.5 + assert error.timeout == 60.0 + assert error.partial_traces == {"agents": {"agent1": {"steps": 3}}} + + def test_timeout_error_message(self): + """TaskTimeoutError message should include timing info.""" + from maseval import TaskTimeoutError + + error = TaskTimeoutError( + "Task exceeded 60s deadline after 62.5s", + component="timeout_check", + elapsed=62.5, + timeout=60.0, + ) + + assert "60s" in str(error) + assert "62.5s" in str(error) + + def test_timeout_error_inherits_from_maseval_error(self): + """TaskTimeoutError should inherit from MASEvalError.""" + from maseval import TaskTimeoutError + from maseval.core.exceptions import MASEvalError + + error = TaskTimeoutError("timeout", elapsed=1.0, timeout=0.5) + + assert isinstance(error, MASEvalError) + assert isinstance(error, Exception) + + def test_timeout_error_defaults(self): + """TaskTimeoutError should have sensible defaults.""" + from maseval import TaskTimeoutError + + error = TaskTimeoutError("timeout") + + assert error.elapsed == 0.0 + assert error.timeout == 0.0 + assert error.partial_traces == {} + + class TestAgentErrorSuggestion: """Tests for AgentError suggestion field.""" @@ -346,7 +401,7 @@ def run(self, query: str) -> str: adapter = DummyAgentAdapter(agent, "agent") return [adapter], {"agent": adapter} - tasks = TaskCollection.from_list( + tasks = TaskQueue.from_list( [ {"query": "Task 1", "environment_data": {}}, {"query": "Task 2", "environment_data": {}}, @@ -354,8 +409,8 @@ def run(self, query: str) -> str: ] ) - benchmark = MixedErrorBenchmark(agent_data={}) - reports = benchmark.run(tasks) + benchmark = MixedErrorBenchmark() + reports = benchmark.run(tasks, agent_data={}) # Should have 1 success, 1 agent error, 1 env error statuses = [r["status"] for r in reports] diff --git a/tests/test_core/test_queue.py b/tests/test_core/test_queue.py new file mode 100644 index 0000000..5c6011b --- /dev/null +++ b/tests/test_core/test_queue.py @@ -0,0 +1,375 @@ +"""Tests for TaskQueue implementations. + +These tests verify that SequentialTaskQueue, PriorityTaskQueue, and AdaptiveTaskQueue +correctly order and iterate over tasks. +""" + +import pytest +from typing import Any, Dict, List, Optional + +from maseval import Task +from maseval.core.task import ( + TaskProtocol, + SequentialTaskQueue, + PriorityTaskQueue, + AdaptiveTaskQueue, + TaskQueue, + BaseTaskQueue, +) + + +# ==================== Fixtures ==================== + + +@pytest.fixture +def tasks_with_priorities() -> List[Task]: + """Create tasks with different priorities.""" + tasks = [] + for i, priority in enumerate([0, 5, 2, 8, 1]): + task = Task( + query=f"Query {i}", + environment_data={"index": i}, + protocol=TaskProtocol(priority=priority), + ) + tasks.append(task) + return tasks + + +@pytest.fixture +def simple_tasks() -> List[Task]: + """Simple task list for basic tests.""" + return [ + Task(query="Q1", environment_data={}), + Task(query="Q2", environment_data={}), + Task(query="Q3", environment_data={}), + ] + + +# ==================== BaseTaskQueue Tests ==================== + + +@pytest.mark.core +class TestBaseTaskQueue: + """Tests for BaseTaskQueue common functionality.""" + + def test_taskqueue_is_alias_for_sequential(self): + """TaskQueue should be an alias for SequentialTaskQueue.""" + assert TaskQueue is SequentialTaskQueue + + def test_sequence_protocol(self, simple_tasks): + """Queue should implement Sequence protocol.""" + queue = SequentialTaskQueue(simple_tasks) + + # __len__ + assert len(queue) == 3 + + # __getitem__ with int + assert queue[0].query == "Q1" + assert queue[1].query == "Q2" + assert queue[-1].query == "Q3" + + # __getitem__ with slice + sliced = queue[1:] + assert isinstance(sliced, BaseTaskQueue) + assert len(sliced) == 2 + + def test_append_and_extend(self, simple_tasks): + """Queue should support append and extend.""" + queue = SequentialTaskQueue(simple_tasks[:2]) + assert len(queue) == 2 + + queue.append(simple_tasks[2]) + assert len(queue) == 3 + + queue.extend([Task(query="Q4"), Task(query="Q5")]) + assert len(queue) == 5 + + def test_to_list(self, simple_tasks): + """to_list() should return a copy of internal list.""" + queue = SequentialTaskQueue(simple_tasks) + + result = queue.to_list() + + assert result == simple_tasks + assert result is not queue._tasks # Should be a copy + + def test_from_list_with_tasks(self, simple_tasks): + """from_list should accept Task objects.""" + queue = SequentialTaskQueue.from_list(simple_tasks) + + assert len(queue) == 3 + assert queue[0].query == "Q1" + + def test_from_list_with_dicts(self): + """from_list should accept dicts and convert to Tasks.""" + data = [ + {"query": "Dict 1"}, + {"query": "Dict 2", "environment_data": {"key": "value"}}, + ] + queue = SequentialTaskQueue.from_list(data) + + assert len(queue) == 2 + assert queue[0].query == "Dict 1" + assert queue[1].environment_data == {"key": "value"} + + def test_from_list_type_error(self): + """from_list should raise TypeError for invalid items.""" + with pytest.raises(TypeError, match="expects Task or dict"): + SequentialTaskQueue.from_list(["not a task"]) # type: ignore[arg-type] # intentional + + def test_from_json_file(self, tmp_path): + """from_json_file should load tasks from JSON file.""" + import json + + data = { + "data": [ + {"query": "Task 1", "environment_data": {}}, + {"query": "Task 2", "environment_data": {}}, + ] + } + + file_path = tmp_path / "tasks.json" + with open(file_path, "w") as f: + json.dump(data, f) + + queue = SequentialTaskQueue.from_json_file(file_path) + + assert len(queue) == 2 + assert queue[0].query == "Task 1" + assert queue[1].query == "Task 2" + + def test_from_json_file_with_limit(self, tmp_path): + """from_json_file should respect limit parameter.""" + import json + + data = {"data": [{"query": f"Task {i}"} for i in range(10)]} + + file_path = tmp_path / "tasks.json" + with open(file_path, "w") as f: + json.dump(data, f) + + queue = SequentialTaskQueue.from_json_file(file_path, limit=5) + + assert len(queue) == 5 + assert queue[4].query == "Task 4" + + def test_from_list_field_mapping(self): + """from_list should map alternative field names.""" + # Test question -> query mapping and short_answer -> evaluation_data + queue = SequentialTaskQueue.from_list([{"question": "What is 2+2?", "short_answer": "4"}]) + + task = queue[0] + assert task.query == "What is 2+2?" + assert task.evaluation_data == {"short_answer": "4"} + + def test_repr(self, simple_tasks): + """Queue should have informative repr.""" + queue = SequentialTaskQueue(simple_tasks) + + repr_str = repr(queue) + # Should mention queue type and task count + assert "SequentialTaskQueue" in repr_str or "TaskQueue" in repr_str or "3" in repr_str + + +# ==================== SequentialTaskQueue Tests ==================== + + +@pytest.mark.core +class TestSequentialTaskQueue: + """Tests for SequentialTaskQueue ordering.""" + + def test_order_preserved(self, simple_tasks): + """Tasks should be yielded in original order.""" + queue = SequentialTaskQueue(simple_tasks) + + queries = [task.query for task in queue] + + assert queries == ["Q1", "Q2", "Q3"] + + def test_all_tasks_yielded(self, simple_tasks): + """All tasks should be yielded exactly once.""" + queue = SequentialTaskQueue(simple_tasks) + + count = sum(1 for _ in queue) + + assert count == 3 + + def test_empty_collection(self): + """Empty collection should yield nothing.""" + queue = SequentialTaskQueue([]) + + items = list(queue) + + assert items == [] + + def test_single_task(self): + """Single task should be handled correctly.""" + queue = SequentialTaskQueue([Task(query="Only one")]) + + items = list(queue) + + assert len(items) == 1 + assert items[0].query == "Only one" + + +# ==================== PriorityTaskQueue Tests ==================== + + +@pytest.mark.core +class TestPriorityTaskQueue: + """Tests for PriorityTaskQueue priority ordering.""" + + def test_high_priority_first(self, tasks_with_priorities): + """Higher priority tasks should come first (default).""" + queue = PriorityTaskQueue(tasks_with_priorities) + + priorities = [task.protocol.priority for task in queue] + + assert priorities == [8, 5, 2, 1, 0] + + def test_low_priority_first_with_reverse_false(self, tasks_with_priorities): + """Lower priority tasks should come first when reverse=False.""" + queue = PriorityTaskQueue(tasks_with_priorities, reverse=False) + + priorities = [task.protocol.priority for task in queue] + + assert priorities == [0, 1, 2, 5, 8] + + def test_stable_sort_for_equal_priorities(self): + """Tasks with equal priority should maintain original order.""" + tasks = [ + Task(query="First", environment_data={}, protocol=TaskProtocol(priority=5)), + Task(query="Second", environment_data={}, protocol=TaskProtocol(priority=5)), + Task(query="Third", environment_data={}, protocol=TaskProtocol(priority=5)), + ] + queue = PriorityTaskQueue(tasks) + + queries = [task.query for task in queue] + + # Python's sort is stable, so original order should be preserved + assert queries == ["First", "Second", "Third"] + + def test_default_priority_zero(self, simple_tasks): + """Tasks without explicit priority should have priority 0.""" + queue = PriorityTaskQueue(simple_tasks) + + for task in queue: + assert task.protocol.priority == 0 + + def test_negative_priority(self): + """Negative priorities should be handled correctly.""" + tasks = [ + Task(query="Low", environment_data={}, protocol=TaskProtocol(priority=-5)), + Task(query="Normal", environment_data={}, protocol=TaskProtocol(priority=0)), + Task(query="High", environment_data={}, protocol=TaskProtocol(priority=5)), + ] + queue = PriorityTaskQueue(tasks) + + queries = [task.query for task in queue] + + assert queries == ["High", "Normal", "Low"] + + +# ==================== AdaptiveTaskQueue Tests ==================== + + +class ConcreteAdaptiveQueue(AdaptiveTaskQueue): + """Concrete implementation of AdaptiveTaskQueue for testing.""" + + def __init__(self, tasks): + super().__init__(tasks) + self._selection_order: List[int] = [] # Track selection indices + + def _select_next_task(self) -> Optional[Task]: + """Select tasks in order (simple FIFO).""" + if not self._remaining: + return None + return self._remaining[0] + + def _update_state(self, task: Task, report: Dict[str, Any]) -> None: + """Track update calls.""" + pass + + +@pytest.mark.core +class TestAdaptiveTaskQueue: + """Tests for AdaptiveTaskQueue adaptive behavior.""" + + def test_basic_iteration_with_completion(self, simple_tasks): + """AdaptiveTaskQueue should yield all tasks when on_task_repeat_end is called.""" + queue = ConcreteAdaptiveQueue(simple_tasks) + + count = 0 + for task in queue: + count += 1 + # Simulate callback from benchmark + queue.on_task_repeat_end(None, {"task_id": str(task.id), "status": "success"}) # type: ignore[arg-type] + + assert count == 3 + + def test_on_task_repeat_end_moves_to_completed(self, simple_tasks): + """on_task_repeat_end should move task to completed list.""" + queue = ConcreteAdaptiveQueue(simple_tasks) + task = next(iter(queue)) + + assert len(queue._completed) == 0 + + queue.on_task_repeat_end(None, {"task_id": str(task.id), "status": "success"}) # type: ignore[arg-type] + + assert len(queue._completed) == 1 + assert queue._completed[0][0].id == task.id + + def test_stop_terminates_iteration(self, simple_tasks): + """Calling stop() should end iteration early.""" + queue = ConcreteAdaptiveQueue(simple_tasks) + + items = [] + for task in queue: + items.append(task) + queue.stop() # Stop immediately after first yield + + assert len(items) == 1 + + def test_stop_sets_flag(self, simple_tasks): + """stop() should set the internal stop flag.""" + queue = ConcreteAdaptiveQueue(simple_tasks) + + assert queue._stop_flag is False + + queue.stop() + + assert queue._stop_flag is True + + def test_iterator_stops_when_empty(self): + """Iterator should stop when no pending tasks.""" + queue = ConcreteAdaptiveQueue([]) + + tasks_yielded = list(queue) + assert len(tasks_yielded) == 0 + + def test_remaining_decreases_after_completion(self, simple_tasks): + """Remaining list should shrink as tasks complete.""" + queue = ConcreteAdaptiveQueue(simple_tasks) + + assert len(queue._remaining) == 3 + + task = next(iter(queue)) + queue.on_task_repeat_end(None, {"task_id": str(task.id), "status": "success"}) # type: ignore[arg-type] + + assert len(queue._remaining) == 2 + assert len(queue._completed) == 1 + + +# ==================== Queue Callback Tests ==================== + + +@pytest.mark.core +class TestQueueCallbacks: + """Tests for queue callback mechanisms.""" + + def test_sequential_queue_iterates_all_tasks(self, simple_tasks): + """SequentialTaskQueue should iterate through all tasks.""" + queue = SequentialTaskQueue(simple_tasks) + + tasks_yielded = list(queue) + assert len(tasks_yielded) == len(simple_tasks) diff --git a/tests/test_core/test_registry.py b/tests/test_core/test_registry.py new file mode 100644 index 0000000..5940877 --- /dev/null +++ b/tests/test_core/test_registry.py @@ -0,0 +1,258 @@ +"""Tests for ComponentRegistry thread safety and functionality. + +These tests verify that ComponentRegistry correctly isolates state between +threads and provides proper trace/config collection. +""" + +import pytest +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, Optional + +from maseval.core.registry import ComponentRegistry +from maseval.core.tracing import TraceableMixin +from maseval.core.config import ConfigurableMixin + + +# ==================== Test Components ==================== + + +class MockTraceableComponent(TraceableMixin): + """Component that implements TraceableMixin for testing.""" + + def __init__(self, name: str, trace_data: Optional[Dict[str, Any]] = None): + super().__init__() + self._name = name + self._trace_data = trace_data or {"component": name} + + def gather_traces(self) -> Dict[str, Any]: + return { + "name": self._name, + **self._trace_data, + } + + +class MockConfigurableComponent(TraceableMixin, ConfigurableMixin): + """Component that implements both TraceableMixin and ConfigurableMixin.""" + + def __init__(self, name: str, config: Optional[Dict[str, Any]] = None): + TraceableMixin.__init__(self) + ConfigurableMixin.__init__(self) + self._name = name + self._config = config or {"setting": "default"} + + def gather_traces(self) -> Dict[str, Any]: + return {"name": self._name, "traced": True} + + def gather_config(self) -> Dict[str, Any]: + return {"name": self._name, **self._config} + + +# ==================== Basic Functionality Tests ==================== + + +@pytest.mark.core +class TestComponentRegistryBasics: + """Tests for basic ComponentRegistry functionality.""" + + def test_register_traceable_component(self): + """Verify component registered for tracing.""" + registry = ComponentRegistry() + component = MockTraceableComponent("test") + + result = registry.register("agents", "my_agent", component) + + assert result is component + assert "agents:my_agent" in registry._trace_registry + assert registry._trace_registry["agents:my_agent"] is component + + def test_register_configurable_component(self): + """Verify configurable component registered in both registries.""" + registry = ComponentRegistry() + component = MockConfigurableComponent("test") + + registry.register("models", "my_model", component) + + assert "models:my_model" in registry._trace_registry + assert "models:my_model" in registry._config_registry + + def test_duplicate_key_idempotent(self): + """Same component, same key should be idempotent.""" + registry = ComponentRegistry() + component = MockTraceableComponent("test") + + registry.register("agents", "agent1", component) + registry.register("agents", "agent1", component) # Same key, no error + + assert len(registry._trace_registry) == 1 + + def test_duplicate_component_different_key_raises(self): + """Same component with different key should raise ValueError.""" + registry = ComponentRegistry() + component = MockTraceableComponent("test") + + registry.register("agents", "name1", component) + + with pytest.raises(ValueError) as exc_info: + registry.register("agents", "name2", component) + + assert "already registered" in str(exc_info.value) + assert "agents:name1" in str(exc_info.value) + + def test_clear_removes_all_registrations(self): + """Clear should remove all registrations.""" + registry = ComponentRegistry() + registry.register("agents", "a1", MockTraceableComponent("a1")) + registry.register("models", "m1", MockConfigurableComponent("m1")) + + registry.clear() + + assert len(registry._trace_registry) == 0 + assert len(registry._config_registry) == 0 + assert len(registry._component_id_map) == 0 + + def test_collect_traces_structure(self): + """Verify trace output has expected structure.""" + registry = ComponentRegistry() + agent = MockTraceableComponent("agent1", {"steps": 5}) + registry.register("agents", "agent1", agent) + + traces = registry.collect_traces() + + assert "metadata" in traces + assert "agents" in traces + assert "agent1" in traces["agents"] + assert traces["agents"]["agent1"]["steps"] == 5 + + def test_collect_configs_structure(self): + """Verify config output has expected structure.""" + registry = ComponentRegistry(benchmark_config={"name": "test_benchmark"}) + model = MockConfigurableComponent("model1", {"temperature": 0.7}) + registry.register("models", "model1", model) + + configs = registry.collect_configs() + + assert "metadata" in configs + assert "benchmark" in configs + assert configs["benchmark"]["name"] == "test_benchmark" + assert "models" in configs + assert "model1" in configs["models"] + + +# ==================== Thread Safety Tests ==================== + + +@pytest.mark.core +class TestComponentRegistryThreadSafety: + """Tests for ComponentRegistry thread isolation.""" + + def test_registry_thread_isolation(self): + """Verify registrations in one thread don't appear in another.""" + registry = ComponentRegistry() + results = {} + barrier = threading.Barrier(2) + + def worker(worker_id: int): + barrier.wait() # Synchronize start + + # Register unique component + component = MockTraceableComponent(f"comp_{worker_id}") + registry.register("agents", f"agent_{worker_id}", component) + + # Record what this thread sees + results[worker_id] = list(registry._trace_registry.keys()) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(2)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Each thread should only see its own component + assert results[0] == ["agents:agent_0"] + assert results[1] == ["agents:agent_1"] + + def test_clear_only_affects_current_thread(self): + """Clearing in one thread shouldn't affect another.""" + registry = ComponentRegistry() + thread1_sees_after_clear = [] + thread2_sees_after_clear = [] + barrier = threading.Barrier(2) + sync_point = threading.Barrier(2) + + def thread1_worker(): + registry.register("agents", "t1_agent", MockTraceableComponent("t1")) + barrier.wait() # Both threads have registered + sync_point.wait() # Wait for thread 2 to check + + registry.clear() + thread1_sees_after_clear.extend(list(registry._trace_registry.keys())) + + def thread2_worker(): + registry.register("agents", "t2_agent", MockTraceableComponent("t2")) + barrier.wait() # Both threads have registered + sync_point.wait() # Sync before thread 1 clears + + # Wait a bit for thread 1 to clear + time.sleep(0.05) + thread2_sees_after_clear.extend(list(registry._trace_registry.keys())) + + t1 = threading.Thread(target=thread1_worker) + t2 = threading.Thread(target=thread2_worker) + t1.start() + t2.start() + t1.join() + t2.join() + + # Thread 1 cleared its own registry + assert thread1_sees_after_clear == [] + # Thread 2 still has its component + assert thread2_sees_after_clear == ["agents:t2_agent"] + + def test_concurrent_registration_no_race(self): + """Multiple threads registering simultaneously without races.""" + registry = ComponentRegistry() + errors = [] + barrier = threading.Barrier(4) + + def worker(worker_id: int): + try: + barrier.wait() + for i in range(10): + component = MockTraceableComponent(f"comp_{worker_id}_{i}") + registry.register("agents", f"agent_{worker_id}_{i}", component) + except Exception as e: + errors.append((worker_id, e)) + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(worker, i) for i in range(4)] + for f in futures: + f.result() + + assert len(errors) == 0, f"Errors occurred: {errors}" + + def test_concurrent_collect_traces(self): + """Multiple threads collecting traces simultaneously.""" + registry = ComponentRegistry() + results = {} + barrier = threading.Barrier(4) + + def worker(worker_id: int): + # Each thread registers its own component + registry.register("agents", f"agent_{worker_id}", MockTraceableComponent(f"agent_{worker_id}")) + barrier.wait() + + # All threads collect simultaneously + traces = registry.collect_traces() + results[worker_id] = traces + + with ThreadPoolExecutor(max_workers=4) as executor: + futures = [executor.submit(worker, i) for i in range(4)] + for f in futures: + f.result() + + # Each thread should see only its own agent + for worker_id, traces in results.items(): + assert f"agent_{worker_id}" in traces["agents"] + assert len(traces["agents"]) == 1 diff --git a/tests/test_core/test_task_collection.py b/tests/test_core/test_task_collection.py deleted file mode 100644 index 3ffc9ce..0000000 --- a/tests/test_core/test_task_collection.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Test TaskCollection functionality. - -These tests verify that TaskCollection behaves like a sequence and correctly -loads tasks from various formats. -""" - -import pytest -from maseval import Task, TaskCollection -from pathlib import Path -import json -import tempfile - - -@pytest.mark.core -class TestTaskCollection: - """Tests for TaskCollection interface and factories.""" - - def test_task_collection_from_list(self): - """Test creating TaskCollection from a list of dicts.""" - data = [ - {"query": "Q1", "environment_data": {"e": 1}}, - {"query": "Q2", "environment_data": {"e": 2}}, - ] - - collection = TaskCollection.from_list(data) - - assert len(collection) == 2 - assert collection[0].query == "Q1" - assert collection[1].query == "Q2" - - def test_task_collection_from_json_file(self): - """Test loading TaskCollection from JSON file.""" - with tempfile.TemporaryDirectory() as tmpdir: - # Create test JSON file - data = { - "data": [ - {"query": "Task 1", "environment_data": {}}, - {"query": "Task 2", "environment_data": {}}, - ] - } - - file_path = Path(tmpdir) / "tasks.json" - with open(file_path, "w") as f: - json.dump(data, f) - - # Load from file - collection = TaskCollection.from_json_file(file_path) - - assert len(collection) == 2 - assert collection[0].query == "Task 1" - assert collection[1].query == "Task 2" - - def test_task_collection_sequence_interface(self): - """Test that TaskCollection implements Sequence interface.""" - collection = TaskCollection.from_list( - [ - {"query": "Q1"}, - {"query": "Q2"}, - {"query": "Q3"}, - ] - ) - - # Test len - assert len(collection) == 3 - - # Test iteration - queries = [task.query for task in collection] - assert queries == ["Q1", "Q2", "Q3"] - - # Test indexing - assert collection[0].query == "Q1" - assert collection[-1].query == "Q3" - - def test_task_collection_slicing(self): - """Test that TaskCollection supports slicing.""" - collection = TaskCollection.from_list([{"query": f"Q{i}"} for i in range(10)]) - - # Test slice - subset = collection[2:5] - assert isinstance(subset, TaskCollection) - assert len(subset) == 3 - assert subset[0].query == "Q2" - assert subset[2].query == "Q4" - - # Test slice from start - start = collection[:3] - assert len(start) == 3 - - # Test slice to end - end = collection[7:] - assert len(end) == 3 - - def test_task_collection_iteration(self): - """Test iterating over TaskCollection.""" - data = [{"query": f"Q{i}"} for i in range(5)] - collection = TaskCollection.from_list(data) - - queries = [] - for task in collection: - queries.append(task.query) - - assert queries == ["Q0", "Q1", "Q2", "Q3", "Q4"] - - def test_task_dict_conversion(self): - """Test that dict items are converted to Task objects.""" - collection = TaskCollection.from_list( - [ - { - "query": "Test", - "environment_data": {"key": "value"}, - "evaluation_data": {"expected": "result"}, - "metadata": {"difficulty": "easy"}, - } - ] - ) - - task = collection[0] - assert isinstance(task, Task) - assert task.query == "Test" - assert task.environment_data == {"key": "value"} - assert task.evaluation_data == {"expected": "result"} - assert task.metadata == {"difficulty": "easy"} - - def test_task_field_mapping(self): - """Test that alternative field names are mapped correctly.""" - # Test question -> query mapping - collection = TaskCollection.from_list([{"question": "What is 2+2?", "short_answer": "4"}]) - - task = collection[0] - assert task.query == "What is 2+2?" - assert task.evaluation_data == {"short_answer": "4"} - - def test_task_collection_append(self): - """Test appending tasks to collection.""" - collection = TaskCollection() - assert len(collection) == 0 - - task = Task(query="Test") - collection.append(task) - - assert len(collection) == 1 - assert collection[0].query == "Test" - - def test_task_collection_extend(self): - """Test extending collection with multiple tasks.""" - collection = TaskCollection() - - new_tasks = [ - Task(query="Q1"), - Task(query="Q2"), - Task(query="Q3"), - ] - - collection.extend(new_tasks) - - assert len(collection) == 3 - assert collection[2].query == "Q3" - - def test_task_collection_to_list(self): - """Test converting TaskCollection to list.""" - data = [{"query": f"Q{i}"} for i in range(3)] - collection = TaskCollection.from_list(data) - - task_list = collection.to_list() - - assert isinstance(task_list, list) - assert len(task_list) == 3 - assert all(isinstance(t, Task) for t in task_list) - - def test_task_collection_from_json_with_limit(self): - """Test loading with a limit on number of tasks.""" - with tempfile.TemporaryDirectory() as tmpdir: - data = {"data": [{"query": f"Task {i}"} for i in range(10)]} - - file_path = Path(tmpdir) / "tasks.json" - with open(file_path, "w") as f: - json.dump(data, f) - - # Load only first 5 - collection = TaskCollection.from_json_file(file_path, limit=5) - - assert len(collection) == 5 - assert collection[4].query == "Task 4" - - def test_task_collection_repr(self): - """Test string representation of TaskCollection.""" - collection = TaskCollection.from_list([{"query": "Q1"}, {"query": "Q2"}]) - - repr_str = repr(collection) - assert "TaskCollection" in repr_str - assert "2" in repr_str # Should mention number of tasks diff --git a/tests/test_core/test_task_protocol.py b/tests/test_core/test_task_protocol.py new file mode 100644 index 0000000..28a7e24 --- /dev/null +++ b/tests/test_core/test_task_protocol.py @@ -0,0 +1,110 @@ +"""Tests for TaskProtocol and TimeoutAction. + +These tests verify that TaskProtocol correctly configures task execution +parameters and that TimeoutAction enum values are correct. +""" + +import pytest +from maseval import Task, TaskQueue +from maseval.core.task import TaskProtocol, TimeoutAction + + +@pytest.mark.core +class TestTimeoutAction: + """Tests for TimeoutAction enum.""" + + def test_enum_values(self): + """TimeoutAction should have expected values.""" + assert TimeoutAction.SKIP.value == "skip" + assert TimeoutAction.RETRY.value == "retry" + assert TimeoutAction.EXTEND.value == "extend" + + def test_enum_members(self): + """TimeoutAction should have expected members.""" + members = list(TimeoutAction) + assert len(members) == 3 + assert TimeoutAction.SKIP in members + assert TimeoutAction.RETRY in members + assert TimeoutAction.EXTEND in members + + +@pytest.mark.core +class TestTaskProtocol: + """Tests for TaskProtocol dataclass.""" + + def test_default_values(self): + """TaskProtocol should have sensible defaults.""" + protocol = TaskProtocol() + + assert protocol.timeout_seconds is None + assert protocol.timeout_action == TimeoutAction.SKIP + assert protocol.max_retries == 0 + assert protocol.priority == 0 + assert protocol.tags == {} + + def test_custom_values(self): + """TaskProtocol should accept custom values.""" + protocol = TaskProtocol( + timeout_seconds=60.0, + timeout_action=TimeoutAction.RETRY, + max_retries=3, + priority=10, + tags={"category": "hard", "group": "A"}, + ) + + assert protocol.timeout_seconds == 60.0 + assert protocol.timeout_action == TimeoutAction.RETRY + assert protocol.max_retries == 3 + assert protocol.priority == 10 + assert protocol.tags == {"category": "hard", "group": "A"} + + def test_tags_isolation(self): + """Tags dict should be independent per instance.""" + p1 = TaskProtocol() + p2 = TaskProtocol() + + p1.tags["key"] = "value" + + assert "key" not in p2.tags + + +@pytest.mark.core +class TestTaskWithProtocol: + """Tests for Task with TaskProtocol integration.""" + + def test_task_has_protocol_field(self): + """Task dataclass should have protocol field.""" + task = Task(query="Test", environment_data={}) + + assert hasattr(task, "protocol") + assert isinstance(task.protocol, TaskProtocol) + + def test_task_default_protocol(self): + """Task should have default protocol if not specified.""" + task = Task(query="Test") + + assert task.protocol.timeout_seconds is None + assert task.protocol.priority == 0 + + def test_task_custom_protocol(self): + """Task should accept custom protocol.""" + protocol = TaskProtocol( + timeout_seconds=30.0, + priority=5, + ) + task = Task(query="Test", protocol=protocol) + + assert task.protocol.timeout_seconds == 30.0 + assert task.protocol.priority == 5 + + def test_task_queue_preserves_protocol(self): + """TaskQueue should preserve protocol on tasks.""" + task1 = Task(query="Q1", protocol=TaskProtocol(priority=1)) + task2 = Task(query="Q2", protocol=TaskProtocol(priority=2)) + tasks = TaskQueue([task1, task2]) + + first_task: Task = tasks[0] # type: ignore[assignment] + second_task: Task = tasks[1] # type: ignore[assignment] + + assert first_task.protocol.priority == 1 + assert second_task.protocol.priority == 2 diff --git a/tests/test_core/test_user.py b/tests/test_core/test_user.py index cf6f2e0..42b2a09 100644 --- a/tests/test_core/test_user.py +++ b/tests/test_core/test_user.py @@ -217,7 +217,7 @@ def test_stop_token_detection_sets_stopped(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Thanks! " # type: ignore[assignment] + user.simulator.return_value = "Thanks! " # type: ignore[union-attr] # mock user.simulate_response("Here's your answer") @@ -234,7 +234,7 @@ def test_stop_token_removed_from_response(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Perfect, thanks! " # type: ignore[assignment] + user.simulator.return_value = "Perfect, thanks! " # type: ignore[union-attr] # mock response = user.simulate_response("Booking confirmed!") @@ -252,7 +252,7 @@ def test_is_done_true_after_stop_token(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Done " # type: ignore[assignment] + user.simulator.return_value = "Done " # type: ignore[union-attr] # mock user.simulate_response("Result") @@ -269,7 +269,7 @@ def test_stop_token_case_insensitive(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Thanks! " # lowercase # type: ignore[assignment] + user.simulator.return_value = "Thanks! " # type: ignore[union-attr] # mock (lowercase) user.simulate_response("Answer") @@ -286,7 +286,7 @@ def test_fallback_message_when_only_stop_token(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "" # type: ignore[assignment] + user.simulator.return_value = "" # type: ignore[union-attr] # mock response = user.simulate_response("Done!") @@ -304,7 +304,7 @@ def test_stop_token_response_counts_as_turn(self, dummy_model): early_stopping_condition="goals are met", max_turns=5, ) - user.simulator.return_value = "Thank you, all is clear " # type: ignore[assignment] + user.simulator.return_value = "Thank you, all is clear " # type: ignore[union-attr] # mock initial_turn_count = user._turn_count user.simulate_response("Here is your result") @@ -383,19 +383,19 @@ def test_get_initial_query_generates_message(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model) - user.simulator.return_value = "I want to book a hotel" # type: ignore[assignment] + user.simulator.return_value = "I want to book a hotel" # type: ignore[union-attr] # mock query = user.get_initial_query() assert query == "I want to book a hotel" - user.simulator.assert_called_once() # type: ignore[attr-defined] + user.simulator.assert_called_once() # type: ignore[union-attr] # mock def test_get_initial_query_adds_to_messages(self, dummy_model): """Generated query is added to message history.""" from conftest import DummyUser user = DummyUser(name="test", model=dummy_model) - user.simulator.return_value = "Help me please" # type: ignore[assignment] + user.simulator.return_value = "Help me please" # type: ignore[union-attr] # mock user.get_initial_query() @@ -421,7 +421,7 @@ def test_get_initial_query_counts_as_turn(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model, max_turns=3) - user.simulator.return_value = "Initial query" # type: ignore[assignment] + user.simulator.return_value = "Initial query" # type: ignore[union-attr] # mock user.get_initial_query() @@ -455,7 +455,7 @@ def test_assistant_message_recorded(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model, max_turns=3) - user.simulator.return_value = "User reply" # type: ignore[assignment] + user.simulator.return_value = "User reply" # type: ignore[union-attr] # mock user.simulate_response("Agent says hello") @@ -469,7 +469,7 @@ def test_user_response_recorded(self, dummy_model): from conftest import DummyUser user = DummyUser(name="test", model=dummy_model, max_turns=3) - user.simulator.return_value = "Thanks for the help" # type: ignore[assignment] + user.simulator.return_value = "Thanks for the help" # type: ignore[union-attr] # mock user.simulate_response("Here's your answer") @@ -486,7 +486,7 @@ def test_full_conversation_tracked(self, dummy_model): initial_query="I need a flight", max_turns=3, ) - user.simulator.side_effect = ["Monday works", "Yes, book it"] # type: ignore[assignment] + user.simulator.side_effect = ["Monday works", "Yes, book it"] # type: ignore[union-attr] # mock # Two agent-user exchanges user.simulate_response("When do you want to travel?") @@ -512,7 +512,7 @@ def test_gather_traces_includes_all_messages(self, dummy_model): initial_query="Hello", max_turns=2, ) - user.simulator.return_value = "Got it" # type: ignore[assignment] + user.simulator.return_value = "Got it" # type: ignore[union-attr] # mock user.simulate_response("Agent response") diff --git a/tests/test_interface/test_agent_integration/test_langgraph_integration.py b/tests/test_interface/test_agent_integration/test_langgraph_integration.py index efb23b9..1e1b3fa 100644 --- a/tests/test_interface/test_agent_integration/test_langgraph_integration.py +++ b/tests/test_interface/test_agent_integration/test_langgraph_integration.py @@ -74,7 +74,7 @@ def agent_node(state: State) -> State: ) return {"messages": messages + [response]} - graph = StateGraph(State) # type: ignore[arg-type] + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) @@ -143,7 +143,7 @@ def agent_node(state: State) -> State: response = AIMessage(content="Response") return {"messages": messages + [response]} - graph = StateGraph(State) # type: ignore[arg-type] + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END) @@ -177,7 +177,7 @@ class State(TypedDict): def failing_node(state: State) -> State: raise ValueError("Intentional test error") - graph = StateGraph(State) # type: ignore[arg-type] + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", failing_node) graph.set_entry_point("agent") graph.add_edge("agent", END) @@ -222,7 +222,7 @@ def agent_node(state: State) -> State: response = AIMessage(content="Test response") return {"messages": messages + [response]} - graph = StateGraph(State) # type: ignore[arg-type] + graph = StateGraph(State) # type: ignore[arg-type] # TypedDict in function scope graph.add_node("agent", agent_node) graph.set_entry_point("agent") graph.add_edge("agent", END)