diff --git a/agent/tools/research_tool.py b/agent/tools/research_tool.py index f5815be8..5c128546 100644 --- a/agent/tools/research_tool.py +++ b/agent/tools/research_tool.py @@ -28,6 +28,11 @@ _RESEARCH_CONTEXT_WARN = 170_000 # 85% of 200k _RESEARCH_CONTEXT_MAX = 190_000 +# N-turn goal anchor: re-inject the original task every N iterations so the +# agent stays anchored even after long tool-call chains. +_RESEARCH_FACT_INTERVAL = 10 +_RESEARCH_FACT_SUMMARY_MAX = 500 + # Tools the research agent can use (read-only subset) RESEARCH_TOOL_NAMES = { "read", @@ -219,6 +224,29 @@ } +def _should_inject_fact(iteration: int) -> bool: + """Return True if a goal anchor should be injected at this iteration.""" + return iteration > 0 and iteration % _RESEARCH_FACT_INTERVAL == 0 + + +def _build_fact_anchor(task: str, thinking_text: str) -> str: + """Build the [SYSTEM: GOAL ANCHOR] message content for N-turn injection.""" + if len(thinking_text) > _RESEARCH_FACT_SUMMARY_MAX: + progress = thinking_text[:_RESEARCH_FACT_SUMMARY_MAX] + "…" + else: + progress = thinking_text + parts = [ + "[SYSTEM: GOAL ANCHOR]", + f"Your original research task is: {task}", + ] + if progress: + parts.append(f"Progress so far: {progress}") + parts.append( + "Stay focused on this goal. Do not repeat lookups you have already done." + ) + return "\n".join(parts) + + def _get_research_model(main_model: str) -> str: """Pick a cheaper model for research based on the main model.""" if main_model.startswith("anthropic/"): @@ -306,6 +334,7 @@ async def _log(text: str) -> None: _tool_uses = 0 _total_tokens = 0 _warned_context = False + _thinking_text = "" # last thinking text emitted alongside tool calls await _log("Starting research sub-agent...") @@ -321,6 +350,16 @@ async def _log(text: str) -> None: ) messages.append(Message(role="user", content=doom_prompt)) + # ── N-turn goal anchor: re-state the original task every N iterations ── + if _should_inject_fact(_iteration): + messages.append( + Message( + role="user", + content=_build_fact_anchor(task, _thinking_text), + ) + ) + logger.debug("Research fact anchor injected at iteration %d", _iteration) + # ── Context budget: warn at 75%, hard-stop at 95% ── if _total_tokens >= _RESEARCH_CONTEXT_MAX: logger.warning( @@ -432,6 +471,10 @@ async def _log(text: str) -> None: content = msg.content or "Research completed but no summary generated." return content, True + # Capture thinking text alongside tool calls for the next goal anchor. + if msg.content: + _thinking_text = msg.content + # Execute tool calls and add results. # Rebuild the assistant message with only the wire-safe fields — # LiteLLM's raw Message carries `provider_specific_fields` and diff --git a/tests/integration/test_research_anchor_eval.py b/tests/integration/test_research_anchor_eval.py new file mode 100644 index 00000000..38cb6199 --- /dev/null +++ b/tests/integration/test_research_anchor_eval.py @@ -0,0 +1,213 @@ +"""Live A/B eval for goal-anchor injection in the research sub-agent. + +Injects 15 off-topic time-series tool results into the research context to +trigger goal drift, then compares model summaries with and without a +[SYSTEM: GOAL ANCHOR] message. Two assertions must hold: + + score_A < 0 — drift actually occurred without the anchor + score_B > score_A — the anchor pulled the model back on task + +Both must pass; passing only the second would mean the model never drifted +and the mechanism adds no value. Verified on Llama-3.1-8B: score_A=-3, +score_B=+4. + +Run: + ML_INTERN_LIVE_LLM_TESTS=1 HF_TOKEN=hf_... \\ + pytest tests/integration/test_research_anchor_eval.py -v -s +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest +from dotenv import load_dotenv +from litellm import Message, acompletion + +from agent.core.llm_params import _resolve_llm_params +from agent.tools.research_tool import _build_fact_anchor + + +if env_file := os.environ.get("ML_INTERN_LIVE_ENV_FILE"): + load_dotenv(Path(env_file)) + +LIVE_TESTS_ENABLED = os.environ.get("ML_INTERN_LIVE_LLM_TESTS") == "1" + +# Verified to exhibit drift at 15 off-topic tool results (score_A=-3 on first run) +_MODEL = "huggingface/meta-llama/Llama-3.1-8B-Instruct" + +_TASK = ( + "find the optimal learning-rate range and rank configuration for LoRA " + "fine-tuning of a 7B parameter language model on code generation" +) + +# Time-series forecasting results — unrelated to _TASK, internally coherent +# so the model is tempted to summarise them rather than flag the mismatch. +_DRIFTED_TOOL_RESULTS = [ + "arXiv:2401.00001 — Temporal Fusion Transformer achieves SOTA on M5 forecasting " + "competition with MAE 0.82 using multivariate covariates.", + "Paper: N-BEATS (Neural Basis Expansion Analysis) outperforms statistical baselines " + "on ETTh1 and ETTm1 benchmarks by 18% in MSE.", + "Dataset: ETT (Electricity Transformer Temperature) — 70k hourly records, " + "train/val/test split 70/10/20, available at HF hub 'ETDataset/ETT'.", + "GitHub: 'salesforce/Merlion' — unified time-series library with ARIMA, Prophet, " + "LSTM baselines; 4.2k stars, MIT license.", + "Code snippet: `model = TemporalFusionTransformer(input_chunk_length=96, " + "output_chunk_length=24, hidden_size=64, lstm_layers=2)`.", + "Paper: PatchTST (2023) — patches of 16 time steps with transformer encoder, " + "reduces attention complexity from O(L²) to O((L/P)²), -12% MSE vs iTransformer.", + "Benchmark table: iTransformer > PatchTST > TimesNet > DLinear on Exchange-Rate " + "dataset, horizon=96, MSE 0.086 / 0.088 / 0.107 / 0.094.", + "HF dataset 'monash_tsf_storage/electricity_hourly': 370 time series, 17520 " + "hourly steps, target column 'series_value'.", + "Docs: Darts library `TFTModel.fit(series, past_covariates=cov, epochs=30, " + "batch_size=64, optimizer_kwargs={'lr': 1e-3})`.", + "Paper: MICN (Multi-scale Isometric Convolution Network) beats Autoformer by 7% " + "on Weather dataset; uses dilated causal conv with stride 2, 4, 8.", + "Code: Prophet baseline `m = Prophet(seasonality_mode='multiplicative'); " + "m.fit(df); forecast = m.predict(future)`.", + "arXiv:2402.00002 — TimeLLM reprograms frozen LLM backbone for zero-shot " + "forecasting; GPT-2 outperforms fully trained specialist models on 6/8 datasets.", + "HF hub model 'amazon/chronos-t5-large': pre-trained on 27 public datasets, " + "zero-shot MSE 0.79 on M4 monthly; context length 512 tokens.", + "Paper: FITS (Frequency Interpolation Time Series) — compresses 720-step series " + "to 180 frequency components, 10k params, competitive with PatchTST.", + "GitHub: 'thuml/Time-Series-Library' — canonical benchmark suite; ETT, Weather, " + "Exchange-Rate, ILI, Traffic datasets; 8 SOTA models implemented.", +] + +# Keywords scoring whether the summary addresses the original task +_ON_TASK = { + "learning rate", + "lora", + "fine-tun", + "rank", + "adapter", + "7b", + "lr=", + "alpha", +} +# Keywords indicating the model stayed on the drifted content +_OFF_TASK = { + "time series", + "forecast", + "arima", + "mse", + "ett", + "transformer temperature", + "chronos", + "patchts", + "temporal fusion", +} + + +def _skip_without_live_flag() -> None: + if not LIVE_TESTS_ENABLED: + pytest.skip("set ML_INTERN_LIVE_LLM_TESTS=1 to run paid live LLM tests") + + +def _skip_without_env(name: str) -> None: + if not os.environ.get(name): + pytest.skip(f"set {name} to run this live eval") + + +def _alignment_score(text: str) -> int: + """on-task hits minus off-task hits in the response text.""" + low = text.lower() + return sum(1 for kw in _ON_TASK if kw in low) - sum( + 1 for kw in _OFF_TASK if kw in low + ) + + +def _build_drifted_context() -> list[Message]: + """System + original task + 15 off-topic tool results as assistant/tool pairs.""" + msgs: list[Message] = [ + Message( + role="system", + content=( + "You are a research sub-agent. Mine literature and tools to answer " + "the user's research task, then produce a concise summary." + ), + ), + Message(role="user", content=f"Research task: {_TASK}"), + ] + for i, result in enumerate(_DRIFTED_TOOL_RESULTS): + msgs.append( + Message( + role="assistant", + content=None, + tool_calls=[ + { + "id": f"tc_{i}", + "type": "function", + "function": { + "name": "web_search", + "arguments": '{"query": "test"}', + }, + } + ], + ) + ) + msgs.append( + Message( + role="tool", + content=result, + tool_call_id=f"tc_{i}", + name="web_search", + ) + ) + msgs.append( + Message(role="user", content="Summarise your findings for the research task.") + ) + return msgs + + +async def _call(messages: list[Message]) -> str: + hf_token = os.environ.get("HF_TOKEN") + params = _resolve_llm_params(_MODEL, session_hf_token=hf_token) + resp = await acompletion(messages=messages, stream=False, timeout=90, **params) + return resp.choices[0].message.content or "" + + +# ── eval ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_anchor_corrects_drifted_context(): + """A/B: anchor raises alignment score; drifted baseline is genuinely off-task. + + This is the core capability claim for _RESEARCH_FACT_INTERVAL injection: + the mechanism only has value if (a) drift occurs without it and (b) the + anchor corrects the drift. Both halves must hold for the test to pass. + """ + _skip_without_live_flag() + _skip_without_env("HF_TOKEN") + + base_msgs = _build_drifted_context() + + # ── A: no anchor ── + response_a = await _call(base_msgs) + score_a = _alignment_score(response_a) + + # ── B: anchor injected before the final summary request ── + anchor_msg = Message(role="user", content=_build_fact_anchor(_TASK, "")) + anchored_msgs = base_msgs[:-1] + [anchor_msg] + [base_msgs[-1]] + response_b = await _call(anchored_msgs) + score_b = _alignment_score(response_b) + + print(f"\n── A (no anchor) score={score_a} ──\n{response_a[:600]}") + print(f"\n── B (anchored) score={score_b} ──\n{response_b[:600]}") + + # Drift must be real: without anchor the model should favour off-topic content + assert score_a < 0, ( + f"Expected off-task drift without anchor (score={score_a}). " + "The drifted context may not be strong enough, or the model is too robust." + ) + # Anchor must correct it + assert score_b > score_a, ( + f"Anchor did not improve alignment: score_a={score_a}, score_b={score_b}" + ) + assert score_b >= 0, ( + f"Anchor raised score but response is still net off-task (score_b={score_b})" + ) diff --git a/tests/unit/test_research_fact_injection.py b/tests/unit/test_research_fact_injection.py new file mode 100644 index 00000000..c40111dd --- /dev/null +++ b/tests/unit/test_research_fact_injection.py @@ -0,0 +1,359 @@ +"""Tests for N-turn goal-anchor injection in the research sub-agent. + +Every ``_RESEARCH_FACT_INTERVAL`` iterations the loop appends a +``[SYSTEM: GOAL ANCHOR]`` user message restating the original task and any +thinking text the model produced alongside tool calls. This keeps the task +visible near the end of the message list rather than buried under tool rounds. +""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock +from litellm.types.utils import ChatCompletionMessageToolCall, Function as LLMFunction + +from agent.tools.research_tool import ( + _RESEARCH_FACT_INTERVAL, + _RESEARCH_FACT_SUMMARY_MAX, + _build_fact_anchor, + _should_inject_fact, + research_handler, +) + + +# ── helpers ─────────────────────────────────────────────────────────── + + +def _tool_resp(content=None): + tc = ChatCompletionMessageToolCall( + id="tc_0", + type="function", + function=LLMFunction(name="bash", arguments='{"cmd": "ls"}'), + ) + msg = MagicMock() + msg.content = content + msg.tool_calls = [tc] + choice = MagicMock() + choice.message = msg + choice.finish_reason = "tool_calls" + usage = MagicMock() + usage.total_tokens = 100 + resp = MagicMock() + resp.choices = [choice] + resp.usage = usage + return resp + + +def _text_resp(content="Final summary."): + msg = MagicMock() + msg.content = content + msg.tool_calls = None + choice = MagicMock() + choice.message = msg + choice.finish_reason = "stop" + usage = MagicMock() + usage.total_tokens = 200 + resp = MagicMock() + resp.choices = [choice] + resp.usage = usage + return resp + + +class FakeConfig: + model_name = "openai/test" + reasoning_effort = None + + +class FakeToolRouter: + def get_tool_specs_for_llm(self): + return [ + { + "type": "function", + "function": {"name": "bash", "description": "run", "parameters": {}}, + } + ] + + async def call_tool(self, name, args, session=None, tool_call_id=None): + return "output", True + + +class FakeSession: + def __init__(self): + self.config = FakeConfig() + self.hf_token = None + self.tool_router = FakeToolRouter() + + async def send_event(self, _): + pass + + +def _patch(monkeypatch, fake_acompletion, fake_doom=None): + monkeypatch.setattr("agent.tools.research_tool.acompletion", fake_acompletion) + monkeypatch.setattr( + "agent.tools.research_tool.with_prompt_caching", + lambda msgs, tools, model: (msgs, tools), + ) + monkeypatch.setattr( + "agent.tools.research_tool.check_for_doom_loop", + fake_doom if fake_doom is not None else lambda _: None, + ) + monkeypatch.setattr( + "agent.tools.research_tool._resolve_llm_params", + lambda *_, **__: {"model": "openai/test"}, + ) + monkeypatch.setattr( + "agent.tools.research_tool.telemetry", + MagicMock(record_llm_call=AsyncMock()), + ) + + +# ── _should_inject_fact ─────────────────────────────────────────────── + + +def test_no_injection_at_zero(): + assert _should_inject_fact(0) is False + + +def test_no_injection_before_interval(): + for i in range(1, _RESEARCH_FACT_INTERVAL): + assert _should_inject_fact(i) is False + + +def test_injection_at_interval(): + assert _should_inject_fact(_RESEARCH_FACT_INTERVAL) is True + + +def test_injection_repeats_at_multiples(): + assert _should_inject_fact(_RESEARCH_FACT_INTERVAL * 2) is True + assert _should_inject_fact(_RESEARCH_FACT_INTERVAL * 3) is True + + +# ── _build_fact_anchor ──────────────────────────────────────────────── + + +def test_anchor_marker_present(): + assert "GOAL ANCHOR" in _build_fact_anchor("find LoRA recipe", "") + + +def test_anchor_contains_task_verbatim(): + task = "find optimal lr schedule for 7B fine-tuning" + assert task in _build_fact_anchor(task, "") + + +def test_anchor_includes_progress(): + anchor = _build_fact_anchor("task", "DPO outperforms RLHF on alignment.") + assert "Progress so far:" in anchor + assert "DPO outperforms RLHF" in anchor + + +def test_anchor_omits_progress_when_empty(): + assert "Progress so far:" not in _build_fact_anchor("task", "") + + +def test_anchor_truncates_at_max(): + long = "x" * (_RESEARCH_FACT_SUMMARY_MAX + 50) + anchor = _build_fact_anchor("task", long) + assert "…" in anchor + assert "x" * (_RESEARCH_FACT_SUMMARY_MAX + 1) not in anchor + + +def test_anchor_preserves_short_progress(): + short = "y" * (_RESEARCH_FACT_SUMMARY_MAX - 1) + anchor = _build_fact_anchor("task", short) + assert "…" not in anchor + assert short in anchor + + +# ── integration: anchor in messages passed to the LLM ──────────────── + + +@pytest.mark.asyncio +async def test_anchor_injected_at_iteration_n(monkeypatch): + task = "find best LoRA training recipe for LLaMA-3" + n = _RESEARCH_FACT_INTERVAL + call_no = 0 + captured = None + + async def fake_llm(messages, **kw): + nonlocal call_no, captured + if call_no == n: + captured = list(messages) + return _text_resp() + call_no += 1 + return _tool_resp() + + _patch(monkeypatch, fake_llm) + result, ok = await research_handler({"task": task}, session=FakeSession()) + + assert ok + anchors = [m for m in captured if "GOAL ANCHOR" in str(getattr(m, "content", ""))] + assert len(anchors) == 1 + assert task in anchors[0].content + assert "Progress so far:" not in anchors[0].content # no thinking text was emitted + + +@pytest.mark.asyncio +async def test_no_anchor_before_interval(monkeypatch): + task = "compare RLHF vs DPO" + call_no = 0 + captured = None + + async def fake_llm(messages, **kw): + nonlocal call_no, captured + if call_no == _RESEARCH_FACT_INTERVAL - 1: + captured = list(messages) + return _text_resp() + call_no += 1 + return _tool_resp() + + _patch(monkeypatch, fake_llm) + await research_handler({"task": task}, session=FakeSession()) + + assert captured is not None + assert not any("GOAL ANCHOR" in str(getattr(m, "content", "")) for m in captured) + + +@pytest.mark.asyncio +async def test_second_cycle_injects_again(monkeypatch): + """Anchor fires at 2N as well as N — the mechanism is truly periodic.""" + task = "evaluate dataset mixing strategies" + n = _RESEARCH_FACT_INTERVAL + call_no = 0 + captured_2n = None + + async def fake_llm(messages, **kw): + nonlocal call_no, captured_2n + if call_no == n * 2: + captured_2n = list(messages) + return _text_resp() + call_no += 1 + return _tool_resp() + + _patch(monkeypatch, fake_llm) + await research_handler({"task": task}, session=FakeSession()) + + anchors = [ + m for m in captured_2n if "GOAL ANCHOR" in str(getattr(m, "content", "")) + ] + assert len(anchors) == 2, "Expected anchors at both N and 2N" + + +@pytest.mark.asyncio +async def test_thinking_text_appears_in_anchor(monkeypatch): + """Thinking text the model emits alongside tool calls surfaces in the anchor.""" + task = "evaluate RLHF vs DPO for alignment" + thinking = "DPO avoids the reward model and is cheaper to train." + call_no = 0 + captured = None + + async def fake_llm(messages, **kw): + nonlocal call_no, captured + if call_no == _RESEARCH_FACT_INTERVAL: + captured = list(messages) + return _text_resp() + content = thinking if call_no == 0 else None + call_no += 1 + return _tool_resp(content=content) + + _patch(monkeypatch, fake_llm) + await research_handler({"task": task}, session=FakeSession()) + + anchors = [m for m in captured if "GOAL ANCHOR" in str(getattr(m, "content", ""))] + assert len(anchors) == 1 + assert "Progress so far:" in anchors[0].content + assert thinking[:40] in anchors[0].content + + +@pytest.mark.asyncio +async def test_doom_loop_and_anchor_coexist(monkeypatch): + """Doom-loop guard and goal anchor can both fire in the same iteration.""" + task = "find training recipe" + call_no = 0 + captured = None + doom_call_no = 0 + + def fake_doom(messages): + nonlocal doom_call_no + doom_call_no += 1 + # Fire at the Nth iteration (doom is called once per iteration) + if doom_call_no == _RESEARCH_FACT_INTERVAL + 1: + return "[SYSTEM: REPETITION GUARD] You are stuck." + return None + + async def fake_llm(messages, **kw): + nonlocal call_no, captured + if call_no == _RESEARCH_FACT_INTERVAL: + captured = list(messages) + return _text_resp() + call_no += 1 + return _tool_resp() + + _patch(monkeypatch, fake_llm, fake_doom=fake_doom) + await research_handler({"task": task}, session=FakeSession()) + + has_doom = any( + "REPETITION GUARD" in str(getattr(m, "content", "")) for m in captured + ) + has_anchor = any("GOAL ANCHOR" in str(getattr(m, "content", "")) for m in captured) + assert has_doom, "Doom-loop guard message missing" + assert has_anchor, "Goal anchor missing" + + +# ── capability: anchor re-states task after context has drifted ─────── + + +@pytest.mark.asyncio +async def test_anchor_freshens_buried_task(monkeypatch): + """Core capability claim: after N tool calls the original task is re-stated + near the end of the message list so the model sees it with full weight. + + Without injection the task only appears at message[1] — the very start of + history. With injection a verbatim copy also appears near the tail, + immediately before the LLM call that must use the findings. + """ + task = "find optimal batch size and lr schedule for 7B fine-tuning" + n = _RESEARCH_FACT_INTERVAL + call_no = 0 + snapshot = None + + async def fake_llm(messages, **kw): + nonlocal call_no, snapshot + if call_no == n: + snapshot = list(messages) + return _text_resp() + call_no += 1 + return _tool_resp() + + _patch(monkeypatch, fake_llm) + _, ok = await research_handler({"task": task}, session=FakeSession()) + assert ok + assert snapshot is not None + + user_msgs = [ + (i, m) for i, m in enumerate(snapshot) if getattr(m, "role", None) == "user" + ] + task_mentions = [ + (i, m) for i, m in user_msgs if task in str(getattr(m, "content", "")) + ] + + # Task appears at least twice: initial message + anchor + assert len(task_mentions) >= 2, ( + "Task should appear in both the initial user message and the GOAL ANCHOR" + ) + + first_pos = task_mentions[0][0] + anchor_pos = task_mentions[-1][0] + + # The anchor is not the initial message + assert anchor_pos > first_pos + + # Between them there are N tool-result messages — the context has genuinely drifted + between = snapshot[first_pos + 1 : anchor_pos] + tool_results = [m for m in between if getattr(m, "role", None) == "tool"] + assert len(tool_results) == n, ( + f"Expected {n} tool results between initial task and anchor, got {len(tool_results)}" + ) + + # Anchor is the last user message before the LLM call — maximum recency + last_user_pos = max(i for i, m in user_msgs) + assert anchor_pos == last_user_pos, "Anchor should be the most recent user message"