diff --git a/AGENTS.md b/AGENTS.md index e8822009..b81ced58 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -8,7 +8,7 @@ ```bash cargo build --release # produce ./target/release/acrawl -cargo test --workspace # run full test suite (~770 tests) +cargo test --workspace # run full test suite (~1,100 tests) cargo test -p # run a single test (e.g. -p agent mvp_tool_specs_contains_expected_21_tools) cargo clippy --workspace --all-targets -- -D warnings # lints must be clean (workspace lints set pedantic = warn) cargo fmt --check # format check @@ -78,6 +78,48 @@ Default model comes from the `default_model` field in the active provider's `Sto `agent::mvp_tool_specs()` returns the canonical 21-tool list with JSON schemas and required permission. When you add or rename a tool, update `mvp_tool_specs`, add a handler in `tools/mod.rs`, and adjust the count assertion in `crates/agent/src/lib.rs` tests. +## Optimization layer + +14 vendor-derived optimizations live in `crates/agent/src/` and `crates/runtime/src/`. All are gated by `settings.optimization.*` fields (all default OFF). The pattern every optimization follows: + +### Shared infrastructure (must understand before touching any optimization) + +**`DynamicPromptContext`** (`crates/agent/src/prompt.rs`) — four optional string fields (`stagnation_alert`, `planning_guidance`, `budget_warning`, `loop_nudge`). `build_system_prompt(specs, Some(&ctx))` appends the context as section 9 of the system prompt. + +**Arc slot pattern** — `CrawlerAgent` and `ConversationRuntime` share two Arc slots created in `run_with_system_prompt()`: +- `prompt_override: Arc>>>` — agent writes a new full system prompt here after any tool execution; runtime applies it before the next API call in `prepare_iteration()`. +- `last_assistant_text: Arc>>` — runtime writes the latest assistant response text here; agent reads it for confidence parsing. +- `cumulative_cost: Arc` (millicents) — runtime updates it after each usage record; agent reads it for budget enforcement. + +All three slots are internal to `ConversationRuntime` (not constructor parameters) but accessible via getters. The agent gets the cost counter via `runtime.cumulative_cost_counter()` after construction. + +### Per-optimization modules + +| Module | Location | What it adds to `CrawlState` / `CrawlerAgent` | +|--------|----------|-----------------------------------------------| +| `page_fingerprint` | `crates/agent/src/page_fingerprint.rs` | `CrawlState.page_fingerprints: Vec` | +| `tools/html_diff` | `crates/agent/src/tools/html_diff.rs` | `CrawlState.html_diff_tracker: Option` | +| `loop_detector` | `crates/agent/src/loop_detector.rs` | `CrawlState.loop_detector: Option` | +| `failure_classifier` | `crates/agent/src/failure_classifier.rs` | (pure function — no state) | +| `self_healing` | `crates/agent/src/self_healing.rs` | (pure function — no state) | +| `action_cache` | `crates/agent/src/action_cache.rs` | `CrawlState.action_cache: Option` | +| `confidence` | `crates/agent/src/confidence.rs` | `CrawlerAgent.confidence_tracker: Option` | +| `budget` | `crates/runtime/src/budget.rs` | `CrawlerAgent.cumulative_cost_slot: SharedCostCounter` | + +### Where optimizations run + +All optimization logic runs inside `CrawlerAgent::execute()` in `crates/agent/src/implementation/mod.rs`. The execution order (each guarded by its settings flag): +1. **Action cache lookup** — before the tool runs (returns cached result if hit) +2. **Tool execution** — normal handler dispatch +3. **Self-healing retry** — on SelectorNotFound/SelectorAmbiguous +4. **Loop detection** — records action + fingerprint, writes nudge to prompt_override_slot +5. **Planning interval** — injects planning/execution guidance at step N +6. **Confidence tracking** — reads last_assistant_text slot, parses `[confidence: ...]` +7. **Budget enforcement** — reads cumulative_cost_slot, warns or blocks +8. **Action cache store** — stores result after successful read-only tool call + +`CrawlState` fields are ephemeral (never persisted to session files). Adding a new field requires no serde changes. + ## Conventions specific to this repo - **Always run `cargo fmt` before committing.** CI checks formatting with `cargo fmt --check` — commits that fail this check will be rejected. diff --git a/CHANGELOG.md b/CHANGELOG.md index 1aaf6540..bad506aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- **HTML Diff Mode** (`optimization.html_diff_mode`) — on repeated visits to the same URL, only changed content sections are returned with `[unchanged: N sections]` markers, reducing token usage 50–70% on multi-turn sessions. +- **Action Loop Detection** (`optimization.loop_detection`) — rolling-window action hash detects repeated identical actions with escalating nudges (soft at 5, medium at 8, strong at 12 repeats); page stagnation detection after 5 consecutive identical page fingerprints. +- **Page Fingerprinting** (`optimization.page_fingerprinting`) — lightweight FNV-1a fingerprint (url + element_count + first-1000-char text hash) stored in CrawlState; used by loop detection and action caching for cache invalidation. +- **Planning Interval** (`optimization.planning_interval`) — every N steps injects planning-checkpoint or execution-mode guidance into the dynamic prompt; disabled by default (interval=0). +- **Failure Classification** (`optimization.failure_classification`) — 16-category keyword-based error taxonomy (zero LLM cost); `classify()` maps error messages to SelectorNotFound, CaptchaDetected, RateLimited, etc.; `retry_strategy()` returns RetryWithHealing, RetryWithDelay, NoRetry, or ResetAndRetry per category. +- **Self-Healing Selectors** (`optimization.self_healing`) — on SelectorNotFound/SelectorAmbiguous, fetches a fresh page_map and text-matches to the correct element ref; logs `[healed: @eOLD → @eNEW]`; zero LLM calls; max retries configurable (default 2). +- **Action Caching** (`optimization.action_caching`) — in-memory FNV-1a keyed cache for read-only tools (`page_map`, `read_content`, `list_resources`, `execute_js`); invalidated on page fingerprint change; TTL-based expiry (default 30s); interaction tools never cached. +- **Confidence Tracking** (`optimization.confidence_tracking`) — parses `[confidence: HIGH/MEDIUM/LOW]` from assistant responses; 2+ consecutive LOWs triggers stagnation alert via DynamicPromptContext; advisory only, never blocks. +- **Compound Component Enrichment** (`optimization.compound_enrichment`) — extends interactive element JSON with an `enrichment` field for complex form controls: date format hints, range min/max/step/value, number bounds, select option lists (max 20 + overflow count), file accept types, textarea maxlength. Max 200 bytes/element. +- **Content-Aware Cleaning Profiles** (`optimization.content_aware_profiles`) — `CleaningProfile` enum (Default/Minimal/Aggressive/ReadingMode) auto-selected by task keyword and content size; `select_profile()` picks ReadingMode for extraction tasks, Minimal for interaction tasks, Aggressive for content > 50KB. +- **Budget Enforcement** (`optimization.budget_max_session_cost_usd`, `optimization.budget_enforcement`) — `BudgetEnforcer` with Warn/Block modes; Warn injects budget warning into the dynamic prompt at configurable threshold (default 80%); Block terminates the agent loop cleanly when the cost limit is reached. +- **Per-Agent Cost Attribution** (`optimization.per_agent_cost_tracking`) — `build_cost_breakdown()` walks flat child sessions and reconstructs per-child cost via UsageTracker; `/cost` command shows per-agent breakdown when flag is ON. +- **Dynamic System Prompt Infrastructure** — `DynamicPromptContext` struct with four optional fields (stagnation_alert, planning_guidance, budget_warning, loop_nudge); injected as section 9 of the system prompt via a shared `Arc>` slot; all optimizations write to this slot, runtime picks up on the next iteration. +- **Optimization Settings Schema** — nested `OptimizationSettings` struct in `Settings` with 18 fields, all `Option` and defaulting to OFF for backward compatibility; 18 `settings_get_*` getter functions. + ## [0.9.1] - 2026-06-10 ### Changed diff --git a/Cargo.lock b/Cargo.lock index 367654c4..6e04d2b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,6 +98,7 @@ dependencies = [ "runtime", "script", "serde_json", + "sha2", "tempfile", "time", "tokio", diff --git a/README.md b/README.md index 923f76df..ef082dfc 100644 --- a/README.md +++ b/README.md @@ -722,6 +722,29 @@ Created with defaults on first run. | `browser_backend` | `null` | Active browser backend: `"extension"` or `null` (CloakBrowser) | | `extension_bridge_port` | `19876` | Port for Chrome extension bridge WebSocket server | +All fields are optional; omitting a field uses the default. The `optimization` block accepts a nested object with the following fields (all default to `false`/`0`/`null`, safe to omit entirely): + +| Field | Default | Description | +|-------|---------|-------------| +| `html_diff_mode` | `false` | On repeated visits to the same URL, returns only changed content sections with `[unchanged: N sections]` markers. 50 to 70% token reduction on multi-turn sessions. No behavior change on first visit. | +| `content_aware_profiles` | `false` | Auto-selects a cleaning profile based on the task keyword: ReadingMode for extraction tasks, Minimal for interaction tasks, Aggressive for content > 50KB. | +| `loop_detection` | `false` | Detects repeated identical actions and injects escalating nudges (soft, medium, strong). Also detects page stagnation. | +| `loop_detection_window` | `20` | Rolling window size for action hash comparison. | +| `loop_nudge_threshold` | `5` | Number of repeated actions before first nudge fires. | +| `page_fingerprinting` | `false` | Enables lightweight page fingerprints used by loop detection and action caching. | +| `failure_classification` | `false` | Classifies errors into 16 categories (SelectorNotFound, CaptchaDetected, RateLimited, etc.) using keyword matching. Zero LLM cost. | +| `self_healing` | `false` | On SelectorNotFound/SelectorAmbiguous, fetches a fresh page_map and text-matches to a replacement element ref. Logs `[healed: @eOLD -> @eNEW]`. Zero LLM calls. | +| `self_healing_max_retries` | `2` | Max healing attempts per failed action. | +| `action_caching` | `false` | Caches results of read-only tools (`page_map`, `read_content`, `list_resources`, `execute_js`) keyed by tool + input + page fingerprint. Cache is invalidated when the page changes. | +| `action_cache_ttl_secs` | `30` | Cache entry TTL in seconds. | +| `planning_interval` | `0` | Every N steps, injects a planning checkpoint into the system prompt. 0 = disabled. | +| `confidence_tracking` | `false` | Asks the LLM to self-report confidence after each action (`[confidence: HIGH/MEDIUM/LOW]`). Two consecutive LOWs trigger a stagnation alert. | +| `compound_enrichment` | `false` | Adds `enrichment` metadata to complex form controls in page_map: date format hints, range min/max/value, select option lists (max 20 + overflow count), file accept types, textarea maxlength. Max 200 bytes per element. | +| `budget_max_session_cost_usd` | `null` | Session cost limit in USD. Null = no limit. | +| `budget_enforcement` | `null` | How to enforce the budget: `warn` injects a warning into the prompt; `block` terminates the session when the limit is reached. | +| `budget_warn_threshold_pct` | `80` | Percentage of budget at which warnings start. | +| `per_agent_cost_tracking` | `false` | When ON, `/cost` shows a per-child-agent cost breakdown. | + ### Environment Variables | Variable | Description | @@ -730,6 +753,43 @@ Created with defaults on first run. Provider-specific env vars (see [provider table](#24-llm-providers) above) are read as fallbacks when no `credentials.json` entry exists. +### Performance Optimizations + +acrawl ships 14 vendor-derived optimizations (sourced from browser-use, Stagehand, crawl4ai, Skyvern, Spider, nanobrowser, and ZeroClaw). All are **disabled by default**, enable selectively via `settings.json`. + +Example `settings.json` with a cost-optimized profile: + +```json +{ + "optimization": { + "html_diff_mode": true, + "action_caching": true, + "page_fingerprinting": true, + "loop_detection": true, + "self_healing": true, + "budget_max_session_cost_usd": 0.50, + "budget_enforcement": "warn" + } +} +``` + +| Optimization | Flag | Benefit | +|--------------|------|---------| +| **HTML Diff Mode** | `html_diff_mode` | Reduces tokens by 50 to 70% on repeated visits by returning only changed content. | +| **Content-Aware Profiles** | `content_aware_profiles` | Auto-selects cleaning profiles (ReadingMode, Minimal, Aggressive) based on task. | +| **Loop Detection** | `loop_detection` | Prevents infinite loops by detecting repeated actions and injecting nudges. | +| **Page Fingerprinting** | `page_fingerprinting` | Generates lightweight page fingerprints for loop detection and action caching. | +| **Failure Classification** | `failure_classification` | Classifies errors into 16 categories using keyword matching with zero LLM cost. | +| **Self-Healing** | `self_healing` | Automatically heals broken selectors using text-matching with zero LLM calls. | +| **Action Caching** | `action_caching` | Caches read-only tool results to avoid redundant LLM calls. | +| **Planning Interval** | `planning_interval` | Injects periodic planning checkpoints to keep the agent focused. | +| **Confidence Tracking** | `confidence_tracking` | Tracks LLM self-reported confidence to alert on stagnation. | +| **Compound Enrichment** | `compound_enrichment` | Enriches complex form controls in the page map with metadata. | +| **Budget Limit** | `budget_max_session_cost_usd` | Sets a hard session cost limit in USD to prevent runaway costs. | +| **Budget Enforcement** | `budget_enforcement` | Controls whether to warn or block when the session budget is reached. | +| **Budget Warning** | `budget_warn_threshold_pct` | Triggers warnings when a percentage of the budget is consumed. | +| **Per-Agent Cost Tracking** | `per_agent_cost_tracking` | Breaks down costs per child agent in the `/cost` command. | + ## Known Limitations acrawl works well on most public web content, but some situations are outside what the agent can reliably handle: @@ -780,7 +840,7 @@ crates/ commands/ 17 slash commands with resume-safety annotations ``` -11 crates, ~38K lines of Rust, 770 tests. +11 crates, ~40K lines of Rust, 1,097 tests. ## Development diff --git a/crates/agent/Cargo.toml b/crates/agent/Cargo.toml index 06abcb69..03ff0db4 100644 --- a/crates/agent/Cargo.toml +++ b/crates/agent/Cargo.toml @@ -14,6 +14,7 @@ regex = "1" runtime = { path = "../runtime" } script = { path = "../script" } serde_json = "1" +sha2 = "0.10" time = { version = "0.3", features = ["formatting"] } tokio = { version = "1", features = ["sync", "time", "fs"] } tokio-util = { version = "0.7", default-features = false } diff --git a/crates/agent/src/action_cache.rs b/crates/agent/src/action_cache.rs new file mode 100644 index 00000000..8039cad0 --- /dev/null +++ b/crates/agent/src/action_cache.rs @@ -0,0 +1,242 @@ +use std::collections::HashMap; +use std::time::Instant; + +use serde_json::Value; +use sha2::{Digest, Sha256}; + +use crate::page_fingerprint::PageFingerprint; + +const DEFAULT_TTL_SECS: u64 = 30; + +/// Read-only tools that are safe to cache. +pub const CACHEABLE_TOOLS: &[&str] = &["page_map", "read_content", "list_resources"]; + +#[must_use] +pub fn is_cacheable(tool_name: &str) -> bool { + CACHEABLE_TOOLS.contains(&tool_name) +} + +#[derive(Debug, Clone)] +pub struct CachedAction { + pub output: String, + pub stored_at: Instant, + pub fingerprint: PageFingerprint, + pub ttl_secs: u64, +} + +impl CachedAction { + #[must_use] + pub fn is_expired(&self) -> bool { + self.stored_at.elapsed().as_secs() >= self.ttl_secs + } +} + +#[derive(Debug, Clone)] +pub struct ActionCache { + entries: HashMap, + ttl_secs: u64, +} + +impl Default for ActionCache { + fn default() -> Self { + Self::new(DEFAULT_TTL_SECS) + } +} + +impl ActionCache { + #[must_use] + pub fn new(ttl_secs: u64) -> Self { + Self { + entries: HashMap::new(), + ttl_secs, + } + } + + /// Build a cache key from `tool_name` + canonical JSON input + page fingerprint. + #[must_use] + pub fn make_key(tool_name: &str, input: &Value, fingerprint: &PageFingerprint) -> String { + let canonical_input = canonicalize_json(input); + let raw = format!( + "{tool_name}:{canonical_input}:{}:{}:{}", + fingerprint.url, fingerprint.element_count, fingerprint.text_hash + ); + let mut hasher = Sha256::new(); + hasher.update(raw.as_bytes()); + format!("{:x}", hasher.finalize()) + } + + /// Look up a cached result. Returns None if not found, expired, or fingerprint mismatch. + pub fn lookup(&mut self, key: &str, current_fingerprint: &PageFingerprint) -> Option { + self.evict_expired(); + let entry = self.entries.get(key)?; + if entry.fingerprint != *current_fingerprint { + return None; + } + Some(entry.output.clone()) + } + + /// Store a result in the cache. + pub fn store(&mut self, key: String, output: String, fingerprint: PageFingerprint) { + self.evict_expired(); + self.entries.insert( + key, + CachedAction { + output, + stored_at: Instant::now(), + fingerprint, + ttl_secs: self.ttl_secs, + }, + ); + } + + /// Remove expired entries to avoid unbounded growth. + pub fn evict_expired(&mut self) { + self.entries.retain(|_, value| !value.is_expired()); + } +} + +fn canonicalize_json(value: &Value) -> String { + match value { + Value::Null => "null".to_string(), + Value::Bool(boolean) => boolean.to_string(), + Value::Number(number) => number.to_string(), + Value::String(string) => { + serde_json::to_string(string).unwrap_or_else(|_| "\"\"".to_string()) + } + Value::Array(items) => { + let canonical_items = items + .iter() + .map(canonicalize_json) + .collect::>() + .join(","); + format!("[{canonical_items}]") + } + Value::Object(map) => { + let mut entries = map.iter().collect::>(); + entries.sort_by_key(|(left, _)| *left); + let canonical_entries = entries + .into_iter() + .map(|(key, nested)| { + let encoded_key = + serde_json::to_string(key).unwrap_or_else(|_| "\"\"".to_string()); + format!("{encoded_key}:{}", canonicalize_json(nested)) + }) + .collect::>() + .join(","); + format!("{{{canonical_entries}}}") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn fp(url: &str) -> PageFingerprint { + PageFingerprint { + url: url.to_string(), + element_count: 3, + text_hash: 42, + } + } + + #[test] + fn cache_hit_same_fingerprint() { + let mut cache = ActionCache::new(30); + let fingerprint = fp("https://example.com"); + let key = ActionCache::make_key("page_map", &json!({}), &fingerprint); + cache.store(key.clone(), "result".to_string(), fingerprint.clone()); + + let hit = cache.lookup(&key, &fingerprint); + + assert_eq!(hit.as_deref(), Some("result")); + } + + #[test] + fn cache_miss_on_fingerprint_change() { + let mut cache = ActionCache::new(30); + let fingerprint_one = fp("https://example.com"); + let fingerprint_two = PageFingerprint { + url: "https://example.com".to_string(), + element_count: 10, + text_hash: 99, + }; + let key = ActionCache::make_key("page_map", &json!({}), &fingerprint_one); + cache.store(key.clone(), "result".to_string(), fingerprint_one); + + assert!(cache.lookup(&key, &fingerprint_two).is_none()); + } + + #[test] + fn cache_miss_after_ttl_expires() { + let mut cache = ActionCache::new(0); + let fingerprint = fp("https://example.com"); + let key = ActionCache::make_key("page_map", &json!({}), &fingerprint); + cache.store(key.clone(), "result".to_string(), fingerprint.clone()); + + assert!(cache.lookup(&key, &fingerprint).is_none()); + } + + #[test] + fn interaction_tools_not_cacheable() { + assert!(!is_cacheable("click")); + assert!(!is_cacheable("fill_form")); + assert!(!is_cacheable("navigate")); + assert!(!is_cacheable("scroll")); + } + + #[test] + fn read_tools_are_cacheable() { + assert!(is_cacheable("page_map")); + assert!(is_cacheable("read_content")); + assert!(is_cacheable("list_resources")); + } + + #[test] + fn execute_js_not_cacheable() { + assert!(!is_cacheable("execute_js")); + } + + #[test] + fn action_cache_lookup_requires_fingerprint_in_state() { + use crate::state::CrawlState; + + let fingerprint = PageFingerprint { + url: "https://example.com".to_string(), + element_count: 5, + text_hash: 42, + }; + let mut state = CrawlState::default(); + + assert!(state.page_fingerprints.is_empty()); + state.page_fingerprints.push(fingerprint.clone()); + assert_eq!(state.page_fingerprints.last(), Some(&fingerprint)); + } + + #[test] + fn different_inputs_produce_different_keys() { + let fingerprint = fp("https://example.com"); + let key_one = ActionCache::make_key("page_map", &json!({}), &fingerprint); + let key_two = ActionCache::make_key("page_map", &json!({"scope": "#main"}), &fingerprint); + + assert_ne!(key_one, key_two); + } + + #[test] + fn equivalent_object_inputs_produce_same_key() { + let fingerprint = fp("https://example.com"); + let key_one = ActionCache::make_key( + "read_content", + &json!({"selector": "#main", "offset": 0, "max_chars": 1000}), + &fingerprint, + ); + let key_two = ActionCache::make_key( + "read_content", + &json!({"max_chars": 1000, "offset": 0, "selector": "#main"}), + &fingerprint, + ); + + assert_eq!(key_one, key_two); + } +} diff --git a/crates/agent/src/confidence.rs b/crates/agent/src/confidence.rs new file mode 100644 index 00000000..db944cf9 --- /dev/null +++ b/crates/agent/src/confidence.rs @@ -0,0 +1,137 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Confidence { + High, + Medium, + Low, +} + +#[derive(Debug, Clone, Default)] +pub struct ConfidenceTracker { + pub last: Option, + pub consecutive_low: u8, +} + +impl ConfidenceTracker { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Parse confidence marker from assistant response text. + /// Looks for pattern: `[confidence: HIGH]`, `[confidence: MEDIUM]`, `[confidence: LOW]` + #[must_use] + pub fn parse_from_text(text: &str) -> Option { + let lower = text.to_lowercase(); + if let Some(start) = lower.find("[confidence:") { + let rest = &lower[start..]; + if let Some(end) = rest.find(']') { + let inner = &rest[..=end]; + if inner.contains("high") { + return Some(Confidence::High); + } else if inner.contains("low") { + return Some(Confidence::Low); + } else if inner.contains("medium") { + return Some(Confidence::Medium); + } + } + } + None + } + + /// Record a new confidence value. Returns true if stagnation alert should + /// be injected (2+ consecutive LOWs). + pub fn record(&mut self, confidence: Confidence) -> bool { + if confidence == Confidence::Low { + self.consecutive_low += 1; + } else { + self.consecutive_low = 0; + } + self.last = Some(confidence); + self.consecutive_low >= 2 + } +} + +/// Build the confidence instruction to inject into `DynamicPromptContext`. +#[must_use] +pub fn confidence_instruction() -> String { + "After each action, rate your confidence in the current approach. \ + Add exactly: [confidence: HIGH], [confidence: MEDIUM], or [confidence: LOW] \ + at the end of your response." + .to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_high_confidence() { + assert_eq!( + ConfidenceTracker::parse_from_text("...done. [confidence: HIGH]"), + Some(Confidence::High) + ); + } + + #[test] + fn parse_low_confidence() { + assert_eq!( + ConfidenceTracker::parse_from_text("stuck. [confidence: LOW]"), + Some(Confidence::Low) + ); + } + + #[test] + fn parse_medium_confidence() { + assert_eq!( + ConfidenceTracker::parse_from_text("[confidence: MEDIUM]"), + Some(Confidence::Medium) + ); + } + + #[test] + fn parse_none_when_absent() { + assert_eq!(ConfidenceTracker::parse_from_text("No marker here."), None); + } + + #[test] + fn parse_case_insensitive() { + assert_eq!( + ConfidenceTracker::parse_from_text("[confidence: high]"), + Some(Confidence::High) + ); + assert_eq!( + ConfidenceTracker::parse_from_text("[Confidence: Low]"), + Some(Confidence::Low) + ); + } + + #[test] + fn stagnation_triggers_at_2_consecutive_lows() { + let mut tracker = ConfidenceTracker::new(); + assert!(!tracker.record(Confidence::Low)); // 1 LOW — no alert + assert!(tracker.record(Confidence::Low)); // 2 LOWs — alert! + } + + #[test] + fn consecutive_reset_on_non_low() { + let mut tracker = ConfidenceTracker::new(); + tracker.record(Confidence::Low); + tracker.record(Confidence::High); // resets consecutive + assert!(!tracker.record(Confidence::Low)); // back to 1 — no alert + } + + #[test] + fn three_consecutive_lows_still_alerts() { + let mut tracker = ConfidenceTracker::new(); + assert!(!tracker.record(Confidence::Low)); + assert!(tracker.record(Confidence::Low)); + assert!(tracker.record(Confidence::Low)); // 3rd consecutive — still alerts + } + + #[test] + fn confidence_instruction_non_empty() { + let instr = confidence_instruction(); + assert!(instr.contains("[confidence: HIGH]")); + assert!(instr.contains("[confidence: LOW]")); + } +} diff --git a/crates/agent/src/failure_classifier.rs b/crates/agent/src/failure_classifier.rs new file mode 100644 index 00000000..8ddf0fdc --- /dev/null +++ b/crates/agent/src/failure_classifier.rs @@ -0,0 +1,488 @@ +use std::fmt::{Display, Formatter}; + +/// 16-category failure taxonomy (Skyvern-derived, keyword matching only — zero LLM cost). +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FailureCategory { + /// Element doesn't exist in DOM + SelectorNotFound, + /// Multiple matches + SelectorAmbiguous, + /// Exists but hidden/off-screen + ElementNotVisible, + /// Exists but disabled + ElementDisabled, + /// Page didn't load + NavigationTimeout, + /// 403/429/503 + NavigationBlocked, + /// Bridge disconnected + BrowserCrash, + /// `execute_js` failed + JavaScriptError, + /// Form rejected input + FormValidation, + /// Redirect to login + AuthRequired, + /// Bot challenge + CaptchaDetected, + /// Too many requests + RateLimited, + /// Connection failed + NetworkError, + /// Expected content not found + ContentMismatch, + /// Cost limit hit + BudgetExceeded, + /// Unclassifiable + Unknown, +} + +impl Display for FailureCategory { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let name = match self { + Self::SelectorNotFound => "SelectorNotFound", + Self::SelectorAmbiguous => "SelectorAmbiguous", + Self::ElementNotVisible => "ElementNotVisible", + Self::ElementDisabled => "ElementDisabled", + Self::NavigationTimeout => "NavigationTimeout", + Self::NavigationBlocked => "NavigationBlocked", + Self::BrowserCrash => "BrowserCrash", + Self::JavaScriptError => "JavaScriptError", + Self::FormValidation => "FormValidation", + Self::AuthRequired => "AuthRequired", + Self::CaptchaDetected => "CaptchaDetected", + Self::RateLimited => "RateLimited", + Self::NetworkError => "NetworkError", + Self::ContentMismatch => "ContentMismatch", + Self::BudgetExceeded => "BudgetExceeded", + Self::Unknown => "Unknown", + }; + write!(f, "{name}") + } +} + +/// Retry strategy associated with a failure category. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum RetryStrategy { + /// Try self-healing selector + RetryWithHealing, + /// Delay in seconds before retry + RetryWithDelay(u64), + /// Reset browser state then retry + ResetAndRetry, + /// Do not retry + NoRetry, +} + +/// Classify an error message into a failure category using keyword matching. +/// This is zero-cost — no LLM calls. +#[must_use] +pub fn classify(tool_name: &str, error_message: &str) -> FailureCategory { + let msg = error_message.to_lowercase(); + let tool = tool_name.to_lowercase(); + + // Budget exceeded (check first — explicit internal error) + if msg.contains("budget exceeded") || msg.contains("cost limit") { + return FailureCategory::BudgetExceeded; + } + + // Browser crash + if msg.contains("bridge") + && (msg.contains("disconnect") || msg.contains("crash") || msg.contains("closed")) + { + return FailureCategory::BrowserCrash; + } + if msg.contains("browser closed") || msg.contains("target closed") { + return FailureCategory::BrowserCrash; + } + + // CAPTCHA / bot challenge + if msg.contains("captcha") + || msg.contains("turnstile") + || msg.contains("hcaptcha") + || msg.contains("recaptcha") + || (msg.contains("challenge") && msg.contains("bot")) + { + return FailureCategory::CaptchaDetected; + } + + // Auth required + if msg.contains("login") + || msg.contains("sign in") + || msg.contains("unauthorized") + || msg.contains("authentication required") + || msg.contains("not authenticated") + { + return FailureCategory::AuthRequired; + } + + // Rate limited + if msg.contains("429") || msg.contains("rate limit") || msg.contains("too many requests") { + return FailureCategory::RateLimited; + } + + // Navigation blocked (403/503) + if msg.contains("403") + || msg.contains("503") + || msg.contains("forbidden") + || msg.contains("access denied") + || msg.contains("blocked") + { + return FailureCategory::NavigationBlocked; + } + + // Navigation timeout + if msg.contains("timeout") || msg.contains("timed out") || msg.contains("navigation timeout") { + return FailureCategory::NavigationTimeout; + } + + // Selector ambiguous + if msg.contains("multiple") || msg.contains("ambiguous") || msg.contains("more than one") { + return FailureCategory::SelectorAmbiguous; + } + + // Element not visible + if msg.contains("hidden") + || msg.contains("not visible") + || msg.contains("off-screen") + || msg.contains("not in viewport") + { + return FailureCategory::ElementNotVisible; + } + + // Element disabled + if msg.contains("disabled") { + return FailureCategory::ElementDisabled; + } + + // Content mismatch (check before selector-not-found — both contain "not found") + if (msg.contains("not found") || msg.contains("expected content missing")) + && (tool == "read_content" || msg.contains("content")) + { + return FailureCategory::ContentMismatch; + } + + // Selector not found (check after ambiguous/visible/disabled/content-mismatch) + if msg.contains("not found") + || msg.contains("no element") + || msg.contains("element not found") + || msg.contains("could not find") + || msg.contains("does not exist") + { + return FailureCategory::SelectorNotFound; + } + + // JavaScript error + if tool == "execute_js" + || msg.contains("javascript") + || msg.contains("script error") + || msg.contains("uncaught") + || msg.contains("syntaxerror") + || msg.contains("referenceerror") + { + return FailureCategory::JavaScriptError; + } + + // Form validation + if msg.contains("validation") + || msg.contains("invalid input") + || msg.contains("required field") + || msg.contains("form error") + { + return FailureCategory::FormValidation; + } + + // Network error + if msg.contains("network") + || msg.contains("connection") + || msg.contains("dns") + || msg.contains("unreachable") + || msg.contains("eof") + || msg.contains("connection refused") + { + return FailureCategory::NetworkError; + } + + FailureCategory::Unknown +} + +/// Get the appropriate retry strategy for a failure category. +#[must_use] +pub fn retry_strategy(category: &FailureCategory) -> RetryStrategy { + match category { + FailureCategory::SelectorNotFound | FailureCategory::SelectorAmbiguous => { + RetryStrategy::RetryWithHealing + } + FailureCategory::ElementNotVisible | FailureCategory::ElementDisabled => { + RetryStrategy::RetryWithDelay(1) + } + FailureCategory::NavigationTimeout => RetryStrategy::RetryWithDelay(2), + FailureCategory::RateLimited => RetryStrategy::RetryWithDelay(5), + FailureCategory::BrowserCrash => RetryStrategy::ResetAndRetry, + FailureCategory::CaptchaDetected + | FailureCategory::BudgetExceeded + | FailureCategory::AuthRequired + | FailureCategory::NavigationBlocked + | FailureCategory::JavaScriptError + | FailureCategory::FormValidation + | FailureCategory::NetworkError + | FailureCategory::ContentMismatch + | FailureCategory::Unknown => RetryStrategy::NoRetry, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn classify_selector_not_found() { + assert_eq!( + classify("click", "Element not found matching selector @e5"), + FailureCategory::SelectorNotFound + ); + assert_eq!( + classify("click", "could not find element"), + FailureCategory::SelectorNotFound + ); + } + + #[test] + fn classify_selector_ambiguous() { + assert_eq!( + classify("click", "Found multiple matches for selector"), + FailureCategory::SelectorAmbiguous + ); + assert_eq!( + classify("click", "Ambiguous selector — more than one match"), + FailureCategory::SelectorAmbiguous + ); + } + + #[test] + fn classify_element_not_visible() { + assert_eq!( + classify("click", "Element is hidden behind overlay"), + FailureCategory::ElementNotVisible + ); + assert_eq!( + classify("click", "Element not visible in current viewport"), + FailureCategory::ElementNotVisible + ); + } + + #[test] + fn classify_element_disabled() { + assert_eq!( + classify("fill_form", "Element is disabled"), + FailureCategory::ElementDisabled + ); + } + + #[test] + fn classify_navigation_timeout() { + assert_eq!( + classify("navigate", "Navigation timed out after 30 seconds"), + FailureCategory::NavigationTimeout + ); + assert_eq!( + classify("navigate", "Page load timeout exceeded"), + FailureCategory::NavigationTimeout + ); + } + + #[test] + fn classify_navigation_blocked() { + assert_eq!( + classify("navigate", "HTTP 403 Forbidden"), + FailureCategory::NavigationBlocked + ); + assert_eq!( + classify("navigate", "Access denied by server"), + FailureCategory::NavigationBlocked + ); + } + + #[test] + fn classify_browser_crash() { + assert_eq!( + classify("click", "Bridge disconnected unexpectedly"), + FailureCategory::BrowserCrash + ); + assert_eq!( + classify("navigate", "Target closed"), + FailureCategory::BrowserCrash + ); + } + + #[test] + fn classify_javascript_error() { + assert_eq!( + classify("execute_js", "some random error"), + FailureCategory::JavaScriptError + ); + assert_eq!( + classify("click", "Uncaught TypeError: cannot read property"), + FailureCategory::JavaScriptError + ); + } + + #[test] + fn classify_captcha() { + assert_eq!( + classify("navigate", "CAPTCHA detected on page"), + FailureCategory::CaptchaDetected + ); + assert_eq!( + classify("navigate", "Turnstile challenge triggered"), + FailureCategory::CaptchaDetected + ); + } + + #[test] + fn classify_auth_required() { + assert_eq!( + classify("navigate", "Redirected to login page"), + FailureCategory::AuthRequired + ); + assert_eq!( + classify("navigate", "401 Unauthorized"), + FailureCategory::AuthRequired + ); + } + + #[test] + fn classify_rate_limited() { + assert_eq!( + classify("navigate", "429 Too Many Requests"), + FailureCategory::RateLimited + ); + assert_eq!( + classify("navigate", "Rate limit exceeded, try again later"), + FailureCategory::RateLimited + ); + } + + #[test] + fn classify_network_error() { + assert_eq!( + classify("navigate", "Network error: connection refused"), + FailureCategory::NetworkError + ); + assert_eq!( + classify("navigate", "DNS resolution failed"), + FailureCategory::NetworkError + ); + } + + #[test] + fn classify_form_validation() { + assert_eq!( + classify( + "fill_form", + "Form validation failed: required field missing" + ), + FailureCategory::FormValidation + ); + } + + #[test] + fn classify_budget_exceeded() { + assert_eq!( + classify("navigate", "Budget exceeded: $5.00 limit reached"), + FailureCategory::BudgetExceeded + ); + assert_eq!( + classify("navigate", "Cost limit hit"), + FailureCategory::BudgetExceeded + ); + } + + #[test] + fn classify_content_mismatch() { + assert_eq!( + classify("read_content", "Content not found under heading"), + FailureCategory::ContentMismatch + ); + } + + #[test] + fn classify_unknown_default() { + assert_eq!( + classify("navigate", "something completely unexpected happened xyz"), + FailureCategory::Unknown + ); + } + + #[test] + fn retry_strategy_selector_not_found_heals() { + assert_eq!( + retry_strategy(&FailureCategory::SelectorNotFound), + RetryStrategy::RetryWithHealing + ); + assert_eq!( + retry_strategy(&FailureCategory::SelectorAmbiguous), + RetryStrategy::RetryWithHealing + ); + } + + #[test] + fn retry_strategy_visibility_delays() { + assert_eq!( + retry_strategy(&FailureCategory::ElementNotVisible), + RetryStrategy::RetryWithDelay(1) + ); + assert_eq!( + retry_strategy(&FailureCategory::ElementDisabled), + RetryStrategy::RetryWithDelay(1) + ); + } + + #[test] + fn retry_strategy_timeout_delay() { + assert_eq!( + retry_strategy(&FailureCategory::NavigationTimeout), + RetryStrategy::RetryWithDelay(2) + ); + } + + #[test] + fn retry_strategy_rate_limit_longer_delay() { + assert_eq!( + retry_strategy(&FailureCategory::RateLimited), + RetryStrategy::RetryWithDelay(5) + ); + } + + #[test] + fn retry_strategy_browser_crash_resets() { + assert_eq!( + retry_strategy(&FailureCategory::BrowserCrash), + RetryStrategy::ResetAndRetry + ); + } + + #[test] + fn retry_strategy_captcha_no_retry() { + assert_eq!( + retry_strategy(&FailureCategory::CaptchaDetected), + RetryStrategy::NoRetry + ); + } + + #[test] + fn retry_strategy_budget_no_retry() { + assert_eq!( + retry_strategy(&FailureCategory::BudgetExceeded), + RetryStrategy::NoRetry + ); + } + + #[test] + fn retry_strategy_unknown_no_retry() { + assert_eq!( + retry_strategy(&FailureCategory::Unknown), + RetryStrategy::NoRetry + ); + } +} diff --git a/crates/agent/src/implementation/fork.rs b/crates/agent/src/implementation/fork.rs index 29f458e2..8f44a2e5 100644 --- a/crates/agent/src/implementation/fork.rs +++ b/crates/agent/src/implementation/fork.rs @@ -202,6 +202,7 @@ impl CrawlerAgent { child_agent.shared_bridge = Some(shared_bridge); child_agent.crawl_state = child_state; child_agent.api_client_arc = Some(child_api_client.clone()); + child_agent.cumulative_cost_slot = self.cumulative_cost_slot.clone(); let child_objective = task.objective.clone(); let join_handle = tokio::spawn(async move { @@ -305,6 +306,7 @@ impl CrawlerAgent { .captured_child_sessions .push(runtime::ChildSession { id: child_id.clone(), + model: crawl_result.model, goal: sub_goal, messages: crawl_result.messages, }); @@ -340,6 +342,7 @@ impl CrawlerAgent { .captured_child_sessions .push(runtime::ChildSession { id: child_id.clone(), + model: None, goal: sub_goal, messages: Vec::new(), }); @@ -368,6 +371,7 @@ impl CrawlerAgent { .captured_child_sessions .push(runtime::ChildSession { id: child_id.clone(), + model: None, goal: sub_goal, messages: Vec::new(), }); @@ -449,6 +453,7 @@ impl CrawlerAgent { .captured_child_sessions .push(runtime::ChildSession { id: child_id.clone(), + model: None, goal: sub_goal.clone(), messages: Vec::new(), }); @@ -748,7 +753,7 @@ mod tests { #[tokio::test] async fn test_two_forks_on_same_url_second_is_rejected() { - let _env_guard = env_lock().lock().await; + let _env_guard = crate::test_async_env_lock().lock().await; std::env::set_var("HEADLESS", "true"); let manager = super::super::default_agent_manager(); manager.lock().await.register_root("root"); @@ -786,12 +791,6 @@ mod tests { } } - fn env_lock() -> &'static tokio::sync::Mutex<()> { - use std::sync::OnceLock; - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| tokio::sync::Mutex::new(())) - } - #[tokio::test] async fn test_wait_no_children_returns_empty_snapshot() { let mut agent = CrawlerAgent::new_for_testing(mock_registry()); @@ -1174,6 +1173,7 @@ mod tests { extracted_data: vec![serde_json::json!({"key": "value"})], steps_executed: 1, messages: child_messages.clone(), + model: Some("anthropic/claude-sonnet-4-6".to_string()), }; let handle: tokio::task::JoinHandle> = tokio::spawn(async move { Some(child_result) }); diff --git a/crates/agent/src/implementation/lifecycle.rs b/crates/agent/src/implementation/lifecycle.rs index cf9c44f6..e8ac3459 100644 --- a/crates/agent/src/implementation/lifecycle.rs +++ b/crates/agent/src/implementation/lifecycle.rs @@ -162,18 +162,13 @@ impl CrawlerAgent { #[cfg(test)] mod tests { - use std::sync::{Arc, OnceLock}; + use std::sync::Arc; use tokio::sync::Mutex; use super::*; use crate::registry::ToolRegistry; - fn env_lock() -> &'static tokio::sync::Mutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| tokio::sync::Mutex::new(())) - } - async fn test_bridge() -> SharedBridge { Arc::new(Mutex::new(Box::new( crate::PlaywrightBridge::new() @@ -217,7 +212,7 @@ mod tests { #[tokio::test] async fn test_lifecycle_state_transitions() { - let _env_guard = env_lock().lock().await; + let _env_guard = crate::test_async_env_lock().lock().await; std::env::set_var("HEADLESS", "true"); let mut agent = CrawlerAgent::new_for_testing(ToolRegistry::new()); diff --git a/crates/agent/src/implementation/mod.rs b/crates/agent/src/implementation/mod.rs index 77acde4a..caef497b 100644 --- a/crates/agent/src/implementation/mod.rs +++ b/crates/agent/src/implementation/mod.rs @@ -1,15 +1,16 @@ use std::collections::{BTreeSet, HashMap}; use std::fmt::{Display, Formatter}; -use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; use acrawl_core::{ApiClient, ContentBlock, ToolError, ToolExecutor, ToolOutcome}; use runtime::{ControlState, ConversationMessage, ConversationRuntime, Session, TurnSummary}; use serde_json::Value; -use tokio::sync::Mutex; +use tokio::sync::Mutex as AsyncMutex; use crate::manager::SharedAgentManager; -use crate::prompt::build_system_prompt; +use crate::prompt::{build_system_prompt, DynamicPromptContext}; use crate::registry::ToolRegistry; use crate::script_manager::{ScriptError, ScriptManager}; use crate::state::CrawlState; @@ -21,6 +22,8 @@ mod lifecycle; #[cfg(test)] use crate::BrowserBackend; +#[cfg(test)] +use crate::{BridgeError, BrowserState, PageInfo, ScreenshotOptions}; const DEFAULT_MAX_STEPS: usize = 50; const DEFAULT_MAX_CONCURRENT_PER_PARENT: usize = 5; @@ -40,6 +43,7 @@ pub struct CrawlResult { pub extracted_data: Vec, pub steps_executed: usize, pub messages: Vec, + pub model: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -116,29 +120,32 @@ pub struct CrawlerAgent { extension_mode: bool, is_child: bool, pub(super) child_snapshots: crate::child_events::ChildSnapshotRegistry, + prompt_override_slot: Arc>>>, + last_assistant_text_slot: Arc>>, + accumulated_turn_ctx: Mutex, + cumulative_cost_slot: runtime::SharedCostCounter, + step_count: usize, + confidence_tracker: Option, #[cfg(test)] pub(super) fork_page_index_override: Option, } impl CrawlerAgent { - #[must_use] - pub fn take_captured_child_sessions(&mut self) -> Vec { - std::mem::take(&mut self.crawl_state.captured_child_sessions) - } - - pub fn push_captured_child_session_for_test(&mut self, session: runtime::ChildSession) { - self.crawl_state.captured_child_sessions.push(session); - } - - #[must_use] - pub fn new(browser: BrowserContext, registry: ToolRegistry) -> Self { + fn new_with_slots( + browser: Option, + registry: ToolRegistry, + agent_id: String, + prompt_override_slot: Arc>>>, + last_assistant_text_slot: Arc>>, + ) -> Self { + let shared_bridge = browser.as_ref().map(|context| context.bridge().clone()); Self { - shared_bridge: Some(browser.bridge().clone()), - browser: Some(browser), + shared_bridge, + browser, registry, allowed_tools: None, max_steps: DEFAULT_MAX_STEPS, - agent_id: generate_agent_id(), + agent_id, agent_manager: default_agent_manager(), crawl_state: CrawlState { max_steps: DEFAULT_MAX_STEPS, @@ -154,67 +161,57 @@ impl CrawlerAgent { extension_mode: false, is_child: false, child_snapshots: crate::child_events::ChildSnapshotRegistry::default(), + prompt_override_slot, + last_assistant_text_slot, + accumulated_turn_ctx: Mutex::new(DynamicPromptContext::default()), + cumulative_cost_slot: runtime::new_cost_counter(), + step_count: 0, + confidence_tracker: None, #[cfg(test)] fork_page_index_override: None, } } + #[must_use] + pub fn take_captured_child_sessions(&mut self) -> Vec { + std::mem::take(&mut self.crawl_state.captured_child_sessions) + } + + pub fn push_captured_child_session_for_test(&mut self, session: runtime::ChildSession) { + self.crawl_state.captured_child_sessions.push(session); + } + + #[must_use] + pub fn new(browser: BrowserContext, registry: ToolRegistry) -> Self { + Self::new_with_slots( + Some(browser), + registry, + generate_agent_id(), + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), + ) + } + #[must_use] pub fn new_lazy(registry: ToolRegistry) -> Self { - Self { - browser: None, + Self::new_with_slots( + None, registry, - allowed_tools: None, - max_steps: DEFAULT_MAX_STEPS, - agent_id: generate_agent_id(), - agent_manager: default_agent_manager(), - shared_bridge: None, - crawl_state: CrawlState { - max_steps: DEFAULT_MAX_STEPS, - ..CrawlState::default() - }, - child_tasks: HashMap::new(), - child_id_counter: std::sync::atomic::AtomicU64::new(0), - api_client_arc: None, - script_manager: default_script_manager(), - control_state: None, - child_event_tx: None, - child_control_registry: None, - extension_mode: false, - is_child: false, - child_snapshots: crate::child_events::ChildSnapshotRegistry::default(), - #[cfg(test)] - fork_page_index_override: None, - } + generate_agent_id(), + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), + ) } #[cfg(test)] fn new_for_testing(registry: ToolRegistry) -> Self { - Self { - browser: None, + Self::new_with_slots( + None, registry, - allowed_tools: None, - max_steps: DEFAULT_MAX_STEPS, - agent_id: "test-agent".to_string(), - agent_manager: default_agent_manager(), - shared_bridge: None, - crawl_state: CrawlState { - max_steps: DEFAULT_MAX_STEPS, - ..CrawlState::default() - }, - child_tasks: HashMap::new(), - child_id_counter: std::sync::atomic::AtomicU64::new(0), - api_client_arc: None, - script_manager: default_script_manager(), - control_state: None, - child_event_tx: None, - child_control_registry: None, - extension_mode: false, - is_child: false, - child_snapshots: crate::child_events::ChildSnapshotRegistry::default(), - #[cfg(test)] - fork_page_index_override: None, - } + "test-agent".to_string(), + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), + ) } #[must_use] @@ -308,7 +305,7 @@ impl CrawlerAgent { .collect(), None => mvp_tool_specs(), }; - let system_prompt = build_system_prompt(&specs); + let system_prompt = build_system_prompt(&specs, None); self.run_with_system_prompt(goal, api_client, system_prompt) .await } @@ -335,9 +332,20 @@ impl CrawlerAgent { } let max_steps = self.max_steps; - let mut runtime = - ConversationRuntime::new(Session::new(), shared_client.clone(), self, system_prompt) - .with_max_iterations(max_steps); + let prompt_override_slot = Arc::new(Mutex::new(None)); + let last_assistant_text_slot = Arc::new(Mutex::new(None)); + self.prompt_override_slot = Arc::clone(&prompt_override_slot); + self.last_assistant_text_slot = Arc::clone(&last_assistant_text_slot); + let mut runtime = ConversationRuntime::new( + Session::new(), + shared_client.clone(), + self, + system_prompt, + prompt_override_slot, + last_assistant_text_slot, + ) + .with_max_iterations(max_steps); + runtime.tool_executor_mut().cumulative_cost_slot = runtime.cumulative_cost_counter(); // Attach child event observer for streaming child output to TUI and // mirroring lifecycle/heartbeat data into the shared snapshot registry. @@ -363,8 +371,10 @@ impl CrawlerAgent { let summary = result.map_err(|error| CrawlError::new(error.to_string()))?; let crawl_state = runtime.tool_executor_mut().crawl_state.clone(); let messages = runtime.session().messages.clone(); + let model = runtime.session().model.clone(); let mut crawl_result = build_crawl_result(&summary, &crawl_state); crawl_result.messages = messages; + crawl_result.model = model; Ok(crawl_result) } @@ -393,6 +403,14 @@ impl ToolExecutor for CrawlerAgent { input: &str, ) -> impl std::future::Future> + Send { async move { + let settings = runtime::load_settings(); + self.step_count += 1; + *self + .accumulated_turn_ctx + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = + DynamicPromptContext::default(); + if self .allowed_tools .as_ref() @@ -410,59 +428,416 @@ impl ToolExecutor for CrawlerAgent { .map_err(|error| ToolError::new(format!("invalid JSON input: {error}")))? }; - let tool_effect = if let Some(handler) = self.registry.get(tool_name) { - match handler(&input_value) { - Ok(effect) => effect, - Err(error) if error.is_requires_async() => { - if !Self::supports_async(tool_name) { - // Forward the canonical phrasing to the runtime - // executor (which uses its own `ToolError` type). - return Err(ToolError::new(error.to_string())); - } + if let Some(cached_output) = + self.lookup_cached_action(&settings, tool_name, &input_value) + { + return Ok(ToolOutcome::reply(cached_output)); + } - self.ensure_browser().await?; - let browser = self - .browser - .as_mut() - .ok_or_else(|| ToolError::new("browser context is not initialized"))?; - self.registry - .execute_async(tool_name, &input_value, browser) - .await - .map_err(|error| ToolError::new(error.to_string()))? + let use_healing = runtime::settings_get_self_healing(&settings); + let max_retries = runtime::settings_get_self_healing_max_retries(&settings); + let mut current_input = input_value.clone(); + let mut heal_log = String::new(); + + let (mut text, observed_effect) = { + let mut attempts = 0usize; + loop { + match self.execute_tool_once(tool_name, ¤t_input).await { + Ok(result) => break result, + Err(error) + if use_healing + && attempts < max_retries + && matches!( + crate::failure_classifier::classify( + tool_name, + &error.to_string() + ), + crate::failure_classifier::FailureCategory::SelectorNotFound + | crate::failure_classifier::FailureCategory::SelectorAmbiguous + ) => + { + let Some((patched_input, current_heal_log)) = + self.try_self_heal(tool_name, ¤t_input).await? + else { + return Err(error); + }; + current_input = patched_input; + heal_log = current_heal_log; + attempts += 1; + } + Err(error) => { + let enriched = if runtime::settings_get_failure_classification(&settings) + { + let category = + crate::failure_classifier::classify(tool_name, &error.to_string()); + ToolError::new(format!("[{category}] {error}")) + } else { + error + }; + return Err(enriched); + } } - Err(error) => return Err(ToolError::new(error.to_string())), } - } else if Self::supports_async(tool_name) { - self.ensure_browser().await?; - let browser = self - .browser - .as_mut() - .ok_or_else(|| ToolError::new("browser context is not initialized"))?; - self.registry - .execute_async(tool_name, &input_value, browser) - .await - .map_err(|error| ToolError::new(error.to_string()))? - } else { - return Err(ToolError::new(format!("unknown tool: `{tool_name}`"))); }; - let observed_effect = match &tool_effect { - ToolEffect::Reply(_) => None, - effect => Some(effect.clone()), - }; + if !heal_log.is_empty() { + text = format!("{text} {heal_log}"); + } - self.dispatch_tool_effect(tool_effect).await.map(|text| { - if let Some(effect) = observed_effect { - ToolOutcome::with_effect(text, effect) - } else { - ToolOutcome::reply(text) - } + if observed_effect.is_none() { + self.store_cached_action(&settings, tool_name, &input_value, &text); + } + + if runtime::settings_get_loop_detection(&settings) { + self.apply_loop_detection(&settings, tool_name, &input_value); + } + + let interval = runtime::settings_get_planning_interval(&settings); + if interval > 0 { + self.apply_planning_guidance(interval); + } + + if runtime::settings_get_confidence_tracking(&settings) { + self.apply_confidence_tracking(); + } + + self.enforce_budget(&settings)?; + + Ok(if let Some(effect) = observed_effect { + ToolOutcome::with_effect(text, effect) + } else { + ToolOutcome::reply(text) }) } } } impl CrawlerAgent { + fn enforce_budget(&self, settings: &runtime::Settings) -> Result<(), ToolError> { + let Some(max_usd) = runtime::settings_get_budget_max_session_cost_usd(settings) else { + return Ok(()); + }; + + let mode = runtime::settings_get_budget_enforcement(settings) + .as_deref() + .and_then(runtime::BudgetMode::parse) + .unwrap_or(runtime::BudgetMode::Block); + let enforcer = runtime::BudgetEnforcer::new( + max_usd, + mode, + runtime::settings_get_budget_warn_threshold_pct(settings), + ); + let current_usd = + runtime::millicents_to_usd(self.cumulative_cost_slot.load(Ordering::Relaxed)); + + match enforcer.check(current_usd) { + runtime::BudgetDecision::Allow => Ok(()), + runtime::BudgetDecision::Warn { remaining_usd } => { + self.write_prompt_override(&DynamicPromptContext { + budget_warning: Some(format!( + "Budget warning: ${remaining_usd:.4} remaining (limit: ${max_usd:.4})" + )), + ..DynamicPromptContext::default() + }); + Ok(()) + } + runtime::BudgetDecision::Block => Err(ToolError::new( + "Budget exceeded: session cost limit reached", + )), + } + } + + async fn execute_tool_once( + &mut self, + tool_name: &str, + input_value: &Value, + ) -> Result<(String, Option), ToolError> { + let tool_effect = if let Some(handler) = self.registry.get(tool_name) { + match handler(input_value) { + Ok(effect) => effect, + Err(error) if error.is_requires_async() => { + if !Self::supports_async(tool_name) { + return Err(ToolError::new(error.to_string())); + } + + self.ensure_browser().await?; + let browser = self + .browser + .as_mut() + .ok_or_else(|| ToolError::new("browser context is not initialized"))?; + self.registry + .execute_async(tool_name, input_value, browser, &mut self.crawl_state) + .await + .map_err(|error| ToolError::new(error.to_string()))? + } + Err(error) => return Err(ToolError::new(error.to_string())), + } + } else if Self::supports_async(tool_name) { + self.ensure_browser().await?; + let browser = self + .browser + .as_mut() + .ok_or_else(|| ToolError::new("browser context is not initialized"))?; + self.registry + .execute_async(tool_name, input_value, browser, &mut self.crawl_state) + .await + .map_err(|error| ToolError::new(error.to_string()))? + } else { + return Err(ToolError::new(format!("unknown tool: `{tool_name}`"))); + }; + + let observed_effect = match &tool_effect { + ToolEffect::Reply(_) => None, + effect => Some(effect.clone()), + }; + let text = self.dispatch_tool_effect(tool_effect).await?; + Ok((text, observed_effect)) + } + + fn should_attempt_self_healing(tool_name: &str) -> bool { + matches!( + tool_name, + "click" | "fill_form" | "hover" | "press_key" | "select_option" + ) + } + + async fn try_self_heal( + &mut self, + tool_name: &str, + input_value: &Value, + ) -> Result, ToolError> { + if !Self::should_attempt_self_healing(tool_name) { + return Ok(None); + } + + let Some(old_ref) = crate::self_healing::extract_element_ref(input_value) else { + return Ok(None); + }; + + self.ensure_browser().await?; + let browser = self + .browser + .as_mut() + .ok_or_else(|| ToolError::new("browser context is not initialized"))?; + + let original_hint = browser + .ref_map() + .get(old_ref.trim_start_matches('@')) + .map(|entry| entry.name.clone()) + .filter(|name| !name.trim().is_empty()); + + let mut fresh_page_map = browser + .acquire_bridge() + .await + .map_err(|error| ToolError::new(error.to_string()))? + .page_map(None, false) + .await + .map_err(|error| ToolError::new(error.to_string()))?; + + let cache_key = crate::tools::page_map::normalize_url( + fresh_page_map + .get("meta") + .and_then(|meta| meta.get("url")) + .and_then(Value::as_str) + .unwrap_or("unknown"), + ) + .to_string(); + + if let Some(prev_url) = browser.snapshot_url() { + if prev_url != cache_key.as_str() { + browser.ref_map_mut().clear(); + } + } + + crate::tools::page_map::annotate_refs(&mut fresh_page_map, browser); + browser.set_page_snapshot(cache_key, fresh_page_map.clone()); + + let Some(new_selector) = crate::self_healing::find_healed_selector( + &old_ref, + &fresh_page_map, + original_hint.as_deref(), + ) else { + return Ok(None); + }; + + let patched = crate::self_healing::patch_selector(input_value, &old_ref, &new_selector); + Ok(Some(( + patched, + format!("[healed: {old_ref} → {new_selector}]"), + ))) + } + + fn lookup_cached_action( + &mut self, + settings: &runtime::Settings, + tool_name: &str, + input_value: &Value, + ) -> Option { + if !runtime::settings_get_action_caching(settings) + || !crate::action_cache::is_cacheable(tool_name) + { + return None; + } + + let current_fingerprint = self.crawl_state.page_fingerprints.last().cloned()?; + let ttl_secs = runtime::settings_get_action_cache_ttl_secs(settings); + + if self.crawl_state.action_cache.is_none() { + self.crawl_state.action_cache = Some(crate::action_cache::ActionCache::new(ttl_secs)); + } + + let cache_key = crate::action_cache::ActionCache::make_key( + tool_name, + input_value, + ¤t_fingerprint, + ); + self.crawl_state + .action_cache + .as_mut() + .and_then(|cache| cache.lookup(&cache_key, ¤t_fingerprint)) + } + + fn store_cached_action( + &mut self, + settings: &runtime::Settings, + tool_name: &str, + input_value: &Value, + text: &str, + ) { + if !runtime::settings_get_action_caching(settings) + || !crate::action_cache::is_cacheable(tool_name) + { + return; + } + + let Some(current_fingerprint) = self.crawl_state.page_fingerprints.last().cloned() else { + return; + }; + let Some(cache) = self.crawl_state.action_cache.as_mut() else { + return; + }; + + let cache_key = crate::action_cache::ActionCache::make_key( + tool_name, + input_value, + ¤t_fingerprint, + ); + cache.store(cache_key, text.to_string(), current_fingerprint); + } + + fn apply_planning_guidance(&self, interval: usize) { + let planning_guidance = if self.step_count.is_multiple_of(interval) { + "Planning checkpoint: Review your overall goal, assess progress, and decide your next major objective." + .to_string() + } else { + "Execution mode: Focus on the current step. Take precise action and evaluate the result." + .to_string() + }; + + self.write_prompt_override(&DynamicPromptContext { + planning_guidance: Some(planning_guidance), + ..Default::default() + }); + } + + fn apply_loop_detection( + &mut self, + settings: &runtime::Settings, + tool_name: &str, + input: &Value, + ) { + let window = runtime::settings_get_loop_detection_window(settings); + let threshold = runtime::settings_get_loop_nudge_threshold(settings); + + if self.crawl_state.loop_detector.is_none() { + self.crawl_state.loop_detector = + Some(crate::loop_detector::LoopDetector::new(window, threshold)); + } + + if let Some(detector) = self.crawl_state.loop_detector.as_mut() { + detector.record_action(tool_name, input); + + if let Some(fingerprint) = self.crawl_state.page_fingerprints.last() { + detector.record_page_state(fingerprint); + } + + if let Some(nudge) = detector.detect_loop() { + self.write_prompt_override(&DynamicPromptContext { + loop_nudge: Some(nudge.message().to_string()), + ..DynamicPromptContext::default() + }); + } + } + } + + fn write_prompt_override(&self, ctx: &DynamicPromptContext) { + let mut accumulated = self + .accumulated_turn_ctx + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + if ctx.stagnation_alert.is_some() { + accumulated + .stagnation_alert + .clone_from(&ctx.stagnation_alert); + } + if ctx.planning_guidance.is_some() { + accumulated + .planning_guidance + .clone_from(&ctx.planning_guidance); + } + if ctx.budget_warning.is_some() { + accumulated.budget_warning.clone_from(&ctx.budget_warning); + } + if ctx.loop_nudge.is_some() { + accumulated.loop_nudge.clone_from(&ctx.loop_nudge); + } + let specs = crate::mvp_tool_specs(); + let new_prompt = build_system_prompt(&specs, Some(&*accumulated)); + *self + .prompt_override_slot + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(new_prompt); + } + + fn apply_confidence_tracking(&mut self) { + let text_opt = { + let guard = self + .last_assistant_text_slot + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + guard.clone() + }; + + if let Some(text) = text_opt { + if let Some(conf) = crate::confidence::ConfidenceTracker::parse_from_text(&text) { + if self.confidence_tracker.is_none() { + self.confidence_tracker = Some(crate::confidence::ConfidenceTracker::new()); + } + let should_alert = self + .confidence_tracker + .as_mut() + .is_some_and(|tracker| tracker.record(conf)); + + if should_alert { + self.write_prompt_override(&DynamicPromptContext { + stagnation_alert: Some( + "Your confidence has been LOW for multiple consecutive steps. \ + Reconsider your approach." + .to_string(), + ), + planning_guidance: Some(crate::confidence::confidence_instruction()), + ..DynamicPromptContext::default() + }); + } else { + self.write_prompt_override(&DynamicPromptContext { + planning_guidance: Some(crate::confidence::confidence_instruction()), + ..DynamicPromptContext::default() + }); + } + } + } + } + async fn dispatch_tool_effect(&mut self, tool_effect: ToolEffect) -> Result { match tool_effect { ToolEffect::Reply(output) => Ok(output), @@ -598,7 +973,7 @@ impl CrawlerAgent { } fn default_agent_manager() -> SharedAgentManager { - Arc::new(Mutex::new(AgentManager::new( + Arc::new(AsyncMutex::new(AgentManager::new( DEFAULT_MAX_CONCURRENT_PER_PARENT, DEFAULT_MAX_FORK_DEPTH, DEFAULT_MAX_TOTAL_AGENTS, @@ -641,17 +1016,18 @@ fn build_crawl_result(summary: &TurnSummary, crawl_state: &CrawlState) -> CrawlR extracted_data, steps_executed: summary.iterations, messages: Vec::new(), + model: None, } } #[cfg(test)] mod tests { - use std::sync::OnceLock; - use acrawl_core::{ApiRequest, AssistantEvent, RuntimeError, TokenUsage}; + use async_trait::async_trait; use tokio::sync::Mutex as AsyncMutex; use super::*; + use crate::page_fingerprint::PageFingerprint; use crate::registry::ToolRegistry; struct MockApiClient { @@ -748,9 +1124,261 @@ mod tests { registry } - fn env_lock() -> &'static AsyncMutex<()> { - static LOCK: OnceLock> = OnceLock::new(); - LOCK.get_or_init(|| AsyncMutex::new(())) + fn action_cache_test_fingerprint(url: &str) -> PageFingerprint { + PageFingerprint { + url: url.to_string(), + element_count: 3, + text_hash: 42, + } + } + + fn write_action_cache_settings(config_home: &std::path::Path, enabled: bool, ttl_secs: u64) { + std::env::set_var("ACRAWL_CONFIG_HOME", config_home); + runtime::save_settings(&runtime::Settings { + optimization: Some(runtime::settings::OptimizationSettings { + action_caching: Some(enabled), + action_cache_ttl_secs: Some(ttl_secs), + ..Default::default() + }), + ..runtime::Settings::default() + }) + .expect("settings should save"); + } + + fn cache_test_registry( + counter: Arc>, + tool_name: &'static str, + ) -> ToolRegistry { + let mut registry = ToolRegistry::new(); + registry.register( + tool_name, + Box::new(move |input| { + let mut calls = counter + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + *calls += 1; + Ok(ToolEffect::Reply(format!( + "{tool_name} call {} with {}", + *calls, input + ))) + }), + ); + registry + } + + fn write_self_healing_settings(config_home: &std::path::Path, enabled: bool, retries: usize) { + std::env::set_var("ACRAWL_CONFIG_HOME", config_home); + runtime::save_settings(&runtime::Settings { + optimization: Some(runtime::settings::OptimizationSettings { + self_healing: Some(enabled), + self_healing_max_retries: Some(retries), + ..Default::default() + }), + ..runtime::Settings::default() + }) + .expect("settings should save"); + } + + #[derive(Debug, Default)] + struct HealingBridgeState { + latest_click_selector: Option, + page_map_calls: usize, + } + + #[derive(Debug)] + struct HealingTestBridge { + click_failures_remaining: usize, + click_error_message: String, + page_map_value: Value, + state: Arc>, + } + + impl HealingTestBridge { + fn new( + click_failures_remaining: usize, + click_error_message: &str, + page_map_value: Value, + state: Arc>, + ) -> Self { + Self { + click_failures_remaining, + click_error_message: click_error_message.to_string(), + page_map_value, + state, + } + } + } + + #[async_trait] + impl BrowserBackend for HealingTestBridge { + async fn navigate(&mut self, _url: &str) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + + async fn new_page(&mut self, _url: Option<&str>) -> Result { + Ok(0) + } + + async fn close_page(&mut self, _page_index: usize) -> Result<(), BridgeError> { + Ok(()) + } + + async fn scroll(&mut self, _direction: &str, _pixels: i64) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn page_map( + &mut self, + _scope: Option<&str>, + _compound_enrichment: bool, + ) -> Result { + let mut state = self + .state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + state.page_map_calls += 1; + Ok(self.page_map_value.clone()) + } + + async fn read_content( + &mut self, + _heading: Option<&str>, + _selector: Option<&str>, + _offset: usize, + _max_chars: usize, + ) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + + async fn wait_for_selector( + &mut self, + _selector: &str, + _timeout_ms: u64, + _state: Option<&str>, + ) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + + async fn select_option( + &mut self, + _selector: &str, + _value: &str, + ) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn evaluate(&mut self, _script: &str) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + + async fn hover(&mut self, _selector: &str) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn press_key( + &mut self, + _key: &str, + _selector: Option<&str>, + ) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn switch_tab(&mut self, _index: i64) -> Result { + Ok(serde_json::json!({})) + } + + async fn export_cookies(&mut self) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + + async fn import_cookies(&mut self, _state: &BrowserState) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn import_cookies_only(&mut self, _state: &BrowserState) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn import_local_storage(&mut self, _state: &BrowserState) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn list_resources(&mut self) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + + async fn save_file(&mut self, _url: &str, _path: &str) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + + async fn click(&mut self, selector: &str) -> Result<(), BridgeError> { + let mut state = self + .state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + state.latest_click_selector = Some(selector.to_string()); + drop(state); + + if self.click_failures_remaining > 0 { + self.click_failures_remaining -= 1; + return Err(BridgeError::Protocol(self.click_error_message.clone())); + } + Ok(()) + } + + async fn click_at(&mut self, _x: f64, _y: f64) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn fill(&mut self, _selector: &str, _value: &str) -> Result<(), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn screenshot( + &mut self, + _options: &ScreenshotOptions<'_>, + ) -> Result<(String, usize), BridgeError> { + Err(BridgeError::Protocol("unused".into())) + } + + async fn go_back(&mut self) -> Result { + Err(BridgeError::Protocol("unused".into())) + } + } + + fn healing_page_map() -> Value { + serde_json::json!({ + "interactive": { + "elements": [ + {"selector": "#new-submit", "role": "button", "name": "Submit", "text": "Submit"} + ] + }, + "meta": {"url": "https://example.com", "title": "Example"}, + "headings": [], + "links": [], + "forms": [], + "landmarks": [] + }) + } + + fn healing_test_agent( + click_failures_remaining: usize, + click_error_message: &str, + page_map_value: Value, + ) -> (CrawlerAgent, Arc>) { + let state = Arc::new(std::sync::Mutex::new(HealingBridgeState::default())); + let shared_bridge: Arc>> = + Arc::new(AsyncMutex::new(Box::new(HealingTestBridge::new( + click_failures_remaining, + click_error_message, + page_map_value, + Arc::clone(&state), + )))); + let browser = BrowserContext::new_shared(shared_bridge, 0); + ( + CrawlerAgent::new(browser, ToolRegistry::new_with_core_tools()), + state, + ) } #[tokio::test] @@ -766,7 +1394,7 @@ mod tests { #[tokio::test] async fn test_spawn_effect_triggers_fork() { - let _env_guard = env_lock().lock().await; + let _env_guard = crate::test_async_env_lock().lock().await; std::env::set_var("HEADLESS", "true"); let manager = default_agent_manager(); manager.lock().await.register_root("root"); @@ -784,7 +1412,7 @@ mod tests { .with_agent_id("root".to_string()) .with_agent_manager(manager.clone()); agent.api_client_arc = Some(SharedApiClient::new(TextOnlyApiClient)); - agent.shared_bridge = Some(Arc::new(Mutex::new( + agent.shared_bridge = Some(Arc::new(AsyncMutex::new( Box::new(bridge) as Box ))); agent.fork_page_index_override = Some(1); @@ -966,7 +1594,7 @@ mod tests { #[tokio::test] async fn test_fork_dispatch_spawns_child() { - let _env_guard = env_lock().lock().await; + let _env_guard = crate::test_async_env_lock().lock().await; std::env::set_var("HEADLESS", "true"); let manager = default_agent_manager(); manager.lock().await.register_root("root"); @@ -979,12 +1607,17 @@ mod tests { .with_agent_id("root".to_string()) .with_agent_manager(manager.clone()); agent.api_client_arc = Some(shared_client); - agent.shared_bridge = Some(Arc::new(Mutex::new(Box::new( - crate::PlaywrightBridge::new() - .await - .expect("bridge should initialize for fork test"), - ) - as Box))); + let bridge = match crate::PlaywrightBridge::new().await { + Ok(bridge) => bridge, + Err(crate::BridgeError::PlaywrightNotInstalled(_)) => { + eprintln!("skipping test: CloakBrowser not installed"); + return; + } + Err(error) => panic!("unexpected bridge error: {error}"), + }; + agent.shared_bridge = Some(Arc::new(AsyncMutex::new( + Box::new(bridge) as Box + ))); agent.fork_page_index_override = Some(1); let observation = agent @@ -1015,7 +1648,7 @@ mod tests { #[tokio::test] async fn test_fork_returns_observation() { - let _env_guard = env_lock().lock().await; + let _env_guard = crate::test_async_env_lock().lock().await; std::env::set_var("HEADLESS", "true"); let manager = default_agent_manager(); manager.lock().await.register_root("root"); @@ -1033,7 +1666,7 @@ mod tests { .with_agent_id("root".to_string()) .with_agent_manager(manager); agent.api_client_arc = Some(SharedApiClient::new(TextOnlyApiClient)); - agent.shared_bridge = Some(Arc::new(Mutex::new( + agent.shared_bridge = Some(Arc::new(AsyncMutex::new( Box::new(bridge) as Box ))); agent.fork_page_index_override = Some(1); @@ -1059,7 +1692,7 @@ mod tests { #[tokio::test] async fn test_fork_at_limit_returns_error() { - let manager = Arc::new(Mutex::new(AgentManager::new(0, 3, 10))); + let manager = Arc::new(AsyncMutex::new(AgentManager::new(0, 3, 10))); manager.lock().await.register_root("root"); let mut agent = CrawlerAgent::new_for_testing(ToolRegistry::new_with_core_tools()) @@ -1276,6 +1909,7 @@ mod tests { let result = build_crawl_result(&summary, &CrawlState::default()); assert_eq!(result.summary, "final answer"); assert_eq!(result.steps_executed, 2); + assert_eq!(result.model, None); } #[test] @@ -1353,4 +1987,364 @@ mod tests { vec![serde_json::json!({"from_state": true})] ); } + + #[tokio::test] + async fn test_planning_interval_disabled_by_default() { + let _env_guard = crate::test_async_env_lock().lock().await; + std::env::set_var("ACRAWL_CONFIG_HOME", ""); + + let mut agent = CrawlerAgent::new_for_testing(mock_registry()); + let initial_slot = agent + .prompt_override_slot + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + assert_eq!(initial_slot, None, "slot should be empty initially"); + + let _ = agent + .execute("navigate", r#"{"url":"https://example.com"}"#) + .await; + + let slot_after = agent + .prompt_override_slot + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + assert_eq!( + slot_after, None, + "slot should remain empty when interval=0 (default)" + ); + } + + #[tokio::test] + async fn test_action_cache_hits_for_same_read_content_and_fingerprint() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_action_cache_settings(temp_dir.path(), true, 30); + + let call_count = Arc::new(std::sync::Mutex::new(0usize)); + let registry = cache_test_registry(Arc::clone(&call_count), "read_content"); + let mut agent = CrawlerAgent::new_for_testing(registry); + agent + .crawl_state + .page_fingerprints + .push(action_cache_test_fingerprint("https://example.com")); + + let first = agent + .execute("read_content", r##"{"selector":"#main","offset":0}"##) + .await + .expect("first call should succeed"); + let second = agent + .execute("read_content", r##"{"offset":0,"selector":"#main"}"##) + .await + .expect("second call should succeed"); + + assert_eq!(first.text, second.text); + assert_eq!( + *call_count + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner), + 1 + ); + } + + #[tokio::test] + async fn test_action_cache_misses_when_page_fingerprint_changes() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_action_cache_settings(temp_dir.path(), true, 30); + + let call_count = Arc::new(std::sync::Mutex::new(0usize)); + let registry = cache_test_registry(Arc::clone(&call_count), "read_content"); + let mut agent = CrawlerAgent::new_for_testing(registry); + agent + .crawl_state + .page_fingerprints + .push(action_cache_test_fingerprint("https://example.com/a")); + + let _ = agent + .execute("read_content", r##"{"selector":"#main"}"##) + .await + .expect("first call should succeed"); + agent.crawl_state.page_fingerprints.push(PageFingerprint { + url: "https://example.com/b".to_string(), + element_count: 7, + text_hash: 99, + }); + let _ = agent + .execute("read_content", r##"{"selector":"#main"}"##) + .await + .expect("second call should succeed"); + + assert_eq!( + *call_count + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner), + 2 + ); + } + + #[tokio::test] + async fn test_action_cache_misses_after_ttl_expires() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_action_cache_settings(temp_dir.path(), true, 0); + + let call_count = Arc::new(std::sync::Mutex::new(0usize)); + let registry = cache_test_registry(Arc::clone(&call_count), "read_content"); + let mut agent = CrawlerAgent::new_for_testing(registry); + agent + .crawl_state + .page_fingerprints + .push(action_cache_test_fingerprint("https://example.com")); + + let _ = agent + .execute("read_content", r##"{"selector":"#main"}"##) + .await + .expect("first call should succeed"); + let _ = agent + .execute("read_content", r##"{"selector":"#main"}"##) + .await + .expect("second call should succeed"); + + assert_eq!( + *call_count + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner), + 2 + ); + } + + #[tokio::test] + async fn test_action_cache_flag_off_keeps_behavior_unchanged() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_action_cache_settings(temp_dir.path(), false, 30); + + let call_count = Arc::new(std::sync::Mutex::new(0usize)); + let registry = cache_test_registry(Arc::clone(&call_count), "read_content"); + let mut agent = CrawlerAgent::new_for_testing(registry); + agent + .crawl_state + .page_fingerprints + .push(action_cache_test_fingerprint("https://example.com")); + + let _ = agent + .execute("read_content", r##"{"selector":"#main"}"##) + .await + .expect("first call should succeed"); + let _ = agent + .execute("read_content", r##"{"selector":"#main"}"##) + .await + .expect("second call should succeed"); + + assert_eq!( + *call_count + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner), + 2 + ); + assert!(agent.crawl_state.action_cache.is_none()); + } + + #[tokio::test] + async fn test_interaction_tools_are_never_action_cached() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_action_cache_settings(temp_dir.path(), true, 30); + + let call_count = Arc::new(std::sync::Mutex::new(0usize)); + let registry = cache_test_registry(Arc::clone(&call_count), "click"); + let mut agent = CrawlerAgent::new_for_testing(registry); + agent + .crawl_state + .page_fingerprints + .push(action_cache_test_fingerprint("https://example.com")); + + let _ = agent + .execute("click", r##"{"selector":"#submit"}"##) + .await + .expect("first click should succeed"); + let _ = agent + .execute("click", r##"{"selector":"#submit"}"##) + .await + .expect("second click should succeed"); + + assert_eq!( + *call_count + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner), + 2 + ); + } + + #[test] + fn test_step_count_increments_on_each_execute() { + let agent = CrawlerAgent::new_for_testing(mock_registry()); + assert_eq!(agent.step_count, 0, "step_count should start at 0"); + } + + #[tokio::test] + async fn selector_not_found_with_matching_text_heals_and_succeeds() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_self_healing_settings(temp_dir.path(), true, 2); + + let (mut agent, state) = + healing_test_agent(1, "Element not found matching selector", healing_page_map()); + agent + .browser + .as_mut() + .expect("browser should exist") + .ref_map_mut() + .assign_or_reuse("#old-submit", "button", "Submit"); + + let result = agent + .execute("click", r#"{"selector":"@e1"}"#) + .await + .expect("healed click should succeed"); + + assert!( + result.text.contains("[healed: @e1 → @e2]"), + "{}", + result.text + ); + let state = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(state.latest_click_selector.as_deref(), Some("#new-submit")); + assert!(state.page_map_calls >= 1); + } + + #[tokio::test] + async fn selector_not_found_with_no_match_returns_original_error() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_self_healing_settings(temp_dir.path(), true, 2); + + let (mut agent, state) = healing_test_agent( + 1, + "Element not found matching selector", + serde_json::json!({ + "interactive": { + "elements": [ + {"selector": "#login", "role": "button", "name": "Login", "text": "Login"} + ] + }, + "meta": {"url": "https://example.com", "title": "Example"}, + "headings": [], + "links": [], + "forms": [], + "landmarks": [] + }), + ); + agent + .browser + .as_mut() + .expect("browser should exist") + .ref_map_mut() + .assign_or_reuse("#old-submit", "button", "Submit"); + + let err = agent + .execute("click", r#"{"selector":"@e1"}"#) + .await + .expect_err("unhealable selector should fail"); + + assert!(err + .to_string() + .contains("Element not found matching selector")); + let state = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(state.latest_click_selector.as_deref(), Some("#old-submit")); + assert_eq!(state.page_map_calls, 1); + } + + #[tokio::test] + async fn max_retries_are_respected_for_self_healing() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_self_healing_settings(temp_dir.path(), true, 1); + + let (mut agent, state) = + healing_test_agent(2, "Element not found matching selector", healing_page_map()); + agent + .browser + .as_mut() + .expect("browser should exist") + .ref_map_mut() + .assign_or_reuse("#old-submit", "button", "Submit"); + + let err = agent + .execute("click", r#"{"selector":"@e1"}"#) + .await + .expect_err("single retry should not mask repeated failure"); + + assert!(err + .to_string() + .contains("Element not found matching selector")); + let state = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(state.latest_click_selector.as_deref(), Some("#new-submit")); + assert_eq!(state.page_map_calls, 1); + } + + #[tokio::test] + async fn non_selector_errors_do_not_attempt_healing() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_self_healing_settings(temp_dir.path(), true, 2); + + let (mut agent, state) = + healing_test_agent(1, "network connection refused", healing_page_map()); + agent + .browser + .as_mut() + .expect("browser should exist") + .ref_map_mut() + .assign_or_reuse("#old-submit", "button", "Submit"); + + let err = agent + .execute("click", r#"{"selector":"@e1"}"#) + .await + .expect_err("network failure should not heal"); + + assert!(err.to_string().contains("network connection refused")); + let state = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(state.page_map_calls, 0); + assert_eq!(state.latest_click_selector.as_deref(), Some("#old-submit")); + } + + #[tokio::test] + async fn self_healing_flag_off_keeps_selector_failure_behavior() { + let _env_guard = crate::test_async_env_lock().lock().await; + let temp_dir = tempfile::tempdir().expect("temp dir should create"); + write_self_healing_settings(temp_dir.path(), false, 2); + + let (mut agent, state) = + healing_test_agent(1, "Element not found matching selector", healing_page_map()); + agent + .browser + .as_mut() + .expect("browser should exist") + .ref_map_mut() + .assign_or_reuse("#old-submit", "button", "Submit"); + + let err = agent + .execute("click", r#"{"selector":"@e1"}"#) + .await + .expect_err("flag-off should preserve original failure"); + + assert!(err + .to_string() + .contains("Element not found matching selector")); + let state = state + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + assert_eq!(state.page_map_calls, 0); + assert_eq!(state.latest_click_selector.as_deref(), Some("#old-submit")); + } } diff --git a/crates/agent/src/lib.rs b/crates/agent/src/lib.rs index 0f7bdba3..e1a37f1c 100644 --- a/crates/agent/src/lib.rs +++ b/crates/agent/src/lib.rs @@ -1,11 +1,17 @@ +pub mod action_cache; pub mod agent; pub mod child_events; +pub mod confidence; +pub mod failure_classifier; +pub mod loop_detector; pub mod manager; pub mod output; +pub mod page_fingerprint; pub mod prompt; pub mod registry; pub mod script_executor; pub mod script_manager; +pub mod self_healing; mod shared_client; pub mod state; pub mod tools; @@ -38,7 +44,7 @@ pub use child_events::{ }; pub use manager::{AgentInfo, AgentManager, AgentStatus, ForkLimitError, SharedAgentManager}; pub use output::{write_output, OutputError, OutputFormat}; -pub use prompt::build_system_prompt; +pub use prompt::{build_system_prompt, DynamicPromptContext}; pub use registry::{ToolHandler, ToolRegistry}; pub use shared_client::SharedApiClient; pub use state::{ChildBlock, CrawlState}; @@ -46,6 +52,12 @@ pub use url_claim::{ClaimConflict, ClaimGuard, UrlClaimRegistry}; use serde_json::json; +#[cfg(test)] +pub(crate) fn test_async_env_lock() -> &'static tokio::sync::Mutex<()> { + static LOCK: std::sync::OnceLock> = std::sync::OnceLock::new(); + LOCK.get_or_init(|| tokio::sync::Mutex::new(())) +} + fn navigation_tools() -> Vec { vec![ ToolSpec { diff --git a/crates/agent/src/loop_detector.rs b/crates/agent/src/loop_detector.rs new file mode 100644 index 00000000..eb324b8d --- /dev/null +++ b/crates/agent/src/loop_detector.rs @@ -0,0 +1,248 @@ +use std::collections::VecDeque; + +use serde_json::Value; + +use crate::page_fingerprint::PageFingerprint; + +const DEFAULT_WINDOW: usize = 20; +const DEFAULT_NUDGE_THRESHOLD: usize = 5; + +#[derive(Debug, Clone)] +pub enum LoopNudge { + Soft(String), + Medium(String), + Strong(String), + Stagnation(String), +} + +impl LoopNudge { + #[must_use] + pub fn message(&self) -> &str { + match self { + Self::Soft(message) + | Self::Medium(message) + | Self::Strong(message) + | Self::Stagnation(message) => message, + } + } +} + +#[derive(Debug, Clone)] +pub struct LoopDetector { + action_window: VecDeque, + page_fingerprints: VecDeque, + window_size: usize, + nudge_threshold: usize, +} + +impl Default for LoopDetector { + fn default() -> Self { + Self::new(DEFAULT_WINDOW, DEFAULT_NUDGE_THRESHOLD) + } +} + +impl LoopDetector { + #[must_use] + pub fn new(window_size: usize, nudge_threshold: usize) -> Self { + Self { + action_window: VecDeque::new(), + page_fingerprints: VecDeque::new(), + window_size: window_size.max(1), + nudge_threshold: nudge_threshold.max(1), + } + } + + /// Record a tool action. Tool name + normalized input is hashed. + pub fn record_action(&mut self, tool_name: &str, input: &Value) { + let hash = hash_action(tool_name, input); + self.action_window.push_back(hash); + while self.action_window.len() > self.window_size { + let _ = self.action_window.pop_front(); + } + } + + /// Record a page state fingerprint. + pub fn record_page_state(&mut self, fingerprint: &PageFingerprint) { + let key = format!( + "{}|{}|{}", + fingerprint.url, fingerprint.element_count, fingerprint.text_hash + ); + self.page_fingerprints.push_back(key); + while self.page_fingerprints.len() > self.window_size { + let _ = self.page_fingerprints.pop_front(); + } + } + + /// Check for repetition patterns. Returns a nudge if a loop is detected. + #[must_use] + pub fn detect_loop(&self) -> Option { + if self.action_window.len() >= self.nudge_threshold { + let last = *self.action_window.back()?; + let repeat_count = self + .action_window + .iter() + .rev() + .take(self.window_size) + .take_while(|&&hash| hash == last) + .count(); + + if repeat_count >= 12 { + return Some(LoopNudge::Strong(format!( + "You have repeated the same action {repeat_count} times. This approach is not working. You MUST try a completely different strategy." + ))); + } + if repeat_count >= 8 { + return Some(LoopNudge::Medium(format!( + "Your actions seem repetitive ({repeat_count} identical actions). Consider a significantly different approach." + ))); + } + if repeat_count >= self.nudge_threshold { + return Some(LoopNudge::Soft( + "Consider a different approach — you may be repeating actions that haven't worked." + .to_string(), + )); + } + } + + if self.page_fingerprints.len() >= self.nudge_threshold { + let last = self.page_fingerprints.back()?; + let stagnant_count = self + .page_fingerprints + .iter() + .rev() + .take(self.nudge_threshold) + .filter(|fingerprint| *fingerprint == last) + .count(); + if stagnant_count >= self.nudge_threshold { + return Some(LoopNudge::Stagnation( + "The page state has not changed for several steps. You may be stuck. Try a different action or navigate elsewhere." + .to_string(), + )); + } + } + + None + } +} + +/// Normalize action to a stable hash. +fn hash_action(tool_name: &str, input: &Value) -> u64 { + let key = match tool_name { + "click" => { + let selector = input.get("selector").and_then(Value::as_str).unwrap_or(""); + format!("click:{selector}") + } + "fill_form" => { + let fields = input + .get("fields") + .map(|value| serde_json::to_string(value).unwrap_or_default()) + .unwrap_or_default() + .to_lowercase(); + format!("fill_form:{fields}") + } + "navigate" => { + let url = input.get("url").and_then(Value::as_str).unwrap_or(""); + format!("navigate:{url}") + } + "scroll" => { + let direction = input.get("direction").and_then(Value::as_str).unwrap_or(""); + format!("scroll:{direction}") + } + _ => { + let canonical = serde_json::to_string(input).unwrap_or_default(); + format!("{tool_name}:{canonical}") + } + }; + fnv1a_hash(&key) +} + +fn fnv1a_hash(input: &str) -> u64 { + const FNV_OFFSET: u64 = 14_695_981_039_346_656_037; + const FNV_PRIME: u64 = 1_099_511_628_211; + let mut hash = FNV_OFFSET; + for byte in input.bytes() { + hash ^= u64::from(byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + hash +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::*; + + fn make_fp(url: &str) -> PageFingerprint { + PageFingerprint { + url: url.to_string(), + element_count: 5, + text_hash: 12_345, + } + } + + #[test] + fn soft_nudge_at_threshold() { + let mut detector = LoopDetector::new(20, 5); + let input = json!({"selector": "@e3"}); + for _ in 0..5 { + detector.record_action("click", &input); + } + let nudge = detector.detect_loop().expect("expected soft nudge"); + assert!(matches!(nudge, LoopNudge::Soft(_))); + assert!(nudge.message().contains("different approach")); + } + + #[test] + fn no_nudge_below_threshold() { + let mut detector = LoopDetector::new(20, 5); + let input = json!({"selector": "@e3"}); + for _ in 0..4 { + detector.record_action("click", &input); + } + assert!(detector.detect_loop().is_none()); + } + + #[test] + fn different_actions_no_false_positive() { + let mut detector = LoopDetector::new(20, 5); + for i in 0..10 { + detector.record_action("click", &json!({"selector": format!("@e{i}")})); + } + assert!(detector.detect_loop().is_none()); + } + + #[test] + fn navigate_different_urls_no_false_positive() { + let mut detector = LoopDetector::new(20, 5); + for i in 1..=6 { + detector.record_action( + "navigate", + &json!({"url": format!("https://example.com/page{i}")}), + ); + } + assert!(detector.detect_loop().is_none()); + } + + #[test] + fn stagnation_after_five_identical_fingerprints() { + let mut detector = LoopDetector::new(20, 5); + let fingerprint = make_fp("https://example.com"); + for _ in 0..5 { + detector.record_page_state(&fingerprint); + } + let nudge = detector.detect_loop().expect("expected stagnation nudge"); + assert!(matches!(nudge, LoopNudge::Stagnation(_))); + } + + #[test] + fn strong_nudge_at_twelve_repeats() { + let mut detector = LoopDetector::new(20, 5); + let input = json!({"selector": "@e3"}); + for _ in 0..12 { + detector.record_action("click", &input); + } + let nudge = detector.detect_loop().expect("expected strong nudge"); + assert!(matches!(nudge, LoopNudge::Strong(_))); + } +} diff --git a/crates/agent/src/page_fingerprint.rs b/crates/agent/src/page_fingerprint.rs new file mode 100644 index 00000000..7c97f4fb --- /dev/null +++ b/crates/agent/src/page_fingerprint.rs @@ -0,0 +1,159 @@ +use serde_json::Value; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PageFingerprint { + pub url: String, + pub element_count: usize, + pub text_hash: u64, +} + +impl PageFingerprint { + /// Compute a fingerprint from URL and `page_map` data. + /// `page_map` is the JSON value returned by the `page_map` tool. + /// Only hashes the first 1000 chars of visible text to stay cheap. + #[must_use] + pub fn compute(url: &str, page_map: &Value) -> Self { + #[allow(clippy::cast_possible_truncation)] + let element_count = page_map + .get("interactive") + .and_then(|i| i.get("counts")) + .and_then(|c| c.get("total")) + .and_then(Value::as_u64) + .unwrap_or(0) as usize; + + let text = extract_page_text(page_map); + let truncated: String = text.chars().take(1000).collect(); + let text_hash = simple_hash(&truncated); + + Self { + url: url.to_string(), + element_count, + text_hash, + } + } + + #[must_use] + pub fn pages_identical(a: &PageFingerprint, b: &PageFingerprint) -> bool { + a == b + } +} + +fn extract_page_text(page_map: &Value) -> String { + let mut parts = Vec::new(); + + if let Some(headings) = page_map.get("headings").and_then(Value::as_array) { + for h in headings { + if let Some(text) = h.get("text").and_then(Value::as_str) { + parts.push(text.to_string()); + } + } + } + + if let Some(links) = page_map.get("links").and_then(Value::as_array) { + for link in links { + if let Some(text) = link.get("text").and_then(Value::as_str) { + parts.push(text.to_string()); + } + } + } + + parts.join(" ") +} + +/// FNV-1a 64-bit hash — zero dependencies, deterministic +fn simple_hash(s: &str) -> u64 { + const FNV_OFFSET: u64 = 14_695_981_039_346_656_037; + const FNV_PRIME: u64 = 1_099_511_628_211; + let mut hash = FNV_OFFSET; + for byte in s.bytes() { + hash ^= u64::from(byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + hash +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn sample_page_map(headings: &[&str], links: &[&str], total_interactive: u64) -> Value { + json!({ + "headings": headings.iter().map(|t| json!({"text": t, "level": 1})).collect::>(), + "links": links.iter().map(|t| json!({"text": t, "href": "https://example.com"})).collect::>(), + "interactive": { + "counts": {"total": total_interactive} + }, + "meta": {"url": "https://example.com", "title": "Test"} + }) + } + + #[test] + fn identical_pages_produce_identical_fingerprints() { + let pm = sample_page_map(&["Welcome", "About"], &["Home", "Contact"], 5); + let fp1 = PageFingerprint::compute("https://example.com", &pm); + let fp2 = PageFingerprint::compute("https://example.com", &pm); + + assert_eq!(fp1, fp2); + assert!(PageFingerprint::pages_identical(&fp1, &fp2)); + } + + #[test] + fn different_text_produces_different_fingerprints() { + let pm1 = sample_page_map(&["Welcome"], &["Home"], 3); + let pm2 = sample_page_map(&["Goodbye"], &["Away"], 3); + let fp1 = PageFingerprint::compute("https://example.com", &pm1); + let fp2 = PageFingerprint::compute("https://example.com", &pm2); + + assert_ne!(fp1, fp2); + assert!(!PageFingerprint::pages_identical(&fp1, &fp2)); + } + + #[test] + fn url_change_produces_different_fingerprint() { + let pm = sample_page_map(&["Welcome"], &["Home"], 3); + let fp1 = PageFingerprint::compute("https://example.com/page1", &pm); + let fp2 = PageFingerprint::compute("https://example.com/page2", &pm); + + assert_ne!(fp1, fp2); + assert!(!PageFingerprint::pages_identical(&fp1, &fp2)); + } + + #[test] + fn empty_page_map_produces_valid_fingerprint() { + let pm = json!({}); + let fp = PageFingerprint::compute("https://empty.com", &pm); + + assert_eq!(fp.url, "https://empty.com"); + assert_eq!(fp.element_count, 0); + assert_eq!(fp.text_hash, simple_hash("")); + } + + #[test] + fn text_truncated_at_1000_chars() { + let long_heading = "A".repeat(1200); + let pm = json!({ + "headings": [{"text": long_heading, "level": 1}], + "links": [], + "interactive": {"counts": {"total": 0}} + }); + + let fp = PageFingerprint::compute("https://example.com", &pm); + let expected_hash = simple_hash(&"A".repeat(1000)); + assert_eq!(fp.text_hash, expected_hash); + } + + #[test] + fn element_count_extracted_from_interactive_total() { + let pm = json!({ + "headings": [], + "links": [], + "interactive": { + "counts": {"total": 42, "buttons": 10, "inputs": 32} + } + }); + + let fp = PageFingerprint::compute("https://example.com", &pm); + assert_eq!(fp.element_count, 42); + } +} diff --git a/crates/agent/src/prompt.rs b/crates/agent/src/prompt.rs index dc29b27c..406fd121 100644 --- a/crates/agent/src/prompt.rs +++ b/crates/agent/src/prompt.rs @@ -1,5 +1,13 @@ use acrawl_core::ToolSpec; +#[derive(Debug, Default, Clone)] +pub struct DynamicPromptContext { + pub stagnation_alert: Option, + pub planning_guidance: Option, + pub budget_warning: Option, + pub loop_nudge: Option, +} + fn format_tool(spec: &ToolSpec) -> String { let required: Vec<&str> = spec .input_schema @@ -33,8 +41,11 @@ fn list_tools(specs: &[ToolSpec]) -> String { /// /// Returns `Vec` for [`runtime::ConversationRuntime`]'s `system_prompt` parameter. #[must_use] -pub fn build_system_prompt(tool_specs: &[ToolSpec]) -> Vec { - vec![ +pub fn build_system_prompt( + tool_specs: &[ToolSpec], + dynamic_context: Option<&DynamicPromptContext>, +) -> Vec { + let mut sections = vec![ section_identity(tool_specs), section_operating_procedure(), section_data_integrity(), @@ -43,7 +54,37 @@ pub fn build_system_prompt(tool_specs: &[ToolSpec]) -> Vec { section_completion(), section_parallel_exploration(), section_autonomous_scripts(), - ] + ]; + + if let Some(dynamic_section) = dynamic_context.and_then(build_dynamic_section) { + sections.push(dynamic_section); + } + + sections +} + +#[must_use] +pub fn build_dynamic_section(ctx: &DynamicPromptContext) -> Option { + let mut items = Vec::new(); + + if let Some(value) = &ctx.stagnation_alert { + items.push(format!("- Stagnation alert: {value}")); + } + if let Some(value) = &ctx.planning_guidance { + items.push(format!("- Planning guidance: {value}")); + } + if let Some(value) = &ctx.budget_warning { + items.push(format!("- Budget warning: {value}")); + } + if let Some(value) = &ctx.loop_nudge { + items.push(format!("- Loop nudge: {value}")); + } + + if items.is_empty() { + None + } else { + Some(format!("Dynamic guidance:\n{}", items.join("\n"))) + } } fn section_identity(tool_specs: &[ToolSpec]) -> String { @@ -261,7 +302,7 @@ mod tests { #[test] fn build_system_prompt_includes_tool_listing() { - let prompt = build_system_prompt(&sample_specs()); + let prompt = build_system_prompt(&sample_specs(), None); assert!(prompt.len() >= 2, "should have at least 2 prompt sections"); let first = &prompt[0]; @@ -308,7 +349,7 @@ mod tests { #[test] fn build_system_prompt_contains_all_sections() { - let prompt = build_system_prompt(&sample_specs()); + let prompt = build_system_prompt(&sample_specs(), None); let joined = prompt.join("\n"); assert!(joined.contains("Operating procedure:")); assert!(joined.contains("Data integrity:")); @@ -322,7 +363,7 @@ mod tests { #[test] fn build_system_prompt_lists_all_tools() { let specs = crate::mvp_tool_specs(); - let prompt = build_system_prompt(&specs); + let prompt = build_system_prompt(&specs, None); let first = &prompt[0]; for spec in &specs { assert!( @@ -336,7 +377,7 @@ mod tests { #[test] fn test_system_prompt_contains_parallel_exploration() { let specs = crate::mvp_tool_specs(); - let prompt = build_system_prompt(&specs); + let prompt = build_system_prompt(&specs, None); let joined = prompt.join("\n"); assert!(joined.contains("fork"), "should mention fork tool"); assert!(joined.contains("parallel"), "should mention parallel"); @@ -351,7 +392,7 @@ mod tests { #[test] fn test_system_prompt_contains_autonomous_scripts() { let specs = crate::mvp_tool_specs(); - let prompt = build_system_prompt(&specs); + let prompt = build_system_prompt(&specs, None); let joined = prompt.join("\n"); assert!( joined.contains("Autonomous scripts:"), @@ -378,4 +419,39 @@ mod tests { assert!(joined.contains("read_script"), "should mention read_script"); assert_eq!(prompt.len(), 8, "should have 8 sections"); } + + #[test] + fn build_system_prompt_is_unchanged_when_dynamic_context_is_none() { + let specs = sample_specs(); + let prompt = build_system_prompt(&specs, None); + + assert_eq!( + prompt, + vec![ + section_identity(&specs), + section_operating_procedure(), + section_data_integrity(), + section_constraints(), + section_error_recovery(), + section_completion(), + section_parallel_exploration(), + section_autonomous_scripts(), + ] + ); + } + + #[test] + fn build_system_prompt_appends_dynamic_section_when_present() { + let specs = sample_specs(); + let prompt = build_system_prompt( + &specs, + Some(&DynamicPromptContext { + stagnation_alert: Some("You are stuck".to_string()), + ..DynamicPromptContext::default() + }), + ); + + assert_eq!(prompt.len(), 9); + assert!(prompt[8].contains("You are stuck")); + } } diff --git a/crates/agent/src/registry.rs b/crates/agent/src/registry.rs index fab84fea..223a6889 100644 --- a/crates/agent/src/registry.rs +++ b/crates/agent/src/registry.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use serde_json::Value; -use crate::BrowserContext; +use crate::{BrowserContext, CrawlState}; use crate::{ToolEffect, ToolExecutionError}; pub type ToolHandler = Box Result + Send + Sync>; @@ -124,14 +124,17 @@ impl ToolRegistry { name: &str, input: &Value, browser: &mut BrowserContext, + crawl_state: &mut CrawlState, ) -> Result { match name { - "navigate" => crate::tools::navigate::execute(input, browser).await, + "navigate" => crate::tools::navigate::execute(input, browser, crawl_state).await, "click" => crate::tools::click::execute(input, browser).await, "click_at" => crate::tools::click_at::execute(input, browser).await, "fill_form" => crate::tools::fill_form::execute(input, browser).await, - "page_map" => crate::tools::page_map::execute(input, browser).await, - "read_content" => crate::tools::read_content::execute(input, browser).await, + "page_map" => crate::tools::page_map::execute(input, browser, crawl_state).await, + "read_content" => { + crate::tools::read_content::execute(input, browser, crawl_state).await + } "screenshot" => crate::tools::screenshot::execute(input, browser).await, "go_back" => crate::tools::go_back::execute(input, browser).await, "scroll" => crate::tools::scroll::execute(input, browser).await, diff --git a/crates/agent/src/script_executor/mod.rs b/crates/agent/src/script_executor/mod.rs index c14dc0c7..337eb1a1 100644 --- a/crates/agent/src/script_executor/mod.rs +++ b/crates/agent/src/script_executor/mod.rs @@ -17,7 +17,7 @@ use serde_json::{json, Value}; use tokio::time::timeout; use tokio_util::sync::CancellationToken; -use crate::{BrowserContext, ToolEffect, ToolExecutionError, ToolRegistry}; +use crate::{BrowserContext, CrawlState, ToolEffect, ToolExecutionError, ToolRegistry}; #[derive(Debug)] pub enum ScriptExecutionError { @@ -46,6 +46,7 @@ impl std::error::Error for ScriptExecutionError {} pub struct ScriptExecutor { browser: BrowserContext, + crawl_state: CrawlState, state: ScriptState, shared_state: Arc>, limits: ScriptLimits, @@ -69,6 +70,7 @@ impl ScriptExecutor { ) -> Self { Self { browser, + crawl_state: CrawlState::default(), state: ScriptState { script_id, status: ScriptStatus::Pending, @@ -338,7 +340,7 @@ impl ScriptExecutor { ) -> Result { timeout( Duration::from_secs(self.limits.per_step_timeout_secs), - registry.execute_async(tool, input, &mut self.browser), + registry.execute_async(tool, input, &mut self.browser, &mut self.crawl_state), ) .await .map_err(|_| ScriptExecutionError::PerStepTimeout)? @@ -403,6 +405,7 @@ impl ScriptExecutor { fn update_current_url(&mut self, output: &Value) { if let Some(url) = output.get("url").and_then(Value::as_str) { self.state.current_url = Some(url.to_string()); + self.crawl_state.current_url = Some(url.to_string()); return; } @@ -412,6 +415,7 @@ impl ScriptExecutor { .and_then(Value::as_str) { self.state.current_url = Some(url.to_string()); + self.crawl_state.current_url = Some(url.to_string()); } } diff --git a/crates/agent/src/script_executor/parallel.rs b/crates/agent/src/script_executor/parallel.rs index 71d04b2e..1529a4d3 100644 --- a/crates/agent/src/script_executor/parallel.rs +++ b/crates/agent/src/script_executor/parallel.rs @@ -56,6 +56,7 @@ impl ScriptExecutor { let branch_executor = Self { browser, + crawl_state: self.crawl_state.clone(), state: self.state.clone(), shared_state: self.shared_state.clone(), limits: self.limits.clone(), diff --git a/crates/agent/src/script_executor/tests.rs b/crates/agent/src/script_executor/tests.rs index a8af436d..72a862e9 100644 --- a/crates/agent/src/script_executor/tests.rs +++ b/crates/agent/src/script_executor/tests.rs @@ -107,7 +107,11 @@ impl BrowserBackend for MockBridge { Ok(()) } - async fn page_map(&mut self, scope: Option<&str>) -> Result { + async fn page_map( + &mut self, + scope: Option<&str>, + _compound_enrichment: bool, + ) -> Result { self.log("page_map", json!({"scope": scope})); Ok(Self::default_page_map()) } diff --git a/crates/agent/src/self_healing.rs b/crates/agent/src/self_healing.rs new file mode 100644 index 00000000..9152d1c6 --- /dev/null +++ b/crates/agent/src/self_healing.rs @@ -0,0 +1,189 @@ +use serde_json::Value; + +/// Extract the first element ref (for example `@e5`) from tool input JSON. +#[must_use] +pub fn extract_element_ref(input: &Value) -> Option { + if let Some(sel) = input.get("selector").and_then(Value::as_str) { + if sel.starts_with('@') { + return Some(sel.to_string()); + } + } + + if let Some(sel) = input.get("form_selector").and_then(Value::as_str) { + if sel.starts_with('@') { + return Some(sel.to_string()); + } + } + + if let Some(fields) = input.get("fields").and_then(Value::as_object) { + if let Some((key, _)) = fields.iter().find(|(key, _)| key.starts_with('@')) { + return Some(key.clone()); + } + } + + if let Some(s) = input.as_str() { + if s.starts_with('@') { + return Some(s.to_string()); + } + } + + None +} + +fn normalized_hint(hint: Option<&str>) -> Option { + let normalized = hint?.trim().to_lowercase(); + if normalized.is_empty() { + None + } else { + Some(normalized) + } +} + +fn text_matches(candidate: &str, hint: &str) -> bool { + let candidate = candidate.trim().to_lowercase(); + !candidate.is_empty() && (candidate.contains(hint) || hint.contains(&candidate)) +} + +/// Attempt to find a replacement selector from a fresh `page_map`. +#[must_use] +pub fn find_healed_selector( + original_ref: &str, + page_map: &Value, + original_text_hint: Option<&str>, +) -> Option { + let interactive = page_map + .get("interactive") + .and_then(|i| i.get("elements")) + .and_then(Value::as_array)?; + let hint = normalized_hint(original_text_hint)?; + + for elem in interactive { + for field in ["name", "text"] { + if let Some(text) = elem.get(field).and_then(Value::as_str) { + if text_matches(text, &hint) { + if let Some(new_ref) = elem.get("ref").and_then(Value::as_str) { + if new_ref != original_ref { + return Some(new_ref.to_string()); + } + } + if let Some(sel) = elem.get("selector").and_then(Value::as_str) { + return Some(sel.to_string()); + } + } + } + } + } + + None +} + +/// Build a patched input with the healed selector. +#[must_use] +pub fn patch_selector(input: &Value, old_selector: &str, new_selector: &str) -> Value { + let mut patched = input.clone(); + let Some(obj) = patched.as_object_mut() else { + return patched; + }; + + if obj + .get("selector") + .and_then(Value::as_str) + .is_some_and(|sel| sel == old_selector) + { + obj.insert( + "selector".to_string(), + Value::String(new_selector.to_string()), + ); + } + + if obj + .get("form_selector") + .and_then(Value::as_str) + .is_some_and(|sel| sel == old_selector) + { + obj.insert( + "form_selector".to_string(), + Value::String(new_selector.to_string()), + ); + } + + if let Some(fields) = obj.get_mut("fields").and_then(Value::as_object_mut) { + if let Some(value) = fields.remove(old_selector) { + fields.insert(new_selector.to_string(), value); + } + } + + patched +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn extract_element_ref_from_selector_field() { + let input = json!({"selector": "@e5", "other": "data"}); + assert_eq!(extract_element_ref(&input), Some("@e5".to_string())); + } + + #[test] + fn extract_element_ref_from_fields_key() { + let input = json!({"fields": {"@e5": "hello"}, "submit": false}); + assert_eq!(extract_element_ref(&input), Some("@e5".to_string())); + } + + #[test] + fn extract_none_for_css_selector() { + let input = json!({"selector": "button.submit"}); + assert_eq!(extract_element_ref(&input), None); + } + + #[test] + fn find_healed_selector_by_name_match() { + let page_map = json!({ + "interactive": { + "elements": [ + {"ref": "@e7", "name": "Submit", "text": "Submit", "tag": "button"}, + {"ref": "@e8", "name": "Cancel", "text": "Cancel", "tag": "button"} + ] + } + }); + let healed = find_healed_selector("@e5", &page_map, Some("submit")); + assert_eq!(healed, Some("@e7".to_string())); + } + + #[test] + fn find_healed_selector_no_match_returns_none() { + let page_map = json!({ + "interactive": { + "elements": [ + {"ref": "@e7", "text": "Login", "tag": "button"} + ] + } + }); + let healed = find_healed_selector("@e5", &page_map, Some("submit")); + assert!(healed.is_none()); + } + + #[test] + fn patch_selector_replaces_top_level_field() { + let input = json!({"selector": "@e5", "other": "data"}); + let patched = patch_selector(&input, "@e5", "@e7"); + assert_eq!(patched.get("selector").and_then(Value::as_str), Some("@e7")); + assert_eq!(patched.get("other").and_then(Value::as_str), Some("data")); + } + + #[test] + fn patch_selector_rewrites_fill_form_field_key() { + let input = json!({"fields": {"@e5": "john@example.com", "#name": "John"}}); + let patched = patch_selector(&input, "@e5", "@e9"); + let fields = patched.get("fields").and_then(Value::as_object).unwrap(); + assert_eq!( + fields.get("@e9").and_then(Value::as_str), + Some("john@example.com") + ); + assert_eq!(fields.get("#name").and_then(Value::as_str), Some("John")); + assert!(!fields.contains_key("@e5")); + } +} diff --git a/crates/agent/src/state.rs b/crates/agent/src/state.rs index e211ea0c..1244d947 100644 --- a/crates/agent/src/state.rs +++ b/crates/agent/src/state.rs @@ -1,6 +1,11 @@ use runtime::ChildSession; use serde_json::Value; +use crate::action_cache::ActionCache; +use crate::loop_detector::LoopDetector; +use crate::page_fingerprint::PageFingerprint; +use crate::tools::html_diff::HtmlDiffTracker; + #[derive(Debug, Clone)] pub struct ChildBlock { pub child_id: String, @@ -17,6 +22,10 @@ pub struct CrawlState { pub child_blocks: Vec, pub max_steps: usize, pub captured_child_sessions: Vec, + pub page_fingerprints: Vec, + pub action_cache: Option, + pub html_diff_tracker: Option, + pub loop_detector: Option, } impl CrawlState { @@ -30,6 +39,10 @@ impl CrawlState { child_blocks: Vec::new(), max_steps: child_max_steps, captured_child_sessions: Vec::new(), + page_fingerprints: Vec::new(), + action_cache: None, + html_diff_tracker: None, + loop_detector: None, } } diff --git a/crates/agent/src/tools/feedback.rs b/crates/agent/src/tools/feedback.rs index 93acb78b..91e0b02d 100644 --- a/crates/agent/src/tools/feedback.rs +++ b/crates/agent/src/tools/feedback.rs @@ -4,6 +4,8 @@ use std::time::Duration; use serde_json::{json, Value}; use tokio::time::timeout; +use crate::page_fingerprint::PageFingerprint; +use crate::state::CrawlState; use crate::BrowserContext; use super::page_map::{apply_page_map_caps, normalize_url}; @@ -416,6 +418,17 @@ pub async fn post_action_page_state(browser: &mut BrowserContext) -> Value { } } +/// Record a page fingerprint into `CrawlState` if the `page_fingerprinting` +/// optimization flag is enabled in settings. +pub fn record_page_fingerprint(url: &str, page_map: &Value, crawl_state: &mut CrawlState) { + let settings = runtime::load_settings(); + if !runtime::settings_get_page_fingerprinting(&settings) { + return; + } + let fingerprint = PageFingerprint::compute(url, page_map); + crawl_state.page_fingerprints.push(fingerprint); +} + #[cfg(test)] mod tests { use serde_json::json; diff --git a/crates/agent/src/tools/html_diff.rs b/crates/agent/src/tools/html_diff.rs new file mode 100644 index 00000000..89af9add --- /dev/null +++ b/crates/agent/src/tools/html_diff.rs @@ -0,0 +1,191 @@ +use std::collections::HashMap; + +#[derive(Debug, Clone)] +pub struct ContentSection { + pub heading: String, + pub hash: u64, + pub content: String, +} + +#[derive(Debug, Clone, Default)] +pub struct HtmlDiffTracker { + /// Maps URL → Vec of previously seen sections + cached: HashMap>, +} + +impl HtmlDiffTracker { + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Split content by heading boundaries and cache it for URL. + /// Returns the sections for further use. + pub fn update(&mut self, url: &str, content: &str) -> Vec { + let sections = split_into_sections(content); + self.cached.insert(url.to_string(), sections.clone()); + sections + } + + /// Returns `Some(diff_output)` if we have a previous version to diff against. + /// Returns `None` if this is the first visit (caller should use full content). + pub fn diff(&mut self, url: &str, new_content: &str) -> Option { + let prev = self.cached.get(url)?.clone(); + let new_sections = split_into_sections(new_content); + + let mut output_parts = Vec::new(); + let mut unchanged_run = 0usize; + + for (i, new_sec) in new_sections.iter().enumerate() { + let prev_hash = prev.get(i).map(|s| s.hash); + if prev_hash == Some(new_sec.hash) { + unchanged_run += 1; + } else { + if unchanged_run > 0 { + output_parts.push(format!("[unchanged: {unchanged_run} sections]")); + unchanged_run = 0; + } + output_parts.push(new_sec.content.clone()); + } + } + + if unchanged_run > 0 { + output_parts.push(format!("[unchanged: {unchanged_run} sections]")); + } + + self.cached.insert(url.to_string(), new_sections); + + Some(output_parts.join("\n\n")) + } +} + +fn hash_str(s: &str) -> u64 { + const FNV_OFFSET: u64 = 14_695_981_039_346_656_037; + const FNV_PRIME: u64 = 1_099_511_628_211; + let mut hash = FNV_OFFSET; + for byte in s.bytes() { + hash ^= u64::from(byte); + hash = hash.wrapping_mul(FNV_PRIME); + } + hash +} + +fn split_into_sections(content: &str) -> Vec { + let mut sections = Vec::new(); + let mut current_heading = String::new(); + let mut current_lines: Vec<&str> = Vec::new(); + + for line in content.lines() { + if line.starts_with('#') { + if !current_lines.is_empty() || !current_heading.is_empty() { + let content_str = current_lines.join("\n"); + sections.push(ContentSection { + heading: current_heading.clone(), + hash: hash_str(&content_str), + content: if current_heading.is_empty() { + content_str + } else { + format!("{current_heading}\n{content_str}") + }, + }); + current_lines.clear(); + } + current_heading = line.to_string(); + } else { + current_lines.push(line); + } + } + + if !current_lines.is_empty() || !current_heading.is_empty() { + let content_str = current_lines.join("\n"); + sections.push(ContentSection { + heading: current_heading.clone(), + hash: hash_str(&content_str), + content: if current_heading.is_empty() { + content_str + } else { + format!("{current_heading}\n{content_str}") + }, + }); + } + + if sections.is_empty() && !content.is_empty() { + sections.push(ContentSection { + heading: String::new(), + hash: hash_str(content), + content: content.to_string(), + }); + } + + sections +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn first_visit_returns_none_no_diff() { + let mut tracker = HtmlDiffTracker::new(); + let result = tracker.diff("https://example.com", "# Hello\nWorld"); + assert!(result.is_none(), "First visit should return None"); + } + + #[test] + fn second_visit_unchanged_returns_unchanged_marker() { + let mut tracker = HtmlDiffTracker::new(); + let content = "# Section 1\nContent 1\n# Section 2\nContent 2"; + tracker.update("https://example.com", content); + let diff = tracker.diff("https://example.com", content).unwrap(); + assert!( + diff.contains("[unchanged:"), + "Should have unchanged markers for identical content" + ); + assert!( + !diff.contains("Content 1"), + "Unchanged content should not appear" + ); + } + + #[test] + fn second_visit_changed_section_returned() { + let mut tracker = HtmlDiffTracker::new(); + tracker.update( + "https://example.com", + "# Section 1\nOld\n# Section 2\nUnchanged", + ); + let new_content = "# Section 1\nNew\n# Section 2\nUnchanged"; + let diff = tracker.diff("https://example.com", new_content).unwrap(); + assert!(diff.contains("New"), "Changed section should appear"); + assert!( + diff.contains("[unchanged:"), + "Unchanged section should be marker" + ); + } + + #[test] + fn diff_output_smaller_than_full_content() { + let mut tracker = HtmlDiffTracker::new(); + let sections: Vec = (0..10) + .map(|i| format!("# Section {i}\nContent {i}")) + .collect(); + let full_content = sections.join("\n"); + tracker.update("https://example.com", &full_content); + let mut new_sections = sections.clone(); + new_sections[0] = "# Section 0\nNew Content".to_string(); + let new_content = new_sections.join("\n"); + let diff = tracker.diff("https://example.com", &new_content).unwrap(); + assert!( + diff.len() < new_content.len(), + "Diff output should be smaller than full content" + ); + } + + #[test] + fn content_without_headings_is_single_section() { + let sections = split_into_sections("plain text only"); + assert_eq!(sections.len(), 1); + assert_eq!(sections[0].heading, ""); + assert_eq!(sections[0].content, "plain text only"); + } +} diff --git a/crates/agent/src/tools/mod.rs b/crates/agent/src/tools/mod.rs index 5379545f..10b49e32 100644 --- a/crates/agent/src/tools/mod.rs +++ b/crates/agent/src/tools/mod.rs @@ -5,6 +5,7 @@ pub mod feedback; pub mod fill_form; pub mod fork; pub mod go_back; +pub mod html_diff; pub mod navigate; pub mod page_map; pub mod read_content; diff --git a/crates/agent/src/tools/navigate.rs b/crates/agent/src/tools/navigate.rs index aa20cf58..835c994a 100644 --- a/crates/agent/src/tools/navigate.rs +++ b/crates/agent/src/tools/navigate.rs @@ -1,7 +1,9 @@ use serde_json::{json, Value}; use crate::markdown::{extract_main_html, html_to_markdown, DEFAULT_MAX_MARKDOWN_CHARS}; -use crate::prune::prune_html; +use crate::prune::{prune_html_with_profile, select_profile, CleaningProfile}; +use crate::state::CrawlState; +use crate::tools::html_diff::HtmlDiffTracker; use crate::tools::page_map::{annotate_refs, apply_page_map_caps, normalize_url}; use crate::BrowserContext; use crate::FetchRouter; @@ -213,6 +215,7 @@ fn resolve_content( markdown: &str, format: &str, depth: &ContentDepth, + profile: CleaningProfile, ) -> (String, bool) { if *depth == ContentDepth::None { return (String::new(), false); @@ -232,7 +235,7 @@ fn resolve_content( "text" => cap_content(text, max_chars), "html" => cap_content(html, max_chars), "fit_markdown" => { - let pruned = prune_html(html); + let pruned = prune_html_with_profile(html, profile); let md = html_to_markdown(&pruned); if md.trim().is_empty() && !text.trim().is_empty() { cap_content(text, max_chars) @@ -262,7 +265,7 @@ fn resolve_content( } } "fit_markdown" => { - let pruned = prune_html(&main_html); + let pruned = prune_html_with_profile(&main_html, profile); let md = html_to_markdown(&pruned); if md.trim().is_empty() && !text.trim().is_empty() { cap_content(text, max_chars) @@ -333,59 +336,73 @@ fn slim_page_map(page_map: &mut Value) { } } -pub async fn execute( - input: &Value, - browser: &mut BrowserContext, -) -> Result { - let params = parse_input(input)?; - - let router = FetchRouter::new().map_err(|e| ToolExecutionError::new(e.to_string()))?; - let page = router - .fetch(¶ms.url, Some(browser)) - .await - .map_err(|e| ToolExecutionError::new(e.to_string()))?; - - let title = page.title.clone().unwrap_or_default(); +fn apply_html_diff(crawl_state: &mut CrawlState, url: &str, content: &mut String) { + let settings = runtime::load_settings(); + if !runtime::settings_get_html_diff_mode(&settings) { + return; + } - let (content, truncated) = resolve_content( - &page.html, - &page.text, - &page.markdown, - ¶ms.format, - ¶ms.content_depth, - ); + if crawl_state.html_diff_tracker.is_none() { + crawl_state.html_diff_tracker = Some(HtmlDiffTracker::new()); + } - let content = - if params.strip_images && matches!(params.format.as_str(), "markdown" | "fit_markdown") { - strip_markdown_images(&content) + if let Some(tracker) = crawl_state.html_diff_tracker.as_mut() { + if let Some(diff_output) = tracker.diff(url, content) { + *content = diff_output; } else { - content - }; + tracker.update(url, content); + } + } +} - browser.set_navigated_url(&page.url, page.fetched_via_browser); - browser.ref_map_mut().clear(); +fn content_depth_label(depth: &ContentDepth) -> &'static str { + match depth { + ContentDepth::Full => "full", + ContentDepth::Main => "main", + ContentDepth::Slim => "slim", + ContentDepth::None => "none", + } +} - if params.page_map_depth == PageMapDepth::None { - let content_length = content.chars().count(); - return Ok(ToolEffect::reply_json(&json!({ - "url": page.url, - "title": title, - "content": content, - "format": params.format, - "content_depth": match params.content_depth { - ContentDepth::Full => "full", - ContentDepth::Main => "main", - ContentDepth::Slim => "slim", - ContentDepth::None => "none", - }, - "truncated": truncated, - "content_length": content_length, - }))); +fn content_profile(html_len: usize) -> CleaningProfile { + let nav_settings = runtime::load_settings(); + if runtime::settings_get_content_aware_profiles(&nav_settings) { + select_profile(None, html_len) + } else { + CleaningProfile::Default } +} - let mut page_map = if page.fetched_via_browser { +fn reply_without_page_map( + page: &browser::FetchedPage, + title: &str, + content: &str, + format: &str, + content_depth: &ContentDepth, + truncated: bool, +) -> ToolEffect { + let content_length = content.chars().count(); + ToolEffect::reply_json(&json!({ + "url": page.url, + "title": title, + "content": content, + "format": format, + "content_depth": content_depth_label(content_depth), + "truncated": truncated, + "content_length": content_length, + })) +} + +async fn build_page_map( + browser: &mut BrowserContext, + page: &browser::FetchedPage, + title: &str, +) -> Value { + if page.fetched_via_browser { + let nav_settings = runtime::load_settings(); + let compound_enrichment = runtime::settings_get_compound_enrichment(&nav_settings); match browser.acquire_bridge().await { - Ok(mut bridge) => match bridge.page_map(None).await { + Ok(mut bridge) => match bridge.page_map(None, compound_enrichment).await { Ok(mut value) => { apply_page_map_caps(&mut value); value @@ -410,21 +427,87 @@ pub async fn execute( let mut value = extract_headings_from_markdown(&page.markdown); if let Some(meta) = value.get_mut("meta").and_then(Value::as_object_mut) { meta.insert("url".to_string(), json!(page.url.clone())); - meta.insert("title".to_string(), json!(title.clone())); + meta.insert("title".to_string(), json!(title)); } value - }; - - annotate_refs(&mut page_map, browser); + } +} +fn cache_page_map_snapshot( + browser: &mut BrowserContext, + crawl_state: &mut CrawlState, + page_url: &str, + page_map: &Value, +) { let pm_url = page_map .get("meta") - .and_then(|m| m.get("url")) + .and_then(|meta| meta.get("url")) .and_then(Value::as_str) .unwrap_or("unknown"); let cache_key = normalize_url(pm_url).to_string(); browser.set_page_snapshot(cache_key, page_map.clone()); + let fp_settings = runtime::load_settings(); + if runtime::settings_get_page_fingerprinting(&fp_settings) { + let fp = crate::page_fingerprint::PageFingerprint::compute(page_url, page_map); + crawl_state.page_fingerprints.push(fp); + } +} + +pub async fn execute( + input: &Value, + browser: &mut BrowserContext, + crawl_state: &mut CrawlState, +) -> Result { + let params = parse_input(input)?; + + let router = FetchRouter::new().map_err(|e| ToolExecutionError::new(e.to_string()))?; + let page = router + .fetch(¶ms.url, Some(browser)) + .await + .map_err(|e| ToolExecutionError::new(e.to_string()))?; + + let title = page.title.clone().unwrap_or_default(); + let profile = content_profile(page.html.len()); + + let (content, truncated) = resolve_content( + &page.html, + &page.text, + &page.markdown, + ¶ms.format, + ¶ms.content_depth, + profile, + ); + + let mut content = + if params.strip_images && matches!(params.format.as_str(), "markdown" | "fit_markdown") { + strip_markdown_images(&content) + } else { + content + }; + + apply_html_diff(crawl_state, &page.url, &mut content); + + browser.set_navigated_url(&page.url, page.fetched_via_browser); + crawl_state.current_url = Some(page.url.clone()); + browser.ref_map_mut().clear(); + + if params.page_map_depth == PageMapDepth::None { + return Ok(reply_without_page_map( + &page, + &title, + &content, + ¶ms.format, + ¶ms.content_depth, + truncated, + )); + } + + let mut page_map = build_page_map(browser, &page, &title).await; + + annotate_refs(&mut page_map, browser); + cache_page_map_snapshot(browser, crawl_state, &page.url, &page_map); + if params.page_map_depth == PageMapDepth::Slim { slim_page_map(&mut page_map); } @@ -436,12 +519,7 @@ pub async fn execute( "title": title, "content": content, "format": params.format, - "content_depth": match params.content_depth { - ContentDepth::Full => "full", - ContentDepth::Main => "main", - ContentDepth::Slim => "slim", - ContentDepth::None => "none", - }, + "content_depth": content_depth_label(¶ms.content_depth), "truncated": truncated, "content_length": content_length, "page_map": page_map @@ -541,6 +619,7 @@ mod tests { "hello", "markdown", &ContentDepth::None, + CleaningProfile::Default, ); assert!(content.is_empty()); assert!(!truncated); @@ -551,7 +630,14 @@ mod tests { let html = r"

Title

Body text

Footer
"; let md = html_to_markdown(html); - let (content, _) = resolve_content(html, "text", &md, "markdown", &ContentDepth::Main); + let (content, _) = resolve_content( + html, + "text", + &md, + "markdown", + &ContentDepth::Main, + CleaningProfile::Default, + ); assert!(content.contains("Title")); assert!(content.contains("Body text")); assert!(!content.contains("Menu")); @@ -562,7 +648,14 @@ mod tests { fn resolve_content_full_includes_everything() { let html = r"

Header

Body

"; let md = html_to_markdown(html); - let (content, _) = resolve_content(html, "text", &md, "markdown", &ContentDepth::Full); + let (content, _) = resolve_content( + html, + "text", + &md, + "markdown", + &ContentDepth::Full, + CleaningProfile::Default, + ); assert!(content.contains("Header")); assert!(content.contains("Body")); } @@ -572,8 +665,14 @@ mod tests { let body = "a".repeat(5000); let html = format!("

{body}

"); let md = html_to_markdown(&html); - let (content, truncated) = - resolve_content(&html, "text", &md, "markdown", &ContentDepth::Slim); + let (content, truncated) = resolve_content( + &html, + "text", + &md, + "markdown", + &ContentDepth::Slim, + CleaningProfile::Default, + ); assert!(truncated); assert!(content.chars().count() <= SLIM_MAX_CHARS); } @@ -674,8 +773,14 @@ mod tests { let html = r#"

Main content here

"#; let text = "Main content here Buy now! menu"; let markdown = html_to_markdown(html); - let (content, _) = - resolve_content(html, text, &markdown, "fit_markdown", &ContentDepth::Main); + let (content, _) = resolve_content( + html, + text, + &markdown, + "fit_markdown", + &ContentDepth::Main, + CleaningProfile::Default, + ); assert!( content.contains("Main content"), "main content should survive pruning" @@ -688,8 +793,14 @@ mod tests { let html = r"

Title

Quality paragraph content.

"; let text = "Title Quality paragraph content."; let markdown = html_to_markdown(html); - let (content, truncated) = - resolve_content(html, text, &markdown, "fit_markdown", &ContentDepth::Full); + let (content, truncated) = resolve_content( + html, + text, + &markdown, + "fit_markdown", + &ContentDepth::Full, + CleaningProfile::Default, + ); assert!( !content.is_empty(), "full depth fit_markdown should return content" @@ -703,8 +814,14 @@ mod tests { let html = r#"
advertisement
"#; let text = "advertisement fallback text"; let markdown = html_to_markdown(html); - let (content, _) = - resolve_content(html, text, &markdown, "fit_markdown", &ContentDepth::Main); + let (content, _) = resolve_content( + html, + text, + &markdown, + "fit_markdown", + &ContentDepth::Main, + CleaningProfile::Default, + ); // Pruning removes all content (ads class) → must fall back to text assert!( content.contains("fallback text"), diff --git a/crates/agent/src/tools/page_map.rs b/crates/agent/src/tools/page_map.rs index 9598d090..423504be 100644 --- a/crates/agent/src/tools/page_map.rs +++ b/crates/agent/src/tools/page_map.rs @@ -81,14 +81,18 @@ fn truncate_array_field(value: &mut Value, key: &str, max_len: usize) -> bool { pub async fn execute( input: &Value, browser: &mut BrowserContext, + crawl_state: &mut crate::state::CrawlState, ) -> Result { let scope = input.get("scope").and_then(Value::as_str); + let settings = runtime::load_settings(); + let compound_enrichment = runtime::settings_get_compound_enrichment(&settings); + let mut result = browser .acquire_bridge() .await .map_err(|e| ToolExecutionError::new(e.to_string()))? - .page_map(scope) + .page_map(scope, compound_enrichment) .await .map_err(|e| ToolExecutionError::new(e.to_string()))?; @@ -115,6 +119,17 @@ pub async fn execute( annotate_refs(&mut result, browser); } + let fp_settings = runtime::load_settings(); + if runtime::settings_get_page_fingerprinting(&fp_settings) { + let url = result + .get("meta") + .and_then(|meta| meta.get("url")) + .and_then(Value::as_str) + .unwrap_or("unknown"); + let fingerprint = crate::page_fingerprint::PageFingerprint::compute(url, &result); + crawl_state.page_fingerprints.push(fingerprint); + } + Ok(ToolEffect::reply_json(&result)) } diff --git a/crates/agent/src/tools/read_content.rs b/crates/agent/src/tools/read_content.rs index 2d97d27b..476daf72 100644 --- a/crates/agent/src/tools/read_content.rs +++ b/crates/agent/src/tools/read_content.rs @@ -1,5 +1,7 @@ use serde_json::Value; +use crate::state::CrawlState; +use crate::tools::html_diff::HtmlDiffTracker; use crate::BrowserContext; use crate::{CrawlError, ToolEffect, ToolExecutionError}; @@ -36,16 +38,42 @@ fn parse_input( pub async fn execute( input: &Value, browser: &mut BrowserContext, + crawl_state: &mut CrawlState, ) -> Result { let (heading, selector, offset, max_chars) = parse_input(input)?; let mut bridge = browser .acquire_bridge() .await .map_err(|e| ToolExecutionError::new(e.to_string()))?; - let result = bridge + let mut result = bridge .read_content(heading.as_deref(), selector.as_deref(), offset, max_chars) .await .map_err(|e| ToolExecutionError::new(e.to_string()))?; + + let settings = runtime::load_settings(); + if runtime::settings_get_html_diff_mode(&settings) { + if let Some(url) = crawl_state.current_url.as_deref() { + if crawl_state.html_diff_tracker.is_none() { + crawl_state.html_diff_tracker = Some(HtmlDiffTracker::new()); + } + + if let Some(content) = result.get("content").and_then(Value::as_str) { + let mut content = content.to_string(); + if let Some(tracker) = crawl_state.html_diff_tracker.as_mut() { + if let Some(diff_output) = tracker.diff(url, &content) { + content = diff_output; + } else { + tracker.update(url, &content); + } + } + + if let Some(obj) = result.as_object_mut() { + obj.insert("content".to_string(), Value::String(content)); + } + } + } + } + Ok(ToolEffect::reply_json(&result)) } diff --git a/crates/agent/src/tools/screenshot.rs b/crates/agent/src/tools/screenshot.rs index bd783b8f..afef36de 100644 --- a/crates/agent/src/tools/screenshot.rs +++ b/crates/agent/src/tools/screenshot.rs @@ -121,7 +121,6 @@ pub async fn execute( mod tests { use super::*; use std::path::PathBuf; - use std::sync::OnceLock; use async_trait::async_trait; @@ -129,8 +128,6 @@ mod tests { BridgeError, BrowserBackend, BrowserState, PageInfo, ScreenshotOptions, SharedBridge, }; - static ENV_LOCK: OnceLock> = OnceLock::new(); - fn setup_temp_dir(suffix: &str) -> PathBuf { let dir = std::env::temp_dir().join(format!( "acrawl_screenshot_test_{}_{suffix}", @@ -195,6 +192,7 @@ mod tests { async fn page_map( &mut self, _scope: Option<&str>, + _compound_enrichment: bool, ) -> Result { Ok(serde_json::json!({})) } @@ -333,10 +331,7 @@ mod tests { #[tokio::test] #[allow(clippy::await_holding_lock)] async fn execute_save_true_writes_file_default_filename() { - let _lock = ENV_LOCK - .get_or_init(|| std::sync::Mutex::new(())) - .lock() - .unwrap(); + let _lock = crate::test_async_env_lock().lock().await; let temp_dir = setup_temp_dir("default_fn"); let output_dir = temp_dir.join("output"); @@ -378,10 +373,7 @@ mod tests { #[tokio::test] #[allow(clippy::await_holding_lock)] async fn execute_save_true_custom_filename() { - let _lock = ENV_LOCK - .get_or_init(|| std::sync::Mutex::new(())) - .lock() - .unwrap(); + let _lock = crate::test_async_env_lock().lock().await; let temp_dir = setup_temp_dir("custom_fn"); let output_dir = temp_dir.join("screenshots"); @@ -431,10 +423,7 @@ mod tests { #[tokio::test] #[allow(clippy::await_holding_lock)] async fn execute_save_true_invalid_base64_errors() { - let _lock = ENV_LOCK - .get_or_init(|| std::sync::Mutex::new(())) - .lock() - .unwrap(); + let _lock = crate::test_async_env_lock().lock().await; let temp_dir = setup_temp_dir("bad_b64"); let output_dir = temp_dir.join("output"); @@ -462,10 +451,7 @@ mod tests { #[tokio::test] #[allow(clippy::await_holding_lock)] async fn execute_save_true_write_error_on_invalid_dir() { - let _lock = ENV_LOCK - .get_or_init(|| std::sync::Mutex::new(())) - .lock() - .unwrap(); + let _lock = crate::test_async_env_lock().lock().await; let temp_dir = setup_temp_dir("write_err"); // Point output_dir at a FILE, so create_dir_all will fail let blocker = temp_dir.join("blocked"); diff --git a/crates/browser/src/browser_backend.rs b/crates/browser/src/browser_backend.rs index 8c53c179..bc88cea5 100644 --- a/crates/browser/src/browser_backend.rs +++ b/crates/browser/src/browser_backend.rs @@ -23,7 +23,11 @@ pub trait BrowserBackend: Debug { async fn new_page(&mut self, url: Option<&str>) -> Result; async fn close_page(&mut self, page_index: usize) -> Result<(), BridgeError>; async fn scroll(&mut self, direction: &str, pixels: i64) -> Result<(), BridgeError>; - async fn page_map(&mut self, scope: Option<&str>) -> Result; + async fn page_map( + &mut self, + scope: Option<&str>, + compound_enrichment: bool, + ) -> Result; async fn read_content( &mut self, heading: Option<&str>, @@ -58,6 +62,6 @@ pub trait BrowserBackend: Debug { async fn go_back(&mut self) -> Result; async fn page_map_feedback(&mut self) -> Result { - self.page_map(None).await + self.page_map(None, false).await } } diff --git a/crates/browser/src/extension.rs b/crates/browser/src/extension.rs index 49e94b3c..3e7db858 100644 --- a/crates/browser/src/extension.rs +++ b/crates/browser/src/extension.rs @@ -145,11 +145,18 @@ impl BrowserBackend for ExtensionBridge { .await } - async fn page_map(&mut self, scope: Option<&str>) -> Result { + async fn page_map( + &mut self, + scope: Option<&str>, + compound_enrichment: bool, + ) -> Result { let mut payload = json!({}); if let Some(s) = scope { payload["scope"] = json!(s); } + if compound_enrichment { + payload["compoundEnrichment"] = json!(true); + } let response = self.send_command("page_map", payload).await?; Self::require_result(response, "page_map") } diff --git a/crates/browser/src/playwright/backend_impl.rs b/crates/browser/src/playwright/backend_impl.rs index ac191ffb..38a7254c 100644 --- a/crates/browser/src/playwright/backend_impl.rs +++ b/crates/browser/src/playwright/backend_impl.rs @@ -21,8 +21,12 @@ impl BrowserBackend for PlaywrightBridge { PlaywrightBridge::scroll(self, direction, pixels).await } - async fn page_map(&mut self, scope: Option<&str>) -> Result { - PlaywrightBridge::page_map(self, scope).await + async fn page_map( + &mut self, + scope: Option<&str>, + compound_enrichment: bool, + ) -> Result { + PlaywrightBridge::page_map(self, scope, compound_enrichment).await } async fn read_content( diff --git a/crates/browser/src/playwright/bridge.rs b/crates/browser/src/playwright/bridge.rs index 7d1a3eb5..c1bd0dc3 100644 --- a/crates/browser/src/playwright/bridge.rs +++ b/crates/browser/src/playwright/bridge.rs @@ -230,11 +230,15 @@ impl PlaywrightBridge { pub async fn page_map( &mut self, scope: Option<&str>, + compound_enrichment: bool, ) -> Result { let mut cmd = serde_json::json!({ "action": "page_map" }); if let Some(s) = scope { cmd["scope"] = serde_json::Value::String(s.to_string()); } + if compound_enrichment { + cmd["compoundEnrichment"] = serde_json::Value::Bool(true); + } self.send_raw_command(&cmd).await } diff --git a/crates/browser/src/playwright/bridge_script.rs b/crates/browser/src/playwright/bridge_script.rs index 01af62c4..5af4257a 100644 --- a/crates/browser/src/playwright/bridge_script.rs +++ b/crates/browser/src/playwright/bridge_script.rs @@ -356,7 +356,8 @@ async function bootstrap() { if (command.action === 'page_map') { try { const scope = command.scope || null; - const result = await page.evaluate((scope) => { + const compoundEnrichment = command.compoundEnrichment || false; + const result = await page.evaluate(({scope, compoundEnrichment}) => { let root = document; if (scope) { const scoped = document.querySelector(scope); @@ -457,6 +458,58 @@ async function bootstrap() { const MAX_INTERACTIVE = 30; const interactiveEls = []; + + function getEnrichment(el, tag, elType) { + if (!compoundEnrichment) return null; + let enrichment = null; + if (tag === 'input') { + const inputType = elType || el.type || ''; + if (inputType === 'date') enrichment = { format: 'YYYY-MM-DD' }; + else if (inputType === 'time') enrichment = { format: 'HH:MM' }; + else if (inputType === 'datetime-local') enrichment = { format: 'YYYY-MM-DDTHH:MM' }; + else if (inputType === 'range') { + enrichment = { + min: el.min !== '' ? Number(el.min) : 0, + max: el.max !== '' ? Number(el.max) : 100, + step: el.step !== '' ? Number(el.step) : 1, + value: el.value !== '' ? Number(el.value) : 50 + }; + } else if (inputType === 'number') { + enrichment = {}; + if (el.min !== '') enrichment.min = Number(el.min); + if (el.max !== '') enrichment.max = Number(el.max); + if (el.step !== '') enrichment.step = Number(el.step); + if (Object.keys(enrichment).length === 0) enrichment = null; + } else if (inputType === 'color') { + enrichment = { value: el.value || '#000000' }; + } else if (inputType === 'file') { + enrichment = { accept: el.accept || '*' }; + } + } else if (tag === 'select') { + const opts = Array.from(el.options || []); + const total = opts.length; + const visible = opts.slice(0, 20).map(o => o.text.trim()).filter(t => t.length > 0); + if (total > 20) visible.push('...and ' + (total - 20) + ' more'); + enrichment = { options: visible, total_options: total }; + } else if (tag === 'textarea') { + const ml = el.getAttribute('maxlength'); + if (ml) enrichment = { maxlength: Number(ml) }; + } + if (enrichment) { + const json = JSON.stringify(enrichment); + if (json.length > 200) { + if (enrichment.options && Array.isArray(enrichment.options)) { + while (JSON.stringify(enrichment).length > 190 && enrichment.options.length > 1) { + enrichment.options.pop(); + } + } else { + return null; + } + } + } + return enrichment; + } + const selectors = [ ['button', 'button'], ['input', 'input:not([type="hidden"])'], @@ -544,6 +597,8 @@ async function bootstrap() { if (!elName && el.title) elName = el.title.slice(0, 60); if (!elName && el.name) elName = el.name.slice(0, 60); if (elName) entry.name = elName; + const enrichment = getEnrichment(el, entry.tag, entry.type); + if (enrichment !== null) entry.enrichment = enrichment; interactiveEls.push(entry); } } @@ -557,7 +612,7 @@ async function bootstrap() { }; return { headings, landmarks: cappedLandmarks, forms, links, interactive, meta, total_landmarks, total_forms, total_links }; - }, scope); + }, {scope, compoundEnrichment}); process.stdout.write(JSON.stringify({ event: 'bridge_response', ok: true, result }) + '\n'); } catch (error) { process.stdout.write(JSON.stringify({ event: 'bridge_response', ok: false, error: { kind: 'page_map_error', message: String(error) } }) + '\n'); diff --git a/crates/browser/src/prune.rs b/crates/browser/src/prune.rs index 8f0ff11f..6f7273f2 100644 --- a/crates/browser/src/prune.rs +++ b/crates/browser/src/prune.rs @@ -5,7 +5,6 @@ use std::sync::LazyLock; static BODY_SELECTOR: LazyLock> = LazyLock::new(|| Selector::parse("body").ok()); static ANCHOR_SELECTOR: LazyLock> = LazyLock::new(|| Selector::parse("a").ok()); -const THRESHOLD: f64 = 0.48; const NEGATIVE_PATTERNS: [&str; 10] = [ "nav", "footer", "header", "sidebar", "ads", "comment", "promo", "advert", "social", "share", ]; @@ -14,8 +13,96 @@ const VOID_ELEMENTS: [&str; 14] = [ "track", "wbr", ]; +/// Cleaning profile that controls pruning aggressiveness and tag weight behavior. +/// +/// Used when the `content_aware_profiles` optimization flag is enabled. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum CleaningProfile { + /// Current default behavior — threshold 0.48. + #[default] + Default, + /// Only remove obvious junk — threshold 0.20. Preserves interactive elements. + Minimal, + /// Heavy pruning for large pages — threshold 0.65. + Aggressive, + /// Article extraction — threshold 0.45, boosted article/paragraph weights. + ReadingMode, +} + +impl CleaningProfile { + /// Score threshold below which elements are pruned. + #[must_use] + pub fn threshold(self) -> f64 { + match self { + Self::Default => 0.48, + Self::Minimal => 0.20, + Self::Aggressive => 0.65, + Self::ReadingMode => 0.45, + } + } + + /// Extra weight multiplier for a given HTML tag (applied on top of base weights). + #[must_use] + pub fn tag_weight_multiplier(self, tag: &str) -> f64 { + match self { + Self::ReadingMode => match tag { + "article" | "main" => 2.0, + "p" | "h1" | "h2" | "h3" | "h4" => 1.5, + "nav" | "aside" | "footer" | "header" => 0.1, + _ => 1.0, + }, + Self::Minimal => match tag { + "form" | "input" | "button" | "select" | "textarea" | "label" => 2.0, + _ => 1.0, + }, + Self::Default | Self::Aggressive => 1.0, + } + } +} + +/// Select a cleaning profile based on an optional task hint string and content length. +/// +/// Used when the `content_aware_profiles` flag is ON. +#[must_use] +pub fn select_profile(task_hint: Option<&str>, content_len: usize) -> CleaningProfile { + if content_len > 50_000 { + return CleaningProfile::Aggressive; + } + + let Some(hint) = task_hint else { + return CleaningProfile::Default; + }; + let h = hint.to_lowercase(); + + if h.contains("extract") + || h.contains("scrape") + || h.contains("get data") + || h.contains("read") + || h.contains("article") + { + CleaningProfile::ReadingMode + } else if h.contains("fill") + || h.contains("click") + || h.contains("interact") + || h.contains("submit") + || h.contains("form") + { + CleaningProfile::Minimal + } else { + CleaningProfile::Default + } +} + #[must_use] pub fn prune_html(html: &str) -> String { + prune_html_with_profile(html, CleaningProfile::Default) +} + +/// Prune HTML using a specific cleaning profile. +/// +/// The profile controls the score threshold and tag weight multipliers. +#[must_use] +pub fn prune_html_with_profile(html: &str, profile: CleaningProfile) -> String { if html.trim().is_empty() { return String::new(); } @@ -31,16 +118,16 @@ pub fn prune_html(html: &str) -> String { body.children() .filter_map(ElementRef::wrap) - .filter_map(prune_node) + .filter_map(|el| prune_node(el, profile)) .collect::() } -fn prune_node(element: ElementRef<'_>) -> Option { +fn prune_node(element: ElementRef<'_>, profile: CleaningProfile) -> Option { if has_negative_class_id_pattern(element) { return None; } - if score_element(element) < THRESHOLD { + if score_element(element, profile) < profile.threshold() { return None; } @@ -54,7 +141,7 @@ fn prune_node(element: ElementRef<'_>) -> Option { .children() .filter_map(|child| match child.value() { Node::Text(text) => Some(escape_text(text.as_ref())), - Node::Element(_) => ElementRef::wrap(child).and_then(prune_node), + Node::Element(_) => ElementRef::wrap(child).and_then(|el| prune_node(el, profile)), _ => None, }) .collect::(); @@ -62,7 +149,7 @@ fn prune_node(element: ElementRef<'_>) -> Option { Some(format!("<{tag_name}{attrs}>{inner}")) } -fn score_element(element: ElementRef<'_>) -> f64 { +fn score_element(element: ElementRef<'_>, profile: CleaningProfile) -> f64 { let text: String = element.text().collect(); let text_len = text.trim().len(); let tag_len = element.inner_html().len(); @@ -84,9 +171,13 @@ fn score_element(element: ElementRef<'_>) -> f64 { 0.0 }; + let base_tag_weight = tag_weight(element.value().name()); + let adjusted_tag_weight = + base_tag_weight * profile.tag_weight_multiplier(element.value().name()); + 0.4 * text_density + 0.2 * link_density_complement - + 0.2 * tag_weight(element.value().name()) + + 0.2 * adjusted_tag_weight + 0.1 * f64::max(0.0, class_id_score) + 0.1 * ln_term } @@ -185,6 +276,10 @@ mod tests { use super::*; fn first_body_child_score(html: &str) -> f64 { + first_body_child_score_with_profile(html, CleaningProfile::Default) + } + + fn first_body_child_score_with_profile(html: &str, profile: CleaningProfile) -> f64 { let document = Html::parse_document(html); let body_selector = Selector::parse("body").expect("body selector should parse"); let body = document @@ -196,7 +291,7 @@ mod tests { .find_map(ElementRef::wrap) .expect("body should have a child element"); - score_element(element) + score_element(element, profile) } #[test] @@ -212,13 +307,14 @@ mod tests { let high_html = r"
abc
"; let low_score = first_body_child_score(low_html); let high_score = first_body_child_score(high_html); + let threshold = CleaningProfile::Default.threshold(); assert!( - low_score < THRESHOLD, + low_score < threshold, "expected low score below threshold, got {low_score}" ); assert!( - high_score >= THRESHOLD, + high_score >= threshold, "expected high score above threshold, got {high_score}" ); assert!(prune_html(high_html).contains("abc")); @@ -264,4 +360,154 @@ mod tests { "nested link-heavy sidebar should be pruned" ); } + + #[test] + fn default_profile_matches_prune_html() { + let html = r#" +

This is real content that should survive.

+ + "#; + let default_result = prune_html(html); + let profile_result = prune_html_with_profile(html, CleaningProfile::Default); + assert_eq!(default_result, profile_result); + } + + #[test] + fn profile_threshold_values() { + assert!((CleaningProfile::Default.threshold() - 0.48).abs() < f64::EPSILON); + assert!((CleaningProfile::Minimal.threshold() - 0.20).abs() < f64::EPSILON); + assert!((CleaningProfile::Aggressive.threshold() - 0.65).abs() < f64::EPSILON); + assert!((CleaningProfile::ReadingMode.threshold() - 0.45).abs() < f64::EPSILON); + } + + #[test] + fn select_profile_no_hint_small_content() { + assert_eq!(select_profile(None, 1000), CleaningProfile::Default); + } + + #[test] + fn select_profile_large_content_always_aggressive() { + assert_eq!(select_profile(None, 60_000), CleaningProfile::Aggressive); + assert_eq!( + select_profile(Some("extract data"), 60_000), + CleaningProfile::Aggressive + ); + assert_eq!( + select_profile(Some("fill form"), 60_000), + CleaningProfile::Aggressive + ); + } + + #[test] + fn select_profile_reading_keywords() { + assert_eq!( + select_profile(Some("extract titles"), 1000), + CleaningProfile::ReadingMode + ); + assert_eq!( + select_profile(Some("scrape all data"), 1000), + CleaningProfile::ReadingMode + ); + assert_eq!( + select_profile(Some("read the article"), 1000), + CleaningProfile::ReadingMode + ); + assert_eq!( + select_profile(Some("get data from page"), 1000), + CleaningProfile::ReadingMode + ); + } + + #[test] + fn select_profile_interaction_keywords() { + assert_eq!( + select_profile(Some("fill in the login form"), 1000), + CleaningProfile::Minimal + ); + assert_eq!( + select_profile(Some("click the submit button"), 1000), + CleaningProfile::Minimal + ); + assert_eq!( + select_profile(Some("interact with the page"), 1000), + CleaningProfile::Minimal + ); + } + + #[test] + fn select_profile_unknown_hint_returns_default() { + assert_eq!( + select_profile(Some("navigate to page"), 1000), + CleaningProfile::Default + ); + } + + #[test] + fn aggressive_produces_smaller_output() { + let html = " +

This is a substantial paragraph of real content that should survive most profiles.

+

Some extra content here with moderate text density value.

+

Another section with enough text to potentially survive default.

+ "; + let default_out = prune_html_with_profile(html, CleaningProfile::Default); + let aggressive_out = prune_html_with_profile(html, CleaningProfile::Aggressive); + assert!( + aggressive_out.len() <= default_out.len(), + "aggressive ({}) should produce output no larger than default ({})", + aggressive_out.len(), + default_out.len() + ); + } + + #[test] + fn minimal_preserves_more_than_default() { + let html = " +
x
+

Content that default keeps.

+ "; + let default_out = prune_html_with_profile(html, CleaningProfile::Default); + let minimal_out = prune_html_with_profile(html, CleaningProfile::Minimal); + assert!( + minimal_out.len() >= default_out.len(), + "minimal ({}) should preserve at least as much as default ({})", + minimal_out.len(), + default_out.len() + ); + } + + #[test] + fn reading_mode_boosts_article_content() { + let html = "

Article content here with enough text to score well.

"; + let default_score = first_body_child_score_with_profile(html, CleaningProfile::Default); + let reading_score = first_body_child_score_with_profile(html, CleaningProfile::ReadingMode); + assert!( + reading_score >= default_score, + "ReadingMode score ({reading_score}) should be >= Default ({default_score}) for article" + ); + } + + #[test] + fn tag_weight_multiplier_reading_mode() { + let profile = CleaningProfile::ReadingMode; + assert!((profile.tag_weight_multiplier("article") - 2.0).abs() < f64::EPSILON); + assert!((profile.tag_weight_multiplier("p") - 1.5).abs() < f64::EPSILON); + assert!((profile.tag_weight_multiplier("nav") - 0.1).abs() < f64::EPSILON); + assert!((profile.tag_weight_multiplier("div") - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn tag_weight_multiplier_minimal_mode() { + let profile = CleaningProfile::Minimal; + assert!((profile.tag_weight_multiplier("form") - 2.0).abs() < f64::EPSILON); + assert!((profile.tag_weight_multiplier("button") - 2.0).abs() < f64::EPSILON); + assert!((profile.tag_weight_multiplier("div") - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn tag_weight_multiplier_default_always_one() { + let profile = CleaningProfile::Default; + assert!((profile.tag_weight_multiplier("article") - 1.0).abs() < f64::EPSILON); + assert!((profile.tag_weight_multiplier("nav") - 1.0).abs() < f64::EPSILON); + assert!((profile.tag_weight_multiplier("form") - 1.0).abs() < f64::EPSILON); + } } diff --git a/crates/browser/src/testing.rs b/crates/browser/src/testing.rs index 331936e7..a2593d67 100644 --- a/crates/browser/src/testing.rs +++ b/crates/browser/src/testing.rs @@ -20,7 +20,11 @@ impl BrowserBackend for NopBridge { async fn scroll(&mut self, _direction: &str, _pixels: i64) -> Result<(), BridgeError> { Err(BridgeError::Protocol("NopBridge".into())) } - async fn page_map(&mut self, _scope: Option<&str>) -> Result { + async fn page_map( + &mut self, + _scope: Option<&str>, + _compound_enrichment: bool, + ) -> Result { Err(BridgeError::Protocol("NopBridge".into())) } async fn read_content( diff --git a/crates/browser/tests/public_api.rs b/crates/browser/tests/public_api.rs index 8c6561e2..1b8c7620 100644 --- a/crates/browser/tests/public_api.rs +++ b/crates/browser/tests/public_api.rs @@ -46,7 +46,11 @@ impl BrowserBackend for MockBrowserBackend { Ok(()) } - async fn page_map(&mut self, _scope: Option<&str>) -> Result { + async fn page_map( + &mut self, + _scope: Option<&str>, + _compound_enrichment: bool, + ) -> Result { Ok(serde_json::json!({"headings": []})) } diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 9793ba8e..15306ad3 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -428,7 +428,7 @@ fn parse_resume_args(args: &[String]) -> Result { fn print_system_prompt() { println!( "{}", - agent::build_system_prompt(&mvp_tool_specs()).join("\n\n") + agent::build_system_prompt(&mvp_tool_specs(), None).join("\n\n") ); } @@ -889,7 +889,7 @@ mod tests { #[test] fn system_prompt_contains_no_ide_content() { - let prompt = agent::build_system_prompt(&mvp_tool_specs()).join("\n\n"); + let prompt = agent::build_system_prompt(&mvp_tool_specs(), None).join("\n\n"); assert!( !prompt.contains("Working directory"), "system prompt should not mention working directory" diff --git a/crates/mcp-server/src/server.rs b/crates/mcp-server/src/server.rs index 8a908493..c344cf4d 100644 --- a/crates/mcp-server/src/server.rs +++ b/crates/mcp-server/src/server.rs @@ -247,9 +247,13 @@ fn execute_browser_tool( registry: &ToolRegistry, browser: &mut BrowserContext, rt: &tokio::runtime::Runtime, + crawl_state: &mut agent::state::CrawlState, ) -> Result { rt.block_on(async { - match registry.execute_async(name, input, browser).await { + match registry + .execute_async(name, input, browser, crawl_state) + .await + { Ok(ToolEffect::Reply(output)) => Ok(output), Ok(_) => Err(format!("tool `{name}` returned unsupported effect")), Err(e) => Err(e.to_string()), @@ -440,7 +444,7 @@ fn filtered_tool_specs(allowed_tools: &[String]) -> Vec { } fn build_run_goal_system_prompt(allowed_tools: &[String]) -> Vec { - agent::build_system_prompt(&filtered_tool_specs(allowed_tools)) + agent::build_system_prompt(&filtered_tool_specs(allowed_tools), None) } fn parse_run_goal_request(arguments: &Value) -> Result { @@ -908,6 +912,7 @@ fn handle_tools_call( browser: &mut Option, script_manager: &mut ScriptManager, rt: &tokio::runtime::Runtime, + crawl_state: &mut agent::state::CrawlState, ) { let Some(params) = params else { send_error(id, -32602, "missing params".to_string()); @@ -999,7 +1004,14 @@ fn handle_tools_call( } } - match execute_browser_tool(name, &arguments, registry, browser.as_mut().unwrap(), rt) { + match execute_browser_tool( + name, + &arguments, + registry, + browser.as_mut().unwrap(), + rt, + crawl_state, + ) { Ok(output) => { let result = json!({ "content": [{ "type": "text", "text": output }], @@ -1036,6 +1048,7 @@ pub fn run_mcp_server() { .expect("failed to create tokio runtime"); let mut browser: Option = None; + let mut crawl_state = agent::state::CrawlState::default(); let registry = ToolRegistry::new_with_core_tools(); let settings = runtime::load_settings(); let script_settings = settings.script.unwrap_or_default(); @@ -1085,6 +1098,7 @@ pub fn run_mcp_server() { &mut browser, &mut script_manager, &rt, + &mut crawl_state, ); } method => { @@ -1099,6 +1113,14 @@ mod tests { use super::*; use std::io::Cursor; + fn with_transport_mode_lock(f: impl FnOnce() -> T) -> T { + let _guard = JOB_MUTEX + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + set_output_mode(TransportMode::Framed); + f() + } + fn assert_jsonrpc_error( outcome: Result, expected_code: i32, @@ -1127,13 +1149,15 @@ mod tests { #[test] fn read_protocol_message_accepts_json_line_mode() { - let body = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; - let data = format!("{body}\n").into_bytes(); - let mut cursor = Cursor::new(data); - let parsed = - read_protocol_message(&mut cursor).expect("line-delimited request should parse"); - assert_eq!(parsed, body.as_bytes()); - assert_eq!(output_mode(), TransportMode::LineDelimited); + with_transport_mode_lock(|| { + let body = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + let data = format!("{body}\n").into_bytes(); + let mut cursor = Cursor::new(data); + let parsed = + read_protocol_message(&mut cursor).expect("line-delimited request should parse"); + assert_eq!(parsed, body.as_bytes()); + assert_eq!(output_mode(), TransportMode::LineDelimited); + }); } #[test] @@ -1148,12 +1172,14 @@ mod tests { #[test] fn read_protocol_message_accepts_framed_mode() { - let body = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; - let framed = encode_mcp_frame(body.as_bytes()); - let mut cursor = Cursor::new(framed); - let parsed = read_protocol_message(&mut cursor).expect("framed request should parse"); - assert_eq!(parsed, body.as_bytes()); - assert_eq!(output_mode(), TransportMode::Framed); + with_transport_mode_lock(|| { + let body = r#"{"jsonrpc":"2.0","id":1,"method":"tools/list"}"#; + let framed = encode_mcp_frame(body.as_bytes()); + let mut cursor = Cursor::new(framed); + let parsed = read_protocol_message(&mut cursor).expect("framed request should parse"); + assert_eq!(parsed, body.as_bytes()); + assert_eq!(output_mode(), TransportMode::Framed); + }); } #[test] @@ -1377,6 +1403,7 @@ mod tests { extracted_data: vec![json!({"title": "Example"})], steps_executed: 3, messages: Vec::new(), + model: Some("anthropic/claude-sonnet-4-6".to_string()), }), }; @@ -1446,6 +1473,7 @@ mod tests { extracted_data: vec![json!({"title": "Example"})], steps_executed: 3, messages: Vec::new(), + model: Some("anthropic/claude-sonnet-4-6".to_string()), }; let response = build_run_goal_success_response(&request, &result); diff --git a/crates/runtime/src/budget.rs b/crates/runtime/src/budget.rs new file mode 100644 index 00000000..eef3b905 --- /dev/null +++ b/crates/runtime/src/budget.rs @@ -0,0 +1,147 @@ +use std::str::FromStr; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; + +const MILLICENTS_PER_USD: u64 = 100_000; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BudgetMode { + Warn, + Block, + RouteDown, +} + +impl BudgetMode { + #[must_use] + pub fn parse(s: &str) -> Option { + Self::from_str(s).ok() + } +} + +impl FromStr for BudgetMode { + type Err = (); + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().as_str() { + "warn" => Ok(Self::Warn), + "block" => Ok(Self::Block), + "route_down" => Ok(Self::RouteDown), + _ => Err(()), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum BudgetDecision { + Allow, + Warn { remaining_usd: f64 }, + Block, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct BudgetEnforcer { + max_cost_usd: f64, + mode: BudgetMode, + warn_threshold_pct: u32, +} + +impl BudgetEnforcer { + #[must_use] + pub fn new(max_cost_usd: f64, mode: BudgetMode, warn_threshold_pct: u32) -> Self { + Self { + max_cost_usd, + mode, + warn_threshold_pct, + } + } + + #[must_use] + pub fn check(&self, current_cost_usd: f64) -> BudgetDecision { + if current_cost_usd >= self.max_cost_usd { + return match self.mode { + BudgetMode::Warn => BudgetDecision::Warn { remaining_usd: 0.0 }, + BudgetMode::Block | BudgetMode::RouteDown => BudgetDecision::Block, + }; + } + + let pct_used = (current_cost_usd / self.max_cost_usd) * 100.0; + if pct_used >= f64::from(self.warn_threshold_pct) { + BudgetDecision::Warn { + remaining_usd: self.max_cost_usd - current_cost_usd, + } + } else { + BudgetDecision::Allow + } + } +} + +#[must_use] +#[allow( + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_sign_loss +)] +pub fn usd_to_millicents(usd: f64) -> u64 { + (usd * MILLICENTS_PER_USD as f64) as u64 +} + +#[must_use] +#[allow(clippy::cast_precision_loss)] +pub fn millicents_to_usd(millicents: u64) -> f64 { + millicents as f64 / MILLICENTS_PER_USD as f64 +} + +pub type SharedCostCounter = Arc; + +#[must_use] +pub fn new_cost_counter() -> SharedCostCounter { + Arc::new(AtomicU64::new(0)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn allow_below_threshold() { + let enforcer = BudgetEnforcer::new(1.0, BudgetMode::Block, 80); + assert_eq!(enforcer.check(0.5), BudgetDecision::Allow); + } + + #[test] + fn warn_at_threshold() { + let enforcer = BudgetEnforcer::new(1.0, BudgetMode::Warn, 80); + assert!(matches!( + enforcer.check(0.85), + BudgetDecision::Warn { remaining_usd } if remaining_usd > 0.0 + )); + } + + #[test] + fn block_at_or_above_limit() { + let enforcer = BudgetEnforcer::new(1.0, BudgetMode::Block, 80); + assert_eq!(enforcer.check(1.01), BudgetDecision::Block); + } + + #[test] + fn route_down_acts_like_block() { + let enforcer = BudgetEnforcer::new(1.0, BudgetMode::RouteDown, 80); + assert_eq!(enforcer.check(1.01), BudgetDecision::Block); + } + + #[test] + fn warn_mode_never_blocks() { + let enforcer = BudgetEnforcer::new(1.0, BudgetMode::Warn, 80); + assert_eq!( + enforcer.check(2.0), + BudgetDecision::Warn { remaining_usd: 0.0 } + ); + } + + #[test] + fn converts_between_usd_and_millicents() { + let usd = 1.2345; + assert_eq!(usd_to_millicents(usd), 123_450); + assert!((millicents_to_usd(123_450) - usd).abs() < 1e-6); + } +} diff --git a/crates/runtime/src/conversation/mod.rs b/crates/runtime/src/conversation/mod.rs index d022091e..6cfe3604 100644 --- a/crates/runtime/src/conversation/mod.rs +++ b/crates/runtime/src/conversation/mod.rs @@ -1,6 +1,8 @@ use std::collections::BTreeMap; -use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex}; +use crate::budget::{new_cost_counter, usd_to_millicents, SharedCostCounter}; use crate::compact::{ compact_session, estimate_session_tokens, CompactionConfig, CompactionResult, }; @@ -8,7 +10,9 @@ use crate::config::RuntimeFeatureConfig; use crate::control::ControlState; use crate::observer::RuntimeObserver; use crate::session::{ContentBlock, ConversationMessage, Session}; -use crate::usage::{TokenUsage, UsageTracker}; +use crate::usage::{ + estimate_cost_usd_with_pricing, pricing_for_model, ModelPricing, TokenUsage, UsageTracker, +}; pub use acrawl_core::error::{RuntimeError, ToolError}; pub use acrawl_core::event::AssistantEvent; @@ -42,6 +46,9 @@ pub struct ConversationRuntime { usage_tracker: UsageTracker, auto_compaction_input_tokens_threshold: u32, control_state: Arc, + prompt_override: Arc>>>, + last_assistant_text: Arc>>, + cumulative_cost: SharedCostCounter, } impl ConversationRuntime @@ -55,12 +62,16 @@ where api_client: C, tool_executor: T, system_prompt: Vec, + prompt_override: Arc>>>, + last_assistant_text: Arc>>, ) -> Self { Self::new_with_features( session, api_client, tool_executor, system_prompt, + prompt_override, + last_assistant_text, &RuntimeFeatureConfig::default(), ) } @@ -71,9 +82,24 @@ where api_client: C, tool_executor: T, system_prompt: Vec, + prompt_override: Arc>>>, + last_assistant_text: Arc>>, _feature_config: &RuntimeFeatureConfig, ) -> Self { let usage_tracker = UsageTracker::from_session(&session); + let cumulative_cost = new_cost_counter(); + let pricing = session + .model + .as_deref() + .and_then(pricing_for_model) + .unwrap_or_else(ModelPricing::default_sonnet_tier); + cumulative_cost.store( + usd_to_millicents( + estimate_cost_usd_with_pricing(usage_tracker.cumulative_usage(), pricing) + .total_cost_usd(), + ), + Ordering::Relaxed, + ); Self { session, api_client, @@ -84,6 +110,9 @@ where usage_tracker, auto_compaction_input_tokens_threshold: auto_compaction_threshold_from_env(), control_state: Arc::new(ControlState::default()), + prompt_override, + last_assistant_text, + cumulative_cost, } } @@ -219,6 +248,11 @@ where &mut self.tool_executor } + #[must_use] + pub fn cumulative_cost_counter(&self) -> SharedCostCounter { + Arc::clone(&self.cumulative_cost) + } + fn push_user_message(&mut self, user_input: String) { self.session .messages @@ -229,6 +263,15 @@ where self.fail_if_cancelled()?; self.check_cancel()?; + if let Some(new_prompt) = self + .prompt_override + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take() + { + self.system_prompt = new_prompt; + } + let next_iterations = iterations + 1; if next_iterations > self.max_iterations { return Err(RuntimeError::new( @@ -252,8 +295,24 @@ where let events = self.api_client.stream(self.build_api_request())?; notify_observer_about_events(&mut self.observer, &events); let (assistant_message, usage) = build_assistant_message(events)?; + let assistant_text = assistant_text_from_message(&assistant_message); + *self + .last_assistant_text + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(assistant_text); if let Some(usage) = usage { self.usage_tracker.record(usage); + let pricing = self + .session + .model + .as_deref() + .and_then(pricing_for_model) + .unwrap_or_else(ModelPricing::default_sonnet_tier); + let cumulative_cost_usd = + estimate_cost_usd_with_pricing(self.usage_tracker.cumulative_usage(), pricing) + .total_cost_usd(); + self.cumulative_cost + .store(usd_to_millicents(cumulative_cost_usd), Ordering::Relaxed); } Ok(assistant_message) @@ -632,6 +691,19 @@ fn build_assistant_message( )) } +fn assistant_text_from_message(message: &ConversationMessage) -> String { + message + .blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.as_str()), + ContentBlock::ToolUse { .. } + | ContentBlock::ToolResult { .. } + | ContentBlock::Reasoning { .. } => None, + }) + .collect() +} + fn notify_observer_about_events( observer: &mut Option>, events: &[AssistantEvent], diff --git a/crates/runtime/src/conversation/tests.rs b/crates/runtime/src/conversation/tests.rs index 1287409f..4ff916c3 100644 --- a/crates/runtime/src/conversation/tests.rs +++ b/crates/runtime/src/conversation/tests.rs @@ -3,10 +3,18 @@ use super::{ ConversationRuntime, RuntimeError, StaticToolExecutor, ToolOutcome, DEFAULT_AUTO_COMPACTION_INPUT_TOKENS_THRESHOLD, }; +use crate::budget::usd_to_millicents; use crate::compact::CompactionConfig; use crate::prompt::SystemPromptBuilder; use crate::session::{ContentBlock, MessageRole, Session}; -use crate::usage::TokenUsage; +use crate::usage::{estimate_cost_usd, TokenUsage}; +use std::sync::{Arc, Mutex}; + +type RuntimeSlots = (Arc>>>, Arc>>); + +fn runtime_slots() -> RuntimeSlots { + (Arc::new(Mutex::new(None)), Arc::new(Mutex::new(None))) +} struct ScriptedApiClient { call_count: usize, @@ -89,8 +97,15 @@ async fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() { Ok(ToolOutcome::reply(total.to_string())) }); let system_prompt = SystemPromptBuilder::new().append_section("# Tools").build(); - let mut runtime = - ConversationRuntime::new(Session::new(), api_client, tool_executor, system_prompt); + let (prompt_override, last_assistant_text) = runtime_slots(); + let mut runtime = ConversationRuntime::new( + Session::new(), + api_client, + tool_executor, + system_prompt, + prompt_override, + last_assistant_text, + ); let summary = runtime .run_turn("what is 2 + 2?") @@ -114,6 +129,12 @@ async fn runs_user_to_tool_to_result_loop_end_to_end_and_tracks_usage() { .. } )); + assert_eq!( + runtime + .cumulative_cost_counter() + .load(std::sync::atomic::Ordering::Relaxed), + usd_to_millicents(estimate_cost_usd(summary.usage).total_cost_usd()) + ); } #[test] @@ -143,11 +164,14 @@ fn reconstructs_usage_tracker_from_restored_session() { }), )); + let (prompt_override, last_assistant_text) = runtime_slots(); let runtime = ConversationRuntime::new( session, SimpleApi, StaticToolExecutor::new(), vec!["system".to_string()], + prompt_override, + last_assistant_text, ); assert_eq!(runtime.usage().turns(), 1); @@ -166,11 +190,14 @@ async fn compacts_session_after_turns() { } } + let (prompt_override, last_assistant_text) = runtime_slots(); let mut runtime = ConversationRuntime::new( Session::new(), SimpleApi, StaticToolExecutor::new(), vec!["system".to_string()], + prompt_override, + last_assistant_text, ); runtime.run_turn("a").await.expect("turn a"); runtime.run_turn("b").await.expect("turn b"); @@ -340,11 +367,14 @@ async fn auto_compacts_when_cumulative_input_threshold_is_crossed() { child_sessions: Vec::new(), }; + let (prompt_override, last_assistant_text) = runtime_slots(); let mut runtime = ConversationRuntime::new( session, SimpleApi, StaticToolExecutor::new(), vec!["system".to_string()], + prompt_override, + last_assistant_text, ) .with_auto_compaction_input_tokens_threshold(100_000); @@ -380,11 +410,14 @@ async fn skips_auto_compaction_below_threshold() { } } + let (prompt_override, last_assistant_text) = runtime_slots(); let mut runtime = ConversationRuntime::new( Session::new(), SimpleApi, StaticToolExecutor::new(), vec!["system".to_string()], + prompt_override, + last_assistant_text, ) .with_auto_compaction_input_tokens_threshold(100_000); @@ -429,3 +462,77 @@ fn reasoning_event_stored_in_message() { )); assert!(matches!(&message.blocks[1], ContentBlock::Text { text } if text == "answer")); } + +#[tokio::test] +async fn prepare_iteration_applies_and_clears_prompt_override() { + struct PromptRecordingApiClient { + prompts: Arc>>>, + } + + impl ApiClient for PromptRecordingApiClient { + fn stream(&mut self, request: ApiRequest) -> Result, RuntimeError> { + self.prompts + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(request.system_prompt); + Ok(vec![ + AssistantEvent::TextDelta("done".to_string()), + AssistantEvent::MessageStop, + ]) + } + } + + let prompts = Arc::new(Mutex::new(Vec::new())); + let prompt_override = Arc::new(Mutex::new(Some(vec!["override".to_string()]))); + let last_assistant_text = Arc::new(Mutex::new(None)); + let mut runtime = ConversationRuntime::new( + Session::new(), + PromptRecordingApiClient { + prompts: Arc::clone(&prompts), + }, + StaticToolExecutor::new(), + vec!["original".to_string()], + Arc::clone(&prompt_override), + last_assistant_text, + ); + + runtime.run_turn("hi").await.expect("turn should succeed"); + + assert_eq!( + prompts + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .as_slice(), + &[vec!["override".to_string()]] + ); + assert!(prompt_override + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .is_none()); +} + +#[tokio::test] +async fn stream_assistant_message_records_last_assistant_text() { + let (prompt_override, last_assistant_text) = runtime_slots(); + let mut runtime = ConversationRuntime::new( + Session::new(), + MockApiClientWithText("latest assistant text".to_string()), + StaticToolExecutor::new(), + vec!["system".to_string()], + prompt_override, + Arc::clone(&last_assistant_text), + ); + + runtime + .run_turn("hello") + .await + .expect("turn should succeed"); + + assert_eq!( + last_assistant_text + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .as_deref(), + Some("latest assistant text") + ); +} diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index ee248498..3db32007 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -1,3 +1,4 @@ +pub mod budget; mod compact; mod config; mod control; @@ -12,6 +13,10 @@ mod summary_compression; pub mod update_check; mod usage; +pub use budget::{ + millicents_to_usd, new_cost_counter, usd_to_millicents, BudgetDecision, BudgetEnforcer, + BudgetMode, SharedCostCounter, +}; pub use compact::{ compact_session, estimate_session_tokens, format_compact_summary, get_compact_continuation_message, should_compact, CompactionConfig, CompactionResult, @@ -44,20 +49,28 @@ pub use session::{ }; pub use settings::{ config_home_dir, load_settings, resolve_output_dir, save_settings, settings_file_path, - settings_get_auto_compact_tokens, settings_get_compaction_llm_summarization, - settings_get_compaction_max_summary_chars, + settings_get_action_cache_ttl_secs, settings_get_action_caching, + settings_get_auto_compact_tokens, settings_get_budget_enforcement, + settings_get_budget_max_session_cost_usd, settings_get_budget_warn_threshold_pct, + settings_get_compaction_llm_summarization, settings_get_compaction_max_summary_chars, settings_get_compaction_preserve_recent_messages_floor, settings_get_compaction_preserve_recent_tokens, settings_get_compaction_prune_max_output_chars, - settings_get_compaction_prune_protect_tokens, settings_get_fork_child_max_steps, - settings_get_fork_wait_timeout_secs, settings_get_headless, - settings_get_max_concurrent_per_parent, settings_get_max_fork_depth, settings_get_max_steps, - settings_get_max_total_agents, settings_get_output_dir, update_settings, Settings, + settings_get_compaction_prune_protect_tokens, settings_get_compound_enrichment, + settings_get_confidence_tracking, settings_get_content_aware_profiles, + settings_get_failure_classification, settings_get_fork_child_max_steps, + settings_get_fork_wait_timeout_secs, settings_get_headless, settings_get_html_diff_mode, + settings_get_loop_detection, settings_get_loop_detection_window, + settings_get_loop_nudge_threshold, settings_get_max_concurrent_per_parent, + settings_get_max_fork_depth, settings_get_max_steps, settings_get_max_total_agents, + settings_get_output_dir, settings_get_page_fingerprinting, + settings_get_per_agent_cost_tracking, settings_get_planning_interval, + settings_get_self_healing, settings_get_self_healing_max_retries, update_settings, Settings, }; pub use update_check::{check_for_update, check_for_update_force, UpdateInfo}; pub use usage::{ - estimate_cost_usd, estimate_cost_usd_with_pricing, format_usd, pricing_for_model, - summary_lines, summary_lines_for_model, ModelPricing, TokenUsage, UsageCostEstimate, - UsageTracker, + build_cost_breakdown, estimate_cost_usd, estimate_cost_usd_with_pricing, format_usd, + pricing_for_model, summary_lines, summary_lines_for_model, AgentCostReport, ModelPricing, + TokenUsage, UsageCostEstimate, UsageTracker, }; #[cfg(test)] diff --git a/crates/runtime/src/observer.rs b/crates/runtime/src/observer.rs index e2caec8a..3e91c246 100644 --- a/crates/runtime/src/observer.rs +++ b/crates/runtime/src/observer.rs @@ -119,6 +119,8 @@ mod tests { TextDeltaApiClient, StaticToolExecutor::new(), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); @@ -170,6 +172,8 @@ mod tests { Ok(ToolOutcome::reply(format!("echo:{input}"))) }), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); @@ -229,6 +233,8 @@ mod tests { )) }), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); @@ -269,6 +275,8 @@ mod tests { FinishedApiClient, StaticToolExecutor::new(), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); @@ -298,6 +306,8 @@ mod tests { Ok(ToolOutcome::reply(format!("echo:{input}"))) }), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); @@ -362,6 +372,8 @@ mod tests { )) }), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); @@ -391,6 +403,8 @@ mod tests { )) }), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); @@ -421,6 +435,8 @@ mod tests { ErrorApiClient, StaticToolExecutor::new(), vec!["system".to_string()], + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), ) .with_observer(Box::new(observer)); diff --git a/crates/runtime/src/session.rs b/crates/runtime/src/session.rs index b18d7b89..3835d915 100644 --- a/crates/runtime/src/session.rs +++ b/crates/runtime/src/session.rs @@ -11,6 +11,7 @@ pub use acrawl_core::message::{ContentBlock, ConversationMessage, MessageRole, T #[derive(Debug, Clone, PartialEq, Eq)] pub struct ChildSession { pub id: String, + pub model: Option, pub goal: String, pub messages: Vec, } @@ -375,6 +376,9 @@ fn required_u32(object: &BTreeMap, key: &str) -> Result JsonValue { let mut object = BTreeMap::new(); object.insert("id".to_string(), JsonValue::String(child.id.clone())); + if let Some(model) = &child.model { + object.insert("model".to_string(), JsonValue::String(model.clone())); + } object.insert("goal".to_string(), JsonValue::String(child.goal.clone())); object.insert( "messages".to_string(), @@ -394,6 +398,10 @@ fn child_session_from_json(value: &JsonValue) -> Result Result, _>>()?; - Ok(ChildSession { id, goal, messages }) + Ok(ChildSession { + id, + model, + goal, + messages, + }) } #[cfg(test)] @@ -511,6 +524,7 @@ mod tests { let child = ChildSession { id: "child-1".to_string(), + model: Some("anthropic/claude-sonnet-4-6".to_string()), goal: "scrape titles".to_string(), messages: vec![ConversationMessage::user_text("child goal")], }; diff --git a/crates/runtime/src/settings.rs b/crates/runtime/src/settings.rs index 770ee65f..e384fdef 100644 --- a/crates/runtime/src/settings.rs +++ b/crates/runtime/src/settings.rs @@ -61,9 +61,87 @@ impl Default for ScriptSettings { } } +/// Optimization settings for advanced crawl behavior. +/// All fields are optional; unset fields use safe defaults (false/0). +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] +#[serde(default)] +pub struct OptimizationSettings { + /// Enable HTML diff mode for page change detection (default: false) + #[serde(default)] + pub html_diff_mode: Option, + + /// Enable loop detection (default: false) + #[serde(default)] + pub loop_detection: Option, + + /// Loop detection window size in steps (default: 20) + #[serde(default)] + pub loop_detection_window: Option, + + /// Loop nudge threshold (default: 5) + #[serde(default)] + pub loop_nudge_threshold: Option, + + /// Enable page fingerprinting (default: false) + #[serde(default)] + pub page_fingerprinting: Option, + + /// Planning interval in steps; 0 = disabled (default: 0) + #[serde(default)] + pub planning_interval: Option, + + /// Enable failure classification (default: false) + #[serde(default)] + pub failure_classification: Option, + + /// Enable self-healing (default: false) + #[serde(default)] + pub self_healing: Option, + + /// Max retries for self-healing (default: 2) + #[serde(default)] + pub self_healing_max_retries: Option, + + /// Enable action caching (default: false) + #[serde(default)] + pub action_caching: Option, + + /// Action cache TTL in seconds (default: 30) + #[serde(default)] + pub action_cache_ttl_secs: Option, + + /// Enable confidence tracking (default: false) + #[serde(default)] + pub confidence_tracking: Option, + + /// Enable compound enrichment (default: false) + #[serde(default)] + pub compound_enrichment: Option, + + /// Enable content-aware profiles (default: false) + #[serde(default)] + pub content_aware_profiles: Option, + + /// Budget: max session cost in USD (default: None = unlimited) + #[serde(default)] + pub budget_max_session_cost_usd: Option, + + /// Budget enforcement mode: "warn" | "block" | "`route_down`" (default: None) + #[serde(default)] + pub budget_enforcement: Option, + + /// Budget warning threshold as percentage (default: 80) + #[serde(default)] + pub budget_warn_threshold_pct: Option, + + /// Enable per-agent cost tracking (default: false) + #[serde(default)] + pub per_agent_cost_tracking: Option, +} + /// Settings loaded from settings.json configuration file. /// All fields are optional with serde defaults to support partial JSON files. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct Settings { /// Run browser in headless mode (default: true) #[serde(default)] @@ -148,6 +226,10 @@ pub struct Settings { /// Script resource limits and configuration #[serde(default)] pub script: Option, + + /// Optimization settings for advanced crawl behavior + #[serde(default)] + pub optimization: Option, } impl Default for Settings { @@ -174,6 +256,7 @@ impl Default for Settings { compaction_max_summary_chars: None, compaction_llm_summarization: None, script: Some(ScriptSettings::default()), + optimization: None, } } } @@ -364,6 +447,166 @@ pub fn settings_get_compaction_llm_summarization(s: &Settings) -> bool { s.compaction_llm_summarization.unwrap_or(false) } +/// Get `html_diff_mode` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_html_diff_mode(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.html_diff_mode) + .unwrap_or(false) +} + +/// Get `loop_detection` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_loop_detection(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.loop_detection) + .unwrap_or(false) +} + +/// Get `loop_detection_window` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_loop_detection_window(s: &Settings) -> usize { + s.optimization + .as_ref() + .and_then(|o| o.loop_detection_window) + .unwrap_or(20) +} + +/// Get `loop_nudge_threshold` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_loop_nudge_threshold(s: &Settings) -> usize { + s.optimization + .as_ref() + .and_then(|o| o.loop_nudge_threshold) + .unwrap_or(5) +} + +/// Get `page_fingerprinting` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_page_fingerprinting(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.page_fingerprinting) + .unwrap_or(false) +} + +/// Get `planning_interval` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_planning_interval(s: &Settings) -> usize { + s.optimization + .as_ref() + .and_then(|o| o.planning_interval) + .unwrap_or(0) +} + +/// Get `failure_classification` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_failure_classification(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.failure_classification) + .unwrap_or(false) +} + +/// Get `self_healing` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_self_healing(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.self_healing) + .unwrap_or(false) +} + +/// Get `self_healing_max_retries` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_self_healing_max_retries(s: &Settings) -> usize { + s.optimization + .as_ref() + .and_then(|o| o.self_healing_max_retries) + .unwrap_or(2) +} + +/// Get `action_caching` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_action_caching(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.action_caching) + .unwrap_or(false) +} + +/// Get `action_cache_ttl_secs` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_action_cache_ttl_secs(s: &Settings) -> u64 { + s.optimization + .as_ref() + .and_then(|o| o.action_cache_ttl_secs) + .unwrap_or(30) +} + +/// Get `confidence_tracking` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_confidence_tracking(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.confidence_tracking) + .unwrap_or(false) +} + +/// Get `compound_enrichment` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_compound_enrichment(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.compound_enrichment) + .unwrap_or(false) +} + +/// Get `content_aware_profiles` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_content_aware_profiles(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.content_aware_profiles) + .unwrap_or(false) +} + +/// Get `budget_max_session_cost_usd` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_budget_max_session_cost_usd(s: &Settings) -> Option { + s.optimization + .as_ref() + .and_then(|o| o.budget_max_session_cost_usd) +} + +/// Get `budget_enforcement` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_budget_enforcement(s: &Settings) -> Option { + s.optimization + .as_ref() + .and_then(|o| o.budget_enforcement.clone()) +} + +/// Get `budget_warn_threshold_pct` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_budget_warn_threshold_pct(s: &Settings) -> u32 { + s.optimization + .as_ref() + .and_then(|o| o.budget_warn_threshold_pct) + .unwrap_or(80) +} + +/// Get `per_agent_cost_tracking` optimization setting, with default fallback. +#[must_use] +pub fn settings_get_per_agent_cost_tracking(s: &Settings) -> bool { + s.optimization + .as_ref() + .and_then(|o| o.per_agent_cost_tracking) + .unwrap_or(false) +} + #[cfg(test)] mod tests { use super::*; @@ -451,6 +694,7 @@ mod tests { compaction_max_summary_chars: None, compaction_llm_summarization: None, script: Some(ScriptSettings::default()), + optimization: None, }; save_settings(&original).expect("Failed to save settings"); @@ -657,6 +901,7 @@ mod tests { compaction_max_summary_chars: None, compaction_llm_summarization: None, script: Some(ScriptSettings::default()), + optimization: None, }) .expect("save settings"); @@ -720,6 +965,7 @@ mod tests { compaction_max_summary_chars: None, compaction_llm_summarization: None, script: Some(ScriptSettings::default()), + optimization: None, }; save_settings(&original).expect("Failed to save settings"); @@ -783,6 +1029,7 @@ mod tests { compaction_max_summary_chars: None, compaction_llm_summarization: None, script: None, + optimization: None, }; assert_eq!(settings_get_max_concurrent_per_parent(&settings), 5); @@ -865,4 +1112,108 @@ mod tests { cleanup_temp_dir(&temp_dir); } + + #[test] + fn test_optimization_settings_backward_compat_no_optimization_field() { + let _lock = test_env_lock(); + let temp_dir = setup_temp_dir(); + + std::env::set_var("ACRAWL_CONFIG_HOME", &temp_dir); + + // Write JSON WITHOUT "optimization" field (old settings.json) + let settings_path = temp_dir.join("settings.json"); + fs::write(&settings_path, r#"{"headless": true, "max_steps": 50}"#) + .expect("Failed to write test settings"); + + let settings = load_settings(); + + // All bool getters should return false + assert!(!settings_get_html_diff_mode(&settings)); + assert!(!settings_get_loop_detection(&settings)); + assert!(!settings_get_page_fingerprinting(&settings)); + assert!(!settings_get_failure_classification(&settings)); + assert!(!settings_get_self_healing(&settings)); + assert!(!settings_get_action_caching(&settings)); + assert!(!settings_get_confidence_tracking(&settings)); + assert!(!settings_get_compound_enrichment(&settings)); + assert!(!settings_get_content_aware_profiles(&settings)); + assert!(!settings_get_per_agent_cost_tracking(&settings)); + + // All numeric getters should return their defaults + assert_eq!(settings_get_loop_detection_window(&settings), 20); + assert_eq!(settings_get_loop_nudge_threshold(&settings), 5); + assert_eq!(settings_get_planning_interval(&settings), 0); + assert_eq!(settings_get_self_healing_max_retries(&settings), 2); + assert_eq!(settings_get_action_cache_ttl_secs(&settings), 30); + assert_eq!(settings_get_budget_warn_threshold_pct(&settings), 80); + + // Option getters should return None + assert_eq!(settings_get_budget_max_session_cost_usd(&settings), None); + assert_eq!(settings_get_budget_enforcement(&settings), None); + + cleanup_temp_dir(&temp_dir); + } + + #[test] + fn test_optimization_settings_parse_with_values() { + let _lock = test_env_lock(); + let temp_dir = setup_temp_dir(); + + std::env::set_var("ACRAWL_CONFIG_HOME", &temp_dir); + + // Write JSON WITH "optimization" field containing some values + let settings_path = temp_dir.join("settings.json"); + fs::write( + &settings_path, + r#"{ + "headless": true, + "max_steps": 50, + "optimization": { + "html_diff_mode": true, + "loop_detection": true, + "loop_detection_window": 25, + "page_fingerprinting": true, + "self_healing": true, + "self_healing_max_retries": 5, + "action_caching": true, + "action_cache_ttl_secs": 60, + "budget_max_session_cost_usd": 10.5, + "budget_enforcement": "block", + "budget_warn_threshold_pct": 75 + } + }"#, + ) + .expect("Failed to write test settings"); + + let settings = load_settings(); + + // Set values should be returned + assert!(settings_get_html_diff_mode(&settings)); + assert!(settings_get_loop_detection(&settings)); + assert_eq!(settings_get_loop_detection_window(&settings), 25); + assert!(settings_get_page_fingerprinting(&settings)); + assert!(settings_get_self_healing(&settings)); + assert_eq!(settings_get_self_healing_max_retries(&settings), 5); + assert!(settings_get_action_caching(&settings)); + assert_eq!(settings_get_action_cache_ttl_secs(&settings), 60); + assert_eq!( + settings_get_budget_max_session_cost_usd(&settings), + Some(10.5) + ); + assert_eq!( + settings_get_budget_enforcement(&settings), + Some("block".to_string()) + ); + assert_eq!(settings_get_budget_warn_threshold_pct(&settings), 75); + + // Unset values should use defaults + assert_eq!(settings_get_loop_nudge_threshold(&settings), 5); + assert!(!settings_get_failure_classification(&settings)); + assert!(!settings_get_confidence_tracking(&settings)); + assert!(!settings_get_compound_enrichment(&settings)); + assert!(!settings_get_content_aware_profiles(&settings)); + assert!(!settings_get_per_agent_cost_tracking(&settings)); + + cleanup_temp_dir(&temp_dir); + } } diff --git a/crates/runtime/src/usage.rs b/crates/runtime/src/usage.rs index 3684e608..203b1160 100644 --- a/crates/runtime/src/usage.rs +++ b/crates/runtime/src/usage.rs @@ -139,6 +139,48 @@ fn cost_for_tokens(tokens: u32, usd_per_million_tokens: f64) -> f64 { f64::from(tokens) / 1_000_000.0 * usd_per_million_tokens } +/// Per-child cost attribution for `/cost` breakdown. +#[derive(Debug, Clone)] +pub struct AgentCostReport { + pub agent_id: String, + pub direct_cost_usd: f64, + pub turn_count: u32, +} + +/// Build a flat per-child cost breakdown from a session's `child_sessions`. +/// +/// Walks `child_sessions` (flat list) and computes cost via each child's +/// recorded usage. When a child session records its model, use model-specific +/// pricing; otherwise fall back to the default estimate. +#[must_use] +pub fn build_cost_breakdown(session: &crate::session::Session) -> Vec { + session + .child_sessions + .iter() + .map(|child| { + let mut tracker = UsageTracker::new(); + for message in &child.messages { + if let Some(usage) = message.usage { + tracker.record(usage); + } + } + let cost = child + .model + .as_deref() + .and_then(pricing_for_model) + .map_or_else( + || estimate_cost_usd(tracker.cumulative_usage()), + |pricing| estimate_cost_usd_with_pricing(tracker.cumulative_usage(), pricing), + ); + AgentCostReport { + agent_id: child.id.clone(), + direct_cost_usd: cost.total_cost_usd(), + turn_count: tracker.turns(), + } + }) + .collect() +} + #[must_use] pub fn format_usd(amount: f64) -> String { format!("${amount:.4}") @@ -297,4 +339,102 @@ mod tests { assert_eq!(tracker.turns(), 1); assert_eq!(tracker.cumulative_usage().total_tokens(), 8); } + + #[test] + fn build_cost_breakdown_computes_per_child_costs() { + use super::build_cost_breakdown; + use crate::session::ChildSession; + + let session = Session { + version: 2, + model: None, + title: None, + messages: Vec::new(), + child_sessions: vec![ + ChildSession { + id: "child-a".to_string(), + model: Some("claude-haiku-4-5-20251001".to_string()), + goal: "scrape page A".to_string(), + messages: vec![ + ConversationMessage { + role: MessageRole::Assistant, + blocks: vec![ContentBlock::Text { + text: "working".to_string(), + }], + usage: Some(TokenUsage { + input_tokens: 1_000_000, + output_tokens: 100_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }), + }, + ConversationMessage { + role: MessageRole::Assistant, + blocks: vec![ContentBlock::Text { + text: "done".to_string(), + }], + usage: Some(TokenUsage { + input_tokens: 500_000, + output_tokens: 50_000, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }), + }, + ], + }, + ChildSession { + id: "child-b".to_string(), + model: None, + goal: "scrape page B".to_string(), + messages: vec![ConversationMessage { + role: MessageRole::User, + blocks: vec![ContentBlock::Text { + text: "go".to_string(), + }], + usage: None, + }], + }, + ], + }; + + let breakdown = build_cost_breakdown(&session); + assert_eq!(breakdown.len(), 2); + + // child-a: 2 assistant messages with usage + assert_eq!(breakdown[0].agent_id, "child-a"); + assert_eq!(breakdown[0].turn_count, 2); + assert_eq!(format_usd(breakdown[0].direct_cost_usd), "$2.2500"); + + // child-b: 1 user message with no usage → 0 turns, 0 cost + assert_eq!(breakdown[1].agent_id, "child-b"); + assert_eq!(breakdown[1].turn_count, 0); + assert!(breakdown[1].direct_cost_usd.abs() < f64::EPSILON); + } + + #[test] + fn build_cost_breakdown_empty_when_no_children() { + use super::build_cost_breakdown; + + let session = Session { + version: 2, + model: None, + title: None, + messages: vec![ConversationMessage { + role: MessageRole::Assistant, + blocks: vec![ContentBlock::Text { + text: "hello".to_string(), + }], + usage: Some(TokenUsage { + input_tokens: 100, + output_tokens: 50, + cache_creation_input_tokens: 0, + cache_read_input_tokens: 0, + }), + }], + child_sessions: Vec::new(), + }; + + let breakdown = build_cost_breakdown(&session); + assert!(breakdown.is_empty()); + } } diff --git a/crates/tui/src/child_tabs.rs b/crates/tui/src/child_tabs.rs index 76679bdf..759e2db6 100644 --- a/crates/tui/src/child_tabs.rs +++ b/crates/tui/src/child_tabs.rs @@ -302,6 +302,7 @@ mod tests { let sessions = vec![ ChildSession { id: "c1".to_string(), + model: None, goal: "scrape prices".to_string(), messages: vec![ ConversationMessage::user_text("scrape"), @@ -312,6 +313,7 @@ mod tests { }, ChildSession { id: "c2".to_string(), + model: None, goal: "fetch reviews".to_string(), messages: vec![ConversationMessage::user_text("fetch")], }, diff --git a/crates/tui/tests/session_resume.rs b/crates/tui/tests/session_resume.rs index 35feb7e2..286521a7 100644 --- a/crates/tui/tests/session_resume.rs +++ b/crates/tui/tests/session_resume.rs @@ -131,6 +131,7 @@ fn session_with_child_sessions_populates_tabs() { let sessions = vec![ ChildSession { id: "c1".to_string(), + model: None, goal: "scrape prices".to_string(), messages: vec![ ConversationMessage::user_text("scrape prices from page"), @@ -141,6 +142,7 @@ fn session_with_child_sessions_populates_tabs() { }, ChildSession { id: "c2".to_string(), + model: None, goal: "fetch reviews".to_string(), messages: vec![ConversationMessage::user_text("fetch reviews")], }, diff --git a/crates/ui/src/app/runtime_builder.rs b/crates/ui/src/app/runtime_builder.rs index be4b478c..3f390570 100644 --- a/crates/ui/src/app/runtime_builder.rs +++ b/crates/ui/src/app/runtime_builder.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use super::{AllowedToolSet, CliError, CliToolExecutor, LlmRuntimeClient}; use agent::{ @@ -11,7 +11,7 @@ use runtime::{ }; pub(super) fn build_system_prompt() -> Vec { - build_agent_system_prompt(&mvp_tool_specs()) + build_agent_system_prompt(&mvp_tool_specs(), None) } pub(super) fn build_runtime_feature_config() -> Result { @@ -75,6 +75,8 @@ pub(super) fn build_runtime_with_options( child_control_registry, ), system_prompt, + Arc::new(Mutex::new(None)), + Arc::new(Mutex::new(None)), &build_runtime_feature_config()?, ) .with_control_state(shared_control) diff --git a/crates/ui/src/app/session.rs b/crates/ui/src/app/session.rs index 23217901..7e33d5d5 100644 --- a/crates/ui/src/app/session.rs +++ b/crates/ui/src/app/session.rs @@ -50,7 +50,15 @@ pub(super) fn merge_child_sessions( if child_sessions.is_empty() { return; } - session.child_sessions.extend(child_sessions); + let parent_model = session.model.clone(); + session + .child_sessions + .extend(child_sessions.into_iter().map(|mut child| { + if child.model.is_none() { + child.model.clone_from(&parent_model); + } + child + })); } #[cfg(test)] @@ -66,6 +74,7 @@ mod tests { &mut session, vec![runtime::ChildSession { id: "child-1".to_string(), + model: None, goal: "scrape prices".to_string(), messages: vec![ConversationMessage::assistant(vec![ContentBlock::Text { text: "done".to_string(), @@ -76,4 +85,25 @@ mod tests { assert_eq!(session.child_sessions.len(), 1); assert_eq!(session.child_sessions[0].id, "child-1"); } + + #[test] + fn merge_child_sessions_inherits_parent_model_when_missing() { + let mut session = Session::new(); + session.model = Some("anthropic/claude-haiku-4-5".to_string()); + + merge_child_sessions( + &mut session, + vec![runtime::ChildSession { + id: "child-1".to_string(), + model: None, + goal: "scrape prices".to_string(), + messages: Vec::new(), + }], + ); + + assert_eq!( + session.child_sessions[0].model.as_deref(), + Some("anthropic/claude-haiku-4-5") + ); + } } diff --git a/crates/ui/src/app/slash.rs b/crates/ui/src/app/slash.rs index 2579956e..db3d12b8 100644 --- a/crates/ui/src/app/slash.rs +++ b/crates/ui/src/app/slash.rs @@ -348,7 +348,34 @@ impl LiveCli { } pub fn cost_report(&self) -> String { - format_cost_report(self.runtime.usage().cumulative_usage()) + let mut report = format_cost_report(self.runtime.usage().cumulative_usage()); + let settings = runtime::load_settings(); + if runtime::settings_get_per_agent_cost_tracking(&settings) { + let session = self.runtime.session(); + if !session.child_sessions.is_empty() { + let breakdown = runtime::build_cost_breakdown(session); + let parent_usage = self.runtime.usage().cumulative_usage(); + let parent_cost = runtime::estimate_cost_usd(parent_usage).total_cost_usd(); + let mut parts = vec![format!( + "\n Per-agent Parent: {}", + runtime::format_usd(parent_cost) + )]; + for entry in &breakdown { + parts.push(format!( + " {} ({}): {} ({} turns)", + entry.agent_id, + "child", + runtime::format_usd(entry.direct_cost_usd), + entry.turn_count + )); + } + let total: f64 = + parent_cost + breakdown.iter().map(|e| e.direct_cost_usd).sum::(); + parts.push(format!(" Total (all) {}", runtime::format_usd(total))); + report.push_str(&parts.join("\n")); + } + } + report } pub fn config_report(section: Option<&str>) -> Result {