Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions agent/tools/research_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
_RESEARCH_CONTEXT_WARN = 170_000 # 85% of 200k
_RESEARCH_CONTEXT_MAX = 190_000

# N-turn goal anchor: re-inject the original task every N iterations so the
# agent stays anchored even after long tool-call chains.
_RESEARCH_FACT_INTERVAL = 10
_RESEARCH_FACT_SUMMARY_MAX = 500

# Tools the research agent can use (read-only subset)
RESEARCH_TOOL_NAMES = {
"read",
Expand Down Expand Up @@ -219,6 +224,29 @@
}


def _should_inject_fact(iteration: int) -> bool:
"""Return True if a goal anchor should be injected at this iteration."""
return iteration > 0 and iteration % _RESEARCH_FACT_INTERVAL == 0


def _build_fact_anchor(task: str, thinking_text: str) -> str:
"""Build the [SYSTEM: GOAL ANCHOR] message content for N-turn injection."""
if len(thinking_text) > _RESEARCH_FACT_SUMMARY_MAX:
progress = thinking_text[:_RESEARCH_FACT_SUMMARY_MAX] + "…"
else:
progress = thinking_text
parts = [
"[SYSTEM: GOAL ANCHOR]",
f"Your original research task is: {task}",
]
if progress:
parts.append(f"Progress so far: {progress}")
parts.append(
"Stay focused on this goal. Do not repeat lookups you have already done."
)
return "\n".join(parts)


def _get_research_model(main_model: str) -> str:
"""Pick a cheaper model for research based on the main model."""
if main_model.startswith("anthropic/"):
Expand Down Expand Up @@ -306,6 +334,7 @@ async def _log(text: str) -> None:
_tool_uses = 0
_total_tokens = 0
_warned_context = False
_thinking_text = "" # last thinking text emitted alongside tool calls

await _log("Starting research sub-agent...")

Expand All @@ -321,6 +350,16 @@ async def _log(text: str) -> None:
)
messages.append(Message(role="user", content=doom_prompt))

# ── N-turn goal anchor: re-state the original task every N iterations ──
if _should_inject_fact(_iteration):
messages.append(
Message(
role="user",
content=_build_fact_anchor(task, _thinking_text),
)
)
logger.debug("Research fact anchor injected at iteration %d", _iteration)

# ── Context budget: warn at 75%, hard-stop at 95% ──
if _total_tokens >= _RESEARCH_CONTEXT_MAX:
logger.warning(
Expand Down Expand Up @@ -432,6 +471,10 @@ async def _log(text: str) -> None:
content = msg.content or "Research completed but no summary generated."
return content, True

# Capture thinking text alongside tool calls for the next goal anchor.
if msg.content:
_thinking_text = msg.content

# Execute tool calls and add results.
# Rebuild the assistant message with only the wire-safe fields —
# LiteLLM's raw Message carries `provider_specific_fields` and
Expand Down
213 changes: 213 additions & 0 deletions tests/integration/test_research_anchor_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Live A/B eval for goal-anchor injection in the research sub-agent.

Injects 15 off-topic time-series tool results into the research context to
trigger goal drift, then compares model summaries with and without a
[SYSTEM: GOAL ANCHOR] message. Two assertions must hold:

score_A < 0 — drift actually occurred without the anchor
score_B > score_A — the anchor pulled the model back on task

Both must pass; passing only the second would mean the model never drifted
and the mechanism adds no value. Verified on Llama-3.1-8B: score_A=-3,
score_B=+4.

Run:
ML_INTERN_LIVE_LLM_TESTS=1 HF_TOKEN=hf_... \\
pytest tests/integration/test_research_anchor_eval.py -v -s
"""

from __future__ import annotations

import os
from pathlib import Path

import pytest
from dotenv import load_dotenv
from litellm import Message, acompletion

from agent.core.llm_params import _resolve_llm_params
from agent.tools.research_tool import _build_fact_anchor


if env_file := os.environ.get("ML_INTERN_LIVE_ENV_FILE"):
load_dotenv(Path(env_file))

LIVE_TESTS_ENABLED = os.environ.get("ML_INTERN_LIVE_LLM_TESTS") == "1"

# Verified to exhibit drift at 15 off-topic tool results (score_A=-3 on first run)
_MODEL = "huggingface/meta-llama/Llama-3.1-8B-Instruct"

_TASK = (
"find the optimal learning-rate range and rank configuration for LoRA "
"fine-tuning of a 7B parameter language model on code generation"
)

# Time-series forecasting results — unrelated to _TASK, internally coherent
# so the model is tempted to summarise them rather than flag the mismatch.
_DRIFTED_TOOL_RESULTS = [
"arXiv:2401.00001 — Temporal Fusion Transformer achieves SOTA on M5 forecasting "
"competition with MAE 0.82 using multivariate covariates.",
"Paper: N-BEATS (Neural Basis Expansion Analysis) outperforms statistical baselines "
"on ETTh1 and ETTm1 benchmarks by 18% in MSE.",
"Dataset: ETT (Electricity Transformer Temperature) — 70k hourly records, "
"train/val/test split 70/10/20, available at HF hub 'ETDataset/ETT'.",
"GitHub: 'salesforce/Merlion' — unified time-series library with ARIMA, Prophet, "
"LSTM baselines; 4.2k stars, MIT license.",
"Code snippet: `model = TemporalFusionTransformer(input_chunk_length=96, "
"output_chunk_length=24, hidden_size=64, lstm_layers=2)`.",
"Paper: PatchTST (2023) — patches of 16 time steps with transformer encoder, "
"reduces attention complexity from O(L²) to O((L/P)²), -12% MSE vs iTransformer.",
"Benchmark table: iTransformer > PatchTST > TimesNet > DLinear on Exchange-Rate "
"dataset, horizon=96, MSE 0.086 / 0.088 / 0.107 / 0.094.",
"HF dataset 'monash_tsf_storage/electricity_hourly': 370 time series, 17520 "
"hourly steps, target column 'series_value'.",
"Docs: Darts library `TFTModel.fit(series, past_covariates=cov, epochs=30, "
"batch_size=64, optimizer_kwargs={'lr': 1e-3})`.",
"Paper: MICN (Multi-scale Isometric Convolution Network) beats Autoformer by 7% "
"on Weather dataset; uses dilated causal conv with stride 2, 4, 8.",
"Code: Prophet baseline `m = Prophet(seasonality_mode='multiplicative'); "
"m.fit(df); forecast = m.predict(future)`.",
"arXiv:2402.00002 — TimeLLM reprograms frozen LLM backbone for zero-shot "
"forecasting; GPT-2 outperforms fully trained specialist models on 6/8 datasets.",
"HF hub model 'amazon/chronos-t5-large': pre-trained on 27 public datasets, "
"zero-shot MSE 0.79 on M4 monthly; context length 512 tokens.",
"Paper: FITS (Frequency Interpolation Time Series) — compresses 720-step series "
"to 180 frequency components, 10k params, competitive with PatchTST.",
"GitHub: 'thuml/Time-Series-Library' — canonical benchmark suite; ETT, Weather, "
"Exchange-Rate, ILI, Traffic datasets; 8 SOTA models implemented.",
]

# Keywords scoring whether the summary addresses the original task
_ON_TASK = {
"learning rate",
"lora",
"fine-tun",
"rank",
"adapter",
"7b",
"lr=",
"alpha",
}
# Keywords indicating the model stayed on the drifted content
_OFF_TASK = {
"time series",
"forecast",
"arima",
"mse",
"ett",
"transformer temperature",
"chronos",
"patchts",
"temporal fusion",
}


def _skip_without_live_flag() -> None:
if not LIVE_TESTS_ENABLED:
pytest.skip("set ML_INTERN_LIVE_LLM_TESTS=1 to run paid live LLM tests")


def _skip_without_env(name: str) -> None:
if not os.environ.get(name):
pytest.skip(f"set {name} to run this live eval")


def _alignment_score(text: str) -> int:
"""on-task hits minus off-task hits in the response text."""
low = text.lower()
return sum(1 for kw in _ON_TASK if kw in low) - sum(
1 for kw in _OFF_TASK if kw in low
)


def _build_drifted_context() -> list[Message]:
"""System + original task + 15 off-topic tool results as assistant/tool pairs."""
msgs: list[Message] = [
Message(
role="system",
content=(
"You are a research sub-agent. Mine literature and tools to answer "
"the user's research task, then produce a concise summary."
),
),
Message(role="user", content=f"Research task: {_TASK}"),
]
for i, result in enumerate(_DRIFTED_TOOL_RESULTS):
msgs.append(
Message(
role="assistant",
content=None,
tool_calls=[
{
"id": f"tc_{i}",
"type": "function",
"function": {
"name": "web_search",
"arguments": '{"query": "test"}',
},
}
],
)
)
msgs.append(
Message(
role="tool",
content=result,
tool_call_id=f"tc_{i}",
name="web_search",
)
)
msgs.append(
Message(role="user", content="Summarise your findings for the research task.")
)
return msgs


async def _call(messages: list[Message]) -> str:
hf_token = os.environ.get("HF_TOKEN")
params = _resolve_llm_params(_MODEL, session_hf_token=hf_token)
resp = await acompletion(messages=messages, stream=False, timeout=90, **params)
return resp.choices[0].message.content or ""


# ── eval ──────────────────────────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_anchor_corrects_drifted_context():
"""A/B: anchor raises alignment score; drifted baseline is genuinely off-task.

This is the core capability claim for _RESEARCH_FACT_INTERVAL injection:
the mechanism only has value if (a) drift occurs without it and (b) the
anchor corrects the drift. Both halves must hold for the test to pass.
"""
_skip_without_live_flag()
_skip_without_env("HF_TOKEN")

base_msgs = _build_drifted_context()

# ── A: no anchor ──
response_a = await _call(base_msgs)
score_a = _alignment_score(response_a)

# ── B: anchor injected before the final summary request ──
anchor_msg = Message(role="user", content=_build_fact_anchor(_TASK, ""))
anchored_msgs = base_msgs[:-1] + [anchor_msg] + [base_msgs[-1]]
response_b = await _call(anchored_msgs)
score_b = _alignment_score(response_b)

print(f"\n── A (no anchor) score={score_a} ──\n{response_a[:600]}")
print(f"\n── B (anchored) score={score_b} ──\n{response_b[:600]}")

# Drift must be real: without anchor the model should favour off-topic content
assert score_a < 0, (
f"Expected off-task drift without anchor (score={score_a}). "
"The drifted context may not be strong enough, or the model is too robust."
)
# Anchor must correct it
assert score_b > score_a, (
f"Anchor did not improve alignment: score_a={score_a}, score_b={score_b}"
)
assert score_b >= 0, (
f"Anchor raised score but response is still net off-task (score_b={score_b})"
)
Loading
Loading